This commit is contained in:
2024-01-23 20:15:47 +08:00
14 changed files with 769 additions and 0 deletions

View File

@@ -15,6 +15,7 @@ import (
networkelement "ems.agt/src/modules/network_element"
"ems.agt/src/modules/system"
"ems.agt/src/modules/trace"
"ems.agt/src/modules/ws"
"github.com/gin-gonic/gin"
)
@@ -123,6 +124,8 @@ func initModulesRoute(app *gin.Engine) {
trace.Setup(app)
// 图表模块
chart.Setup(app)
// ws 模块
ws.Setup(app)
// 调度任务模块--暂无接口
crontask.Setup(app)
// 监控模块 - 含调度处理加入队列,放最后

View File

@@ -0,0 +1,98 @@
package controller
import (
"strings"
"ems.agt/src/framework/i18n"
"ems.agt/src/framework/logger"
"ems.agt/src/framework/utils/ctx"
"ems.agt/src/framework/utils/parse"
"ems.agt/src/framework/vo/result"
"ems.agt/src/modules/ws/service"
"github.com/gin-gonic/gin"
)
// 实例化控制层 WSController 结构体
var NewWSController = &WSController{
wsService: service.NewWSImpl,
wsSendService: service.NewWSSendImpl,
}
// WebSocket通信
//
// PATH /ws
type WSController struct {
// WebSocket 服务
wsService service.IWS
// WebSocket消息发送 服务
wsSendService service.IWSSend
}
// 通用
//
// GET /?subGroupIDs=0
func (s *WSController) WS(c *gin.Context) {
language := ctx.AcceptLanguage(c)
// 登录用户信息
loginUser, err := ctx.LoginUser(c)
if err != nil {
c.JSON(401, result.CodeMsg(401, i18n.TKey(language, err.Error())))
return
}
// 订阅消息组
var subGroupIDs []string
subGroupIDStr := c.Query("subGroupID")
if subGroupIDStr != "" {
// 处理字符转id数组后去重
ids := strings.Split(subGroupIDStr, ",")
uniqueIDs := parse.RemoveDuplicates(ids)
if len(uniqueIDs) > 0 {
subGroupIDs = uniqueIDs
}
}
// 将 HTTP 连接升级为 WebSocket 连接
conn := s.wsService.UpgraderWs(c.Writer, c.Request)
if conn == nil {
return
}
defer conn.Close()
wsClient := s.wsService.NewClient(loginUser.UserID, subGroupIDs, conn)
// 等待停止信号
for value := range wsClient.StopChan {
logger.Infof("ws Stop Client UID %s %s", wsClient.BindUid, value)
return
}
}
// 测试
//
// GET /test?clientId=&groupID=
func (s *WSController) Test(c *gin.Context) {
language := ctx.AcceptLanguage(c)
// 登录用户信息
loginUser, err := ctx.LoginUser(c)
if err != nil {
c.JSON(401, result.CodeMsg(401, i18n.TKey(language, err.Error())))
return
}
// err = s.wsSendService.ByClientID(c.Query("clientId"), loginUser)
// if err != nil {
// c.JSON(200, result.ErrMsg(err.Error()))
// return
// }
err = s.wsSendService.ByGroupID(c.Query("groupID"), loginUser)
if err != nil {
c.JSON(200, result.ErrMsg(err.Error()))
return
}
c.JSON(200, result.Ok(nil))
}

View File

@@ -0,0 +1,20 @@
package model
import "github.com/shirou/gopsutil/v3/net"
// NetConnectData 网络连接进程数据
type NetConnectData struct {
Type string `json:"type"`
Status string `json:"status"`
Laddr net.Addr `json:"localaddr"`
Raddr net.Addr `json:"remoteaddr"`
PID int32 `json:"PID"`
Name string `json:"name"`
}
// NetConnectQuery 网络连接进程查询
type NetConnectQuery struct {
Port int32 `json:"port"`
ProcessName string `json:"processName"`
ProcessID int32 `json:"processID"`
}

View File

@@ -0,0 +1,38 @@
package model
// PsProcessData 进程数据
type PsProcessData struct {
PID int32 `json:"PID"`
Name string `json:"name"`
PPID int32 `json:"PPID"`
Username string `json:"username"`
Status string `json:"status"`
StartTime string `json:"startTime"`
NumThreads int32 `json:"numThreads"`
NumConnections int `json:"numConnections"`
CpuPercent string `json:"cpuPercent"`
DiskRead string `json:"diskRead"`
DiskWrite string `json:"diskWrite"`
CmdLine string `json:"cmdLine"`
Rss string `json:"rss"`
VMS string `json:"vms"`
HWM string `json:"hwm"`
Data string `json:"data"`
Stack string `json:"stack"`
Locked string `json:"locked"`
Swap string `json:"swap"`
CpuValue float64 `json:"cpuValue"`
RssValue uint64 `json:"rssValue"`
Envs []string `json:"envs"`
}
// PsProcessQuery 进程查询
type PsProcessQuery struct {
Pid int32 `json:"pid"`
Name string `json:"name"`
Username string `json:"username"`
}

View File

@@ -0,0 +1,21 @@
package model
import "github.com/gorilla/websocket"
// WSClient ws客户端
type WSClient struct {
ID string // 连接ID-随机字符串16位
Conn *websocket.Conn // 连接实例
LastHeartbeat int64 // 最近一次心跳消息(毫秒)
BindUid string // 绑定登录用户ID
SubGroup []string // 订阅组ID
MsgChan chan []byte // 消息通道
StopChan chan struct{} // 停止信号-退出协程
}
// WSRequest ws消息接收
type WSRequest struct {
RequestID string `json:"requestId"` // 请求ID
Type string `json:"type"` // 业务类型
Data any `json:"data"` // 查询结构
}

View File

@@ -0,0 +1,61 @@
package processor
import (
"encoding/json"
"fmt"
"strings"
"ems.agt/src/framework/logger"
"ems.agt/src/framework/vo/result"
"ems.agt/src/modules/ws/model"
"github.com/shirou/gopsutil/v3/net"
"github.com/shirou/gopsutil/v3/process"
)
// GetNetConnections 获取网络连接进程
func GetNetConnections(requestID string, data any) ([]byte, error) {
msgByte, _ := json.Marshal(data)
var query model.NetConnectQuery
err := json.Unmarshal(msgByte, &query)
if err != nil {
logger.Warnf("ws processor GetNetConnections err: %s", err.Error())
return nil, fmt.Errorf("query data structure error")
}
dataArr := []model.NetConnectData{}
for _, netType := range [...]string{"tcp", "udp"} {
connections, err := net.Connections(netType)
if err != nil {
continue
}
for _, conn := range connections {
if query.ProcessID > 0 && query.ProcessID != conn.Pid {
continue
}
proc, err := process.NewProcess(conn.Pid)
if err == nil {
name, _ := proc.Name()
if name != "" && query.ProcessName != "" && !strings.Contains(name, query.ProcessName) {
continue
}
if query.Port > 0 && query.Port != int32(conn.Laddr.Port) && query.Port != int32(conn.Raddr.Port) {
continue
}
dataArr = append(dataArr, model.NetConnectData{
Type: netType,
Status: conn.Status,
Laddr: conn.Laddr,
Raddr: conn.Raddr,
PID: conn.Pid,
Name: name,
})
}
}
}
resultByte, err := json.Marshal(result.Ok(map[string]any{
"requestID": requestID,
"data": dataArr,
}))
return resultByte, err
}

View File

@@ -0,0 +1,142 @@
package processor
import (
"encoding/json"
"fmt"
"sort"
"strings"
"sync"
"ems.agt/src/framework/logger"
"ems.agt/src/framework/utils/date"
"ems.agt/src/framework/utils/parse"
"ems.agt/src/framework/vo/result"
"ems.agt/src/modules/ws/model"
"github.com/shirou/gopsutil/v3/process"
)
// GetProcessData 获取进程数据
func GetProcessData(requestID string, data any) ([]byte, error) {
msgByte, _ := json.Marshal(data)
var query model.PsProcessQuery
err := json.Unmarshal(msgByte, &query)
if err != nil {
logger.Warnf("ws processor GetNetConnections err: %s", err.Error())
return nil, fmt.Errorf("query data structure error")
}
var processes []*process.Process
processes, err = process.Processes()
if err != nil {
return nil, err
}
var (
dataArr = []model.PsProcessData{}
resultMutex sync.Mutex
wg sync.WaitGroup
numWorkers = 4
)
handleData := func(proc *process.Process) {
procData := model.PsProcessData{
PID: proc.Pid,
}
if query.Pid > 0 && query.Pid != proc.Pid {
return
}
procName, err := proc.Name()
if procName == "" || err != nil {
return
} else {
procData.Name = procName
}
if query.Name != "" && !strings.Contains(procData.Name, query.Name) {
return
}
if username, err := proc.Username(); err == nil {
procData.Username = username
}
if query.Username != "" && !strings.Contains(procData.Username, query.Username) {
return
}
procData.PPID, _ = proc.Ppid()
statusArray, _ := proc.Status()
if len(statusArray) > 0 {
procData.Status = strings.Join(statusArray, ",")
}
createTime, procErr := proc.CreateTime()
if procErr == nil {
procData.StartTime = date.ParseDateToStr(createTime, date.YYYY_MM_DD_HH_MM_SS)
}
procData.NumThreads, _ = proc.NumThreads()
procData.CpuValue, _ = proc.CPUPercent()
procData.CpuPercent = fmt.Sprintf("%.2f", procData.CpuValue) + "%"
menInfo, procErr := proc.MemoryInfo()
if procErr == nil {
procData.Rss = parse.Bit(float64(menInfo.RSS))
procData.Data = parse.Bit(float64(menInfo.Data))
procData.VMS = parse.Bit(float64(menInfo.VMS))
procData.HWM = parse.Bit(float64(menInfo.HWM))
procData.Stack = parse.Bit(float64(menInfo.Stack))
procData.Locked = parse.Bit(float64(menInfo.Locked))
procData.Swap = parse.Bit(float64(menInfo.Swap))
procData.RssValue = menInfo.RSS
} else {
procData.Rss = "--"
procData.Data = "--"
procData.VMS = "--"
procData.HWM = "--"
procData.Stack = "--"
procData.Locked = "--"
procData.Swap = "--"
procData.RssValue = 0
}
ioStat, procErr := proc.IOCounters()
if procErr == nil {
procData.DiskWrite = parse.Bit(float64(ioStat.WriteBytes))
procData.DiskRead = parse.Bit(float64(ioStat.ReadBytes))
} else {
procData.DiskWrite = "--"
procData.DiskRead = "--"
}
procData.CmdLine, _ = proc.Cmdline()
procData.Envs, _ = proc.Environ()
resultMutex.Lock()
dataArr = append(dataArr, procData)
resultMutex.Unlock()
}
chunkSize := (len(processes) + numWorkers - 1) / numWorkers
for i := 0; i < numWorkers; i++ {
wg.Add(1)
start := i * chunkSize
end := (i + 1) * chunkSize
if end > len(processes) {
end = len(processes)
}
go func(start, end int) {
defer wg.Done()
for j := start; j < end; j++ {
handleData(processes[j])
}
}(start, end)
}
wg.Wait()
sort.Slice(dataArr, func(i, j int) bool {
return dataArr[i].PID < dataArr[j].PID
})
resultByte, err := json.Marshal(result.Ok(map[string]any{
"requestID": requestID,
"data": dataArr,
}))
return resultByte, err
}

View File

@@ -0,0 +1,20 @@
package service
import (
"net/http"
"ems.agt/src/modules/ws/model"
"github.com/gorilla/websocket"
)
// IWS WebSocket通信 服务层接口
type IWS interface {
// UpgraderWs http升级ws请求
UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn
// NewClient 新建客户端 uid 登录用户ID
NewClient(uid string, gids []string, conn *websocket.Conn) *model.WSClient
// CloseClient 客户端关闭
CloseClient(clientID string)
}

View File

@@ -0,0 +1,207 @@
package service
import (
"encoding/json"
"net/http"
"sync"
"time"
"ems.agt/src/framework/logger"
"ems.agt/src/framework/utils/generate"
"ems.agt/src/framework/vo/result"
"ems.agt/src/modules/ws/model"
"github.com/gorilla/websocket"
)
var (
// ws客户端 [clientId: client]
WsClients = sync.Map{}
// ws用户对应的多个客户端id [uid:clientIds]
WsUsers = sync.Map{}
// ws组对应的多个用户id [groupID:uids]
WsGroup = sync.Map{}
)
// 实例化服务层 WSImpl 结构体
var NewWSImpl = &WSImpl{}
// WSImpl WebSocket通信 服务层处理
type WSImpl struct{}
// UpgraderWs http升级ws请求
func (s *WSImpl) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn {
wsUpgrader := websocket.Upgrader{
// 设置消息发送缓冲区大小byte如果这个值设置得太小可能会导致服务端在发送大型消息时遇到问题
WriteBufferSize: 1024,
// 消息包启用压缩
EnableCompression: true,
// ws握手超时时间
HandshakeTimeout: 5 * time.Second,
// ws握手过程中允许跨域
CheckOrigin: func(r *http.Request) bool {
return true
},
}
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
logger.Errorf("ws Upgrade err: %s", err.Error())
}
return conn
}
// NewClient 新建客户端 uid 登录用户ID
func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn) *model.WSClient {
// clientID也可以用其他方式生成只要能保证在所有服务端中都能保证唯一即可
clientID := generate.Code(16)
wsClient := &model.WSClient{
ID: clientID,
Conn: conn,
LastHeartbeat: time.Now().UnixMilli(),
BindUid: uid,
SubGroup: groupIDs,
MsgChan: make(chan []byte, 100),
StopChan: make(chan struct{}, 1), // 请求卡死循环标记
}
// 存入客户端
WsClients.Store(clientID, wsClient)
// 存入用户持有客户端
if uid != "" {
if v, ok := WsUsers.Load(uid); ok {
uidClientIds := v.(*[]string)
*uidClientIds = append(*uidClientIds, clientID)
} else {
WsUsers.Store(uid, &[]string{clientID})
}
}
// 存入用户订阅组
if uid != "" && len(groupIDs) > 0 {
for _, groupID := range groupIDs {
if v, ok := WsGroup.Load(groupID); ok {
groupUIDs := v.(*[]string)
*groupUIDs = append(*groupUIDs, uid)
} else {
WsGroup.Store(groupID, &[]string{uid})
}
}
}
go s.clientRead(wsClient)
go s.clientWrite(wsClient)
// 发客户端id确认是否连接
msgByte, _ := json.Marshal(result.OkData(map[string]string{
"clientId": clientID,
}))
wsClient.MsgChan <- msgByte
return wsClient
}
// clientRead 客户端读取消息
func (s *WSImpl) clientRead(wsClient *model.WSClient) {
for {
// 读取消息
messageType, msg, err := wsClient.Conn.ReadMessage()
if err != nil {
logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.CloseClient(wsClient.ID)
return
}
// 文本和二进制类型只处理文本json
if messageType == websocket.TextMessage {
var reqMsg model.WSRequest
err := json.Unmarshal(msg, &reqMsg)
if err != nil {
msgByte, _ := json.Marshal(result.ErrMsg("message format not supported"))
wsClient.MsgChan <- msgByte
} else {
err := NewWSReceiveImpl.Receive(wsClient, reqMsg)
if err != nil {
logger.Warnf("ws ReceiveMessage UID %s err: %s", wsClient.BindUid, err.Error())
msgByte, _ := json.Marshal(result.ErrMsg(err.Error()))
wsClient.MsgChan <- msgByte
}
}
}
}
}
// clientWrite 客户端写入消息
func (s *WSImpl) clientWrite(wsClient *model.WSClient) {
ticker := time.NewTicker(time.Second * 5) // 设置心跳间隔为 5 秒钟
defer ticker.Stop()
for {
select {
case <-ticker.C:
wsClient.LastHeartbeat = time.Now().UnixMilli()
// 发送 Ping 消息
err := wsClient.Conn.WriteMessage(websocket.PingMessage, []byte{})
if err != nil {
logger.Warnf("ws PingMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.CloseClient(wsClient.ID)
return
}
case msg := <-wsClient.MsgChan:
// 发送消息
err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.CloseClient(wsClient.ID)
return
}
}
}
}
// CloseClient 客户端关闭
func (s *WSImpl) CloseClient(clientID string) {
v, ok := WsClients.Load(clientID)
if !ok {
return
}
client := v.(*model.WSClient)
defer func() {
client.Conn.WriteMessage(websocket.CloseMessage, []byte{})
client.Conn.Close()
client.StopChan <- struct{}{}
WsClients.Delete(clientID)
}()
// 客户端断线时自动踢出Uid绑定列表
if client.BindUid != "" {
if clientIds, ok := WsUsers.Load(client.BindUid); ok {
uidClientIds := clientIds.(*[]string)
if len(*uidClientIds) > 0 {
for i, clientId := range *uidClientIds {
if clientId == client.ID {
*uidClientIds = append((*uidClientIds)[:i], (*uidClientIds)[i+1:]...)
}
}
}
}
}
// 客户端断线时自动踢出已加入的组
if client.BindUid != "" && len(client.SubGroup) > 0 {
for _, groupID := range client.SubGroup {
uids, ok := WsGroup.Load(groupID)
if !ok {
continue
}
groupUIDs := uids.(*[]string)
if len(*groupUIDs) > 0 {
for i, v := range *groupUIDs {
if v == client.BindUid {
*groupUIDs = append((*groupUIDs)[:i], (*groupUIDs)[i+1:]...)
}
}
}
}
}
}

