Merge remote-tracking branch 'origin/main' into multi-tenant

This commit is contained in:
TsMask
2024-09-20 20:23:39 +08:00
142 changed files with 7687 additions and 2096 deletions

View File

@@ -48,8 +48,42 @@ type YamlConfig struct {
Timeout uint16 `yaml:"timeout"`
Session string `yaml:"session"`
MmlHome string `yaml:"mmlHome"`
UserName string `yaml:"userName"`
Password string `yaml:"password"`
AuthType string `yaml:"authType"`
TagNE string `yaml:"tagNE"`
} `yaml:"sshd"`
TelnetServer struct {
ListenAddr string `yaml:"listenAddr"`
ListenPort uint16 `yaml:"listenPort"`
MaxConnNum uint8 `yaml:"maxConnNum"`
Timeout uint16 `yaml:"timeout"`
Session string `yaml:"session"`
MmlHome string `yaml:"mmlHome"`
UserName string `yaml:"userName"`
Password string `yaml:"password"`
AuthType string `yaml:"authType"`
TagNE string `yaml:"tagNE"`
} `yaml:"telnetServer"`
SNMPServer struct {
ListenAddr string `yaml:"listenAddr"`
ListenPort uint16 `yaml:"listenPort"`
UserName string `yaml:"userName"`
AuthPass string `yaml:"authPass"`
AuthProto string `yaml:"authProto"`
PrivPass string `yaml:"privPass"`
PrivProto string `yaml:"privProto"`
EngineID string `yaml:"engineID"`
TrapPort uint16 `yaml:"trapPort"`
TrapListen bool `yaml:"trapListen"`
TrapBool bool `yaml:"trapBool"`
TrapTick uint16 `yaml:"trapTick"`
TimeOut uint16 `yaml:"timeOut"`
TrapTarget string `yaml:"trapTarget"`
} `yaml:"snmpServer"`
Database DbConfig `yaml:"database"`
OMC struct {

View File

@@ -21,15 +21,52 @@ logmml:
# ssh service listen ipv4/v6 and port, support multiple routines
# ip: 0.0.0.0 or ::0, support IPv4/v6
# session: single/multiple session for one user
# authType: local/radius
sshd:
listenAddr: 0.0.0.0
listenPort: 2222
listenPort: 32222
privateKey: ./.ssh/id_rsa
maxConnNum: 20
maxConnNum: 2
timeout: 1800
session: multiple
mmlHome: ./mmlhome
userName: manager
password: pass123
authType: local
tagNE: hlr
# authType: local/omc
telnetServer:
listenAddr: 0.0.0.0
listenPort: 32323
maxConnNum: 2
timeout: 1800
session: multiple
mmlHome: ./mmlhome
userName: manager
password: pass123
authType: local
tagNE: hlr
# authproto: NoAuth/MD5/SHA
# privProto: NoPriv/DES/AES/AES192/AES256
snmpServer:
listenAddr: '[::]'
listenPort: 34957
userName: manager
authPass: pass123
authproto: MD5
privPass: "3F2A1B4C5D6E7F8A9B0C1D2E3F4A5B6C7D8E9F0A1B2C3D4E"
privProto: DES
#engineID: "800007db03360102101101"
engineID: "8000000004323030313a6462383a3a39313636"
trapPort: 34958
trapListen: true
trapBool: true
trapTick: 60
timeOut: 5
trapTarget: "2001:db8::9219"
database:
type: mysql
user: root

View File

@@ -1,7 +1,7 @@
# Makefile for OMC-OMC-crontask project
PROJECT = OMC
VERSION = 2.2408.1
VERSION = 2.2409.3
LIBDIR = be.ems/lib
BINNAME = sshsvc

439
sshsvc/snmp/snmp.go Normal file
View File

@@ -0,0 +1,439 @@
package snmp
import (
"flag"
"fmt"
"log"
"net"
"os"
"path/filepath"
"strings"
"time"
g "github.com/gosnmp/gosnmp"
"github.com/slayercat/GoSNMPServer"
"github.com/slayercat/GoSNMPServer/mibImps"
)
type SNMPService struct {
ListenAddr string
ListenPort uint16
UserName string
AuthPass string
AuthProto string
PrivPass string
PrivProto string
EngineID string
TrapPort uint16
TrapListen bool
TrapBool bool
TrapTick uint16
TimeOut uint16
TrapTarget string
ListenHost string
TrapHost string
SysDescr string
SysService int
}
func (s *SNMPService) getAuthProto() g.SnmpV3AuthProtocol {
switch s.AuthProto {
case "NoAuth":
return g.NoAuth
case "MD5":
return g.MD5
case "SHA":
return g.SHA
default:
}
return g.MD5
}
func (s *SNMPService) getPrivProto() g.SnmpV3PrivProtocol {
switch s.PrivProto {
case "NoPriv":
return g.NoPriv
case "DES":
return g.DES
case "AES":
return g.AES
case "AES192":
return g.AES192
case "AES256":
return g.AES256
default:
}
return g.DES
}
func (s *SNMPService) setSecParamsList() []g.UsmSecurityParameters {
var secParamsList = []g.UsmSecurityParameters{
{
UserName: s.UserName,
AuthenticationProtocol: s.getAuthProto(),
AuthenticationPassphrase: s.AuthPass,
PrivacyProtocol: s.getPrivProto(),
PrivacyPassphrase: s.PrivPass,
AuthoritativeEngineID: s.EngineID,
},
// {
// UserName: "myuser2",
// AuthenticationProtocol: g.SHA,
// AuthenticationPassphrase: "mypassword2",
// PrivacyProtocol: g.DES,
// PrivacyPassphrase: "myprivacy2",
// AuthoritativeEngineID: s.EngineID,
// },
// {
// UserName: "myuser2",
// AuthenticationProtocol: g.MD5,
// AuthenticationPassphrase: "mypassword2",
// PrivacyProtocol: g.AES,
// PrivacyPassphrase: "myprivacy2",
// AuthoritativeEngineID: s.EngineID,
// },
}
return secParamsList
}
func (s *SNMPService) StartSNMPServer() {
// 设置引擎启动次数和引varvar
var engineBoots uint32 = 1
//var engineTime uint32 = uint32(time.Now().Unix() % 2147483647) // 使用当前时间初始化
//var engineTime uint32 = 3600 // 使用当前时间初始化
master := GoSNMPServer.MasterAgent{
Logger: GoSNMPServer.NewDefaultLogger(),
SecurityConfig: GoSNMPServer.SecurityConfig{
NoSecurity: true,
AuthoritativeEngineBoots: engineBoots,
// OnGetAuthoritativeEngineTime: func() uint32 {
// return engineTime
// },
//AuthoritativeEngineID: GoSNMPServer.SNMPEngineID{EngineIDData: "0x800007DB03360102101100"},
Users: s.setSecParamsList(),
},
SubAgents: []*GoSNMPServer.SubAgent{
{
UserErrorMarkPacket: false,
CommunityIDs: []string{"public", "private"}, // SNMPv1 and SNMPv2c community strings
OIDs: s.handleOIDs(),
//OIDs: mibImps.All(),
},
},
}
server := GoSNMPServer.NewSNMPServer(master)
err := server.ListenUDP("udp", s.ListenHost)
if err != nil {
log.Fatalf("Error in listen: %+v", err)
}
server.ServeForever()
}
func (s *SNMPService) handleOIDs() []*GoSNMPServer.PDUValueControlItem {
customOIDs := []*GoSNMPServer.PDUValueControlItem{
{
OID: "1.3.6.1.2.1.1.1.0",
Type: g.OctetString,
OnGet: func() (value interface{}, err error) {
return s.SysDescr, nil
},
OnSet: func(value interface{}) error {
// 将[]uint8转换为string
if v, ok := value.([]uint8); ok {
s.SysDescr = string(v)
log.Printf("Set request for OID 1.3.6.1.2.1.1.1.0 with value %v", s.SysDescr)
return nil
}
return nil
},
},
{
OID: "1.3.6.1.2.1.1.3.0",
Type: g.TimeTicks,
OnGet: func() (value interface{}, err error) {
return uint32(time.Now().Unix()), nil
},
},
{
OID: "1.3.6.1.2.1.1.7.0",
Type: g.Integer,
OnGet: func() (value interface{}, err error) {
return s.SysService, nil
},
OnSet: func(value interface{}) error {
// 将[]uint8转换为string
if v, ok := value.(int); ok {
s.SysService = v
log.Printf("Set request for OID 1.3.6.1.2.1.1.7.0 with value %v", s.SysService)
return nil
}
return nil
},
},
}
// 获取mibImps.All()返回的OID列表
mibOIDs := mibImps.All()
// 使用Map来检测并移除重复的OID
oidMap := make(map[string]*GoSNMPServer.PDUValueControlItem)
for _, oid := range customOIDs {
oidMap[oid.OID] = oid
}
for _, oid := range mibOIDs {
if _, exists := oidMap[oid.OID]; !exists {
oidMap[oid.OID] = oid
} else {
log.Printf("Duplicate OID found: %s", oid.OID)
}
}
// 将Map转换为Slice
allOIDs := make([]*GoSNMPServer.PDUValueControlItem, 0, len(oidMap))
for _, oid := range oidMap {
allOIDs = append(allOIDs, oid)
}
return allOIDs
}
func (s *SNMPService) StartTrapServer() {
flag.Usage = func() {
fmt.Printf("Usage:\n")
fmt.Printf(" %s\n", filepath.Base(os.Args[0]))
flag.PrintDefaults()
}
tl := g.NewTrapListener()
tl.OnNewTrap = s.MyTrapHandler
usmTable := g.NewSnmpV3SecurityParametersTable(g.NewLogger(log.New(os.Stdout, "", 0)))
for i := range s.setSecParamsList() {
sp := &s.setSecParamsList()[i] // 使用指针
err := usmTable.Add(sp.UserName, sp)
if err != nil {
usmTable.Logger.Print(err)
}
}
// 设置引擎启动次数和引varvar
//var engineBoots uint32 = 1
// var engineTime uint32 = uint32(time.Now().Unix() % 2147483647) // 使用当前时间初始化
//var engineTime uint32 = 3600 // 使用当前时间初始化
gs := &g.GoSNMP{
Target: s.TrapTarget,
Port: s.TrapPort,
Transport: "udp",
Timeout: time.Duration(s.TimeOut) * time.Second, // 设置超时时间为x秒
Version: g.Version3, // Always using version3 for traps, only option that works with all SNMP versions simultaneously
MsgFlags: g.NoAuthNoPriv,
SecurityModel: g.UserSecurityModel,
SecurityParameters: &g.UsmSecurityParameters{
UserName: s.UserName,
AuthoritativeEngineID: s.EngineID,
AuthoritativeEngineBoots: 1,
//AuthoritativeEngineTime: 3600,
AuthenticationProtocol: s.getAuthProto(),
AuthenticationPassphrase: s.AuthPass,
PrivacyProtocol: s.getPrivProto(),
PrivacyPassphrase: s.PrivPass,
},
//TrapSecurityParametersTable: usmTable,
ContextEngineID: s.EngineID,
ContextName: "v3test",
}
tl.Params = gs
tl.Params.Logger = g.NewLogger(log.New(os.Stdout, "", 0))
// 定时发送Trap
if s.TrapBool {
go s.SendPeriodicTraps(gs)
}
go s.monitorNetwork(gs)
if s.TrapListen {
err := tl.Listen(s.TrapHost)
if err != nil {
log.Panicf("error in listen: %s", err)
}
}
}
func (s *SNMPService) MyTrapHandler(packet *g.SnmpPacket, addr *net.UDPAddr) {
log.Printf("got trapdata from %s\n", addr.IP)
for _, v := range packet.Variables {
switch v.Type {
case g.OctetString:
b := v.Value.([]byte)
fmt.Printf("OID: %s, string: %x\n", v.Name, b)
default:
log.Printf("trap: %+v\n", v)
}
}
}
func (s *SNMPService) SendPeriodicTraps(gs *g.GoSNMP) {
err := gs.Connect()
if err != nil {
log.Fatalf("Connect() err: %v", err)
}
defer gs.Conn.Close()
ticker := time.NewTicker(time.Duration(s.TrapTick) * time.Second) // 每10秒发送一次Trap
defer ticker.Stop()
for range ticker.C { // 每x秒发送一次Trap
trap := g.SnmpTrap{
Variables: []g.SnmpPDU{
{
Name: ".1.3.6.1.2.1.1.3.0",
Type: g.TimeTicks,
Value: uint32(time.Now().Unix()),
},
{
Name: ".1.3.6.1.6.3.1.1.4.1.0",
Type: g.ObjectIdentifier,
Value: ".1.3.6.1.6.3.1.1.5.1",
},
},
}
_, err = gs.SendTrap(trap)
if err != nil {
log.Printf("error sending trap: %s", err)
} else {
log.Printf("trap sent successfully")
}
}
}
// 1. 设备链路连接失败时发送Trap (LinkDown)
func (s *SNMPService) sendLinkDownTrap(gs *g.GoSNMP, ifIndex int, ifDescr string) {
trap := g.SnmpTrap{
Variables: []g.SnmpPDU{
{
Name: ".1.3.6.1.2.1.2.2.1.1", // ifIndex
Type: g.Integer,
Value: ifIndex,
},
{
Name: ".1.3.6.1.2.1.2.2.1.2", // ifDescr
Type: g.OctetString,
Value: ifDescr,
},
{
Name: ".1.3.6.1.6.3.1.1.5.3", // linkDown
Type: g.ObjectIdentifier,
Value: ".1.3.6.1.6.3.1.1.5.3",
},
},
}
_, err := gs.SendTrap(trap)
if err != nil {
log.Printf("error sending LinkDown trap: %s", err)
} else {
log.Printf("LinkDown trap sent successfully")
}
}
// 2. 设备链路恢复正常时发送Trap (LinkUp)
func (s *SNMPService) sendLinkUpTrap(gs *g.GoSNMP, ifIndex int, ifDescr string) {
trap := g.SnmpTrap{
Variables: []g.SnmpPDU{
{
Name: ".1.3.6.1.2.1.2.2.1.1", // ifIndex
Type: g.Integer,
Value: ifIndex,
},
{
Name: ".1.3.6.1.2.1.2.2.1.2", // ifDescr
Type: g.OctetString,
Value: ifDescr,
},
{
Name: ".1.3.6.1.6.3.1.1.5.4", // linkUp
Type: g.ObjectIdentifier,
Value: ".1.3.6.1.6.3.1.1.5.4",
},
},
}
_, err := gs.SendTrap(trap)
if err != nil {
log.Printf("error sending LinkUp trap: %s", err)
} else {
log.Printf("LinkUp trap sent successfully")
}
}
// 3. 设备鉴权失败时发送Trap (AuthenticationFailure)
func (s *SNMPService) sendAuthFailureTrap(gs *g.GoSNMP, username, descr string) {
trap := g.SnmpTrap{
Variables: []g.SnmpPDU{
{
Name: ".1.3.6.1.6.3.1.1.5.5", // authenticationFailure
Type: g.ObjectIdentifier,
Value: ".1.3.6.1.6.3.1.1.5.5",
},
{
Name: ".1.3.6.1.4.1.2021.251.1", // 自定义OID用于记录失败的用户名
Type: g.OctetString,
Value: username,
},
{
Name: ".1.3.6.1.4.1.2021.252.1", // 自定义OID用于记录描述
Type: g.OctetString,
Value: descr,
},
},
}
_, err := gs.SendTrap(trap)
if err != nil {
log.Printf("error sending AuthenticationFailure trap: %s", err)
} else {
log.Printf("AuthenticationFailure trap sent successfully")
}
}
func (s *SNMPService) monitorNetwork(gs *g.GoSNMP) {
// 假设有一个函数 checkLinkStatus 返回链路状态
for {
serviceStatus := s.checkServiceStatus()
switch strings.ToUpper(serviceStatus) {
case "LINK_DOWN":
index := 1
ifDescr := fmt.Sprintf("Link(index=%d) DOWN", index)
s.sendLinkDownTrap(gs, index, ifDescr) // 假设接口索引为1
s.SysService = 0
case "LINK_UP":
index := 1
ifDescr := fmt.Sprintf("Link(index=%d) UP", index)
s.sendLinkUpTrap(gs, index, ifDescr) // 假设接口索引为1
s.SysService = 0
case "AUTH_FAILURE":
descr := "Authentication Failure"
s.sendAuthFailureTrap(gs, s.UserName, descr)
s.SysService = 0
default:
}
time.Sleep(10 * time.Second) // 每10秒检查一次
}
}
func (s *SNMPService) checkServiceStatus() string {
switch s.SysService {
case 1:
return "LINK_DOWN"
case 2:
return "LINK_UP"
case 3:
return "AUTH_FAILURE"
default:
}
return "NORMAL"
}

View File

@@ -1,13 +1,16 @@
package main
import (
"errors"
"bufio"
"fmt"
"io"
"net"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
"be.ems/lib/dborm"
"be.ems/lib/global"
@@ -15,15 +18,23 @@ import (
"be.ems/lib/mmlp"
"be.ems/sshsvc/config"
"be.ems/sshsvc/logmml"
"be.ems/sshsvc/snmp"
omctelnet "be.ems/sshsvc/telnet"
//"github.com/gliderlabs/ssh"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
var connNum int = 0
var conf *config.YamlConfig
var (
telnetCC int
sshCC int
telnetMu sync.Mutex
sshMu sync.Mutex
)
func init() {
conf = config.GetYamlConfig()
log.InitLogger(conf.Logger.File, conf.Logger.Duration, conf.Logger.Count, "omc:sshsvc", config.GetLogLevel())
@@ -57,33 +68,32 @@ func main() {
serverConfig := &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
// 这里可以进行密码验证逻辑,例如检查用户名和密码是否匹配
validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt)
if err != nil {
return nil, err
}
if validUser == true {
sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID
sourceAddr := conn.RemoteAddr().String()
timeOut := uint32(conf.Sshd.Timeout)
sessionMode := conf.Sshd.Session
log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr)
affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode)
if err != nil {
log.Error("Failed to insert Session table:", err)
return nil, err
}
if affected == -1 {
err := errors.New("Failed to get session")
log.Error(err)
return nil, err
}
// validUser, _, err := dborm.XormCheckLoginUser(conn.User(), string(password), conf.OMC.UserCrypt)
// if err != nil {
// return nil, err
// }
// if validUser == true {
// sessionToken := fmt.Sprintf("%x", conn.SessionID()) // Generate new token to session ID
// sourceAddr := conn.RemoteAddr().String()
// timeOut := uint32(conf.Sshd.Timeout)
// sessionMode := conf.Sshd.Session
// log.Debugf("token:%s sourceAddr:%s", sessionToken, sourceAddr)
// affected, err := dborm.XormInsertSession(conn.User(), sourceAddr, sessionToken, timeOut, sessionMode)
// if err != nil {
// log.Error("Failed to insert Session table:", err)
// return nil, err
// }
// if affected == -1 {
// err := errors.New("Failed to get session")
// log.Error(err)
// return nil, err
// }
return nil, nil
}
// if conn.User() == "admin" && string(password) == "123456" {
// return nil, nil
// }
if handleAuth(conf.Sshd.AuthType, conn.User(), string(password)) {
return nil, nil
}
return nil, fmt.Errorf("invalid user or password")
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
@@ -101,8 +111,58 @@ func main() {
log.Fatal("Failed to Listen: ", err)
os.Exit(4)
}
//fmt.Printf("MML SSH server startup, listen port%d\n", conf.Sshd.ListenPort)
// 启动telnet服务器
//telnetUri := fmt.Sprintf("%s:%d", conf.TelnetServer.ListenAddr, conf.TelnetServer.ListenPort)
// telnetListener, err := net.Listen("tcp", telnetUri)
// if err != nil {
// log.Fatal("Failed to Listen: ", err)
// os.Exit(4)
// }
//fmt.Printf("MML Telnet server startup, listen port%d\n", conf.TelnetServer.ListenPort)
// telnetconn, err := telnetListener.Accept()
// if err != nil {
// log.Fatal("Failed to accept telnet connection: ", err)
// os.Exit(6)
// }
fmt.Printf("MML SSH server startup, listen port%d\n", conf.Sshd.ListenPort)
telnetSvc := omctelnet.TelnetHandler{
ListenAddr: conf.TelnetServer.ListenAddr,
ListenPort: conf.TelnetServer.ListenPort,
UserName: conf.TelnetServer.UserName,
Password: conf.TelnetServer.Password,
AuthType: conf.TelnetServer.AuthType,
MaxConnNum: conf.TelnetServer.MaxConnNum,
TagNE: conf.TelnetServer.TagNE,
ListenHost: conf.TelnetServer.ListenAddr + ":" + strconv.Itoa(int(conf.TelnetServer.ListenPort)),
}
go telnetSvc.StartTelnetServer()
// go StartTelnetServer(telnetSvc.ListenHost)
snmpSvc := snmp.SNMPService{
ListenAddr: conf.SNMPServer.ListenAddr,
ListenPort: conf.SNMPServer.ListenPort,
UserName: conf.SNMPServer.UserName,
AuthPass: conf.SNMPServer.AuthPass,
AuthProto: conf.SNMPServer.AuthProto,
PrivPass: conf.SNMPServer.PrivPass,
PrivProto: conf.SNMPServer.PrivProto,
EngineID: conf.SNMPServer.EngineID,
TrapPort: conf.SNMPServer.TrapPort,
TrapListen: conf.SNMPServer.TrapListen,
TrapBool: conf.SNMPServer.TrapBool,
TrapTick: conf.SNMPServer.TrapTick,
TimeOut: conf.SNMPServer.TimeOut,
TrapTarget: conf.SNMPServer.TrapTarget,
ListenHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.ListenPort)),
TrapHost: conf.SNMPServer.ListenAddr + ":" + strconv.Itoa(int(conf.SNMPServer.TrapPort)),
SysDescr: "HLR server",
SysService: 0,
}
go snmpSvc.StartSNMPServer()
go snmpSvc.StartTrapServer()
for {
conn, err := listener.Accept()
@@ -115,6 +175,175 @@ func main() {
}
}
func handleAuth(authType, userName, password string) bool {
switch authType {
case "local":
if userName == conf.Sshd.UserName && password == conf.Sshd.Password {
return true
}
return false
case "radius":
exist, err := dborm.XEngDB().Table("OMC_PUB.sysUser").Where("userName=? AND password=md5(?)", userName, password).Exist()
if err != nil {
return false
}
return exist
case "omc":
default:
}
return false
}
func StartTelnetServer(addr string) {
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Error starting Telnet server:", err)
return
}
defer listener.Close()
fmt.Println("Telnet server started on", addr)
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Error accepting Telnet connection:", err)
continue
}
telnetMu.Lock()
if telnetCC >= int(conf.TelnetServer.MaxConnNum) {
telnetMu.Unlock()
io.WriteString(conn, "Connection limit reached. Try again later.\r\n")
conn.Close()
continue
}
telnetCC++
telnetMu.Unlock()
go handleTelnetConnection(conn)
}
}
func handleTelnetConnection(conn net.Conn) {
defer func() {
telnetMu.Lock()
telnetCC--
telnetMu.Unlock()
}()
defer conn.Close()
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
// 发送欢迎信息
writer.WriteString("Welcome to the Telnet server!\r\n")
writer.Flush()
// 请求用户名
writer.WriteString("Username: ")
writer.Flush()
user, _ := reader.ReadString('\n')
user = strings.TrimSpace(user)
// 关闭回显模式
writer.Write([]byte{255, 251, 1}) // IAC WILL ECHO
writer.Flush()
// 请求密码
writer.WriteString("Password: ")
writer.Flush()
// 读取密码并清除控制序列
var passBuilder strings.Builder
for {
b, err := reader.ReadByte()
if err != nil {
return
}
if b == '\n' || b == '\r' {
break
}
if b == 255 { // IAC
reader.ReadByte() // 忽略下一个字节
reader.ReadByte() // 忽略下一个字节
} else {
passBuilder.WriteByte(b)
}
}
pass := passBuilder.String()
// 恢复回显模式
writer.Write([]byte{255, 252, 1}) // IAC WONT ECHO
writer.Flush()
if handleAuth(conf.TelnetServer.AuthType, user, pass) {
writer.WriteString("\r\nAuthentication successful!\r\n")
writer.Flush()
HandleCommands(user, conf.TelnetServer.TagNE, reader, writer)
} else {
writer.WriteString("\r\nAuthentication failed!\r\n")
writer.Flush()
}
}
// 处理命令输
func HandleCommands(user, tag string, reader *bufio.Reader, writer *bufio.Writer) {
header := fmt.Sprintf("[%s@%s]> ", user, tag)
clearLine := "\033[2K\r" // ANSI 转义序列,用于清除当前行
for {
var commandBuilder strings.Builder
for {
b, err := reader.ReadByte()
if err != nil {
return
}
if b == '\n' || b == '\r' {
break
}
if b == '\xff' || b == '\xfe' || b == '\x01' {
continue
}
if b == 127 { // 处理退格键
if commandBuilder.Len() > 0 {
// 手动截断字符串
command := commandBuilder.String()
command = command[:len(command)-1]
commandBuilder.Reset()
commandBuilder.WriteString(command)
writer.WriteString("\b \b") // 回显退格
writer.Flush()
}
} else {
// 回显用户输入的字符
writer.WriteByte(b)
writer.Flush()
commandBuilder.WriteByte(b)
}
}
command := strings.TrimSpace(commandBuilder.String())
// 处理其他命令
switch command {
case "hello":
writer.WriteString("\r\nHello, world!\r\n")
case "time":
writer.WriteString(fmt.Sprintf("\r\nCurrent time: %s\r\n", time.Now().Format(time.RFC1123)))
case "exit", "quit":
writer.WriteString("\r\nGoodbye!\r\n")
writer.Flush()
return
case "":
default:
writer.WriteString("\r\nUnknown command\r\n")
writer.Flush()
}
writer.WriteString(clearLine + header)
writer.Flush()
}
}
func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) {
// SSH握手
sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
@@ -141,13 +370,16 @@ func handleSSHConnection(conn net.Conn, serverConfig *ssh.ServerConfig) {
continue
}
connNum++
if connNum > int(conf.Sshd.MaxConnNum) {
sshMu.Lock()
sshCC++
if sshCC > int(conf.Sshd.MaxConnNum) {
sshMu.Unlock()
log.Error("Maximum number of connections exceeded")
//conn.Write([]byte("Reach max connections"))
conn.Close()
continue
}
sshMu.Unlock()
go handleSSHChannel(conn, sshConn, channel, requests)
}
@@ -196,8 +428,10 @@ func handleSSHChannel(conn net.Conn, sshConn *ssh.ServerConn, channel ssh.Channe
}
func closeConnection(conn net.Conn) {
sshMu.Lock()
conn.Close()
connNum--
sshCC--
sshMu.Unlock()
}
func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
@@ -205,7 +439,7 @@ func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
// 检查通道是否支持终端
omcMmlVar := &mmlp.MmlVar{
Version: "16.1.1",
Version: global.Version,
Output: mmlp.DefaultFormatType,
MmlHome: conf.Sshd.MmlHome,
Limit: 50,
@@ -213,9 +447,10 @@ func handleSSHShell(sshConn *ssh.ServerConn, channel ssh.Channel) {
SessionToken: fmt.Sprintf("%x", sshConn.SessionID()),
HttpUri: conf.OMC.HttpUri,
UserAgent: config.GetDefaultUserAgent(),
TagNE: conf.Sshd.TagNE,
}
term := term.NewTerminal(channel, fmt.Sprintf("[%s@omc]> ", omcMmlVar.User))
term := term.NewTerminal(channel, fmt.Sprintf("[%s@%s]> ", omcMmlVar.User, omcMmlVar.TagNE))
// 启动交互式shell会话
for {
line, err := term.ReadLine()

195
sshsvc/telnet/telnet.go Normal file
View File

@@ -0,0 +1,195 @@
package omctelnet
import (
"bufio"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"be.ems/lib/dborm"
)
type TelnetHandler struct {
ListenAddr string
ListenPort uint16
UserName string
Password string
AuthType string
MaxConnNum uint8
TagNE string
ListenHost string
connCount int
mu sync.Mutex
}
func (t *TelnetHandler) handleTelnetAuth(authType, userName, password string) bool {
switch authType {
case "local":
if userName == t.UserName && password == t.Password {
return true
}
return false
case "radius", "omc":
exist, err := dborm.XEngDB().Table("OMC_PUB.sysUser").Where("userName=? AND password=md5(?)", userName, password).Exist()
if err != nil {
return false
}
return exist
default:
}
return false
}
func (t *TelnetHandler) StartTelnetServer() {
listener, err := net.Listen("tcp", t.ListenHost)
if err != nil {
fmt.Println("Error starting Telnet server:", err)
return
}
defer listener.Close()
fmt.Println("Telnet server started on", t.ListenHost)
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Error accepting Telnet connection:", err)
continue
}
t.mu.Lock()
if t.connCount >= int(t.MaxConnNum) {
t.mu.Unlock()
io.WriteString(conn, "Connection limit reached. Try again later.\r\n")
conn.Close()
continue
}
t.connCount++
t.mu.Unlock()
go t.handleTelnetConnection(conn)
}
}
func (t *TelnetHandler) handleTelnetConnection(conn net.Conn) {
defer func() {
t.mu.Lock()
t.connCount--
t.mu.Unlock()
}()
defer conn.Close()
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
// 发送欢迎信息
writer.WriteString("Welcome to the Telnet server!\r\n")
writer.Flush()
// 请求用户名
writer.WriteString("Username: ")
writer.Flush()
user, _ := reader.ReadString('\n')
user = strings.TrimSpace(user)
// 关闭回显模式
writer.Write([]byte{255, 251, 1}) // IAC WILL ECHO
writer.Flush()
// 请求密码
writer.WriteString("Password: ")
writer.Flush()
// 读取密码并清除控制序列
var passBuilder strings.Builder
for {
b, err := reader.ReadByte()
if err != nil {
return
}
if b == '\n' || b == '\r' {
break
}
if b == 255 { // IAC
reader.ReadByte() // 忽略下一个字节
reader.ReadByte() // 忽略下一个字节
} else {
passBuilder.WriteByte(b)
}
}
pass := passBuilder.String()
// 恢复回显模式
writer.Write([]byte{255, 252, 1}) // IAC WONT ECHO
writer.Flush()
if t.handleTelnetAuth(t.AuthType, user, pass) {
writer.WriteString("\r\nAuthentication successful!\r\n")
writer.Flush()
t.HandleCommands(user, t.TagNE, reader, writer)
} else {
writer.WriteString("\r\nAuthentication failed!\r\n")
writer.Flush()
}
}
// 处理命令输
func (t *TelnetHandler) HandleCommands(user, tag string, reader *bufio.Reader, writer *bufio.Writer) {
header := fmt.Sprintf("[%s@%s]> ", user, tag)
clearLine := "\033[2K\r" // ANSI 转义序列,用于清除当前行
for {
var commandBuilder strings.Builder
for {
b, err := reader.ReadByte()
if err != nil {
return
}
if b == '\n' || b == '\r' {
break
}
if b == '\xff' || b == '\xfe' || b == '\x01' {
continue
}
if b == 127 { // 处理退格键
if commandBuilder.Len() > 0 {
// 手动截断字符串
command := commandBuilder.String()
command = command[:len(command)-1]
commandBuilder.Reset()
commandBuilder.WriteString(command)
writer.WriteString("\b \b") // 回显退格
writer.Flush()
}
} else {
// 回显用户输入的字符
writer.WriteByte(b)
writer.Flush()
commandBuilder.WriteByte(b)
}
}
command := strings.TrimSpace(commandBuilder.String())
// 处理其他命令
switch command {
case "hello":
writer.WriteString("\r\nHello, world!\r\n")
case "time":
writer.WriteString(fmt.Sprintf("\r\nCurrent time: %s\r\n", time.Now().Format(time.RFC1123)))
case "exit", "quit":
writer.WriteString("\r\nGoodbye!\r\n")
writer.Flush()
return
case "":
default:
writer.WriteString("\r\nUnknown command\r\n")
writer.Flush()
}
writer.WriteString(clearLine + header)
writer.Flush()
}
}