package main import ( "bufio" "fmt" "io" "net" "os" "os/exec" "strconv" "strings" "sync" "time" "be.ems/lib/dborm" "be.ems/lib/global" "be.ems/lib/log" "be.ems/lib/mmlp" "be.ems/sshsvc/config" "be.ems/sshsvc/logmml" "be.ems/sshsvc/snmp" omctelnet "be.ems/sshsvc/telnet" //"github.com/gliderlabs/ssh" "golang.org/x/crypto/ssh" "golang.org/x/term" ) var conf *config.YamlConfig var ( telnetCC int sshCC int telnetMu sync.Mutex sshMu sync.Mutex ) func init() { conf = config.GetYamlConfig() log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel()) fmt.Printf("OMC sshsvc version: %s\n", global.Version) log.Infof("========================= OMC sshsvc startup =========================") log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer) db := conf.Database err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name, db.ConnParam) if err != nil { fmt.Println("dborm.initDbClient err:", err) os.Exit(1) } logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel()) } func main() { // 生成SSH密钥对 privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey) if err != nil { log.Fatal("Failed to ReadFile:", err) os.Exit(2) } privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) if err != nil { log.Fatal("Failed to ParsePrivateKey", err) os.Exit(3) } // 配置SSH服务器 serverConfig := &ssh.ServerConfig{ PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { // 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配 // validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt) // if err != nil { // return nil, err // } // if validUser == true { // sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID // sourceAddr := conn.RemoteAddr().String() // timeOut := uint32(conf.Sshd.Timeout) // sessionMode := conf.Sshd.Session // log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr) // affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode) // if err != nil { // log.Error("Failed to insert Session table:", err) // return nil, err // } // if affected == -1 { // err := errors.New("Failed to get session") // log.Error(err) // return nil, err // } // return nil, nil // } if handleAuth(conf.Sshd.AuthType, conn.User(), string(password)) { return nil, nil } return nil, fmt.Errorf("invalid user or password") }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中 return nil, fmt.Errorf("public key authentication is failed") }, } serverConfig.AddHostKey(privateKey) // 启动SSH服务器 hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort) listener, err := net.Listen("tcp", hostUri) if err != nil { log.Fatal("Failed to Listen: ", err) os.Exit(4) } //fmt.Printf("MML SSH server startup, listen port:%d\n", conf.Sshd.ListenPort) // 启动telnet服务器 //telnetUri := fmt.Sprintf("%s:%d", conf.TelnetServer.ListenAddr, conf.TelnetServer.ListenPort) // telnetListener, err := net.Listen("tcp", telnetUri) // if err != nil { // log.Fatal("Failed to Listen: ", err) // os.Exit(4) // } //fmt.Printf("MML Telnet server startup, listen port:%d\n", conf.TelnetServer.ListenPort) // telnetconn, err := telnetListener.Accept() // if err != nil { // log.Fatal("Failed to accept telnet connection: ", err) // os.Exit(6) // } telnetSvc := omctelnet.TelnetHandler{ ListenAddr: conf.TelnetServer.ListenAddr, ListenPort: conf.TelnetServer.ListenPort, UserName: conf.TelnetServer.UserName, Password: conf.TelnetServer.Password, AuthType: conf.TelnetServer.AuthType, MaxConnNum: conf.TelnetServer.MaxConnNum, TagNE: conf.TelnetServer.TagNE, ListenHost: conf.TelnetServer.ListenAddr + ":" + strconv.Itoa(int(conf.TelnetServer.ListenPort)), } go telnetSvc.StartTelnetServer() // go StartTelnetServer(telnetSvc.ListenHost) snmpSvc := snmp.SNMPService{ ListenAddr: conf.SNMPServer.ListenAddr, ListenPort: conf.SNMPServer.ListenPort, UserName: conf.SNMPServer.UserName, AuthPass: conf.SNMPServer.AuthPass, AuthProto: conf.SNMPServer.AuthProto, PrivPass: conf.SNMPServer.PrivPass, PrivProto: conf.SNMPServer.PrivProto, EngineID: conf.SNMPServer.EngineID, TrapPort: conf.SNMPServer.TrapPort, TrapListen: conf.SNMPServer.TrapListen, TrapBool: conf.SNMPServer.TrapBool, TrapTick: conf.SNMPServer.TrapTick, TimeOut: conf.SNMPServer.TimeOut, TrapTarget: conf.SNMPServer.TrapTarget, ListenHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.ListenPort)), TrapHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.TrapPort)), SysName: "HLR-0", SysStatus: "Normal", SysDescr: "HLR server(sysNO=0)", SysLocation: "Shanghai", SysContact: "", SysService: 0, } go snmpSvc.StartSNMPServer() go snmpSvc.StartTrapServer() for { conn, err := listener.Accept() if err != nil { log.Fatal("Failed to Accept: ", err) os.Exit(5) } go handleSSHConnection(conn, serverConfig) } } func handleAuth(authType, userName, password string) bool { switch authType { case "local": if userName == conf.Sshd.UserName && password == conf.Sshd.Password { return true } return false case "radius": exist, err := dborm.XEngDB().Table("OMC_PUB.sysUser").Where("userName=? AND password=md5(?)", userName, password).Exist() if err != nil { return false } return exist case "omc": default: } return false } func StartTelnetServer(addr string) { listener, err := net.Listen("tcp", addr) if err != nil { fmt.Println("Error starting Telnet server:", err) return } defer listener.Close() fmt.Println("Telnet server started on", addr) for { conn, err := listener.Accept() if err != nil { fmt.Println("Error accepting Telnet connection:", err) continue } telnetMu.Lock() if telnetCC >= int(conf.TelnetServer.MaxConnNum) { telnetMu.Unlock() io.WriteString(conn, "Connection limit reached. Try again later.\r\n") conn.Close() continue } telnetCC++ telnetMu.Unlock() go handleTelnetConnection(conn) } } func handleTelnetConnection(conn net.Conn) { defer func() { telnetMu.Lock() telnetCC-- telnetMu.Unlock() }() defer conn.Close() reader := bufio.NewReader(conn) writer := bufio.NewWriter(conn) // 发送欢迎信息 writer.WriteString("Welcome to the Telnet server!\r\n") writer.Flush() // 请求用户名 writer.WriteString("Username: ") writer.Flush() user, _ := reader.ReadString('\n') user = strings.TrimSpace(user) // 关闭回显模式 writer.Write([]byte{255, 251, 1}) // IAC WILL ECHO writer.Flush() // 请求密码 writer.WriteString("Password: ") writer.Flush() // 读取密码并清除控制序列 var passBuilder strings.Builder for { b, err := reader.ReadByte() if err != nil { return } if b == '\n' || b == '\r' { break } if b == 255 { // IAC reader.ReadByte() // 忽略下一个字节 reader.ReadByte() // 忽略下一个字节 } else { passBuilder.WriteByte(b) } } pass := passBuilder.String() // 恢复回显模式 writer.Write([]byte{255, 252, 1}) // IAC WONT ECHO writer.Flush() if handleAuth(conf.TelnetServer.AuthType, user, pass) { writer.WriteString("\r\nAuthentication successful!\r\n") writer.Flush() HandleCommands(user, conf.TelnetServer.TagNE, reader, writer) } else { writer.WriteString("\r\nAuthentication failed!\r\n") writer.Flush() } } // 处理命令输 func HandleCommands(user, tag string, reader *bufio.Reader, writer *bufio.Writer) { header := fmt.Sprintf("[%s@%s]> ", user, tag) clearLine := "\033[2K\r" // ANSI 转义序列,用于清除当前行 for { var commandBuilder strings.Builder for { b, err := reader.ReadByte() if err != nil { return } if b == '\n' || b == '\r' { break } if b == '\xff' || b == '\xfe' || b == '\x01' { continue } if b == 127 { // 处理退格键 if commandBuilder.Len() > 0 { // 手动截断字符串 command := commandBuilder.String() command = command[:len(command)-1] commandBuilder.Reset() commandBuilder.WriteString(command) writer.WriteString("\b \b") // 回显退格 writer.Flush() } } else { // 回显用户输入的字符 writer.WriteByte(b) writer.Flush() commandBuilder.WriteByte(b) } } command := strings.TrimSpace(commandBuilder.String()) // 处理其他命令 switch command { case "hello": writer.WriteString("\r\nHello, world!\r\n") case "time": writer.WriteString(fmt.Sprintf("\r\nCurrent time: %s\r\n", time.Now().Format(time.RFC1123))) case "exit", "quit": writer.WriteString("\r\nGoodbye!\r\n") writer.Flush() return case "": default: writer.WriteString("\r\nUnknown command\r\n") writer.Flush() } writer.WriteString(clearLine + header) writer.Flush() } } func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) { // SSH握手 sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig) if err != nil { log.Error("Failed to NewServerConn: ", err) return } log.Infof("SSH connect accepted,client version:%s,user:%s", sshConn.ClientVersion(), sshConn.User()) // 处理SSH请求 go ssh.DiscardRequests(reqs) // 处理SSH通道 for newChannel := range chans { if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { log.Error("Failed to NewServerConn: ", err) continue } sshMu.Lock() sshCC++ if sshCC > int(conf.Sshd.MaxConnNum) { sshMu.Unlock() log.Error("Maximum number of connections exceeded") channel.Write([]byte(fmt.Sprintf("Connection limit reached (limit=%d). Try again later.\r\n", conf.Sshd.MaxConnNum))) conn.Close() continue } sshMu.Unlock() go handleSSHChannel(conn, sshConn, channel, requests) } } func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) { for req := range requests { switch req.Type { case "exec": // 执行远程命令 command := strings.TrimSpace(string(req.Payload))[4:] //cmd := exec.Command( cmd := exec.Command("cmd", "/C", command) cmd.Stdin = channel cmd.Stdout = channel cmd.Stderr = channel.Stderr() log.Trace("command:", command) err := cmd.Run() if err != nil { log.Error("Failed to cmd.Run: ", err) } channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) channel.Close() closeConnection(conn) // case "shell": // // 处理交互式shell会话 // // 在这里添加您的处理逻辑,例如启动一个shell进程并将其连接到channel // // 请注意,处理交互式shell会话需要更复杂的逻辑,您可能需要使用类似于pty包来处理终端相关的操作 // channel.Write([]byte("交互式shell会话已启动\n")) // channel.Close() // handleSSHShell(user, channel) case "pty-req": log.Info("A pty-req processing...") req.Reply(true, nil) handleSSHShell(sshConn, channel) channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) channel.Close() closeConnection(conn) log.Info("Channel closed") default: req.Reply(false, nil) } } } func closeConnection(conn net.Conn) { sshMu.Lock() conn.Close() sshCC-- sshMu.Unlock() } func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) { //conf = config.GetYamlConfig() // 检查通道是否支持终端 omcMmlVar := &mmlp.MmlVar{ Version: global.Version, Output: mmlp.DefaultFormatType, MmlHome: conf.Sshd.MmlHome, Limit: conf.Sshd.MaxConnNum, User: sshConn.User(), SessionToken: fmt.Sprintf("%x", sshConn.SessionID()), HttpUri: conf.OMC.HttpUri, UserAgent: config.GetDefaultUserAgent(), TagNE: conf.Sshd.TagNE, } term := term.NewTerminal(channel, fmt.Sprintf("[%s@%s]> ", omcMmlVar.User, omcMmlVar.TagNE)) msg := fmt.Sprintf("\r\nWelcome to the %s server!\r\n", strings.ToUpper(omcMmlVar.TagNE)) term.Write([]byte(msg)) msg = fmt.Sprintf("Last login: %s from %s \r\n\r\n", time.Now().Format(time.RFC1123), sshConn.RemoteAddr()) term.Write([]byte(msg)) // 启动交互式shell会话 for { line, err := term.ReadLine() if err != nil { if err == io.EOF { break } log.Error("Failed to read line: ", err) break } cmdline := strings.TrimSpace(line) if cmdline != "" { logmml.Cmd(cmdline) } var response string switch cmdline { case "hello": term.Write([]byte("Hello, world!\r\n")) goto continueLoop case "time": response = fmt.Sprintf("Current time: %s\r\n", time.Now().Format(time.RFC1123)) term.Write([]byte(response)) goto continueLoop case "exit", "quit": goto exitLoop case "": goto continueLoop case "help": response = fmt.Sprintf("Usage: %s\n", line) term.Write([]byte(response)) goto continueLoop case "dsp variables": response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop case "set mml output=json": // mmlp.SetOmcMmlVarOutput("json") omcMmlVar.Output = "json" response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop case "set mml output=table": // mmlp.SetOmcMmlVarOutput("table") omcMmlVar.Output = "table" response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop default: var mmlCmds []mmlp.MmlCommand mmlLine := strings.TrimSpace(line) if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil { response = fmt.Sprintf("parse command error: %v\n", err) term.Write([]byte(response)) goto continueLoop } // if err = mmlp.ParseMMLParams(&mmlCmds); err != nil { // response := fmt.Sprintf("#2 parse command error: %v\n", err) // term.Write([]byte(response)) // } for _, mmlCmd := range mmlCmds { output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd) if err != nil { response = fmt.Sprintf("translate MML command error: %v\n", err) term.Write([]byte(response)) goto continueLoop } response = string(*output) term.Write(*output) } goto continueLoop } continueLoop: if response != "" { logmml.Ret(response) } continue exitLoop: token := fmt.Sprintf("%x", sshConn.SessionID()) _, err = dborm.XormLogoutUpdateSession(token) if err != nil { log.Error("Failed to XormLogoutUpdateSession:", err) } break } }