View File

@@ -0,0 +1,9 @@
package service
import "ems.agt/src/modules/ws/model"
// IWSReceive WebSocket消息接收处理 服务层接口
type IWSReceive interface {
// Receive 接收处理
Receive(client *model.WSClient, reqMsg model.WSRequest) error
}

View File

@@ -0,0 +1,38 @@
package service
import (
"fmt"
"ems.agt/src/modules/ws/model"
"ems.agt/src/modules/ws/processor"
)
// 实例化服务层 WSReceiveImpl 结构体
var NewWSReceiveImpl = &WSReceiveImpl{}
// WSReceiveImpl WebSocket消息接收处理 服务层处理
type WSReceiveImpl struct{}
// Receive 接收处理
func (s *WSReceiveImpl) Receive(client *model.WSClient, reqMsg model.WSRequest) error {
if reqMsg.RequestID == "" {
return fmt.Errorf("message requestId is required")
}
switch reqMsg.Type {
case "ps":
res, err := processor.GetProcessData(reqMsg.RequestID, reqMsg.Data)
if err != nil {
return err
}
client.MsgChan <- res
case "net":
res, err := processor.GetNetConnections(reqMsg.RequestID, reqMsg.Data)
if err != nil {
return err
}
client.MsgChan <- res
default:
return fmt.Errorf("message type not supported")
}
return nil
}

