From 89499c9d2876d5ac43163f0485d0bb2de66ed913 Mon Sep 17 00:00:00 2001 From: TsMask <340112800@qq.com> Date: Tue, 23 Jan 2024 18:06:44 +0800 Subject: [PATCH 1/3] =?UTF-8?q?faet:=20=E6=96=B0=E5=A2=9EWS=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/app.go | 3 + src/modules/ws/controller/ws.go | 98 ++++++++++ src/modules/ws/model/ps_process.go | 38 ++++ src/modules/ws/model/ws.go | 20 +++ src/modules/ws/processor/ps_process.go | 136 ++++++++++++++ src/modules/ws/service/ws.go | 20 +++ src/modules/ws/service/ws.impl.go | 207 ++++++++++++++++++++++ src/modules/ws/service/ws_receive.go | 9 + src/modules/ws/service/ws_receive.impl.go | 30 ++++ src/modules/ws/service/ws_send.go | 10 ++ src/modules/ws/service/ws_send.impl.go | 72 ++++++++ src/modules/ws/ws.go | 30 ++++ 12 files changed, 673 insertions(+) create mode 100644 src/modules/ws/controller/ws.go create mode 100644 src/modules/ws/model/ps_process.go create mode 100644 src/modules/ws/model/ws.go create mode 100644 src/modules/ws/processor/ps_process.go create mode 100644 src/modules/ws/service/ws.go create mode 100644 src/modules/ws/service/ws.impl.go create mode 100644 src/modules/ws/service/ws_receive.go create mode 100644 src/modules/ws/service/ws_receive.impl.go create mode 100644 src/modules/ws/service/ws_send.go create mode 100644 src/modules/ws/service/ws_send.impl.go create mode 100644 src/modules/ws/ws.go diff --git a/src/app.go b/src/app.go index f422066f..713fe00c 100644 --- a/src/app.go +++ b/src/app.go @@ -15,6 +15,7 @@ import ( networkelement "ems.agt/src/modules/network_element" "ems.agt/src/modules/system" "ems.agt/src/modules/trace" + "ems.agt/src/modules/ws" "github.com/gin-gonic/gin" ) @@ -123,6 +124,8 @@ func initModulesRoute(app *gin.Engine) { trace.Setup(app) // 图表模块 chart.Setup(app) + // ws 模块 + ws.Setup(app) // 调度任务模块--暂无接口 crontask.Setup(app) // 监控模块 - 含调度处理加入队列,放最后 diff --git a/src/modules/ws/controller/ws.go b/src/modules/ws/controller/ws.go new file mode 100644 index 00000000..b93342de --- /dev/null +++ b/src/modules/ws/controller/ws.go @@ -0,0 +1,98 @@ +package controller + +import ( + "strings" + + "ems.agt/src/framework/i18n" + "ems.agt/src/framework/logger" + "ems.agt/src/framework/utils/ctx" + "ems.agt/src/framework/utils/parse" + "ems.agt/src/framework/vo/result" + "ems.agt/src/modules/ws/service" + "github.com/gin-gonic/gin" +) + +// 实例化控制层 WSController 结构体 +var NewWSController = &WSController{ + wsService: service.NewWSImpl, + wsSendService: service.NewWSSendImpl, +} + +// WebSocket通信 +// +// PATH /ws +type WSController struct { + // WebSocket 服务 + wsService service.IWS + // WebSocket消息发送 服务 + wsSendService service.IWSSend +} + +// 通用 +// +// GET /?subGroupIDs=0 +func (s *WSController) WS(c *gin.Context) { + language := ctx.AcceptLanguage(c) + + // 登录用户信息 + loginUser, err := ctx.LoginUser(c) + if err != nil { + c.JSON(401, result.CodeMsg(401, i18n.TKey(language, err.Error()))) + return + } + + // 订阅消息组 + var subGroupIDs []string + subGroupIDStr := c.Query("subGroupID") + if subGroupIDStr != "" { + // 处理字符转id数组后去重 + ids := strings.Split(subGroupIDStr, ",") + uniqueIDs := parse.RemoveDuplicates(ids) + if len(uniqueIDs) > 0 { + subGroupIDs = uniqueIDs + } + } + + // 将 HTTP 连接升级为 WebSocket 连接 + conn := s.wsService.UpgraderWs(c.Writer, c.Request) + if conn == nil { + return + } + defer conn.Close() + + wsClient := s.wsService.NewClient(loginUser.UserID, subGroupIDs, conn) + + // 等待停止信号 + for value := range wsClient.StopChan { + logger.Infof("ws Stop Client UID %s %s", wsClient.BindUid, value) + return + } +} + +// 测试 +// +// GET /test?clientId=&groupID= +func (s *WSController) Test(c *gin.Context) { + language := ctx.AcceptLanguage(c) + + // 登录用户信息 + loginUser, err := ctx.LoginUser(c) + if err != nil { + c.JSON(401, result.CodeMsg(401, i18n.TKey(language, err.Error()))) + return + } + + // err = s.wsSendService.ByClientID(c.Query("clientId"), loginUser) + // if err != nil { + // c.JSON(200, result.ErrMsg(err.Error())) + // return + // } + + err = s.wsSendService.ByGroupID(c.Query("groupID"), loginUser) + if err != nil { + c.JSON(200, result.ErrMsg(err.Error())) + return + } + + c.JSON(200, result.Ok(nil)) +} diff --git a/src/modules/ws/model/ps_process.go b/src/modules/ws/model/ps_process.go new file mode 100644 index 00000000..e93247a8 --- /dev/null +++ b/src/modules/ws/model/ps_process.go @@ -0,0 +1,38 @@ +package model + +// PsProcessData 进程数据 +type PsProcessData struct { + PID int32 `json:"PID"` + Name string `json:"name"` + PPID int32 `json:"PPID"` + Username string `json:"username"` + Status string `json:"status"` + StartTime string `json:"startTime"` + NumThreads int32 `json:"numThreads"` + NumConnections int `json:"numConnections"` + CpuPercent string `json:"cpuPercent"` + + DiskRead string `json:"diskRead"` + DiskWrite string `json:"diskWrite"` + CmdLine string `json:"cmdLine"` + + Rss string `json:"rss"` + VMS string `json:"vms"` + HWM string `json:"hwm"` + Data string `json:"data"` + Stack string `json:"stack"` + Locked string `json:"locked"` + Swap string `json:"swap"` + + CpuValue float64 `json:"cpuValue"` + RssValue uint64 `json:"rssValue"` + + Envs []string `json:"envs"` +} + +// PsProcessQuery 进程查询 +type PsProcessQuery struct { + Pid int32 `json:"pid"` + Name string `json:"name"` + Username string `json:"username"` +} diff --git a/src/modules/ws/model/ws.go b/src/modules/ws/model/ws.go new file mode 100644 index 00000000..7a0eb5bf --- /dev/null +++ b/src/modules/ws/model/ws.go @@ -0,0 +1,20 @@ +package model + +import "github.com/gorilla/websocket" + +// WSClient ws客户端 +type WSClient struct { + ID string // 连接ID-随机字符串16位 + Conn *websocket.Conn // 连接实例 + LastHeartbeat int64 // 最近一次心跳消息(毫秒) + BindUid string // 绑定登录用户ID + SubGroup []string // 订阅组ID + MsgChan chan []byte // 消息通道 + StopChan chan struct{} // 停止信号-退出协程 +} + +// WSRequest ws消息接收 +type WSRequest struct { + Type string `json:"type"` + Data any `json:"data"` +} diff --git a/src/modules/ws/processor/ps_process.go b/src/modules/ws/processor/ps_process.go new file mode 100644 index 00000000..d15d2830 --- /dev/null +++ b/src/modules/ws/processor/ps_process.go @@ -0,0 +1,136 @@ +package processor + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + "sync" + + "ems.agt/src/framework/utils/date" + "ems.agt/src/framework/utils/parse" + "ems.agt/src/modules/ws/model" + "github.com/shirou/gopsutil/v3/process" +) + +// GetProcessData 获取进程数据 +func GetProcessData(data any) ([]byte, error) { + msgByte, _ := json.Marshal(data) + var query model.PsProcessQuery + err := json.Unmarshal(msgByte, &query) + if err != nil { + return nil, err + } + + var processes []*process.Process + processes, err = process.Processes() + if err != nil { + return nil, err + } + + var ( + result = []model.PsProcessData{} + resultMutex sync.Mutex + wg sync.WaitGroup + numWorkers = 4 + ) + + handleData := func(proc *process.Process) { + procData := model.PsProcessData{ + PID: proc.Pid, + } + if query.Pid > 0 && query.Pid != proc.Pid { + return + } + procName, err := proc.Name() + if procName == "" || err != nil { + return + } else { + procData.Name = procName + } + if query.Name != "" && !strings.Contains(procData.Name, query.Name) { + return + } + if username, err := proc.Username(); err == nil { + procData.Username = username + } + if query.Username != "" && !strings.Contains(procData.Username, query.Username) { + return + } + + procData.PPID, _ = proc.Ppid() + statusArray, _ := proc.Status() + if len(statusArray) > 0 { + procData.Status = strings.Join(statusArray, ",") + } + createTime, procErr := proc.CreateTime() + if procErr == nil { + procData.StartTime = date.ParseDateToStr(createTime, date.YYYY_MM_DD_HH_MM_SS) + } + procData.NumThreads, _ = proc.NumThreads() + procData.CpuValue, _ = proc.CPUPercent() + procData.CpuPercent = fmt.Sprintf("%.2f", procData.CpuValue) + "%" + menInfo, procErr := proc.MemoryInfo() + if procErr == nil { + procData.Rss = parse.Bit(float64(menInfo.RSS)) + procData.Data = parse.Bit(float64(menInfo.Data)) + procData.VMS = parse.Bit(float64(menInfo.VMS)) + procData.HWM = parse.Bit(float64(menInfo.HWM)) + procData.Stack = parse.Bit(float64(menInfo.Stack)) + procData.Locked = parse.Bit(float64(menInfo.Locked)) + procData.Swap = parse.Bit(float64(menInfo.Swap)) + + procData.RssValue = menInfo.RSS + } else { + procData.Rss = "--" + procData.Data = "--" + procData.VMS = "--" + procData.HWM = "--" + procData.Stack = "--" + procData.Locked = "--" + procData.Swap = "--" + + procData.RssValue = 0 + } + ioStat, procErr := proc.IOCounters() + if procErr == nil { + procData.DiskWrite = parse.Bit(float64(ioStat.WriteBytes)) + procData.DiskRead = parse.Bit(float64(ioStat.ReadBytes)) + } else { + procData.DiskWrite = "--" + procData.DiskRead = "--" + } + procData.CmdLine, _ = proc.Cmdline() + procData.Envs, _ = proc.Environ() + + resultMutex.Lock() + result = append(result, procData) + resultMutex.Unlock() + } + + chunkSize := (len(processes) + numWorkers - 1) / numWorkers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + start := i * chunkSize + end := (i + 1) * chunkSize + if end > len(processes) { + end = len(processes) + } + + go func(start, end int) { + defer wg.Done() + for j := start; j < end; j++ { + handleData(processes[j]) + } + }(start, end) + } + + wg.Wait() + + sort.Slice(result, func(i, j int) bool { + return result[i].PID < result[j].PID + }) + + resultByte, err := json.Marshal(result) + return resultByte, err +} diff --git a/src/modules/ws/service/ws.go b/src/modules/ws/service/ws.go new file mode 100644 index 00000000..f404c145 --- /dev/null +++ b/src/modules/ws/service/ws.go @@ -0,0 +1,20 @@ +package service + +import ( + "net/http" + + "ems.agt/src/modules/ws/model" + "github.com/gorilla/websocket" +) + +// IWS WebSocket通信 服务层接口 +type IWS interface { + // UpgraderWs http升级ws请求 + UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn + + // NewClient 新建客户端 uid 登录用户ID + NewClient(uid string, gids []string, conn *websocket.Conn) *model.WSClient + + // CloseClient 客户端关闭 + CloseClient(clientID string) +} diff --git a/src/modules/ws/service/ws.impl.go b/src/modules/ws/service/ws.impl.go new file mode 100644 index 00000000..5add5595 --- /dev/null +++ b/src/modules/ws/service/ws.impl.go @@ -0,0 +1,207 @@ +package service + +import ( + "encoding/json" + "net/http" + "sync" + "time" + + "ems.agt/src/framework/logger" + "ems.agt/src/framework/utils/generate" + "ems.agt/src/framework/vo/result" + "ems.agt/src/modules/ws/model" + "github.com/gorilla/websocket" +) + +var ( + // ws客户端 [clientId: client] + WsClients = sync.Map{} + // ws用户对应的多个客户端id [uid:clientIds] + WsUsers = sync.Map{} + // ws组对应的多个用户id [groupID:uids] + WsGroup = sync.Map{} +) + +// 实例化服务层 WSImpl 结构体 +var NewWSImpl = &WSImpl{} + +// WSImpl WebSocket通信 服务层处理 +type WSImpl struct{} + +// UpgraderWs http升级ws请求 +func (s *WSImpl) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn { + wsUpgrader := websocket.Upgrader{ + // 设置消息发送缓冲区大小(byte),如果这个值设置得太小,可能会导致服务端在发送大型消息时遇到问题 + WriteBufferSize: 1024, + // 消息包启用压缩 + EnableCompression: true, + // ws握手超时时间 + HandshakeTimeout: 5 * time.Second, + // ws握手过程中允许跨域 + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + conn, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + logger.Errorf("ws Upgrade err: %s", err.Error()) + } + return conn +} + +// NewClient 新建客户端 uid 登录用户ID +func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn) *model.WSClient { + // clientID也可以用其他方式生成,只要能保证在所有服务端中都能保证唯一即可 + clientID := generate.Code(16) + + wsClient := &model.WSClient{ + ID: clientID, + Conn: conn, + LastHeartbeat: time.Now().UnixMilli(), + BindUid: uid, + SubGroup: groupIDs, + MsgChan: make(chan []byte, 100), + StopChan: make(chan struct{}, 1), // 请求卡死循环标记 + } + + // 存入客户端 + WsClients.Store(clientID, wsClient) + + // 存入用户持有客户端 + if uid != "" { + if v, ok := WsUsers.Load(uid); ok { + uidClientIds := v.(*[]string) + *uidClientIds = append(*uidClientIds, clientID) + } else { + WsUsers.Store(uid, &[]string{clientID}) + } + } + + // 存入用户订阅组 + if uid != "" && len(groupIDs) > 0 { + for _, groupID := range groupIDs { + if v, ok := WsGroup.Load(groupID); ok { + groupUIDs := v.(*[]string) + *groupUIDs = append(*groupUIDs, uid) + } else { + WsGroup.Store(groupID, &[]string{uid}) + } + } + } + + go s.clientRead(wsClient) + go s.clientWrite(wsClient) + + // 发客户端id确认是否连接 + msgByte, _ := json.Marshal(result.OkData(map[string]string{ + "clientId": clientID, + })) + wsClient.MsgChan <- msgByte + + return wsClient +} + +// clientRead 客户端读取消息 +func (s *WSImpl) clientRead(wsClient *model.WSClient) { + for { + // 读取消息 + messageType, msg, err := wsClient.Conn.ReadMessage() + if err != nil { + logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error()) + s.CloseClient(wsClient.ID) + return + } + // 文本和二进制类型,只处理文本json + if messageType == websocket.TextMessage { + var reqMsg model.WSRequest + err := json.Unmarshal(msg, &reqMsg) + if err != nil { + msgByte, _ := json.Marshal(result.ErrMsg("message format not supported")) + wsClient.MsgChan <- msgByte + } else { + err := NewWSReceiveImpl.Receive(wsClient, reqMsg) + if err != nil { + logger.Warnf("ws ReceiveMessage UID %s err: %s", wsClient.BindUid, err.Error()) + msgByte, _ := json.Marshal(result.ErrMsg(err.Error())) + wsClient.MsgChan <- msgByte + } + } + } + } +} + +// clientWrite 客户端写入消息 +func (s *WSImpl) clientWrite(wsClient *model.WSClient) { + ticker := time.NewTicker(time.Second * 5) // 设置心跳间隔为 5 秒钟 + defer ticker.Stop() + for { + select { + case <-ticker.C: + wsClient.LastHeartbeat = time.Now().UnixMilli() + // 发送 Ping 消息 + err := wsClient.Conn.WriteMessage(websocket.PingMessage, []byte{}) + if err != nil { + logger.Warnf("ws PingMessage UID %s err: %s", wsClient.BindUid, err.Error()) + s.CloseClient(wsClient.ID) + return + } + case msg := <-wsClient.MsgChan: + // 发送消息 + err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg) + if err != nil { + logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error()) + s.CloseClient(wsClient.ID) + return + } + } + } +} + +// CloseClient 客户端关闭 +func (s *WSImpl) CloseClient(clientID string) { + v, ok := WsClients.Load(clientID) + if !ok { + return + } + + client := v.(*model.WSClient) + defer func() { + client.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + client.Conn.Close() + client.StopChan <- struct{}{} + WsClients.Delete(clientID) + }() + + // 客户端断线时自动踢出Uid绑定列表 + if client.BindUid != "" { + if clientIds, ok := WsUsers.Load(client.BindUid); ok { + uidClientIds := clientIds.(*[]string) + if len(*uidClientIds) > 0 { + for i, clientId := range *uidClientIds { + if clientId == client.ID { + *uidClientIds = append((*uidClientIds)[:i], (*uidClientIds)[i+1:]...) + } + } + } + } + } + + // 客户端断线时自动踢出已加入的组 + if client.BindUid != "" && len(client.SubGroup) > 0 { + for _, groupID := range client.SubGroup { + uids, ok := WsGroup.Load(groupID) + if !ok { + continue + } + + groupUIDs := uids.(*[]string) + if len(*groupUIDs) > 0 { + for i, v := range *groupUIDs { + if v == client.BindUid { + *groupUIDs = append((*groupUIDs)[:i], (*groupUIDs)[i+1:]...) + } + } + } + } + } +} diff --git a/src/modules/ws/service/ws_receive.go b/src/modules/ws/service/ws_receive.go new file mode 100644 index 00000000..0a1d8e02 --- /dev/null +++ b/src/modules/ws/service/ws_receive.go @@ -0,0 +1,9 @@ +package service + +import "ems.agt/src/modules/ws/model" + +// IWSReceive WebSocket消息接收处理 服务层接口 +type IWSReceive interface { + // Receive 接收处理 + Receive(client *model.WSClient, reqMsg model.WSRequest) error +} diff --git a/src/modules/ws/service/ws_receive.impl.go b/src/modules/ws/service/ws_receive.impl.go new file mode 100644 index 00000000..bc548c2b --- /dev/null +++ b/src/modules/ws/service/ws_receive.impl.go @@ -0,0 +1,30 @@ +package service + +import ( + "fmt" + + "ems.agt/src/modules/ws/model" + "ems.agt/src/modules/ws/processor" +) + +// 实例化服务层 WSReceiveImpl 结构体 +var NewWSReceiveImpl = &WSReceiveImpl{} + +// WSReceiveImpl WebSocket消息接收处理 服务层处理 +type WSReceiveImpl struct{} + +// Receive 接收处理 +func (s *WSReceiveImpl) Receive(client *model.WSClient, reqMsg model.WSRequest) error { + fmt.Println(client.ID, reqMsg) + switch reqMsg.Type { + case "ps": + res, err := processor.GetProcessData(reqMsg.Data) + if err != nil { + return err + } + client.MsgChan <- res + default: + return fmt.Errorf("message type not supported") + } + return nil +} diff --git a/src/modules/ws/service/ws_send.go b/src/modules/ws/service/ws_send.go new file mode 100644 index 00000000..020d022a --- /dev/null +++ b/src/modules/ws/service/ws_send.go @@ -0,0 +1,10 @@ +package service + +// IWSSend WebSocket消息发送处理 服务层接口 +type IWSSend interface { + // ByClientID 给已知客户端发消息 + ByClientID(clientID string, data any) error + + // ByGroupID 给订阅组的用户发送消息 + ByGroupID(gid string, data any) error +} diff --git a/src/modules/ws/service/ws_send.impl.go b/src/modules/ws/service/ws_send.impl.go new file mode 100644 index 00000000..942e6585 --- /dev/null +++ b/src/modules/ws/service/ws_send.impl.go @@ -0,0 +1,72 @@ +package service + +import ( + "encoding/json" + "fmt" + + "ems.agt/src/modules/ws/model" +) + +const ( + // 组号-其他 + GROUP_OTHER = "0" + // 组号-指标 + GROUP_KPI = "1000" + // 组号-会话记录 + GROUP_CDR = "1005" +) + +// 实例化服务层 WSSendImpl 结构体 +var NewWSSendImpl = &WSSendImpl{} + +// IWSSend WebSocket消息发送处理 服务层处理 +type WSSendImpl struct{} + +// ByClientID 给已知客户端发消息 +func (s *WSSendImpl) ByClientID(clientID string, data any) error { + v, ok := WsClients.Load(clientID) + if !ok { + return fmt.Errorf("no fount client ID: %s", clientID) + } + + dataByte, err := json.Marshal(data) + if err != nil { + return err + } + + client := v.(*model.WSClient) + client.MsgChan <- dataByte + return nil +} + +// ByGroupID 给订阅组的用户发送消息 +func (s *WSSendImpl) ByGroupID(groupID string, data any) error { + uids, ok := WsGroup.Load(groupID) + if !ok { + return fmt.Errorf("no fount Group ID: %s", groupID) + } + + groupUids := uids.(*[]string) + // 群组中没有成员 + if len(*groupUids) == 0 { + return fmt.Errorf("no members in the group") + } + + // 在群组中找到对应的 uid + for _, uid := range *groupUids { + clientIds, ok := WsUsers.Load(uid) + if !ok { + continue + } + // 在用户中找到客户端并发送 + uidClientIds := clientIds.(*[]string) + for _, clientId := range *uidClientIds { + err := s.ByClientID(clientId, data) + if err != nil { + continue + } + } + } + + return nil +} diff --git a/src/modules/ws/ws.go b/src/modules/ws/ws.go new file mode 100644 index 00000000..e0ad3d1f --- /dev/null +++ b/src/modules/ws/ws.go @@ -0,0 +1,30 @@ +package ws + +import ( + "ems.agt/src/framework/logger" + "ems.agt/src/framework/middleware" + "ems.agt/src/framework/middleware/collectlogs" + "ems.agt/src/modules/ws/controller" + + "github.com/gin-gonic/gin" +) + +// 模块路由注册 +func Setup(router *gin.Engine) { + logger.Infof("开始加载 ====> ws 模块路由") + + // WebSocket 协议 + wsGroup := router.Group("/ws") + { + wsGroup.GET("", + middleware.PreAuthorize(nil), + collectlogs.OperateLog(collectlogs.OptionNew("WS 订阅", collectlogs.BUSINESS_TYPE_OTHER)), + controller.NewWSController.WS, + ) + + wsGroup.GET("/test", + middleware.PreAuthorize(nil), + controller.NewWSController.Test, + ) + } +} From ec9a30d78c557ac740976c5a629b4651b0c04b81 Mon Sep 17 00:00:00 2001 From: TsMask <340112800@qq.com> Date: Tue, 23 Jan 2024 18:43:26 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20ws=E8=8E=B7=E5=8F=96=E7=BD=91?= =?UTF-8?q?=E7=BB=9C=E8=BF=9E=E6=8E=A5=E8=BF=9B=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/modules/ws/model/net_connect.go | 20 +++++++++ src/modules/ws/processor/net_connect.go | 53 +++++++++++++++++++++++ src/modules/ws/service/ws_receive.impl.go | 6 +++ 3 files changed, 79 insertions(+) create mode 100644 src/modules/ws/model/net_connect.go create mode 100644 src/modules/ws/processor/net_connect.go diff --git a/src/modules/ws/model/net_connect.go b/src/modules/ws/model/net_connect.go new file mode 100644 index 00000000..8116c285 --- /dev/null +++ b/src/modules/ws/model/net_connect.go @@ -0,0 +1,20 @@ +package model + +import "github.com/shirou/gopsutil/v3/net" + +// NetConnectData 网络连接进程数据 +type NetConnectData struct { + Type string `json:"type"` + Status string `json:"status"` + Laddr net.Addr `json:"localaddr"` + Raddr net.Addr `json:"remoteaddr"` + PID int32 `json:"PID"` + Name string `json:"name"` +} + +// NetConnectQuery 网络连接进程查询 +type NetConnectQuery struct { + Port int32 `json:"port"` + ProcessName string `json:"processName"` + ProcessID int32 `json:"processID"` +} diff --git a/src/modules/ws/processor/net_connect.go b/src/modules/ws/processor/net_connect.go new file mode 100644 index 00000000..16b0c6a5 --- /dev/null +++ b/src/modules/ws/processor/net_connect.go @@ -0,0 +1,53 @@ +package processor + +import ( + "encoding/json" + "strings" + + "ems.agt/src/modules/ws/model" + "github.com/shirou/gopsutil/v3/net" + "github.com/shirou/gopsutil/v3/process" +) + +// GetNetConnections 获取网络连接进程 +func GetNetConnections(data any) ([]byte, error) { + msgByte, _ := json.Marshal(data) + var query model.NetConnectQuery + err := json.Unmarshal(msgByte, &query) + if err != nil { + return nil, err + } + + result := []model.NetConnectData{} + for _, netType := range [...]string{"tcp", "udp"} { + connections, _ := net.Connections(netType) + if err == nil { + for _, conn := range connections { + if query.ProcessID > 0 && query.ProcessID != conn.Pid { + continue + } + proc, err := process.NewProcess(conn.Pid) + if err == nil { + name, _ := proc.Name() + if name != "" && query.ProcessName != "" && !strings.Contains(name, query.ProcessName) { + continue + } + if query.Port > 0 && query.Port != int32(conn.Laddr.Port) && query.Port != int32(conn.Raddr.Port) { + continue + } + result = append(result, model.NetConnectData{ + Type: netType, + Status: conn.Status, + Laddr: conn.Laddr, + Raddr: conn.Raddr, + PID: conn.Pid, + Name: name, + }) + } + + } + } + } + resultByte, err := json.Marshal(result) + return resultByte, err +} diff --git a/src/modules/ws/service/ws_receive.impl.go b/src/modules/ws/service/ws_receive.impl.go index bc548c2b..c6dda62c 100644 --- a/src/modules/ws/service/ws_receive.impl.go +++ b/src/modules/ws/service/ws_receive.impl.go @@ -23,6 +23,12 @@ func (s *WSReceiveImpl) Receive(client *model.WSClient, reqMsg model.WSRequest) return err } client.MsgChan <- res + case "net": + res, err := processor.GetNetConnections(reqMsg.Data) + if err != nil { + return err + } + client.MsgChan <- res default: return fmt.Errorf("message type not supported") } From aba5e48005ca316f98a299db5e3a7dc998d20b98 Mon Sep 17 00:00:00 2001 From: TsMask <340112800@qq.com> Date: Tue, 23 Jan 2024 20:00:05 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat:=20ws=20=E8=AF=B7=E6=B1=82=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E4=BD=93=E5=88=A4=E6=96=ADreqid=E5=92=8C=E5=92=8C?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E5=BA=8F=E5=88=97=E5=8C=96=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/modules/ws/model/ws.go | 5 +- src/modules/ws/processor/net_connect.go | 60 +++++++++++++---------- src/modules/ws/processor/ps_process.go | 20 +++++--- src/modules/ws/service/ws_receive.impl.go | 8 +-- 4 files changed, 55 insertions(+), 38 deletions(-) diff --git a/src/modules/ws/model/ws.go b/src/modules/ws/model/ws.go index 7a0eb5bf..130419c5 100644 --- a/src/modules/ws/model/ws.go +++ b/src/modules/ws/model/ws.go @@ -15,6 +15,7 @@ type WSClient struct { // WSRequest ws消息接收 type WSRequest struct { - Type string `json:"type"` - Data any `json:"data"` + RequestID string `json:"requestId"` // 请求ID + Type string `json:"type"` // 业务类型 + Data any `json:"data"` // 查询结构 } diff --git a/src/modules/ws/processor/net_connect.go b/src/modules/ws/processor/net_connect.go index 16b0c6a5..6d2ed20d 100644 --- a/src/modules/ws/processor/net_connect.go +++ b/src/modules/ws/processor/net_connect.go @@ -2,52 +2,60 @@ package processor import ( "encoding/json" + "fmt" "strings" + "ems.agt/src/framework/logger" + "ems.agt/src/framework/vo/result" "ems.agt/src/modules/ws/model" "github.com/shirou/gopsutil/v3/net" "github.com/shirou/gopsutil/v3/process" ) // GetNetConnections 获取网络连接进程 -func GetNetConnections(data any) ([]byte, error) { +func GetNetConnections(requestID string, data any) ([]byte, error) { msgByte, _ := json.Marshal(data) var query model.NetConnectQuery err := json.Unmarshal(msgByte, &query) if err != nil { - return nil, err + logger.Warnf("ws processor GetNetConnections err: %s", err.Error()) + return nil, fmt.Errorf("query data structure error") } - result := []model.NetConnectData{} + dataArr := []model.NetConnectData{} for _, netType := range [...]string{"tcp", "udp"} { - connections, _ := net.Connections(netType) - if err == nil { - for _, conn := range connections { - if query.ProcessID > 0 && query.ProcessID != conn.Pid { + connections, err := net.Connections(netType) + if err != nil { + continue + } + for _, conn := range connections { + if query.ProcessID > 0 && query.ProcessID != conn.Pid { + continue + } + proc, err := process.NewProcess(conn.Pid) + if err == nil { + name, _ := proc.Name() + if name != "" && query.ProcessName != "" && !strings.Contains(name, query.ProcessName) { continue } - proc, err := process.NewProcess(conn.Pid) - if err == nil { - name, _ := proc.Name() - if name != "" && query.ProcessName != "" && !strings.Contains(name, query.ProcessName) { - continue - } - if query.Port > 0 && query.Port != int32(conn.Laddr.Port) && query.Port != int32(conn.Raddr.Port) { - continue - } - result = append(result, model.NetConnectData{ - Type: netType, - Status: conn.Status, - Laddr: conn.Laddr, - Raddr: conn.Raddr, - PID: conn.Pid, - Name: name, - }) + if query.Port > 0 && query.Port != int32(conn.Laddr.Port) && query.Port != int32(conn.Raddr.Port) { + continue } - + dataArr = append(dataArr, model.NetConnectData{ + Type: netType, + Status: conn.Status, + Laddr: conn.Laddr, + Raddr: conn.Raddr, + PID: conn.Pid, + Name: name, + }) } } } - resultByte, err := json.Marshal(result) + + resultByte, err := json.Marshal(result.Ok(map[string]any{ + "requestID": requestID, + "data": dataArr, + })) return resultByte, err } diff --git a/src/modules/ws/processor/ps_process.go b/src/modules/ws/processor/ps_process.go index d15d2830..55e509ee 100644 --- a/src/modules/ws/processor/ps_process.go +++ b/src/modules/ws/processor/ps_process.go @@ -7,19 +7,22 @@ import ( "strings" "sync" + "ems.agt/src/framework/logger" "ems.agt/src/framework/utils/date" "ems.agt/src/framework/utils/parse" + "ems.agt/src/framework/vo/result" "ems.agt/src/modules/ws/model" "github.com/shirou/gopsutil/v3/process" ) // GetProcessData 获取进程数据 -func GetProcessData(data any) ([]byte, error) { +func GetProcessData(requestID string, data any) ([]byte, error) { msgByte, _ := json.Marshal(data) var query model.PsProcessQuery err := json.Unmarshal(msgByte, &query) if err != nil { - return nil, err + logger.Warnf("ws processor GetNetConnections err: %s", err.Error()) + return nil, fmt.Errorf("query data structure error") } var processes []*process.Process @@ -29,7 +32,7 @@ func GetProcessData(data any) ([]byte, error) { } var ( - result = []model.PsProcessData{} + dataArr = []model.PsProcessData{} resultMutex sync.Mutex wg sync.WaitGroup numWorkers = 4 @@ -104,7 +107,7 @@ func GetProcessData(data any) ([]byte, error) { procData.Envs, _ = proc.Environ() resultMutex.Lock() - result = append(result, procData) + dataArr = append(dataArr, procData) resultMutex.Unlock() } @@ -127,10 +130,13 @@ func GetProcessData(data any) ([]byte, error) { wg.Wait() - sort.Slice(result, func(i, j int) bool { - return result[i].PID < result[j].PID + sort.Slice(dataArr, func(i, j int) bool { + return dataArr[i].PID < dataArr[j].PID }) - resultByte, err := json.Marshal(result) + resultByte, err := json.Marshal(result.Ok(map[string]any{ + "requestID": requestID, + "data": dataArr, + })) return resultByte, err } diff --git a/src/modules/ws/service/ws_receive.impl.go b/src/modules/ws/service/ws_receive.impl.go index c6dda62c..acdfd23f 100644 --- a/src/modules/ws/service/ws_receive.impl.go +++ b/src/modules/ws/service/ws_receive.impl.go @@ -15,16 +15,18 @@ type WSReceiveImpl struct{} // Receive 接收处理 func (s *WSReceiveImpl) Receive(client *model.WSClient, reqMsg model.WSRequest) error { - fmt.Println(client.ID, reqMsg) + if reqMsg.RequestID == "" { + return fmt.Errorf("message requestId is required") + } switch reqMsg.Type { case "ps": - res, err := processor.GetProcessData(reqMsg.Data) + res, err := processor.GetProcessData(reqMsg.RequestID, reqMsg.Data) if err != nil { return err } client.MsgChan <- res case "net": - res, err := processor.GetNetConnections(reqMsg.Data) + res, err := processor.GetNetConnections(reqMsg.RequestID, reqMsg.Data) if err != nil { return err }