package main import ( "errors" "fmt" "io" "net" "os" "os/exec" "strings" "be.ems/lib/dborm" "be.ems/lib/global" "be.ems/lib/log" "be.ems/lib/mmlp" "be.ems/sshsvc/config" "be.ems/sshsvc/logmml" //"github.com/gliderlabs/ssh" "golang.org/x/crypto/ssh" "golang.org/x/term" ) var connNum int = 0 var conf *config.YamlConfig func init() { conf = config.GetYamlConfig() log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel()) fmt.Printf("OMC sshsvc version: %s\n", global.Version) log.Infof("========================= OMC sshsvc startup =========================") log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer) db := conf.Database err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name, db.ConnParam) if err != nil { fmt.Println("dborm.initDbClient err:", err) os.Exit(1) } logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel()) } func main() { // 生成SSH密钥对 privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey) if err != nil { log.Fatal("Failed to ReadFile:", err) os.Exit(2) } privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) if err != nil { log.Fatal("Failed to ParsePrivateKey", err) os.Exit(3) } // 配置SSH服务器 serverConfig := &ssh.ServerConfig{ PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { // 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配 validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt) if err != nil { return nil, err } if validUser == true { sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID sourceAddr := conn.RemoteAddr().String() timeOut := uint32(conf.Sshd.Timeout) sessionMode := conf.Sshd.Session log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr) affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode) if err != nil { log.Error("Failed to insert Session table:", err) return nil, err } if affected == -1 { err := errors.New("Failed to get session") log.Error(err) return nil, err } return nil, nil } // if conn.User() == "admin" && string(password) == "123456" { // return nil, nil // } return nil, fmt.Errorf("invalid user or password") }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中 return nil, fmt.Errorf("public key authentication is failed") }, } serverConfig.AddHostKey(privateKey) // 启动SSH服务器 hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort) listener, err := net.Listen("tcp", hostUri) if err != nil { log.Fatal("Failed to Listen: ", err) os.Exit(4) } fmt.Printf("MML SSH server startup, listen port:%d\n", conf.Sshd.ListenPort) for { conn, err := listener.Accept() if err != nil { log.Fatal("Failed to Accept: ", err) os.Exit(5) } go handleSSHConnection(conn, serverConfig) } } func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) { // SSH握手 sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig) if err != nil { log.Error("Failed to NewServerConn: ", err) return } log.Infof("SSH connect accepted,client version:%s,user:%s", sshConn.ClientVersion(), sshConn.User()) // 处理SSH请求 go ssh.DiscardRequests(reqs) // 处理SSH通道 for newChannel := range chans { if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { log.Error("Failed to NewServerConn: ", err) continue } connNum++ if connNum > int(conf.Sshd.MaxConnNum) { log.Error("Maximum number of connections exceeded") //conn.Write([]byte("Reach max connections")) conn.Close() continue } go handleSSHChannel(conn, sshConn, channel, requests) } } func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) { for req := range requests { switch req.Type { case "exec": // 执行远程命令 command := strings.TrimSpace(string(req.Payload))[4:] //cmd := exec.Command( cmd := exec.Command("cmd", "/C", command) cmd.Stdin = channel cmd.Stdout = channel cmd.Stderr = channel.Stderr() log.Trace("command:", command) err := cmd.Run() if err != nil { log.Error("Failed to cmd.Run: ", err) } channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) channel.Close() closeConnection(conn) // case "shell": // // 处理交互式shell会话 // // 在这里添加您的处理逻辑,例如启动一个shell进程并将其连接到channel // // 请注意,处理交互式shell会话需要更复杂的逻辑,您可能需要使用类似于pty包来处理终端相关的操作 // channel.Write([]byte("交互式shell会话已启动\n")) // channel.Close() // handleSSHShell(user, channel) case "pty-req": log.Info("A pty-req processing...") req.Reply(true, nil) handleSSHShell(sshConn, channel) channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) channel.Close() closeConnection(conn) log.Info("Channel closed") default: req.Reply(false, nil) } } } func closeConnection(conn net.Conn) { conn.Close() connNum-- } func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) { //conf = config.GetYamlConfig() // 检查通道是否支持终端 omcMmlVar := &mmlp.MmlVar{ Version: "16.1.1", Output: mmlp.DefaultFormatType, MmlHome: conf.Sshd.MmlHome, Limit: 50, User: sshConn.User(), SessionToken: fmt.Sprintf("%x", sshConn.SessionID()), HttpUri: conf.OMC.HttpUri, UserAgent: config.GetDefaultUserAgent(), } term := term.NewTerminal(channel, fmt.Sprintf("[%s@omc]> ", omcMmlVar.User)) // 启动交互式shell会话 for { line, err := term.ReadLine() if err != nil { if err == io.EOF { break } log.Error("Failed to read line: ", err) break } cmdline := strings.TrimSpace(line) if cmdline != "" { logmml.Cmd(cmdline) } var response string switch cmdline { case "exit", "quit": goto exitLoop case "": goto continueLoop case "help": response = fmt.Sprintf("Usage: %s\n", line) term.Write([]byte(response)) goto continueLoop case "dsp variables": response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop case "set mml output=json": // mmlp.SetOmcMmlVarOutput("json") omcMmlVar.Output = "json" response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop case "set mml output=table": // mmlp.SetOmcMmlVarOutput("table") omcMmlVar.Output = "table" response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output) term.Write([]byte(response)) goto continueLoop default: var mmlCmds []mmlp.MmlCommand mmlLine := strings.TrimSpace(line) if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil { response = fmt.Sprintf("parse command error: %v\n", err) term.Write([]byte(response)) goto continueLoop } // if err = mmlp.ParseMMLParams(&mmlCmds); err != nil { // response := fmt.Sprintf("#2 parse command error: %v\n", err) // term.Write([]byte(response)) // } for _, mmlCmd := range mmlCmds { output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd) if err != nil { response = fmt.Sprintf("translate MML command error: %v\n", err) term.Write([]byte(response)) goto continueLoop } response = string(*output) term.Write(*output) } goto continueLoop } continueLoop: if response != "" { logmml.Ret(response) } continue exitLoop: token := fmt.Sprintf("%x", sshConn.SessionID()) _, err = dborm.XormLogoutUpdateSession(token) if err != nil { log.Error("Failed to XormLogoutUpdateSession:", err) } break } }