View File

@@ -0,0 +1,10 @@
package service
// IWSSend WebSocket消息发送处理 服务层接口
type IWSSend interface {
// ByClientID 给已知客户端发消息
ByClientID(clientID string, data any) error
// ByGroupID 给订阅组的用户发送消息
ByGroupID(gid string, data any) error
}

View File

@@ -0,0 +1,72 @@
package service
import (
"encoding/json"
"fmt"
"ems.agt/src/modules/ws/model"
)
const (
// 组号-其他
GROUP_OTHER = "0"
// 组号-指标
GROUP_KPI = "1000"
// 组号-会话记录
GROUP_CDR = "1005"
)
// 实例化服务层 WSSendImpl 结构体
var NewWSSendImpl = &WSSendImpl{}
// IWSSend WebSocket消息发送处理 服务层处理
type WSSendImpl struct{}
// ByClientID 给已知客户端发消息
func (s *WSSendImpl) ByClientID(clientID string, data any) error {
v, ok := WsClients.Load(clientID)
if !ok {
return fmt.Errorf("no fount client ID: %s", clientID)
}
dataByte, err := json.Marshal(data)
if err != nil {
return err
}
client := v.(*model.WSClient)
client.MsgChan <- dataByte
return nil
}
// ByGroupID 给订阅组的用户发送消息
func (s *WSSendImpl) ByGroupID(groupID string, data any) error {
uids, ok := WsGroup.Load(groupID)
if !ok {
return fmt.Errorf("no fount Group ID: %s", groupID)
}
groupUids := uids.(*[]string)
// 群组中没有成员
if len(*groupUids) == 0 {
return fmt.Errorf("no members in the group")
}
// 在群组中找到对应的 uid
for _, uid := range *groupUids {
clientIds, ok := WsUsers.Load(uid)
if !ok {
continue
}
// 在用户中找到客户端并发送
uidClientIds := clientIds.(*[]string)
for _, clientId := range *uidClientIds {
err := s.ByClientID(clientId, data)
if err != nil {
continue
}
}
}
return nil
}

30
src/modules/ws/ws.go Normal file
View File

@@ -0,0 +1,30 @@
package ws
import (
"ems.agt/src/framework/logger"
"ems.agt/src/framework/middleware"
"ems.agt/src/framework/middleware/collectlogs"
"ems.agt/src/modules/ws/controller"
"github.com/gin-gonic/gin"
)
// 模块路由注册
func Setup(router *gin.Engine) {
logger.Infof("开始加载 ====> ws 模块路由")
// WebSocket 协议
wsGroup := router.Group("/ws")
{
wsGroup.GET("",
middleware.PreAuthorize(nil),
collectlogs.OperateLog(collectlogs.OptionNew("WS 订阅", collectlogs.BUSINESS_TYPE_OTHER)),
controller.NewWSController.WS,
)
wsGroup.GET("/test",
middleware.PreAuthorize(nil),
controller.NewWSController.Test,
)
}
}