576 lines
16 KiB
Go
576 lines
16 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"be.ems/lib/global"
|
||
"be.ems/lib/log"
|
||
"be.ems/lib/mmlp"
|
||
"be.ems/sshsvc/config"
|
||
"be.ems/sshsvc/dborm"
|
||
"be.ems/sshsvc/logmml"
|
||
"be.ems/sshsvc/snmp"
|
||
omctelnet "be.ems/sshsvc/telnet"
|
||
|
||
//"github.com/gliderlabs/ssh"
|
||
"golang.org/x/crypto/ssh"
|
||
"golang.org/x/term"
|
||
)
|
||
|
||
var conf *config.YamlConfig
|
||
|
||
var (
|
||
telnetCC int
|
||
sshCC int
|
||
telnetMu sync.Mutex
|
||
sshMu sync.Mutex
|
||
)
|
||
|
||
func init() {
|
||
conf = config.GetYamlConfig()
|
||
log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel())
|
||
fmt.Printf("OMC sshsvc version: %s\n", global.Version)
|
||
log.Infof("========================= OMC sshsvc startup =========================")
|
||
log.Infof("OMC sshsvc version: %s %s %s", global.Version, global.BuildTime, global.GoVer)
|
||
db := conf.Database
|
||
err := dborm.InitDbClient(db.Type, db.User, db.Password, db.Host, db.Port, db.Name, db.ConnParam)
|
||
if err != nil {
|
||
fmt.Println("dborm.initDbClient err:", err)
|
||
os.Exit(1)
|
||
}
|
||
logmml.InitMmlLogger(conf.Logmml.File, conf.Logmml.Duration, conf.Logmml.Count, "omc", config.GetLogMmlLevel())
|
||
}
|
||
|
||
// readPrivateKey 读取SSH私钥,如果不存在则生成新的密钥对
|
||
func readPrivateKey() ssh.Signer {
|
||
// 检查私钥文件是否存在
|
||
if _, err := os.Stat(conf.Sshd.PrivateKey); os.IsNotExist(err) {
|
||
// 如果文件不存在,创建目录并生成密钥
|
||
dir := filepath.Dir(conf.Sshd.PrivateKey)
|
||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||
log.Fatal("Failed to create .ssh directory:", err)
|
||
os.Exit(2)
|
||
}
|
||
|
||
// 使用ssh-keygen命令生成密钥对
|
||
cmd := exec.Command("ssh-keygen", "-t", "rsa", "-P", "", "-f", conf.Sshd.PrivateKey)
|
||
if err := cmd.Run(); err != nil {
|
||
log.Fatal("Failed to generate SSH key:", err)
|
||
os.Exit(2)
|
||
}
|
||
}
|
||
|
||
// 读取SSH密钥对
|
||
privateKeyBytes, err := os.ReadFile(conf.Sshd.PrivateKey)
|
||
if err != nil {
|
||
log.Fatal("Failed to ReadFile", err)
|
||
os.Exit(2)
|
||
}
|
||
|
||
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes)
|
||
if err != nil {
|
||
log.Fatal("Failed to ParsePrivateKey", err)
|
||
os.Exit(3)
|
||
}
|
||
return privateKey
|
||
}
|
||
|
||
func main() {
|
||
// 配置SSH服务器
|
||
serverConfig := &ssh.ServerConfig{
|
||
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||
// 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配
|
||
// validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt)
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
// if validUser == true {
|
||
// sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID
|
||
// sourceAddr := conn.RemoteAddr().String()
|
||
// timeOut := uint32(conf.Sshd.Timeout)
|
||
// sessionMode := conf.Sshd.Session
|
||
// log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr)
|
||
// affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode)
|
||
// if err != nil {
|
||
// log.Error("Failed to insert Session table:", err)
|
||
// return nil, err
|
||
// }
|
||
// if affected == -1 {
|
||
// err := errors.New("Failed to get session")
|
||
// log.Error(err)
|
||
// return nil, err
|
||
// }
|
||
|
||
// return nil, nil
|
||
// }
|
||
if handleAuth(conf.Sshd.AuthType, conn.User(), string(password)) {
|
||
return nil, nil
|
||
}
|
||
return nil, fmt.Errorf("invalid user or password")
|
||
},
|
||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||
// 这里可以进行公钥验证逻辑,例如检查用户的公钥是否在允许的公钥列表中
|
||
return nil, fmt.Errorf("public key authentication is failed")
|
||
},
|
||
}
|
||
|
||
privateKey := readPrivateKey()
|
||
serverConfig.AddHostKey(privateKey)
|
||
|
||
// 启动SSH服务器
|
||
hostUri := fmt.Sprintf("%s:%d", conf.Sshd.ListenAddr, conf.Sshd.ListenPort)
|
||
listener, err := net.Listen("tcp", hostUri)
|
||
if err != nil {
|
||
log.Fatal("Failed to Listen: ", err)
|
||
os.Exit(4)
|
||
}
|
||
//fmt.Printf("MML SSH server startup, listen port:%d\n", conf.Sshd.ListenPort)
|
||
// 启动telnet服务器
|
||
//telnetUri := fmt.Sprintf("%s:%d", conf.TelnetServer.ListenAddr, conf.TelnetServer.ListenPort)
|
||
// telnetListener, err := net.Listen("tcp", telnetUri)
|
||
// if err != nil {
|
||
// log.Fatal("Failed to Listen: ", err)
|
||
// os.Exit(4)
|
||
// }
|
||
//fmt.Printf("MML Telnet server startup, listen port:%d\n", conf.TelnetServer.ListenPort)
|
||
// telnetconn, err := telnetListener.Accept()
|
||
// if err != nil {
|
||
// log.Fatal("Failed to accept telnet connection: ", err)
|
||
// os.Exit(6)
|
||
// }
|
||
|
||
telnetSvc := omctelnet.TelnetHandler{
|
||
ListenAddr: conf.TelnetServer.ListenAddr,
|
||
ListenPort: conf.TelnetServer.ListenPort,
|
||
UserName: conf.TelnetServer.UserName,
|
||
Password: conf.TelnetServer.Password,
|
||
AuthType: conf.TelnetServer.AuthType,
|
||
MaxConnNum: conf.TelnetServer.MaxConnNum,
|
||
TagNE: conf.TelnetServer.TagNE,
|
||
ListenHost: conf.TelnetServer.ListenAddr + ":" + strconv.Itoa(int(conf.TelnetServer.ListenPort)),
|
||
}
|
||
go telnetSvc.StartTelnetServer()
|
||
// go StartTelnetServer(telnetSvc.ListenHost)
|
||
|
||
snmpSvc := snmp.SNMPService{
|
||
ListenAddr: conf.SNMPServer.ListenAddr,
|
||
ListenPort: conf.SNMPServer.ListenPort,
|
||
UserName: conf.SNMPServer.UserName,
|
||
AuthPass: conf.SNMPServer.AuthPass,
|
||
AuthProto: conf.SNMPServer.AuthProto,
|
||
PrivPass: conf.SNMPServer.PrivPass,
|
||
PrivProto: conf.SNMPServer.PrivProto,
|
||
EngineID: conf.SNMPServer.EngineID,
|
||
TrapPort: conf.SNMPServer.TrapPort,
|
||
TrapListen: conf.SNMPServer.TrapListen,
|
||
TrapBool: conf.SNMPServer.TrapBool,
|
||
TrapTick: conf.SNMPServer.TrapTick,
|
||
TimeOut: conf.SNMPServer.TimeOut,
|
||
TrapTarget: conf.SNMPServer.TrapTarget,
|
||
|
||
ListenHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.ListenPort)),
|
||
TrapHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.TrapPort)),
|
||
SysName: "HLR-0",
|
||
SysStatus: "Normal",
|
||
SysDescr: "HLR server(sysNO=0)",
|
||
SysLocation: "Shanghai",
|
||
SysContact: "",
|
||
SysService: 0,
|
||
}
|
||
|
||
go snmpSvc.StartSNMPServer()
|
||
go snmpSvc.StartTrapServer()
|
||
|
||
for {
|
||
conn, err := listener.Accept()
|
||
if err != nil {
|
||
log.Fatal("Failed to Accept: ", err)
|
||
os.Exit(5)
|
||
}
|
||
|
||
go handleSSHConnection(conn, serverConfig)
|
||
}
|
||
}
|
||
|
||
func handleAuth(authType, userName, password string) bool {
|
||
switch authType {
|
||
case "local":
|
||
if userName == conf.Sshd.UserName && password == conf.Sshd.Password {
|
||
return true
|
||
}
|
||
return false
|
||
case "radius":
|
||
exist, err := dborm.XEngDB().Table("OMC_PUB.sysUser").Where("userName=? AND password=md5(?)", userName, password).Exist()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return exist
|
||
case "omc":
|
||
|
||
default:
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func StartTelnetServer(addr string) {
|
||
listener, err := net.Listen("tcp", addr)
|
||
if err != nil {
|
||
fmt.Println("Error starting Telnet server:", err)
|
||
return
|
||
}
|
||
defer listener.Close()
|
||
fmt.Println("Telnet server started on", addr)
|
||
|
||
for {
|
||
conn, err := listener.Accept()
|
||
if err != nil {
|
||
fmt.Println("Error accepting Telnet connection:", err)
|
||
continue
|
||
}
|
||
|
||
telnetMu.Lock()
|
||
if telnetCC >= int(conf.TelnetServer.MaxConnNum) {
|
||
telnetMu.Unlock()
|
||
io.WriteString(conn, "Connection limit reached. Try again later.\r\n")
|
||
conn.Close()
|
||
continue
|
||
}
|
||
telnetCC++
|
||
telnetMu.Unlock()
|
||
|
||
go handleTelnetConnection(conn)
|
||
}
|
||
}
|
||
|
||
func handleTelnetConnection(conn net.Conn) {
|
||
defer func() {
|
||
telnetMu.Lock()
|
||
telnetCC--
|
||
telnetMu.Unlock()
|
||
}()
|
||
defer conn.Close()
|
||
|
||
reader := bufio.NewReader(conn)
|
||
writer := bufio.NewWriter(conn)
|
||
|
||
// 发送欢迎信息
|
||
writer.WriteString("Welcome to the Telnet server!\r\n")
|
||
writer.Flush()
|
||
|
||
// 请求用户名
|
||
writer.WriteString("Username: ")
|
||
writer.Flush()
|
||
user, _ := reader.ReadString('\n')
|
||
user = strings.TrimSpace(user)
|
||
|
||
// 关闭回显模式
|
||
writer.Write([]byte{255, 251, 1}) // IAC WILL ECHO
|
||
writer.Flush()
|
||
|
||
// 请求密码
|
||
writer.WriteString("Password: ")
|
||
writer.Flush()
|
||
|
||
// 读取密码并清除控制序列
|
||
var passBuilder strings.Builder
|
||
for {
|
||
b, err := reader.ReadByte()
|
||
if err != nil {
|
||
return
|
||
}
|
||
if b == '\n' || b == '\r' {
|
||
break
|
||
}
|
||
if b == 255 { // IAC
|
||
reader.ReadByte() // 忽略下一个字节
|
||
reader.ReadByte() // 忽略下一个字节
|
||
} else {
|
||
passBuilder.WriteByte(b)
|
||
}
|
||
}
|
||
pass := passBuilder.String()
|
||
|
||
// 恢复回显模式
|
||
writer.Write([]byte{255, 252, 1}) // IAC WONT ECHO
|
||
writer.Flush()
|
||
|
||
if handleAuth(conf.TelnetServer.AuthType, user, pass) {
|
||
writer.WriteString("\r\nAuthentication successful!\r\n")
|
||
writer.Flush()
|
||
HandleCommands(user, conf.TelnetServer.TagNE, reader, writer)
|
||
} else {
|
||
writer.WriteString("\r\nAuthentication failed!\r\n")
|
||
writer.Flush()
|
||
}
|
||
}
|
||
|
||
// 处理命令输
|
||
func HandleCommands(user, tag string, reader *bufio.Reader, writer *bufio.Writer) {
|
||
header := fmt.Sprintf("[%s@%s]> ", user, tag)
|
||
clearLine := "\033[2K\r" // ANSI 转义序列,用于清除当前行
|
||
for {
|
||
var commandBuilder strings.Builder
|
||
for {
|
||
b, err := reader.ReadByte()
|
||
if err != nil {
|
||
return
|
||
}
|
||
if b == '\n' || b == '\r' {
|
||
break
|
||
}
|
||
if b == '\xff' || b == '\xfe' || b == '\x01' {
|
||
continue
|
||
}
|
||
if b == 127 { // 处理退格键
|
||
if commandBuilder.Len() > 0 {
|
||
// 手动截断字符串
|
||
command := commandBuilder.String()
|
||
command = command[:len(command)-1]
|
||
commandBuilder.Reset()
|
||
commandBuilder.WriteString(command)
|
||
writer.WriteString("\b \b") // 回显退格
|
||
writer.Flush()
|
||
}
|
||
} else {
|
||
// 回显用户输入的字符
|
||
writer.WriteByte(b)
|
||
writer.Flush()
|
||
commandBuilder.WriteByte(b)
|
||
}
|
||
}
|
||
command := strings.TrimSpace(commandBuilder.String())
|
||
|
||
// 处理其他命令
|
||
switch command {
|
||
case "hello":
|
||
writer.WriteString("\r\nHello, world!\r\n")
|
||
case "time":
|
||
writer.WriteString(fmt.Sprintf("\r\nCurrent time: %s\r\n", time.Now().Format(time.RFC1123)))
|
||
case "exit", "quit":
|
||
writer.WriteString("\r\nGoodbye!\r\n")
|
||
writer.Flush()
|
||
return
|
||
case "":
|
||
default:
|
||
writer.WriteString("\r\nUnknown command\r\n")
|
||
writer.Flush()
|
||
}
|
||
writer.WriteString(clearLine + header)
|
||
writer.Flush()
|
||
}
|
||
}
|
||
|
||
func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) {
|
||
// SSH握手
|
||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
|
||
if err != nil {
|
||
log.Error("Failed to NewServerConn: ", err)
|
||
return
|
||
}
|
||
|
||
log.Infof("SSH connect accepted,client version:%s,user:%s", sshConn.ClientVersion(), sshConn.User())
|
||
|
||
// 处理SSH请求
|
||
go ssh.DiscardRequests(reqs)
|
||
|
||
// 处理SSH通道
|
||
for newChannel := range chans {
|
||
if newChannel.ChannelType() != "session" {
|
||
newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type")
|
||
continue
|
||
}
|
||
|
||
channel, requests, err := newChannel.Accept()
|
||
if err != nil {
|
||
log.Error("Failed to NewServerConn: ", err)
|
||
continue
|
||
}
|
||
|
||
sshMu.Lock()
|
||
sshCC++
|
||
if sshCC > int(conf.Sshd.MaxConnNum) {
|
||
sshMu.Unlock()
|
||
log.Error("Maximum number of connections exceeded")
|
||
channel.Write([]byte(fmt.Sprintf("Connection limit reached (limit=%d). Try again later.\r\n", conf.Sshd.MaxConnNum)))
|
||
conn.Close()
|
||
continue
|
||
}
|
||
sshMu.Unlock()
|
||
|
||
go handleSSHChannel(conn, sshConn, channel, requests)
|
||
}
|
||
}
|
||
|
||
func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) {
|
||
for req := range requests {
|
||
switch req.Type {
|
||
case "exec":
|
||
// 执行远程命令
|
||
command := strings.TrimSpace(string(req.Payload))[4:]
|
||
//cmd := exec.Command(
|
||
cmd := exec.Command("cmd", "/C", command)
|
||
cmd.Stdin = channel
|
||
cmd.Stdout = channel
|
||
cmd.Stderr = channel.Stderr()
|
||
log.Trace("command:", command)
|
||
|
||
err := cmd.Run()
|
||
if err != nil {
|
||
log.Error("Failed to cmd.Run: ", err)
|
||
}
|
||
|
||
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
|
||
channel.Close()
|
||
closeConnection(conn)
|
||
// case "shell":
|
||
// // 处理交互式shell会话
|
||
// // 在这里添加您的处理逻辑,例如启动一个shell进程并将其连接到channel
|
||
// // 请注意,处理交互式shell会话需要更复杂的逻辑,您可能需要使用类似于pty包来处理终端相关的操作
|
||
// channel.Write([]byte("交互式shell会话已启动\n"))
|
||
// channel.Close()
|
||
// handleSSHShell(user, channel)
|
||
case "pty-req":
|
||
log.Info("A pty-req processing...")
|
||
req.Reply(true, nil)
|
||
handleSSHShell(sshConn, channel)
|
||
channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
|
||
channel.Close()
|
||
closeConnection(conn)
|
||
log.Info("Channel closed")
|
||
default:
|
||
req.Reply(false, nil)
|
||
}
|
||
}
|
||
}
|
||
|
||
func closeConnection(conn net.Conn) {
|
||
sshMu.Lock()
|
||
conn.Close()
|
||
sshCC--
|
||
sshMu.Unlock()
|
||
}
|
||
|
||
func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
|
||
//conf = config.GetYamlConfig()
|
||
// 检查通道是否支持终端
|
||
|
||
omcMmlVar := &mmlp.MmlVar{
|
||
Version: global.Version,
|
||
Output: mmlp.DefaultFormatType,
|
||
MmlHome: conf.Sshd.MmlHome,
|
||
Limit: conf.Sshd.MaxConnNum,
|
||
User: sshConn.User(),
|
||
SessionToken: fmt.Sprintf("%x", sshConn.SessionID()),
|
||
HttpUri: conf.OMC.HttpUri,
|
||
UserAgent: config.GetDefaultUserAgent(),
|
||
TagNE: conf.Sshd.TagNE,
|
||
}
|
||
term := term.NewTerminal(channel, fmt.Sprintf("[%s@%s]> ", omcMmlVar.User, omcMmlVar.TagNE))
|
||
msg := fmt.Sprintf("\r\nWelcome to the %s server!\r\n", strings.ToUpper(omcMmlVar.TagNE))
|
||
term.Write([]byte(msg))
|
||
msg = fmt.Sprintf("Last login: %s from %s \r\n\r\n", time.Now().Format(time.RFC1123), sshConn.RemoteAddr())
|
||
term.Write([]byte(msg))
|
||
|
||
// 启动交互式shell会话
|
||
for {
|
||
line, err := term.ReadLine()
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
log.Error("Failed to read line: ", err)
|
||
break
|
||
}
|
||
|
||
cmdline := strings.TrimSpace(line)
|
||
if cmdline != "" {
|
||
logmml.Cmd(cmdline)
|
||
}
|
||
var response string
|
||
switch cmdline {
|
||
case "hello":
|
||
term.Write([]byte("Hello, world!\r\n"))
|
||
goto continueLoop
|
||
case "time":
|
||
response = fmt.Sprintf("Current time: %s\r\n", time.Now().Format(time.RFC1123))
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
case "exit", "quit":
|
||
goto exitLoop
|
||
case "":
|
||
goto continueLoop
|
||
case "help":
|
||
response = fmt.Sprintf("Usage: %s\n", line)
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
|
||
case "dsp variables":
|
||
response = fmt.Sprintf("version: %s\n Output: %s\n", omcMmlVar.Version, omcMmlVar.Output)
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
|
||
case "set mml output=json":
|
||
// mmlp.SetOmcMmlVarOutput("json")
|
||
omcMmlVar.Output = "json"
|
||
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
|
||
case "set mml output=table":
|
||
// mmlp.SetOmcMmlVarOutput("table")
|
||
omcMmlVar.Output = "table"
|
||
response = fmt.Sprintf("set ok, mmlVar.output = %s\n", omcMmlVar.Output)
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
|
||
default:
|
||
var mmlCmds []mmlp.MmlCommand
|
||
mmlLine := strings.TrimSpace(line)
|
||
if err = mmlp.ParseMMLCommand(mmlLine, &mmlCmds); err != nil {
|
||
response = fmt.Sprintf("parse command error: %v\n", err)
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
}
|
||
// if err = mmlp.ParseMMLParams(&mmlCmds); err != nil {
|
||
// response := fmt.Sprintf("#2 parse command error: %v\n", err)
|
||
// term.Write([]byte(response))
|
||
// }
|
||
|
||
for _, mmlCmd := range mmlCmds {
|
||
output, err := mmlp.TransMml2HttpReq(omcMmlVar, &mmlCmd)
|
||
if err != nil {
|
||
response = fmt.Sprintf("translate MML command error: %v\n", err)
|
||
term.Write([]byte(response))
|
||
goto continueLoop
|
||
}
|
||
response = string(*output)
|
||
term.Write(*output)
|
||
}
|
||
goto continueLoop
|
||
}
|
||
continueLoop:
|
||
if response != "" {
|
||
logmml.Ret(response)
|
||
}
|
||
continue
|
||
exitLoop:
|
||
token := fmt.Sprintf("%x", sshConn.SessionID())
|
||
_, err = dborm.XormLogoutUpdateSession(token)
|
||
if err != nil {
|
||
log.Error("Failed to XormLogoutUpdateSession:", err)
|
||
}
|
||
break
|
||
}
|
||
}
|