From 9c66018e47ef218936e516a7336dcc2a165ee9b1 Mon Sep 17 00:00:00 2001 From: TsMask <340112800@qq.com> Date: Sat, 31 Aug 2024 12:09:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=8C=E6=AD=A5ws=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/modules/ws/controller/ws.go | 231 +++++++++++++++------- src/modules/ws/model/net_connect.go | 2 +- src/modules/ws/processor/cdr_connect.go | 40 ++++ src/modules/ws/processor/net_connect.go | 4 +- src/modules/ws/processor/ps_process.go | 2 +- src/modules/ws/processor/shell_command.go | 71 +++++++ src/modules/ws/service/ws.go | 15 +- src/modules/ws/service/ws.impl.go | 205 ++++++++++--------- src/modules/ws/service/ws_receive.go | 20 +- src/modules/ws/service/ws_receive.impl.go | 224 +++++++++++++++++---- src/modules/ws/service/ws_send.go | 2 +- src/modules/ws/service/ws_send.impl.go | 60 +++--- src/modules/ws/ws.go | 9 +- 13 files changed, 630 insertions(+), 255 deletions(-) create mode 100644 src/modules/ws/processor/shell_command.go diff --git a/src/modules/ws/controller/ws.go b/src/modules/ws/controller/ws.go index fa3d042..dc2fa5e 100644 --- a/src/modules/ws/controller/ws.go +++ b/src/modules/ws/controller/ws.go @@ -3,7 +3,6 @@ package controller import ( "encoding/json" "fmt" - "strconv" "strings" "time" @@ -21,14 +20,15 @@ import ( "github.com/gin-gonic/gin" ) -// 实例化控制层 WSController 结构体 +// NewWSController 实例化控制层 WSController 结构体 var NewWSController = &WSController{ wsService: service.NewWSImpl, wsSendService: service.NewWSSendImpl, neHostService: neService.NewNeHostImpl, + neInfoService: neService.NewNeInfoImpl, } -// WebSocket通信 +// WSController WebSocket通信 // // PATH /ws type WSController struct { @@ -38,9 +38,11 @@ type WSController struct { wsSendService service.IWSSend // 网元主机连接服务 neHostService neService.INeHost + // 网元信息服务 + neInfoService neService.INeInfo } -// 通用 +// WS 通用 // // GET /?subGroupIDs=0 func (s *WSController) WS(c *gin.Context) { @@ -72,19 +74,21 @@ func (s *WSController) WS(c *gin.Context) { } defer conn.Close() - wsClient := s.wsService.NewClient(loginUser.UserID, subGroupIDs, conn, nil) + wsClient := s.wsService.ClientCreate(loginUser.UserID, subGroupIDs, conn, nil) + go s.wsService.ClientWriteListen(wsClient) + go s.wsService.ClientReadListen(wsClient, service.ReceiveCommont) // 等待停止信号 for value := range wsClient.StopChan { - s.wsService.CloseClient(wsClient.ID) + s.wsService.ClientClose(wsClient.ID) logger.Infof("ws Stop Client UID %s %s", wsClient.BindUid, value) return } } -// 测试 +// Test 测试 // -// GET /test?clientId=&groupID= +// GET /test?clientId=xxx&groupID=xxx func (s *WSController) Test(c *gin.Context) { language := ctx.AcceptLanguage(c) @@ -116,11 +120,26 @@ func (s *WSController) Test(c *gin.Context) { c.JSON(200, result.OkData(errMsgArr)) } -// SSH终端 +// SSH 终端 // // GET /ssh?hostId=1&cols=80&rows=40 func (s *WSController) SSH(c *gin.Context) { language := ctx.AcceptLanguage(c) + var query struct { + HostId string `form:"hostId" binding:"required"` // 连接主机ID + Cols int `form:"cols"` // 终端单行字符数 + Rows int `form:"rows"` // 终端显示行数 + } + if err := c.ShouldBindQuery(&query); err != nil { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + if query.Cols < 80 || query.Cols > 400 { + query.Cols = 80 + } + if query.Rows < 40 || query.Rows > 1200 { + query.Rows = 40 + } // 登录用户信息 loginUser, err := ctx.LoginUser(c) @@ -129,14 +148,8 @@ func (s *WSController) SSH(c *gin.Context) { return } - // 连接主机ID - hostId := c.Query("hostId") - if hostId == "" { - c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) - return - } - neHost := s.neHostService.SelectById(hostId) - if neHost.HostID != hostId || neHost.HostType != "ssh" { + neHost := s.neHostService.SelectById(query.HostId) + if neHost.HostID != query.HostId || neHost.HostType != "ssh" { // 没有可访问主机信息数据! c.JSON(200, result.ErrMsg(i18n.TKey(language, "neHost.noData"))) return @@ -159,19 +172,8 @@ func (s *WSController) SSH(c *gin.Context) { } defer client.Close() - // 终端单行字符数 - cols, err := strconv.Atoi(c.Query("cols")) - if err != nil { - cols = 80 - } - // 终端显示行数 - rows, err := strconv.Atoi(c.Query("rows")) - if err != nil { - rows = 40 - } - // 创建SSH客户端会话 - clientSession, err := client.NewClientSession(cols, rows) + clientSession, err := client.NewClientSession(query.Cols, query.Rows) if err != nil { // 连接主机失败,请检查连接参数后重试 c.JSON(200, result.ErrMsg(i18n.TKey(language, "neHost.errByHostInfo"))) @@ -186,18 +188,21 @@ func (s *WSController) SSH(c *gin.Context) { } defer wsConn.Close() - wsClient := s.wsService.NewClient(loginUser.UserID, nil, wsConn, clientSession) + wsClient := s.wsService.ClientCreate(loginUser.UserID, nil, wsConn, clientSession) + go s.wsService.ClientWriteListen(wsClient) + go s.wsService.ClientReadListen(wsClient, service.ReceiveShell) // 实时读取SSH消息直接输出 msTicker := time.NewTicker(100 * time.Millisecond) defer msTicker.Stop() - go func() { - for ms := range msTicker.C { + for { + select { + case ms := <-msTicker.C: outputByte := clientSession.Read() if len(outputByte) > 0 { outputStr := string(outputByte) msgByte, _ := json.Marshal(result.Ok(map[string]any{ - "requestId": fmt.Sprintf("ssh_%s_%d", hostId, ms.UnixMilli()), + "requestId": fmt.Sprintf("ssh_%s_%d", neHost.HostID, ms.UnixMilli()), "data": outputStr, })) wsClient.MsgChan <- msgByte @@ -209,22 +214,34 @@ func (s *WSController) SSH(c *gin.Context) { // return // } } + case <-wsClient.StopChan: // 等待停止信号 + s.wsService.ClientClose(wsClient.ID) + logger.Infof("ws Stop Client UID %s", wsClient.BindUid) + return } - }() - - // 等待停止信号 - for value := range wsClient.StopChan { - s.wsService.CloseClient(wsClient.ID) - logger.Infof("ws Stop Client UID %s %s", wsClient.BindUid, value) - return } } -// Telnet终端 +// Telnet 终端 // // GET /telnet?hostId=1 func (s *WSController) Telnet(c *gin.Context) { language := ctx.AcceptLanguage(c) + var query struct { + HostId string `form:"hostId" binding:"required"` // 连接主机ID + Cols int `form:"cols"` // 终端单行字符数 + Rows int `form:"rows"` // 终端显示行数 + } + if err := c.ShouldBindQuery(&query); err != nil { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + if query.Cols < 120 || query.Cols > 400 { + query.Cols = 120 + } + if query.Rows < 128 || query.Rows > 1200 { + query.Rows = 128 + } // 登录用户信息 loginUser, err := ctx.LoginUser(c) @@ -233,14 +250,8 @@ func (s *WSController) Telnet(c *gin.Context) { return } - // 连接主机ID - hostId := c.Query("hostId") - if hostId == "" { - c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) - return - } - neHost := s.neHostService.SelectById(hostId) - if neHost.HostID != hostId || neHost.HostType != "telnet" { + neHost := s.neHostService.SelectById(query.HostId) + if neHost.HostID != query.HostId || neHost.HostType != "telnet" { // 没有可访问主机信息数据! c.JSON(200, result.ErrMsg(i18n.TKey(language, "neHost.noData"))) return @@ -256,20 +267,8 @@ func (s *WSController) Telnet(c *gin.Context) { return } defer client.Close() - - // 终端单行字符数 - cols, err := strconv.Atoi(c.DefaultQuery("cols", "120")) - if err != nil { - cols = 120 - } - // 终端显示行数 - rows, err := strconv.Atoi(c.DefaultQuery("rows", "128")) - if err != nil { - rows = 128 - } - // 创建Telnet客户端会话 - clientSession, err := client.NewClientSession(cols, rows) + clientSession, err := client.NewClientSession(query.Cols, query.Rows) if err != nil { // 连接主机失败,请检查连接参数后重试 c.JSON(200, result.ErrMsg(i18n.TKey(language, "neHost.errByHostInfo"))) @@ -284,18 +283,25 @@ func (s *WSController) Telnet(c *gin.Context) { } defer wsConn.Close() - wsClient := s.wsService.NewClient(loginUser.UserID, nil, wsConn, clientSession) + wsClient := s.wsService.ClientCreate(loginUser.UserID, nil, wsConn, clientSession) + go s.wsService.ClientWriteListen(wsClient) + go s.wsService.ClientReadListen(wsClient, service.ReceiveTelnet) + + // 等待1秒,排空首次消息 + time.Sleep(1 * time.Second) + _ = clientSession.Read() // 实时读取Telnet消息直接输出 msTicker := time.NewTicker(100 * time.Millisecond) defer msTicker.Stop() - go func() { - for ms := range msTicker.C { + for { + select { + case ms := <-msTicker.C: outputByte := clientSession.Read() if len(outputByte) > 0 { outputStr := strings.TrimRight(string(outputByte), "\x00") msgByte, _ := json.Marshal(result.Ok(map[string]any{ - "requestId": fmt.Sprintf("telnet_%s_%d", hostId, ms.UnixMilli()), + "requestId": fmt.Sprintf("telnet_%s_%d", neHost.HostID, ms.UnixMilli()), "data": outputStr, })) wsClient.MsgChan <- msgByte @@ -307,13 +313,92 @@ func (s *WSController) Telnet(c *gin.Context) { // return // } } + case <-wsClient.StopChan: // 等待停止信号 + s.wsService.ClientClose(wsClient.ID) + logger.Infof("ws Stop Client UID %s", wsClient.BindUid) + return + } + } +} + +// ShellView 终端交互式文件内容查看 +// +// GET /view +func (s *WSController) ShellView(c *gin.Context) { + language := ctx.AcceptLanguage(c) + var query struct { + NeType string `form:"neType" binding:"required"` + NeId string `form:"neId" binding:"required"` + Cols int `form:"cols"` // 终端单行字符数 + Rows int `form:"rows"` // 终端显示行数 + } + if err := c.ShouldBindQuery(&query); err != nil { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + if query.Cols < 120 || query.Cols > 400 { + query.Cols = 120 + } + if query.Rows < 40 || query.Rows > 1200 { + query.Rows = 40 + } + + // 登录用户信息 + loginUser, err := ctx.LoginUser(c) + if err != nil { + c.JSON(401, result.CodeMsg(401, i18n.TKey(language, err.Error()))) + return + } + + // 网元主机的SSH客户端 + sshClient, err := s.neInfoService.NeRunSSHClient(query.NeType, query.NeId) + if err != nil { + c.JSON(200, result.ErrMsg(err.Error())) + return + } + defer sshClient.Close() + // ssh连接会话 + clientSession, err := sshClient.NewClientSession(query.Cols, query.Rows) + if err != nil { + c.JSON(200, result.ErrMsg("neinfo ssh client session new err")) + return + } + defer clientSession.Close() + + // 将 HTTP 连接升级为 WebSocket 连接 + wsConn := s.wsService.UpgraderWs(c.Writer, c.Request) + if wsConn == nil { + return + } + defer wsConn.Close() + + wsClient := s.wsService.ClientCreate(loginUser.UserID, nil, wsConn, clientSession) + go s.wsService.ClientWriteListen(wsClient) + go s.wsService.ClientReadListen(wsClient, service.ReceiveShellView) + + // 等待1秒,排空首次消息 + time.Sleep(1 * time.Second) + _ = clientSession.Read() + + // 实时读取SSH消息直接输出 + msTicker := time.NewTicker(100 * time.Millisecond) + defer msTicker.Stop() + for { + select { + case ms := <-msTicker.C: + outputByte := clientSession.Read() + if len(outputByte) > 0 { + outputStr := string(outputByte) + msgByte, _ := json.Marshal(result.Ok(map[string]any{ + "requestId": fmt.Sprintf("view_%d", ms.UnixMilli()), + "data": outputStr, + })) + wsClient.MsgChan <- msgByte + } + case <-wsClient.StopChan: // 等待停止信号 + s.wsService.ClientClose(wsClient.ID) + logger.Infof("ws Stop Client UID %s", wsClient.BindUid) + return } - }() - - // 等待停止信号 - for value := range wsClient.StopChan { - s.wsService.CloseClient(wsClient.ID) - logger.Infof("ws Stop Client UID %s %s", wsClient.BindUid, value) - return } } diff --git a/src/modules/ws/model/net_connect.go b/src/modules/ws/model/net_connect.go index 8116c28..945ede5 100644 --- a/src/modules/ws/model/net_connect.go +++ b/src/modules/ws/model/net_connect.go @@ -1,6 +1,6 @@ package model -import "github.com/shirou/gopsutil/v3/net" +import "github.com/shirou/gopsutil/v4/net" // NetConnectData 网络连接进程数据 type NetConnectData struct { diff --git a/src/modules/ws/processor/cdr_connect.go b/src/modules/ws/processor/cdr_connect.go index f5a1eab..e96c03f 100644 --- a/src/modules/ws/processor/cdr_connect.go +++ b/src/modules/ws/processor/cdr_connect.go @@ -8,6 +8,7 @@ import ( "nms_cxy/src/framework/vo/result" neDataModel "nms_cxy/src/modules/network_data/model" neDataService "nms_cxy/src/modules/network_data/service" + neInfoService "nms_cxy/src/modules/network_element/service" ) // GetCDRConnectByIMS 获取CDR会话事件-IMS @@ -20,6 +21,13 @@ func GetCDRConnectByIMS(requestID string, data any) ([]byte, error) { return nil, fmt.Errorf("query data structure error") } + // 查询网元信息 rmUID + neInfo := neInfoService.NewNeInfoImpl.SelectNeInfoByNeTypeAndNeID(query.NeType, query.NeID) + if neInfo.NeId != query.NeID || neInfo.IP == "" { + return nil, fmt.Errorf("query neinfo not found") + } + query.RmUID = neInfo.RmUID + dataMap := neDataService.NewCDREventIMSImpl.SelectPage(query) resultByte, err := json.Marshal(result.Ok(map[string]any{ "requestId": requestID, @@ -38,6 +46,13 @@ func GetCDRConnectBySMF(requestID string, data any) ([]byte, error) { return nil, fmt.Errorf("query data structure error") } + // 查询网元信息 rmUID + neInfo := neInfoService.NewNeInfoImpl.SelectNeInfoByNeTypeAndNeID(query.NeType, query.NeID) + if neInfo.NeId != query.NeID || neInfo.IP == "" { + return nil, fmt.Errorf("query neinfo not found") + } + query.RmUID = neInfo.RmUID + dataMap := neDataService.NewCDREventSMFImpl.SelectPage(query) resultByte, err := json.Marshal(result.Ok(map[string]any{ "requestId": requestID, @@ -45,3 +60,28 @@ func GetCDRConnectBySMF(requestID string, data any) ([]byte, error) { })) return resultByte, err } + +// GetCDRConnectBySMSC 获取CDR会话事件-SMSC +func GetCDRConnectBySMSC(requestID string, data any) ([]byte, error) { + msgByte, _ := json.Marshal(data) + var query neDataModel.CDREventSMSCQuery + err := json.Unmarshal(msgByte, &query) + if err != nil { + logger.Warnf("ws processor GetCDRConnect err: %s", err.Error()) + return nil, fmt.Errorf("query data structure error") + } + + // 查询网元信息 rmUID + neInfo := neInfoService.NewNeInfoImpl.SelectNeInfoByNeTypeAndNeID(query.NeType, query.NeID) + if neInfo.NeId != query.NeID || neInfo.IP == "" { + return nil, fmt.Errorf("query neinfo not found") + } + query.RmUID = neInfo.RmUID + + dataMap := neDataService.NewCDREventSMSCImpl.SelectPage(query) + resultByte, err := json.Marshal(result.Ok(map[string]any{ + "requestId": requestID, + "data": dataMap, + })) + return resultByte, err +} diff --git a/src/modules/ws/processor/net_connect.go b/src/modules/ws/processor/net_connect.go index 383619a..2a11a04 100644 --- a/src/modules/ws/processor/net_connect.go +++ b/src/modules/ws/processor/net_connect.go @@ -9,8 +9,8 @@ import ( "nms_cxy/src/framework/vo/result" "nms_cxy/src/modules/ws/model" - "github.com/shirou/gopsutil/v3/net" - "github.com/shirou/gopsutil/v3/process" + "github.com/shirou/gopsutil/v4/net" + "github.com/shirou/gopsutil/v4/process" ) // GetNetConnections 获取网络连接进程 diff --git a/src/modules/ws/processor/ps_process.go b/src/modules/ws/processor/ps_process.go index 26c5edc..335b19a 100644 --- a/src/modules/ws/processor/ps_process.go +++ b/src/modules/ws/processor/ps_process.go @@ -13,7 +13,7 @@ import ( "nms_cxy/src/framework/vo/result" "nms_cxy/src/modules/ws/model" - "github.com/shirou/gopsutil/v3/process" + "github.com/shirou/gopsutil/v4/process" ) // GetProcessData 获取进程数据 diff --git a/src/modules/ws/processor/shell_command.go b/src/modules/ws/processor/shell_command.go new file mode 100644 index 0000000..d356bba --- /dev/null +++ b/src/modules/ws/processor/shell_command.go @@ -0,0 +1,71 @@ +package processor + +import ( + "encoding/json" + "fmt" + "strings" + + "nms_cxy/src/framework/logger" +) + +// ParseCat 解析拼装cat命令 +func ParseCat(reqData any) (string, error) { + msgByte, _ := json.Marshal(reqData) + var data struct { + FilePath string `json:"filePath"` // 文件地址 + ShowNumber bool `json:"showNumber"` // 显示文件的行号,从 1 开始 + ShowAll bool `json:"showAll"` // 结合 -vET 参数,显示所有特殊字符,包括行尾符、制表符等 + } + if err := json.Unmarshal(msgByte, &data); err != nil { + logger.Warnf("ws processor ParseCat err: %s", err.Error()) + return "", fmt.Errorf("query data structure error") + } + if data.FilePath == "" { + return "", fmt.Errorf("query data filePath empty") + } + + command := []string{"cat"} + if data.ShowNumber { + command = append(command, "-n") + } + if data.ShowAll { + command = append(command, "-A") + } + + command = append(command, data.FilePath) + command = append(command, "\n") + return strings.Join(command, " "), nil +} + +// ParseTail 解析拼装tail命令 +func ParseTail(reqData any) (string, error) { + msgByte, _ := json.Marshal(reqData) + var data struct { + FilePath string `json:"filePath"` // 文件地址 + Lines int `json:"lines"` // 显示文件末尾的指定行数 + Char int `json:"char"` // 显示文件末尾的指定字数 + Follow bool `json:"follow"` // 输出文件末尾的内容,并继续监视文件的新增内容 + } + if err := json.Unmarshal(msgByte, &data); err != nil { + logger.Warnf("ws processor ParseTail err: %s", err.Error()) + return "", fmt.Errorf("query data structure error") + } + if data.FilePath == "" { + return "", fmt.Errorf("query data filePath empty") + } + + command := []string{"tail"} + if data.Follow { + command = append(command, "-f") + } + if data.Lines > 0 { + command = append(command, fmt.Sprintf("-n %d", data.Lines)) + } + if data.Char > 0 { + command = append(command, fmt.Sprintf("-c %d", data.Char)) + } + + command = append(command, data.FilePath) + command = append(command, "\n") + return strings.Join(command, " "), nil +} diff --git a/src/modules/ws/service/ws.go b/src/modules/ws/service/ws.go index f88ea92..5e34eb0 100644 --- a/src/modules/ws/service/ws.go +++ b/src/modules/ws/service/ws.go @@ -13,14 +13,21 @@ type IWS interface { // UpgraderWs http升级ws请求 UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn - // NewClient 新建客户端 + // ClientCreate 客户端新建 // // uid 登录用户ID // groupIDs 用户订阅组 // conn ws连接实例 // childConn 子连接实例 - NewClient(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient + ClientCreate(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient - // CloseClient 关闭客户端 - CloseClient(clientID string) + // ClientClose 客户端关闭 + ClientClose(clientID string) + + // ClientReadListen 客户端读取消息监听 + // receiveType 根据接收类型进行消息处理 + ClientReadListen(wsClient *model.WSClient, receiveType int) + + // ClientWriteListen 客户端写入消息监听 + ClientWriteListen(wsClient *model.WSClient) } diff --git a/src/modules/ws/service/ws.impl.go b/src/modules/ws/service/ws.impl.go index eecc24f..6d20b14 100644 --- a/src/modules/ws/service/ws.impl.go +++ b/src/modules/ws/service/ws.impl.go @@ -15,15 +15,12 @@ import ( ) var ( - // ws客户端 [clientId: client] - WsClients sync.Map - // ws用户对应的多个客户端id [uid:clientIds] - WsUsers sync.Map - // ws组对应的多个用户id [groupID:uids] - WsGroup sync.Map + wsClients sync.Map // ws客户端 [clientId: client] + wsUsers sync.Map // ws用户对应的多个客户端id [uid:clientIds] + wsGroup sync.Map // ws组对应的多个客户端id [groupId:clientIds] ) -// 实例化服务层 WSImpl 结构体 +// NewWSImpl 实例化服务层 WSImpl 结构体 var NewWSImpl = &WSImpl{} // WSImpl WebSocket通信 服务层处理 @@ -51,13 +48,13 @@ func (s *WSImpl) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.C return conn } -// NewClient 新建客户端 +// ClientCreate 客户端新建 // // uid 登录用户ID // groupIDs 用户订阅组 // conn ws连接实例 // childConn 子连接实例 -func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient { +func (s *WSImpl) ClientCreate(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient { // clientID也可以用其他方式生成,只要能保证在所有服务端中都能保证唯一即可 clientID := generate.Code(16) @@ -73,122 +70,52 @@ func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn, } // 存入客户端 - WsClients.Store(clientID, wsClient) + wsClients.Store(clientID, wsClient) // 存入用户持有客户端 if uid != "" { - if v, ok := WsUsers.Load(uid); ok { + if v, ok := wsUsers.Load(uid); ok { uidClientIds := v.(*[]string) *uidClientIds = append(*uidClientIds, clientID) } else { - WsUsers.Store(uid, &[]string{clientID}) + wsUsers.Store(uid, &[]string{clientID}) } } // 存入用户订阅组 if uid != "" && len(groupIDs) > 0 { for _, groupID := range groupIDs { - if v, ok := WsGroup.Load(groupID); ok { - groupUIDs := v.(*[]string) - // 避免同组内相同用户 - hasUid := false - for _, uidv := range *groupUIDs { - if uidv == uid { - hasUid = true - break - } - } - if !hasUid { - *groupUIDs = append(*groupUIDs, uid) - } + if v, ok := wsGroup.Load(groupID); ok { + groupClientIds := v.(*[]string) + *groupClientIds = append(*groupClientIds, clientID) } else { - WsGroup.Store(groupID, &[]string{uid}) + wsGroup.Store(groupID, &[]string{clientID}) } } } - go s.clientRead(wsClient) - go s.clientWrite(wsClient) - - // 发客户端id确认是否连接 - msgByte, _ := json.Marshal(result.OkData(map[string]string{ - "clientId": clientID, - })) - wsClient.MsgChan <- msgByte - return wsClient } -// clientRead 客户端读取消息 -func (s *WSImpl) clientRead(wsClient *model.WSClient) { - defer func() { - if err := recover(); err != nil { - logger.Errorf("ws ReadMessage Panic Error: %v", err) - } - }() - for { - // 读取消息 - messageType, msg, err := wsClient.Conn.ReadMessage() - if err != nil { - logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error()) - s.CloseClient(wsClient.ID) - return - } - // fmt.Println(messageType, string(msg)) - - // 文本和二进制类型,只处理文本json - if messageType == websocket.TextMessage { - var reqMsg model.WSRequest - err := json.Unmarshal(msg, &reqMsg) - if err != nil { - msgByte, _ := json.Marshal(result.ErrMsg("message format not supported")) - wsClient.MsgChan <- msgByte - } else { - // 协程异步处理 - go NewWSReceiveImpl.AsyncReceive(wsClient, reqMsg) - } - } - } -} - -// clientWrite 客户端写入消息 -func (s *WSImpl) clientWrite(wsClient *model.WSClient) { - defer func() { - if err := recover(); err != nil { - logger.Errorf("ws WriteMessage Panic Error: %v", err) - } - }() - for msg := range wsClient.MsgChan { - // 发送消息 - err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg) - if err != nil { - logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error()) - s.CloseClient(wsClient.ID) - return - } - wsClient.LastHeartbeat = time.Now().UnixMilli() - } -} - -// CloseClient 客户端关闭 -func (s *WSImpl) CloseClient(clientID string) { - v, ok := WsClients.Load(clientID) +// ClientClose 客户端关闭 +func (s *WSImpl) ClientClose(clientID string) { + v, ok := wsClients.Load(clientID) if !ok { return } client := v.(*model.WSClient) defer func() { - client.Conn.WriteMessage(websocket.CloseMessage, []byte{}) - client.Conn.Close() - WsClients.Delete(clientID) + client.MsgChan <- []byte("ws:close") client.StopChan <- struct{}{} + client.Conn.Close() + wsClients.Delete(clientID) }() // 客户端断线时自动踢出Uid绑定列表 if client.BindUid != "" { - if clientIds, ok := WsUsers.Load(client.BindUid); ok { - uidClientIds := clientIds.(*[]string) + if v, ok := wsUsers.Load(client.BindUid); ok { + uidClientIds := v.(*[]string) if len(*uidClientIds) > 0 { tempClientIds := make([]string, 0, len(*uidClientIds)) for _, v := range *uidClientIds { @@ -202,23 +129,93 @@ func (s *WSImpl) CloseClient(clientID string) { } // 客户端断线时自动踢出已加入的组 - if client.BindUid != "" && len(client.SubGroup) > 0 { + if len(client.SubGroup) > 0 { for _, groupID := range client.SubGroup { - uids, ok := WsGroup.Load(groupID) + v, ok := wsGroup.Load(groupID) if !ok { continue } - - groupUIDs := uids.(*[]string) - if len(*groupUIDs) > 0 { - tempUIDs := make([]string, 0, len(*groupUIDs)) - for _, v := range *groupUIDs { - if v != client.BindUid { - tempUIDs = append(tempUIDs, v) + groupClientIds := v.(*[]string) + if len(*groupClientIds) > 0 { + tempClientIds := make([]string, 0, len(*groupClientIds)) + for _, v := range *groupClientIds { + if v != client.ID { + tempClientIds = append(tempClientIds, v) } } - *groupUIDs = tempUIDs + *groupClientIds = tempClientIds } } } } + +// ClientReadListen 客户端读取消息监听 +// receiveType 根据接收类型进行消息处理 +func (s *WSImpl) ClientReadListen(wsClient *model.WSClient, receiveType int) { + defer func() { + if err := recover(); err != nil { + logger.Errorf("ws ReadMessage Panic Error: %v", err) + } + }() + for { + // 读取消息 + messageType, msg, err := wsClient.Conn.ReadMessage() + if err != nil { + logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error()) + s.ClientClose(wsClient.ID) + return + } + // fmt.Println(messageType, string(msg)) + + // 文本 只处理文本json + if messageType == websocket.TextMessage { + var reqMsg model.WSRequest + if err := json.Unmarshal(msg, &reqMsg); err != nil { + msgByte, _ := json.Marshal(result.ErrMsg("message format json error")) + wsClient.MsgChan <- msgByte + continue + } + // 接收器处理 + switch receiveType { + case ReceiveCommont: + go NewWSReceiveImpl.Commont(wsClient, reqMsg) + case ReceiveShell: + go NewWSReceiveImpl.Shell(wsClient, reqMsg) + case ReceiveShellView: + go NewWSReceiveImpl.ShellView(wsClient, reqMsg) + case ReceiveTelnet: + go NewWSReceiveImpl.Telnet(wsClient, reqMsg) + } + } + } +} + +// ClientWriteListen 客户端写入消息监听 +func (s *WSImpl) ClientWriteListen(wsClient *model.WSClient) { + defer func() { + if err := recover(); err != nil { + logger.Errorf("ws WriteMessage Panic Error: %v", err) + } + }() + // 发客户端id确认是否连接 + msgByte, _ := json.Marshal(result.OkData(map[string]string{ + "clientId": wsClient.ID, + })) + wsClient.MsgChan <- msgByte + // 消息发送监听 + for msg := range wsClient.MsgChan { + // 关闭句柄 + if string(msg) == "ws:close" { + wsClient.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + // 发送消息 + err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg) + if err != nil { + logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error()) + s.ClientClose(wsClient.ID) + return + } + wsClient.LastHeartbeat = time.Now().UnixMilli() + } +} diff --git a/src/modules/ws/service/ws_receive.go b/src/modules/ws/service/ws_receive.go index 4f11d6a..1f5d242 100644 --- a/src/modules/ws/service/ws_receive.go +++ b/src/modules/ws/service/ws_receive.go @@ -2,8 +2,24 @@ package service import "nms_cxy/src/modules/ws/model" +const ( + ReceiveCommont = iota // Commont 接收通用业务处理 + ReceiveShell // Shell 接收终端交互业务处理 + ReceiveShellView // ShellView 接收查看文件终端交互业务处理 + ReceiveTelnet // Telnet 接收终端交互业务处理 +) + // IWSReceive WebSocket消息接收处理 服务层接口 type IWSReceive interface { - // AsyncReceive 接收业务异步处理 - AsyncReceive(client *model.WSClient, reqMsg model.WSRequest) + // Commont 接收通用业务处理 + Commont(client *model.WSClient, reqMsg model.WSRequest) + + // Shell 接收终端交互业务处理 + Shell(client *model.WSClient, reqMsg model.WSRequest) + + // ShellView 接收查看文件终端交互业务处理 + ShellView(client *model.WSClient, reqMsg model.WSRequest) + + // Telnet 接收终端交互业务处理 + Telnet(client *model.WSClient, reqMsg model.WSRequest) } diff --git a/src/modules/ws/service/ws_receive.impl.go b/src/modules/ws/service/ws_receive.impl.go index 3f6dfb0..d036268 100644 --- a/src/modules/ws/service/ws_receive.impl.go +++ b/src/modules/ws/service/ws_receive.impl.go @@ -20,12 +20,22 @@ var NewWSReceiveImpl = &WSReceiveImpl{} // WSReceiveImpl WebSocket消息接收处理 服务层处理 type WSReceiveImpl struct{} -// AsyncReceive 接收业务异步处理 -func (s *WSReceiveImpl) AsyncReceive(client *model.WSClient, reqMsg model.WSRequest) { +// Commont 接收通用业务处理 +func (s *WSReceiveImpl) close(client *model.WSClient) { + // 主动关闭 + resultByte, _ := json.Marshal(result.OkMsg("user initiated closure")) + client.MsgChan <- resultByte + // 等待1s后关闭连接 + time.Sleep(1 * time.Second) + NewWSImpl.ClientClose(client.ID) +} + +// Commont 接收通用业务处理 +func (s *WSReceiveImpl) Commont(client *model.WSClient, reqMsg model.WSRequest) { // 必传requestId确认消息 if reqMsg.RequestID == "" { msg := "message requestId is required" - logger.Infof("ws AsyncReceive UID %s err: %s", client.BindUid, msg) + logger.Infof("ws Commont UID %s err: %s", client.BindUid, msg) msgByte, _ := json.Marshal(result.ErrMsg(msg)) client.MsgChan <- msgByte return @@ -36,14 +46,61 @@ func (s *WSReceiveImpl) AsyncReceive(client *model.WSClient, reqMsg model.WSRequ switch reqMsg.Type { case "close": - // 主动关闭 - resultByte, _ := json.Marshal(result.OkMsg("user initiated closure")) - client.MsgChan <- resultByte - // 等待1s后关闭连接 - time.Sleep(1 * time.Second) - client.StopChan <- struct{}{} + s.close(client) + return + case "ps": + resByte, err = processor.GetProcessData(reqMsg.RequestID, reqMsg.Data) + case "net": + resByte, err = processor.GetNetConnections(reqMsg.RequestID, reqMsg.Data) + case "ims_cdr": + resByte, err = processor.GetCDRConnectByIMS(reqMsg.RequestID, reqMsg.Data) + case "smf_cdr": + resByte, err = processor.GetCDRConnectBySMF(reqMsg.RequestID, reqMsg.Data) + case "smsc_cdr": + resByte, err = processor.GetCDRConnectBySMSC(reqMsg.RequestID, reqMsg.Data) + case "amf_ue": + resByte, err = processor.GetUEConnectByAMF(reqMsg.RequestID, reqMsg.Data) + case "mme_ue": + resByte, err = processor.GetUEConnectByMME(reqMsg.RequestID, reqMsg.Data) + case "upf_tf": + resByte, err = processor.GetUPFTotalFlow(reqMsg.RequestID, reqMsg.Data) + case "ne_state": + resByte, err = processor.GetNeState(reqMsg.RequestID, reqMsg.Data) + default: + err = fmt.Errorf("message type %s not supported", reqMsg.Type) + } + + if err != nil { + logger.Warnf("ws Commont UID %s err: %s", client.BindUid, err.Error()) + msgByte, _ := json.Marshal(result.ErrMsg(err.Error())) + client.MsgChan <- msgByte + return + } + if len(resByte) > 0 { + client.MsgChan <- resByte + } +} + +// Shell 接收终端交互业务处理 +func (s *WSReceiveImpl) Shell(client *model.WSClient, reqMsg model.WSRequest) { + // 必传requestId确认消息 + if reqMsg.RequestID == "" { + msg := "message requestId is required" + logger.Infof("ws Shell UID %s err: %s", client.BindUid, msg) + msgByte, _ := json.Marshal(result.ErrMsg(msg)) + client.MsgChan <- msgByte + return + } + + var resByte []byte + var err error + + switch reqMsg.Type { + case "close": + s.close(client) + return case "ssh": - // SSH会话消息接收直接写入会话 + // SSH会话消息接收写入会话 command := reqMsg.Data.(string) sshClientSession := client.ChildConn.(*ssh.SSHClientSession) _, err = sshClientSession.Write(command) @@ -59,33 +116,134 @@ func (s *WSReceiveImpl) AsyncReceive(client *model.WSClient, reqMsg model.WSRequ sshClientSession := client.ChildConn.(*ssh.SSHClientSession) err = sshClientSession.Session.WindowChange(data.Rows, data.Cols) } - case "telnet": - // Telnet会话消息接收直接写入会话 - command := reqMsg.Data.(string) - telnetClientSession := client.ChildConn.(*telnet.TelnetClientSession) - _, err = telnetClientSession.Write(command) - case "ps": - resByte, err = processor.GetProcessData(reqMsg.RequestID, reqMsg.Data) - case "net": - resByte, err = processor.GetNetConnections(reqMsg.RequestID, reqMsg.Data) - case "ims_cdr": - resByte, err = processor.GetCDRConnectByIMS(reqMsg.RequestID, reqMsg.Data) - case "smf_cdr": - resByte, err = processor.GetCDRConnectBySMF(reqMsg.RequestID, reqMsg.Data) - case "amf_ue": - resByte, err = processor.GetUEConnectByAMF(reqMsg.RequestID, reqMsg.Data) - case "mme_ue": - resByte, err = processor.GetUEConnectByMME(reqMsg.RequestID, reqMsg.Data) - case "upf_tf": - resByte, err = processor.GetUPFTotalFlow(reqMsg.RequestID, reqMsg.Data) - case "ne_state": - resByte, err = processor.GetNeState(reqMsg.RequestID, reqMsg.Data) default: - err = fmt.Errorf("message type not supported") + err = fmt.Errorf("message type %s not supported", reqMsg.Type) } if err != nil { - logger.Warnf("ws AsyncReceive UID %s err: %s", client.BindUid, err.Error()) + logger.Warnf("ws Shell UID %s err: %s", client.BindUid, err.Error()) + msgByte, _ := json.Marshal(result.ErrMsg(err.Error())) + client.MsgChan <- msgByte + if err == io.EOF { + // 等待1s后关闭连接 + time.Sleep(1 * time.Second) + client.StopChan <- struct{}{} + } + return + } + if len(resByte) > 0 { + client.MsgChan <- resByte + } +} + +// ShellView 接收查看文件终端交互业务处理 +func (s *WSReceiveImpl) ShellView(client *model.WSClient, reqMsg model.WSRequest) { + // 必传requestId确认消息 + if reqMsg.RequestID == "" { + msg := "message requestId is required" + logger.Infof("ws ShellView UID %s err: %s", client.BindUid, msg) + msgByte, _ := json.Marshal(result.ErrMsg(msg)) + client.MsgChan <- msgByte + return + } + + var resByte []byte + var err error + + switch reqMsg.Type { + case "close": + s.close(client) + return + case "cat", "tail": + var command string + if reqMsg.Type == "cat" { + command, err = processor.ParseCat(reqMsg.Data) + } + if reqMsg.Type == "tail" { + command, err = processor.ParseTail(reqMsg.Data) + } + if command != "" && err == nil { + sshClientSession := client.ChildConn.(*ssh.SSHClientSession) + _, err = sshClientSession.Write(command) + } + case "ctrl-c": + // 模拟按下 Ctrl+C + sshClientSession := client.ChildConn.(*ssh.SSHClientSession) + _, err = sshClientSession.Write("\u0003\n") + case "resize": + // 会话窗口重置 + msgByte, _ := json.Marshal(reqMsg.Data) + var data struct { + Cols int `json:"cols"` + Rows int `json:"rows"` + } + err = json.Unmarshal(msgByte, &data) + if err == nil { + sshClientSession := client.ChildConn.(*ssh.SSHClientSession) + err = sshClientSession.Session.WindowChange(data.Rows, data.Cols) + } + default: + err = fmt.Errorf("message type %s not supported", reqMsg.Type) + } + + if err != nil { + logger.Warnf("ws ShellView UID %s err: %s", client.BindUid, err.Error()) + msgByte, _ := json.Marshal(result.ErrMsg(err.Error())) + client.MsgChan <- msgByte + if err == io.EOF { + // 等待1s后关闭连接 + time.Sleep(1 * time.Second) + client.StopChan <- struct{}{} + } + return + } + if len(resByte) > 0 { + client.MsgChan <- resByte + } +} + +// Telnet 接收终端交互业务处理 +func (s *WSReceiveImpl) Telnet(client *model.WSClient, reqMsg model.WSRequest) { + // 必传requestId确认消息 + if reqMsg.RequestID == "" { + msg := "message requestId is required" + logger.Infof("ws Shell UID %s err: %s", client.BindUid, msg) + msgByte, _ := json.Marshal(result.ErrMsg(msg)) + client.MsgChan <- msgByte + return + } + + var resByte []byte + var err error + + switch reqMsg.Type { + case "close": + s.close(client) + return + case "telnet": + // Telnet会话消息接收写入会话 + command := reqMsg.Data.(string) + telnetClientSession := client.ChildConn.(*telnet.TelnetClientSession) + _, err = telnetClientSession.Write(command) + case "telnet_resize": + // Telnet会话窗口重置 + msgByte, _ := json.Marshal(reqMsg.Data) + var data struct { + Cols int `json:"cols"` + Rows int `json:"rows"` + } + err = json.Unmarshal(msgByte, &data) + if err == nil { + telnetClientSession := client.ChildConn.(*telnet.TelnetClientSession) + err = telnetClientSession.WindowChange(data.Rows, data.Cols) + _ = telnetClientSession.Read() + } + default: + err = fmt.Errorf("message type %s not supported", reqMsg.Type) + } + + if err != nil { + logger.Warnf("ws Shell UID %s err: %s", client.BindUid, err.Error()) msgByte, _ := json.Marshal(result.ErrMsg(err.Error())) client.MsgChan <- msgByte if err == io.EOF { diff --git a/src/modules/ws/service/ws_send.go b/src/modules/ws/service/ws_send.go index 020d022..91e48b1 100644 --- a/src/modules/ws/service/ws_send.go +++ b/src/modules/ws/service/ws_send.go @@ -5,6 +5,6 @@ type IWSSend interface { // ByClientID 给已知客户端发消息 ByClientID(clientID string, data any) error - // ByGroupID 给订阅组的用户发送消息 + // ByGroupID 给订阅组的客户端发送消息 ByGroupID(gid string, data any) error } diff --git a/src/modules/ws/service/ws_send.impl.go b/src/modules/ws/service/ws_send.impl.go index b4f7629..9be1243 100644 --- a/src/modules/ws/service/ws_send.impl.go +++ b/src/modules/ws/service/ws_send.impl.go @@ -12,18 +12,22 @@ import ( const ( // 组号-其他 GROUP_OTHER = "0" - // 组号-指标 - GROUP_KPI = "10" - // 组号-指标UPF - GROUP_KPI_UPF = "12" - // 组号-IMS_CDR会话事件 - GROUP_IMS_CDR = "1005" - // 组号-SMF_CDR会话事件 - GROUP_SMF_CDR = "1006" + // 组号-指标通用 10_neType_neId + GROUP_KPI = "10_" + // 组号-指标UPF 12_neId + GROUP_KPI_UPF = "12_" + // 组号-自定义KPI指标20_neType_neId + GROUP_KPI_C = "20_" + // 组号-IMS_CDR会话事件 1005_neId + GROUP_IMS_CDR = "1005_" + // 组号-SMF_CDR会话事件 1006_neId + GROUP_SMF_CDR = "1006_" + // 组号-SMSC_CDR会话事件 1007_neId + GROUP_SMSC_CDR = "1007_" // 组号-AMF_UE会话事件 GROUP_AMF_UE = "1010" - // 组号-MME_UE会话事件 - GROUP_MME_UE = "1011" + // 组号-MME_UE会话事件 1011_neId + GROUP_MME_UE = "1011_" ) // 实例化服务层 WSSendImpl 结构体 @@ -34,7 +38,7 @@ type WSSendImpl struct{} // ByClientID 给已知客户端发消息 func (s *WSSendImpl) ByClientID(clientID string, data any) error { - v, ok := WsClients.Load(clientID) + v, ok := wsClients.Load(clientID) if !ok { return fmt.Errorf("no fount client ID: %s", clientID) } @@ -46,43 +50,35 @@ func (s *WSSendImpl) ByClientID(clientID string, data any) error { client := v.(*model.WSClient) if len(client.MsgChan) > 90 { - NewWSImpl.CloseClient(client.ID) + NewWSImpl.ClientClose(client.ID) return fmt.Errorf("msg chan over 90 will close client ID: %s", clientID) } client.MsgChan <- dataByte return nil } -// ByGroupID 给订阅组的用户发送消息 +// ByGroupID 给订阅组的客户端发送消息 func (s *WSSendImpl) ByGroupID(groupID string, data any) error { - uids, ok := WsGroup.Load(groupID) + clientIds, ok := wsGroup.Load(groupID) if !ok { return fmt.Errorf("no fount Group ID: %s", groupID) } - groupUids := uids.(*[]string) - // 群组中没有成员 - if len(*groupUids) == 0 { + // 检查组内是否有客户端 + ids := clientIds.(*[]string) + if len(*ids) == 0 { return fmt.Errorf("no members in the group") } - // 在群组中找到对应的 uid - for _, uid := range *groupUids { - clientIds, ok := WsUsers.Load(uid) - if !ok { + // 遍历给客户端发消息 + for _, clientId := range *ids { + err := s.ByClientID(clientId, map[string]any{ + "groupId": groupID, + "data": data, + }) + if err != nil { continue } - // 在用户中找到客户端并发送 - uidClientIds := clientIds.(*[]string) - for _, clientId := range *uidClientIds { - err := s.ByClientID(clientId, map[string]any{ - "groupId": groupID, - "data": data, - }) - if err != nil { - continue - } - } } return nil diff --git a/src/modules/ws/ws.go b/src/modules/ws/ws.go index 8fc6270..c7f566c 100644 --- a/src/modules/ws/ws.go +++ b/src/modules/ws/ws.go @@ -21,6 +21,10 @@ func Setup(router *gin.Engine) { collectlogs.OperateLog(collectlogs.OptionNew("log.operate.title.ws", collectlogs.BUSINESS_TYPE_OTHER)), controller.NewWSController.WS, ) + wsGroup.GET("/test", + middleware.PreAuthorize(nil), + controller.NewWSController.Test, + ) wsGroup.GET("/ssh", middleware.PreAuthorize(nil), collectlogs.OperateLog(collectlogs.OptionNew("log.operate.title.ws", collectlogs.BUSINESS_TYPE_OTHER)), @@ -31,9 +35,10 @@ func Setup(router *gin.Engine) { collectlogs.OperateLog(collectlogs.OptionNew("log.operate.title.ws", collectlogs.BUSINESS_TYPE_OTHER)), controller.NewWSController.Telnet, ) - wsGroup.GET("/test", + wsGroup.GET("/view", middleware.PreAuthorize(nil), - controller.NewWSController.Test, + collectlogs.OperateLog(collectlogs.OptionNew("log.operate.title.ws", collectlogs.BUSINESS_TYPE_OTHER)), + controller.NewWSController.ShellView, ) } }