diff --git a/src/modules/ws/model/ws.go b/src/modules/ws/model/ws.go index 7a0eb5bf..130419c5 100644 --- a/src/modules/ws/model/ws.go +++ b/src/modules/ws/model/ws.go @@ -15,6 +15,7 @@ type WSClient struct { // WSRequest ws消息接收 type WSRequest struct { - Type string `json:"type"` - Data any `json:"data"` + RequestID string `json:"requestId"` // 请求ID + Type string `json:"type"` // 业务类型 + Data any `json:"data"` // 查询结构 } diff --git a/src/modules/ws/processor/net_connect.go b/src/modules/ws/processor/net_connect.go index 16b0c6a5..6d2ed20d 100644 --- a/src/modules/ws/processor/net_connect.go +++ b/src/modules/ws/processor/net_connect.go @@ -2,52 +2,60 @@ package processor import ( "encoding/json" + "fmt" "strings" + "ems.agt/src/framework/logger" + "ems.agt/src/framework/vo/result" "ems.agt/src/modules/ws/model" "github.com/shirou/gopsutil/v3/net" "github.com/shirou/gopsutil/v3/process" ) // GetNetConnections 获取网络连接进程 -func GetNetConnections(data any) ([]byte, error) { +func GetNetConnections(requestID string, data any) ([]byte, error) { msgByte, _ := json.Marshal(data) var query model.NetConnectQuery err := json.Unmarshal(msgByte, &query) if err != nil { - return nil, err + logger.Warnf("ws processor GetNetConnections err: %s", err.Error()) + return nil, fmt.Errorf("query data structure error") } - result := []model.NetConnectData{} + dataArr := []model.NetConnectData{} for _, netType := range [...]string{"tcp", "udp"} { - connections, _ := net.Connections(netType) - if err == nil { - for _, conn := range connections { - if query.ProcessID > 0 && query.ProcessID != conn.Pid { + connections, err := net.Connections(netType) + if err != nil { + continue + } + for _, conn := range connections { + if query.ProcessID > 0 && query.ProcessID != conn.Pid { + continue + } + proc, err := process.NewProcess(conn.Pid) + if err == nil { + name, _ := proc.Name() + if name != "" && query.ProcessName != "" && !strings.Contains(name, query.ProcessName) { continue } - proc, err := process.NewProcess(conn.Pid) - if err == nil { - name, _ := proc.Name() - if name != "" && query.ProcessName != "" && !strings.Contains(name, query.ProcessName) { - continue - } - if query.Port > 0 && query.Port != int32(conn.Laddr.Port) && query.Port != int32(conn.Raddr.Port) { - continue - } - result = append(result, model.NetConnectData{ - Type: netType, - Status: conn.Status, - Laddr: conn.Laddr, - Raddr: conn.Raddr, - PID: conn.Pid, - Name: name, - }) + if query.Port > 0 && query.Port != int32(conn.Laddr.Port) && query.Port != int32(conn.Raddr.Port) { + continue } - + dataArr = append(dataArr, model.NetConnectData{ + Type: netType, + Status: conn.Status, + Laddr: conn.Laddr, + Raddr: conn.Raddr, + PID: conn.Pid, + Name: name, + }) } } } - resultByte, err := json.Marshal(result) + + resultByte, err := json.Marshal(result.Ok(map[string]any{ + "requestID": requestID, + "data": dataArr, + })) return resultByte, err } diff --git a/src/modules/ws/processor/ps_process.go b/src/modules/ws/processor/ps_process.go index d15d2830..55e509ee 100644 --- a/src/modules/ws/processor/ps_process.go +++ b/src/modules/ws/processor/ps_process.go @@ -7,19 +7,22 @@ import ( "strings" "sync" + "ems.agt/src/framework/logger" "ems.agt/src/framework/utils/date" "ems.agt/src/framework/utils/parse" + "ems.agt/src/framework/vo/result" "ems.agt/src/modules/ws/model" "github.com/shirou/gopsutil/v3/process" ) // GetProcessData 获取进程数据 -func GetProcessData(data any) ([]byte, error) { +func GetProcessData(requestID string, data any) ([]byte, error) { msgByte, _ := json.Marshal(data) var query model.PsProcessQuery err := json.Unmarshal(msgByte, &query) if err != nil { - return nil, err + logger.Warnf("ws processor GetNetConnections err: %s", err.Error()) + return nil, fmt.Errorf("query data structure error") } var processes []*process.Process @@ -29,7 +32,7 @@ func GetProcessData(data any) ([]byte, error) { } var ( - result = []model.PsProcessData{} + dataArr = []model.PsProcessData{} resultMutex sync.Mutex wg sync.WaitGroup numWorkers = 4 @@ -104,7 +107,7 @@ func GetProcessData(data any) ([]byte, error) { procData.Envs, _ = proc.Environ() resultMutex.Lock() - result = append(result, procData) + dataArr = append(dataArr, procData) resultMutex.Unlock() } @@ -127,10 +130,13 @@ func GetProcessData(data any) ([]byte, error) { wg.Wait() - sort.Slice(result, func(i, j int) bool { - return result[i].PID < result[j].PID + sort.Slice(dataArr, func(i, j int) bool { + return dataArr[i].PID < dataArr[j].PID }) - resultByte, err := json.Marshal(result) + resultByte, err := json.Marshal(result.Ok(map[string]any{ + "requestID": requestID, + "data": dataArr, + })) return resultByte, err } diff --git a/src/modules/ws/service/ws_receive.impl.go b/src/modules/ws/service/ws_receive.impl.go index c6dda62c..acdfd23f 100644 --- a/src/modules/ws/service/ws_receive.impl.go +++ b/src/modules/ws/service/ws_receive.impl.go @@ -15,16 +15,18 @@ type WSReceiveImpl struct{} // Receive 接收处理 func (s *WSReceiveImpl) Receive(client *model.WSClient, reqMsg model.WSRequest) error { - fmt.Println(client.ID, reqMsg) + if reqMsg.RequestID == "" { + return fmt.Errorf("message requestId is required") + } switch reqMsg.Type { case "ps": - res, err := processor.GetProcessData(reqMsg.Data) + res, err := processor.GetProcessData(reqMsg.RequestID, reqMsg.Data) if err != nil { return err } client.MsgChan <- res case "net": - res, err := processor.GetNetConnections(reqMsg.Data) + res, err := processor.GetNetConnections(reqMsg.RequestID, reqMsg.Data) if err != nil { return err }