update at 2023/08/14

This commit is contained in:
2023-08-14 21:41:37 +08:00
parent a039a664f1
commit 44e8cbee2c
255 changed files with 20426 additions and 233 deletions

301
sshsvc/sshsvc.go Normal file
View File

@@ -0,0 +1,301 @@
package main
import (
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"strings"
"ems.agt/lib/dborm"
"ems.agt/lib/global"
"ems.agt/lib/log"
"ems.agt/lib/mmlp"
"ems.agt/sshsvc/config"
"ems.agt/sshsvc/logmml"
//"github.com/gliderlabs/ssh"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
var connNum int = 0
var conf *config.YamlConfig
func init() {
conf = config.GetYamlConfig()
log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel())
fmt.Printf("OMC sshsvc version: %s\n", global.Version)
log.Infof("========================= OMC sshsvc startup =========================")
log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer)
db := conf.Database
err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name)
if err != nil {
fmt.Println("dborm.initDbClient err:", err)
os.Exit(1)
}
logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel())
}
func main() {
// 生成SSH密钥对
privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey)
if err != nil {
log.Fatal("Failed to ReadFile", err)
os.Exit(2)
}
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
log.Fatal("Failed to ParsePrivateKey", err)
os.Exit(3)
}
// 配置SSH服务器
serverConfig := &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
// 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配
validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt)
if err != nil {
return nil, err
}
if validUser == true {
sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID
sourceAddr := conn.RemoteAddr().String()
timeOut := uint32(conf.Sshd.Timeout)
sessionMode := conf.Sshd.Session
log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr)
affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode)
if err != nil {
log.Error("Failed to insert Session table:", err)
return nil, err
}
if affected == -1 {
err := errors.New("Failed to get session")
log.Error(err)
return nil, err
}
return nil, nil
}
// if conn.User() == "admin" && string(password) == "123456" {
// return nil, nil
// }
return nil, fmt.Errorf("invalid user or password")
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
// 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中
return nil, fmt.Errorf("public key authentication is failed")
},
}
serverConfig.AddHostKey(privateKey)
// 启动SSH服务器
hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort)
listener, err := net.Listen("tcp", hostUri)
if err != nil {
log.Fatal("Failed to Listen: ", err)
os.Exit(4)
}
fmt.Printf("MML SSH server startup, listen port%d\n", conf.Sshd.ListenPort)
for {
conn, err := listener.Accept()
if err != nil {
log.Fatal("Failed to Accept: ", err)
os.Exit(5)
}
go handleSSHConnection(conn, serverConfig)
}
}
func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) {
// SSH握手
sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
log.Error("Failed to NewServerConn: ", err)
return
}
log.Infof("SSH connect acceptedclient version%suser%s", sshConn.ClientVersion(), sshConn.User())
// 处理SSH请求
go ssh.DiscardRequests(reqs)
// 处理SSH通道
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Error("Failed to NewServerConn: ", err)
continue
}
if connNum > int(conf.Sshd.MaxConnNum) {
log.Error("Maximum number of connections exceeded")
//conn.Write([]byte("Reach max connections"))
conn.Close()
continue
}
connNum++
go handleSSHChannel(conn, sshConn, channel, requests)
}
}
func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) {
for req := range requests {
switch req.Type {
case "exec":
// 执行远程命令
command := strings.TrimSpace(string(req.Payload))[4:]
//cmd := exec.Command(
cmd := exec.Command("cmd", "/C", command)
cmd.Stdin = channel
cmd.Stdout = channel
cmd.Stderr = channel.Stderr()
log.Trace("command:", command)
err := cmd.Run()
if err != nil {
log.Error("Failed to cmd.Run: ", err)
}
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
channel.Close()
closeConnection(conn)
// case "shell":
// // 处理交互式shell会话
// // 在这里添加您的处理逻辑例如启动一个shell进程并将其连接到channel
// // 请注意处理交互式shell会话需要更复杂的逻辑您可能需要使用类似于pty包来处理终端相关的操作
// channel.Write([]byte("交互式shell会话已启动\n"))
// channel.Close()
// handleSSHShell(user, channel)
case "pty-req":
log.Info("A pty-req processing...")
req.Reply(true, nil)
handleSSHShell(sshConn, channel)
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
channel.Close()
closeConnection(conn)
log.Info("Channel closed")
default:
req.Reply(false, nil)
}
}
}
func closeConnection(conn net.Conn) {
conn.Close()
connNum--
}
func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
//conf = config.GetYamlConfig()
// 检查通道是否支持终端
omcMmlVar := &mmlp.MmlVar{
Version: "16.1.1",
Output: mmlp.DefaultFormatType,
Limit: 50,
User: sshConn.User(),
SessionToken: fmt.Sprintf("%x", sshConn.SessionID()),
HttpUri: conf.OMC.HttpUri,
UserAgent: config.GetDefaultUserAgent(),
}
term := term.NewTerminal(channel, fmt.Sprintf("[%s@omc]> ", omcMmlVar.User))
// 启动交互式shell会话
for {
line, err := term.ReadLine()
if err != nil {
if err == io.EOF {
break
}
log.Error("Failed to read line: ", err)
break
}
cmdline := strings.TrimSpace(line)
if cmdline != "" {
logmml.Cmd(cmdline)
}
var response string
switch cmdline {
case "exit", "quit":
goto exitLoop
case "":
goto continueLoop
case "help":
response = fmt.Sprintf("Usage: %s\n", line)
term.Write([]byte(response))
goto continueLoop
case "dsp variables":
response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
case "set mml output=json":
// mmlp.SetOmcMmlVarOutput("json")
omcMmlVar.Output = "json"
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
case "set mml output=table":
// mmlp.SetOmcMmlVarOutput("table")
omcMmlVar.Output = "table"
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
default:
var mmlCmds []mmlp.MmlCommand
mmlLine := strings.TrimSpace(line)
if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil {
response = fmt.Sprintf("parse command error: %v\n", err)
term.Write([]byte(response))
goto continueLoop
}
// if err = mmlp.ParseMMLParams(&mmlCmds); err != nil {
// response := fmt.Sprintf("#2 parse command error: %v\n", err)
// term.Write([]byte(response))
// }
for _, mmlCmd := range mmlCmds {
output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd)
if err != nil {
response = fmt.Sprintf("translate MML command error: %v\n", err)
term.Write([]byte(response))
goto continueLoop
}
response = string(*output)
term.Write(*output)
}
goto continueLoop
}
continueLoop:
if response != "" {
logmml.Ret(response)
}
continue
exitLoop:
token := fmt.Sprintf("%x", sshConn.SessionID())
_, err = dborm.XormLogoutUpdateSession(token)
if err != nil {
log.Error("Failed to XormLogoutUpdateSession:", err)
}
break
}
}