Files
be.ems/sshsvc/sshsvc.go
2023-08-14 21:41:37 +08:00

302 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"strings"
"ems.agt/lib/dborm"
"ems.agt/lib/global"
"ems.agt/lib/log"
"ems.agt/lib/mmlp"
"ems.agt/sshsvc/config"
"ems.agt/sshsvc/logmml"
//"github.com/gliderlabs/ssh"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
var connNum int = 0
var conf *config.YamlConfig
func init() {
conf = config.GetYamlConfig()
log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel())
fmt.Printf("OMC sshsvc version: %s\n", global.Version)
log.Infof("========================= OMC sshsvc startup =========================")
log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer)
db := conf.Database
err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name)
if err != nil {
fmt.Println("dborm.initDbClient err:", err)
os.Exit(1)
}
logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel())
}
func main() {
// 生成SSH密钥对
privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey)
if err != nil {
log.Fatal("Failed to ReadFile", err)
os.Exit(2)
}
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
log.Fatal("Failed to ParsePrivateKey", err)
os.Exit(3)
}
// 配置SSH服务器
serverConfig := &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
// 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配
validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt)
if err != nil {
return nil, err
}
if validUser == true {
sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID
sourceAddr := conn.RemoteAddr().String()
timeOut := uint32(conf.Sshd.Timeout)
sessionMode := conf.Sshd.Session
log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr)
affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode)
if err != nil {
log.Error("Failed to insert Session table:", err)
return nil, err
}
if affected == -1 {
err := errors.New("Failed to get session")
log.Error(err)
return nil, err
}
return nil, nil
}
// if conn.User() == "admin" && string(password) == "123456" {
// return nil, nil
// }
return nil, fmt.Errorf("invalid user or password")
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
// 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中
return nil, fmt.Errorf("public key authentication is failed")
},
}
serverConfig.AddHostKey(privateKey)
// 启动SSH服务器
hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort)
listener, err := net.Listen("tcp", hostUri)
if err != nil {
log.Fatal("Failed to Listen: ", err)
os.Exit(4)
}
fmt.Printf("MML SSH server startup, listen port%d\n", conf.Sshd.ListenPort)
for {
conn, err := listener.Accept()
if err != nil {
log.Fatal("Failed to Accept: ", err)
os.Exit(5)
}
go handleSSHConnection(conn, serverConfig)
}
}
func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) {
// SSH握手
sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
log.Error("Failed to NewServerConn: ", err)
return
}
log.Infof("SSH connect acceptedclient version%suser%s", sshConn.ClientVersion(), sshConn.User())
// 处理SSH请求
go ssh.DiscardRequests(reqs)
// 处理SSH通道
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Error("Failed to NewServerConn: ", err)
continue
}
if connNum > int(conf.Sshd.MaxConnNum) {
log.Error("Maximum number of connections exceeded")
//conn.Write([]byte("Reach max connections"))
conn.Close()
continue
}
connNum++
go handleSSHChannel(conn, sshConn, channel, requests)
}
}
func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) {
for req := range requests {
switch req.Type {
case "exec":
// 执行远程命令
command := strings.TrimSpace(string(req.Payload))[4:]
//cmd := exec.Command(
cmd := exec.Command("cmd", "/C", command)
cmd.Stdin = channel
cmd.Stdout = channel
cmd.Stderr = channel.Stderr()
log.Trace("command:", command)
err := cmd.Run()
if err != nil {
log.Error("Failed to cmd.Run: ", err)
}
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
channel.Close()
closeConnection(conn)
// case "shell":
// // 处理交互式shell会话
// // 在这里添加您的处理逻辑例如启动一个shell进程并将其连接到channel
// // 请注意处理交互式shell会话需要更复杂的逻辑您可能需要使用类似于pty包来处理终端相关的操作
// channel.Write([]byte("交互式shell会话已启动\n"))
// channel.Close()
// handleSSHShell(user, channel)
case "pty-req":
log.Info("A pty-req processing...")
req.Reply(true, nil)
handleSSHShell(sshConn, channel)
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
channel.Close()
closeConnection(conn)
log.Info("Channel closed")
default:
req.Reply(false, nil)
}
}
}
func closeConnection(conn net.Conn) {
conn.Close()
connNum--
}
func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
//conf = config.GetYamlConfig()
// 检查通道是否支持终端
omcMmlVar := &mmlp.MmlVar{
Version: "16.1.1",
Output: mmlp.DefaultFormatType,
Limit: 50,
User: sshConn.User(),
SessionToken: fmt.Sprintf("%x", sshConn.SessionID()),
HttpUri: conf.OMC.HttpUri,
UserAgent: config.GetDefaultUserAgent(),
}
term := term.NewTerminal(channel, fmt.Sprintf("[%s@omc]> ", omcMmlVar.User))
// 启动交互式shell会话
for {
line, err := term.ReadLine()
if err != nil {
if err == io.EOF {
break
}
log.Error("Failed to read line: ", err)
break
}
cmdline := strings.TrimSpace(line)
if cmdline != "" {
logmml.Cmd(cmdline)
}
var response string
switch cmdline {
case "exit", "quit":
goto exitLoop
case "":
goto continueLoop
case "help":
response = fmt.Sprintf("Usage: %s\n", line)
term.Write([]byte(response))
goto continueLoop
case "dsp variables":
response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
case "set mml output=json":
// mmlp.SetOmcMmlVarOutput("json")
omcMmlVar.Output = "json"
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
case "set mml output=table":
// mmlp.SetOmcMmlVarOutput("table")
omcMmlVar.Output = "table"
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
default:
var mmlCmds []mmlp.MmlCommand
mmlLine := strings.TrimSpace(line)
if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil {
response = fmt.Sprintf("parse command error: %v\n", err)
term.Write([]byte(response))
goto continueLoop
}
// if err = mmlp.ParseMMLParams(&mmlCmds); err != nil {
// response := fmt.Sprintf("#2 parse command error: %v\n", err)
// term.Write([]byte(response))
// }
for _, mmlCmd := range mmlCmds {
output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd)
if err != nil {
response = fmt.Sprintf("translate MML command error: %v\n", err)
term.Write([]byte(response))
goto continueLoop
}
response = string(*output)
term.Write(*output)
}
goto continueLoop
}
continueLoop:
if response != "" {
logmml.Ret(response)
}
continue
exitLoop:
token := fmt.Sprintf("%x", sshConn.SessionID())
_, err = dborm.XormLogoutUpdateSession(token)
if err != nil {
log.Error("Failed to XormLogoutUpdateSession:", err)
}
break
}
}