package main import ( "bufio" "fmt" "io" "net" "os" "os/exec" "strconv" "strings" "sync" "time" "be.ems/lib/dborm" "be.ems/lib/global" "be.ems/lib/log" "be.ems/lib/mmlp" "be.ems/sshsvc/config" "be.ems/sshsvc/logmml" "be.ems/sshsvc/snmp" //"github.com/gliderlabs/ssh" "golang.org/x/crypto/ssh" "golang.org/x/term" ) var conf *config.YamlConfig var ( telnetCC int sshCC int telnetMu sync.Mutex sshMu sync.Mutex ) func init() { conf = config.GetYamlConfig() log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel()) fmt.Printf("OMC sshsvc version: %s\n", global.Version) log.Infof("========================= OMC sshsvc startup =========================") log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer) db := conf.Database err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name, db.ConnParam) if err != nil { fmt.Println("dborm.initDbClient err:", err) os.Exit(1) } logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel()) } func main() { // 生成SSH密钥对 privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey) if err != nil { log.Fatal("Failed to ReadFile:", err) os.Exit(2) } privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) if err != nil { log.Fatal("Failed to ParsePrivateKey", err) os.Exit(3) } // 配置SSH服务器 serverConfig := &ssh.ServerConfig{ PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { // 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配 // validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt) // if err != nil { // return nil, err // } // if validUser == true { // sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID // sourceAddr := conn.RemoteAddr().String() // timeOut := uint32(conf.Sshd.Timeout) // sessionMode := conf.Sshd.Session // log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr) // affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode) // if err != nil { // log.Error("Failed to insert Session table:", err) // return nil, err // } // if affected == -1 { // err := errors.New("Failed to get session") // log.Error(err) // return nil, err // } // return nil, nil // } if handleAuth(conf.Sshd.AuthType, conn.User(), string(password)) { return nil, nil } return nil, fmt.Errorf("invalid user or password") }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中 return nil, fmt.Errorf("public key authentication is failed") }, } serverConfig.AddHostKey(privateKey) // 启动SSH服务器 hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort) listener, err := net.Listen("tcp", hostUri) if err != nil { log.Fatal("Failed to Listen: ", err) os.Exit(4) } fmt.Printf("MML SSH server startup, listen port:%d\n", conf.Sshd.ListenPort) // 启动telnet服务器 telnetUri := fmt.Sprintf("%s:%d", conf.TelnetServer.ListenAddr, conf.TelnetServer.ListenPort) // telnetListener, err := net.Listen("tcp", telnetUri) // if err != nil { // log.Fatal("Failed to Listen: ", err) // os.Exit(4) // } fmt.Printf("MML Telnet server startup, listen port:%d\n", conf.TelnetServer.ListenPort) // telnetconn, err := telnetListener.Accept() // if err != nil { // log.Fatal("Failed to accept telnet connection: ", err) // os.Exit(6) // } go startTelnetServer(telnetUri) snmpSvc := snmp.SNMPService{ ListenAddr: conf.SNMPServer.ListenAddr, ListenPort: conf.SNMPServer.ListenPort, UserName: conf.SNMPServer.UserName, AuthPass: conf.SNMPServer.AuthPass, AuthProto: conf.SNMPServer.AuthProto, PrivPass: conf.SNMPServer.PrivPass, PrivProto: conf.SNMPServer.PrivProto, EngineID: conf.SNMPServer.EngineID, TrapPort: conf.SNMPServer.TrapPort, TrapTick: conf.SNMPServer.TrapTick, TimeOut: conf.SNMPServer.TimeOut, TrapTarget: conf.SNMPServer.TrapTarget, ListenHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.ListenPort)), TrapHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.TrapPort)), SysDescr: "HLR server", SysService: 0, } go snmpSvc.StartSNMPServer() go snmpSvc.StartTrapServer() for { conn, err := listener.Accept() if err != nil { log.Fatal("Failed to Accept: ", err) os.Exit(5) } go handleSSHConnection(conn, serverConfig) } } func handleAuth(authType, userName, password string) bool { switch authType { case "local": if userName == conf.Sshd.UserName && password == conf.Sshd.Password { return true } return false case "radius": exist, err := dborm.XEngDB().Table("OMC_PUB.sysUser").Where("userName=? AND password=md5(?)", userName, password).Exist() if err != nil { return false } return exist case "omc": default: } return false } func startTelnetServer(addr string) { listener, err := net.Listen("tcp", addr) if err != nil { fmt.Println("Error starting Telnet server:", err) return } defer listener.Close() fmt.Println("Telnet server started on", addr) for { conn, err := listener.Accept() if err != nil { fmt.Println("Error accepting Telnet connection:", err) continue } telnetMu.Lock() if telnetCC >= int(conf.TelnetServer.MaxConnNum) { telnetMu.Unlock() io.WriteString(conn, "Connection limit reached. Try again later.\n") conn.Close() continue } telnetCC++ telnetMu.Unlock() go handleTelnetConnection(conn) } } func handleTelnetConnection(conn net.Conn) { defer func() { telnetMu.Lock() telnetCC-- telnetMu.Unlock() }() defer conn.Close() reader := bufio.NewReader(conn) writer := bufio.NewWriter(conn) // 发送欢迎信息 writer.WriteString("Welcome to the Telnet server!\n") writer.Flush() // 请求用户名 writer.WriteString("Username: ") writer.Flush() user, _ := reader.ReadString('\n') user = strings.TrimSpace(user) // 关闭回显模式 writer.Write([]byte{255, 251, 1}) // IAC WILL ECHO writer.Flush() // 请求密码 writer.WriteString("Password: ") writer.Flush() // 读取密码并清除控制序列 var passBuilder strings.Builder for { b, err := reader.ReadByte() if err != nil { return } if b == '\n' || b == '\r' { break } if b == 255 { // IAC reader.ReadByte() // 忽略下一个字节 reader.ReadByte() // 忽略下一个字节 } else { passBuilder.WriteByte(b) } } pass := passBuilder.String() // 恢复回显模式 writer.Write([]byte{255, 252, 1}) // IAC WONT ECHO writer.Flush() if handleAuth(conf.TelnetServer.AuthType, user, pass) { writer.WriteString("\nAuthentication successful!\n") writer.Flush() handleCommands(user, conf.TelnetServer.TagNE, reader, writer) } else { writer.WriteString("\nAuthentication failed!\n") writer.Flush() } } // 处理命令输入 func handleCommands(user, tag string, reader *bufio.Reader, writer *bufio.Writer) { header := fmt.Sprintf("[%s@%s]> ", user, tag) for { command, err := reader.ReadString('\n') if err != nil { return } command = strings.TrimSpace(command) // 处理其他命令 switch command { case "hello": writer.WriteString("Hello, world!\n") case "time": writer.WriteString(fmt.Sprintf("Current time: %s\n", time.Now().Format(time.RFC1123))) case "exit", "quit": writer.WriteString("Goodbye!\n") writer.Flush() return case "": case "\n": case "\xff\xfe\x01": default: writer.WriteString("Unknown command\n") } writer.WriteString(header) writer.Flush() } } func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) { // SSH握手 sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig) if err != nil { log.Error("Failed to NewServerConn: ", err) return } log.Infof("SSH connect accepted,client version:%s,user:%s", sshConn.ClientVersion(), sshConn.User()) // 处理SSH请求 go ssh.DiscardRequests(reqs) // 处理SSH通道 for newChannel := range chans { if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { log.Error("Failed to NewServerConn: ", err) continue } sshMu.Lock() sshCC++ if sshCC > int(conf.Sshd.MaxConnNum) { sshMu.Unlock() log.Error("Maximum number of connections exceeded") //conn.Write([]byte("Reach max connections")) conn.Close() continue } sshMu.Unlock() go handleSSHChannel(conn, sshConn, channel, requests) } } func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) { for req := range requests { switch req.Type { case "exec": // 执行远程命令 command := strings.TrimSpace(string(req.Payload))[4:] //cmd := exec.Command( cmd := exec.Command("cmd", "/C", command) cmd.Stdin = channel cmd.Stdout = channel cmd.Stderr = channel.Stderr() log.Trace("command:", command) err := cmd.Run() if err != nil { log.Error("Failed to cmd.Run: ", err) } channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) channel.Close() closeConnection(conn) // case "shell": // // 处理交互式shell会话 // // 在这里添加您的处理逻辑,例如启动一个shell进程并将其连接到channel // // 请注意,处理交互式shell会话需要更复杂的逻辑,您可能需要使用类似于pty包来处理终端相关的操作 // channel.Write([]byte("交互式shell会话已启动\n")) // channel.Close() // handleSSHShell(user, channel) case "pty-req": log.Info("A pty-req processing...") req.Reply(true, nil) handleSSHShell(sshConn, channel) channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) channel.Close() closeConnection(conn) log.Info("Channel closed") default: req.Reply(false, nil) } } } func closeConnection(conn net.Conn) { sshMu.Lock() conn.Close() sshCC-- sshMu.Unlock() } func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) { //conf = config.GetYamlConfig() // 检查通道是否支持终端 omcMmlVar := &mmlp.MmlVar{ Version: global.Version, Output: mmlp.DefaultFormatType, MmlHome: conf.Sshd.MmlHome, Limit: 50, User: sshConn.User(), SessionToken: fmt.Sprintf("%x", sshConn.SessionID()), HttpUri: conf.OMC.HttpUri, UserAgent: config.GetDefaultUserAgent(), TagNE: conf.Sshd.TagNE, } term := term.NewTerminal(channel, fmt.Sprintf("[%s@%s]> ", omcMmlVar.User, omcMmlVar.TagNE)) // 启动交互式shell会话 for { line, err := term.ReadLine() if err != nil { if err == io.EOF { break } log.Error("Failed to read line: ", err) break } cmdline := strings.TrimSpace(line) if cmdline != "" { logmml.Cmd(cmdline) } var response string switch cmdline { case "exit", "quit": goto exitLoop case "": goto continueLoop case "help": response = fmt.Sprintf("Usage: %s\n", line) term.Write([]byte(response)) goto continueLoop case "dsp variables": response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop case "set mml output=json": // mmlp.SetOmcMmlVarOutput("json") omcMmlVar.Output = "json" response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop case "set mml output=table": // mmlp.SetOmcMmlVarOutput("table") omcMmlVar.Output = "table" response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop default: var mmlCmds []mmlp.MmlCommand mmlLine := strings.TrimSpace(line) if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil { response = fmt.Sprintf("parse command error: %v\n", err) term.Write([]byte(response)) goto continueLoop } // if err = mmlp.ParseMMLParams(&mmlCmds); err != nil { // response := fmt.Sprintf("#2 parse command error: %v\n", err) // term.Write([]byte(response)) // } for _, mmlCmd := range mmlCmds { output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd) if err != nil { response = fmt.Sprintf("translate MML command error: %v\n", err) term.Write([]byte(response)) goto continueLoop } response = string(*output) term.Write(*output) } goto continueLoop } continueLoop: if response != "" { logmml.Ret(response) } continue exitLoop: token := fmt.Sprintf("%x", sshConn.SessionID()) _, err = dborm.XormLogoutUpdateSession(token) if err != nil { log.Error("Failed to XormLogoutUpdateSession:", err) } break } }