Files
be.ems/sshsvc/sshsvc.go
2024-09-21 13:54:14 +08:00

553 lines
15 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bufio"
"fmt"
"io"
"net"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
"be.ems/lib/dborm"
"be.ems/lib/global"
"be.ems/lib/log"
"be.ems/lib/mmlp"
"be.ems/sshsvc/config"
"be.ems/sshsvc/logmml"
"be.ems/sshsvc/snmp"
omctelnet "be.ems/sshsvc/telnet"
//"github.com/gliderlabs/ssh"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
var conf *config.YamlConfig
var (
telnetCC int
sshCC int
telnetMu sync.Mutex
sshMu sync.Mutex
)
func init() {
conf = config.GetYamlConfig()
log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel())
fmt.Printf("OMC sshsvc version: %s\n", global.Version)
log.Infof("========================= OMC sshsvc startup =========================")
log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer)
db := conf.Database
err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name, db.ConnParam)
if err != nil {
fmt.Println("dborm.initDbClient err:", err)
os.Exit(1)
}
logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel())
}
func main() {
// 生成SSH密钥对
privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey)
if err != nil {
log.Fatal("Failed to ReadFile", err)
os.Exit(2)
}
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
log.Fatal("Failed to ParsePrivateKey", err)
os.Exit(3)
}
// 配置SSH服务器
serverConfig := &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
// 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配
// validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt)
// if err != nil {
// return nil, err
// }
// if validUser == true {
// sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID
// sourceAddr := conn.RemoteAddr().String()
// timeOut := uint32(conf.Sshd.Timeout)
// sessionMode := conf.Sshd.Session
// log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr)
// affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode)
// if err != nil {
// log.Error("Failed to insert Session table:", err)
// return nil, err
// }
// if affected == -1 {
// err := errors.New("Failed to get session")
// log.Error(err)
// return nil, err
// }
// return nil, nil
// }
if handleAuth(conf.Sshd.AuthType, conn.User(), string(password)) {
return nil, nil
}
return nil, fmt.Errorf("invalid user or password")
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
// 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中
return nil, fmt.Errorf("public key authentication is failed")
},
}
serverConfig.AddHostKey(privateKey)
// 启动SSH服务器
hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort)
listener, err := net.Listen("tcp", hostUri)
if err != nil {
log.Fatal("Failed to Listen: ", err)
os.Exit(4)
}
//fmt.Printf("MML SSH server startup, listen port%d\n", conf.Sshd.ListenPort)
// 启动telnet服务器
//telnetUri := fmt.Sprintf("%s:%d", conf.TelnetServer.ListenAddr, conf.TelnetServer.ListenPort)
// telnetListener, err := net.Listen("tcp", telnetUri)
// if err != nil {
// log.Fatal("Failed to Listen: ", err)
// os.Exit(4)
// }
//fmt.Printf("MML Telnet server startup, listen port%d\n", conf.TelnetServer.ListenPort)
// telnetconn, err := telnetListener.Accept()
// if err != nil {
// log.Fatal("Failed to accept telnet connection: ", err)
// os.Exit(6)
// }
telnetSvc := omctelnet.TelnetHandler{
ListenAddr: conf.TelnetServer.ListenAddr,
ListenPort: conf.TelnetServer.ListenPort,
UserName: conf.TelnetServer.UserName,
Password: conf.TelnetServer.Password,
AuthType: conf.TelnetServer.AuthType,
MaxConnNum: conf.TelnetServer.MaxConnNum,
TagNE: conf.TelnetServer.TagNE,
ListenHost: conf.TelnetServer.ListenAddr + ":" + strconv.Itoa(int(conf.TelnetServer.ListenPort)),
}
go telnetSvc.StartTelnetServer()
// go StartTelnetServer(telnetSvc.ListenHost)
snmpSvc := snmp.SNMPService{
ListenAddr: conf.SNMPServer.ListenAddr,
ListenPort: conf.SNMPServer.ListenPort,
UserName: conf.SNMPServer.UserName,
AuthPass: conf.SNMPServer.AuthPass,
AuthProto: conf.SNMPServer.AuthProto,
PrivPass: conf.SNMPServer.PrivPass,
PrivProto: conf.SNMPServer.PrivProto,
EngineID: conf.SNMPServer.EngineID,
TrapPort: conf.SNMPServer.TrapPort,
TrapListen: conf.SNMPServer.TrapListen,
TrapBool: conf.SNMPServer.TrapBool,
TrapTick: conf.SNMPServer.TrapTick,
TimeOut: conf.SNMPServer.TimeOut,
TrapTarget: conf.SNMPServer.TrapTarget,
ListenHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.ListenPort)),
TrapHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.TrapPort)),
SysName: "HLR-0",
SysStatus: "Normal",
SysDescr: "HLR server(sysNO=0)",
SysLocation: "Shanghai",
SysContact: "",
SysService: 0,
}
go snmpSvc.StartSNMPServer()
go snmpSvc.StartTrapServer()
for {
conn, err := listener.Accept()
if err != nil {
log.Fatal("Failed to Accept: ", err)
os.Exit(5)
}
go handleSSHConnection(conn, serverConfig)
}
}
func handleAuth(authType, userName, password string) bool {
switch authType {
case "local":
if userName == conf.Sshd.UserName && password == conf.Sshd.Password {
return true
}
return false
case "radius":
exist, err := dborm.XEngDB().Table("OMC_PUB.sysUser").Where("userName=? AND password=md5(?)", userName, password).Exist()
if err != nil {
return false
}
return exist
case "omc":
default:
}
return false
}
func StartTelnetServer(addr string) {
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Error starting Telnet server:", err)
return
}
defer listener.Close()
fmt.Println("Telnet server started on", addr)
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Error accepting Telnet connection:", err)
continue
}
telnetMu.Lock()
if telnetCC >= int(conf.TelnetServer.MaxConnNum) {
telnetMu.Unlock()
io.WriteString(conn, "Connection limit reached. Try again later.\r\n")
conn.Close()
continue
}
telnetCC++
telnetMu.Unlock()
go handleTelnetConnection(conn)
}
}
func handleTelnetConnection(conn net.Conn) {
defer func() {
telnetMu.Lock()
telnetCC--
telnetMu.Unlock()
}()
defer conn.Close()
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
// 发送欢迎信息
writer.WriteString("Welcome to the Telnet server!\r\n")
writer.Flush()
// 请求用户名
writer.WriteString("Username: ")
writer.Flush()
user, _ := reader.ReadString('\n')
user = strings.TrimSpace(user)
// 关闭回显模式
writer.Write([]byte{255, 251, 1}) // IAC WILL ECHO
writer.Flush()
// 请求密码
writer.WriteString("Password: ")
writer.Flush()
// 读取密码并清除控制序列
var passBuilder strings.Builder
for {
b, err := reader.ReadByte()
if err != nil {
return
}
if b == '\n' || b == '\r' {
break
}
if b == 255 { // IAC
reader.ReadByte() // 忽略下一个字节
reader.ReadByte() // 忽略下一个字节
} else {
passBuilder.WriteByte(b)
}
}
pass := passBuilder.String()
// 恢复回显模式
writer.Write([]byte{255, 252, 1}) // IAC WONT ECHO
writer.Flush()
if handleAuth(conf.TelnetServer.AuthType, user, pass) {
writer.WriteString("\r\nAuthentication successful!\r\n")
writer.Flush()
HandleCommands(user, conf.TelnetServer.TagNE, reader, writer)
} else {
writer.WriteString("\r\nAuthentication failed!\r\n")
writer.Flush()
}
}
// 处理命令输
func HandleCommands(user, tag string, reader *bufio.Reader, writer *bufio.Writer) {
header := fmt.Sprintf("[%s@%s]> ", user, tag)
clearLine := "\033[2K\r" // ANSI 转义序列,用于清除当前行
for {
var commandBuilder strings.Builder
for {
b, err := reader.ReadByte()
if err != nil {
return
}
if b == '\n' || b == '\r' {
break
}
if b == '\xff' || b == '\xfe' || b == '\x01' {
continue
}
if b == 127 { // 处理退格键
if commandBuilder.Len() > 0 {
// 手动截断字符串
command := commandBuilder.String()
command = command[:len(command)-1]
commandBuilder.Reset()
commandBuilder.WriteString(command)
writer.WriteString("\b \b") // 回显退格
writer.Flush()
}
} else {
// 回显用户输入的字符
writer.WriteByte(b)
writer.Flush()
commandBuilder.WriteByte(b)
}
}
command := strings.TrimSpace(commandBuilder.String())
// 处理其他命令
switch command {
case "hello":
writer.WriteString("\r\nHello, world!\r\n")
case "time":
writer.WriteString(fmt.Sprintf("\r\nCurrent time: %s\r\n", time.Now().Format(time.RFC1123)))
case "exit", "quit":
writer.WriteString("\r\nGoodbye!\r\n")
writer.Flush()
return
case "":
default:
writer.WriteString("\r\nUnknown command\r\n")
writer.Flush()
}
writer.WriteString(clearLine + header)
writer.Flush()
}
}
func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) {
// SSH握手
sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
log.Error("Failed to NewServerConn: ", err)
return
}
log.Infof("SSH connect acceptedclient version%suser%s", sshConn.ClientVersion(), sshConn.User())
// 处理SSH请求
go ssh.DiscardRequests(reqs)
// 处理SSH通道
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Error("Failed to NewServerConn: ", err)
continue
}
sshMu.Lock()
sshCC++
if sshCC > int(conf.Sshd.MaxConnNum) {
sshMu.Unlock()
log.Error("Maximum number of connections exceeded")
channel.Write([]byte(fmt.Sprintf("Connection limit reached (limit=%d). Try again later.\r\n", conf.Sshd.MaxConnNum)))
conn.Close()
continue
}
sshMu.Unlock()
go handleSSHChannel(conn, sshConn, channel, requests)
}
}
func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) {
for req := range requests {
switch req.Type {
case "exec":
// 执行远程命令
command := strings.TrimSpace(string(req.Payload))[4:]
//cmd := exec.Command(
cmd := exec.Command("cmd", "/C", command)
cmd.Stdin = channel
cmd.Stdout = channel
cmd.Stderr = channel.Stderr()
log.Trace("command:", command)
err := cmd.Run()
if err != nil {
log.Error("Failed to cmd.Run: ", err)
}
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
channel.Close()
closeConnection(conn)
// case "shell":
// // 处理交互式shell会话
// // 在这里添加您的处理逻辑例如启动一个shell进程并将其连接到channel
// // 请注意处理交互式shell会话需要更复杂的逻辑您可能需要使用类似于pty包来处理终端相关的操作
// channel.Write([]byte("交互式shell会话已启动\n"))
// channel.Close()
// handleSSHShell(user, channel)
case "pty-req":
log.Info("A pty-req processing...")
req.Reply(true, nil)
handleSSHShell(sshConn, channel)
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
channel.Close()
closeConnection(conn)
log.Info("Channel closed")
default:
req.Reply(false, nil)
}
}
}
func closeConnection(conn net.Conn) {
sshMu.Lock()
conn.Close()
sshCC--
sshMu.Unlock()
}
func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
//conf = config.GetYamlConfig()
// 检查通道是否支持终端
omcMmlVar := &mmlp.MmlVar{
Version: global.Version,
Output: mmlp.DefaultFormatType,
MmlHome: conf.Sshd.MmlHome,
Limit: conf.Sshd.MaxConnNum,
User: sshConn.User(),
SessionToken: fmt.Sprintf("%x", sshConn.SessionID()),
HttpUri: conf.OMC.HttpUri,
UserAgent: config.GetDefaultUserAgent(),
TagNE: conf.Sshd.TagNE,
}
term := term.NewTerminal(channel, fmt.Sprintf("[%s@%s]> ", omcMmlVar.User, omcMmlVar.TagNE))
msg := fmt.Sprintf("\r\nWelcome to the %s server!\r\n", strings.ToUpper(omcMmlVar.TagNE))
term.Write([]byte(msg))
msg = fmt.Sprintf("Last login: %s from %s \r\n\r\n", time.Now().Format(time.RFC1123), sshConn.RemoteAddr())
term.Write([]byte(msg))
// 启动交互式shell会话
for {
line, err := term.ReadLine()
if err != nil {
if err == io.EOF {
break
}
log.Error("Failed to read line: ", err)
break
}
cmdline := strings.TrimSpace(line)
if cmdline != "" {
logmml.Cmd(cmdline)
}
var response string
switch cmdline {
case "hello":
term.Write([]byte("Hello, world!\r\n"))
goto continueLoop
case "time":
response = fmt.Sprintf("Current time: %s\r\n", time.Now().Format(time.RFC1123))
term.Write([]byte(response))
goto continueLoop
case "exit", "quit":
goto exitLoop
case "":
goto continueLoop
case "help":
response = fmt.Sprintf("Usage: %s\n", line)
term.Write([]byte(response))
goto continueLoop
case "dsp variables":
response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
case "set mml output=json":
// mmlp.SetOmcMmlVarOutput("json")
omcMmlVar.Output = "json"
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
case "set mml output=table":
// mmlp.SetOmcMmlVarOutput("table")
omcMmlVar.Output = "table"
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
term.Write([]byte(response))
goto continueLoop
default:
var mmlCmds []mmlp.MmlCommand
mmlLine := strings.TrimSpace(line)
if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil {
response = fmt.Sprintf("parse command error: %v\n", err)
term.Write([]byte(response))
goto continueLoop
}
// if err = mmlp.ParseMMLParams(&mmlCmds); err != nil {
// response := fmt.Sprintf("#2 parse command error: %v\n", err)
// term.Write([]byte(response))
// }
for _, mmlCmd := range mmlCmds {
output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd)
if err != nil {
response = fmt.Sprintf("translate MML command error: %v\n", err)
term.Write([]byte(response))
goto continueLoop
}
response = string(*output)
term.Write(*output)
}
goto continueLoop
}
continueLoop:
if response != "" {
logmml.Ret(response)
}
continue
exitLoop:
token := fmt.Sprintf("%x", sshConn.SessionID())
_, err = dborm.XormLogoutUpdateSession(token)
if err != nil {
log.Error("Failed to XormLogoutUpdateSession:", err)
}
break
}
}