style: 变更ws模块函数实例命名

This commit is contained in:
TsMask
2024-09-24 11:51:46 +08:00
parent 0287852470
commit 67caba4379
11 changed files with 558 additions and 618 deletions

View File

@@ -1,32 +1,220 @@
package service
import (
"encoding/json"
"net/http"
"sync"
"time"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/generate"
"be.ems/src/framework/vo/result"
"be.ems/src/modules/ws/model"
"github.com/gorilla/websocket"
)
// IWS WebSocket通信 服务层接口
type IWS interface {
// UpgraderWs http升级ws请求
UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn
var (
wsClients sync.Map // ws客户端 [clientId: client]
wsUsers sync.Map // ws用户对应的多个客户端id [uid:clientIds]
wsGroup sync.Map // ws组对应的多个客户端id [groupId:clientIds]
)
// ClientCreate 客户端新建
//
// uid 登录用户ID
// groupIDs 用户订阅组
// conn ws连接实例
// childConn 子连接实例
ClientCreate(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient
// NewWS 实例化服务层 WS 结构体
var NewWS = &WS{}
// ClientClose 客户端关闭
ClientClose(clientID string)
// WS WebSocket通信 服务层处理
type WS struct{}
// ClientReadListen 客户端读取消息监听
// receiveType 根据接收类型进行消息处理
ClientReadListen(wsClient *model.WSClient, receiveType int)
// ClientWriteListen 客户端写入消息监听
ClientWriteListen(wsClient *model.WSClient)
// UpgraderWs http升级ws请求
func (s *WS) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn {
wsUpgrader := websocket.Upgrader{
Subprotocols: []string{"omc-ws"},
// 设置消息发送缓冲区大小byte如果这个值设置得太小可能会导致服务端在发送大型消息时遇到问题
WriteBufferSize: 1024,
// 消息包启用压缩
EnableCompression: true,
// ws握手超时时间
HandshakeTimeout: 5 * time.Second,
// ws握手过程中允许跨域
CheckOrigin: func(r *http.Request) bool {
return true
},
}
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
logger.Errorf("ws Upgrade err: %s", err.Error())
}
return conn
}
// ClientCreate 客户端新建
//
// uid 登录用户ID
// groupIDs 用户订阅组
// conn ws连接实例
// childConn 子连接实例
func (s *WS) ClientCreate(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient {
// clientID也可以用其他方式生成只要能保证在所有服务端中都能保证唯一即可
clientID := generate.Code(16)
wsClient := &model.WSClient{
ID: clientID,
Conn: conn,
LastHeartbeat: time.Now().UnixMilli(),
BindUid: uid,
SubGroup: groupIDs,
MsgChan: make(chan []byte, 100),
StopChan: make(chan struct{}, 1), // 卡死循环标记
ChildConn: childConn,
}
// 存入客户端
wsClients.Store(clientID, wsClient)
// 存入用户持有客户端
if uid != "" {
if v, ok := wsUsers.Load(uid); ok {
uidClientIds := v.(*[]string)
*uidClientIds = append(*uidClientIds, clientID)
} else {
wsUsers.Store(uid, &[]string{clientID})
}
}
// 存入用户订阅组
if uid != "" && len(groupIDs) > 0 {
for _, groupID := range groupIDs {
if v, ok := wsGroup.Load(groupID); ok {
groupClientIds := v.(*[]string)
*groupClientIds = append(*groupClientIds, clientID)
} else {
wsGroup.Store(groupID, &[]string{clientID})
}
}
}
return wsClient
}
// ClientClose 客户端关闭
func (s *WS) ClientClose(clientID string) {
v, ok := wsClients.Load(clientID)
if !ok {
return
}
client := v.(*model.WSClient)
defer func() {
client.MsgChan <- []byte("ws:close")
client.StopChan <- struct{}{}
client.Conn.Close()
wsClients.Delete(clientID)
}()
// 客户端断线时自动踢出Uid绑定列表
if client.BindUid != "" {
if v, ok := wsUsers.Load(client.BindUid); ok {
uidClientIds := v.(*[]string)
if len(*uidClientIds) > 0 {
tempClientIds := make([]string, 0, len(*uidClientIds))
for _, v := range *uidClientIds {
if v != client.ID {
tempClientIds = append(tempClientIds, v)
}
}
*uidClientIds = tempClientIds
}
}
}
// 客户端断线时自动踢出已加入的组
if len(client.SubGroup) > 0 {
for _, groupID := range client.SubGroup {
v, ok := wsGroup.Load(groupID)
if !ok {
continue
}
groupClientIds := v.(*[]string)
if len(*groupClientIds) > 0 {
tempClientIds := make([]string, 0, len(*groupClientIds))
for _, v := range *groupClientIds {
if v != client.ID {
tempClientIds = append(tempClientIds, v)
}
}
*groupClientIds = tempClientIds
}
}
}
}
// ClientReadListen 客户端读取消息监听
// receiveType 根据接收类型进行消息处理
func (s *WS) ClientReadListen(wsClient *model.WSClient, receiveType int) {
defer func() {
if err := recover(); err != nil {
logger.Errorf("ws ReadMessage Panic Error: %v", err)
}
}()
for {
// 读取消息
messageType, msg, err := wsClient.Conn.ReadMessage()
if err != nil {
logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.ClientClose(wsClient.ID)
return
}
// fmt.Println(messageType, string(msg))
// 文本 只处理文本json
if messageType == websocket.TextMessage {
var reqMsg model.WSRequest
if err := json.Unmarshal(msg, &reqMsg); err != nil {
msgByte, _ := json.Marshal(result.ErrMsg("message format json error"))
wsClient.MsgChan <- msgByte
continue
}
// 接收器处理
switch receiveType {
case ReceiveCommont:
go NewWSReceive.Commont(wsClient, reqMsg)
case ReceiveShell:
go NewWSReceive.Shell(wsClient, reqMsg)
case ReceiveShellView:
go NewWSReceive.ShellView(wsClient, reqMsg)
case ReceiveTelnet:
go NewWSReceive.Telnet(wsClient, reqMsg)
}
}
}
}
// ClientWriteListen 客户端写入消息监听
func (s *WS) ClientWriteListen(wsClient *model.WSClient) {
defer func() {
if err := recover(); err != nil {
logger.Errorf("ws WriteMessage Panic Error: %v", err)
}
}()
// 发客户端id确认是否连接
msgByte, _ := json.Marshal(result.OkData(map[string]string{
"clientId": wsClient.ID,
}))
wsClient.MsgChan <- msgByte
// 消息发送监听
for msg := range wsClient.MsgChan {
// 关闭句柄
if string(msg) == "ws:close" {
wsClient.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
// 发送消息
err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.ClientClose(wsClient.ID)
return
}
wsClient.LastHeartbeat = time.Now().UnixMilli()
}
}