220 lines
5.6 KiB
Go
220 lines
5.6 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"be.ems/src/framework/logger"
|
||
"be.ems/src/framework/utils/generate"
|
||
"be.ems/src/framework/vo/result"
|
||
"be.ems/src/modules/ws/model"
|
||
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
var (
|
||
wsClients sync.Map // ws客户端 [clientId: client]
|
||
wsUsers sync.Map // ws用户对应的多个客户端id [uid:clientIds]
|
||
wsGroup sync.Map // ws组对应的多个客户端id [groupId:clientIds]
|
||
)
|
||
|
||
// NewWS 实例化服务层 WS 结构体
|
||
var NewWS = &WS{}
|
||
|
||
// WS WebSocket通信 服务层处理
|
||
type WS struct{}
|
||
|
||
// UpgraderWs http升级ws请求
|
||
func (s *WS) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn {
|
||
wsUpgrader := websocket.Upgrader{
|
||
Subprotocols: []string{"omc-ws"},
|
||
// 设置消息发送缓冲区大小(byte),如果这个值设置得太小,可能会导致服务端在发送大型消息时遇到问题
|
||
WriteBufferSize: 4096,
|
||
// 消息包启用压缩
|
||
EnableCompression: true,
|
||
// ws握手超时时间
|
||
HandshakeTimeout: 5 * time.Second,
|
||
// ws握手过程中允许跨域
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true
|
||
},
|
||
}
|
||
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
logger.Errorf("ws Upgrade err: %s", err.Error())
|
||
}
|
||
return conn
|
||
}
|
||
|
||
// ClientCreate 客户端新建
|
||
//
|
||
// uid 登录用户ID
|
||
// groupIDs 用户订阅组
|
||
// conn ws连接实例
|
||
// childConn 子连接实例
|
||
func (s *WS) ClientCreate(uid string, groupIDs []string, conn *websocket.Conn, childConn any) *model.WSClient {
|
||
// clientID也可以用其他方式生成,只要能保证在所有服务端中都能保证唯一即可
|
||
clientID := generate.Code(16)
|
||
|
||
wsClient := &model.WSClient{
|
||
ID: clientID,
|
||
Conn: conn,
|
||
LastHeartbeat: time.Now().UnixMilli(),
|
||
BindUid: uid,
|
||
SubGroup: groupIDs,
|
||
MsgChan: make(chan []byte, 100),
|
||
StopChan: make(chan struct{}, 1), // 卡死循环标记
|
||
ChildConn: childConn,
|
||
}
|
||
|
||
// 存入客户端
|
||
wsClients.Store(clientID, wsClient)
|
||
|
||
// 存入用户持有客户端
|
||
if uid != "" {
|
||
if v, ok := wsUsers.Load(uid); ok {
|
||
uidClientIds := v.(*[]string)
|
||
*uidClientIds = append(*uidClientIds, clientID)
|
||
} else {
|
||
wsUsers.Store(uid, &[]string{clientID})
|
||
}
|
||
}
|
||
|
||
// 存入用户订阅组
|
||
if uid != "" && len(groupIDs) > 0 {
|
||
for _, groupID := range groupIDs {
|
||
if v, ok := wsGroup.Load(groupID); ok {
|
||
groupClientIds := v.(*[]string)
|
||
*groupClientIds = append(*groupClientIds, clientID)
|
||
} else {
|
||
wsGroup.Store(groupID, &[]string{clientID})
|
||
}
|
||
}
|
||
}
|
||
|
||
return wsClient
|
||
}
|
||
|
||
// ClientClose 客户端关闭
|
||
func (s *WS) ClientClose(clientID string) {
|
||
v, ok := wsClients.Load(clientID)
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
client := v.(*model.WSClient)
|
||
defer func() {
|
||
client.MsgChan <- []byte("ws:close")
|
||
client.StopChan <- struct{}{}
|
||
client.Conn.Close()
|
||
wsClients.Delete(clientID)
|
||
}()
|
||
|
||
// 客户端断线时自动踢出Uid绑定列表
|
||
if client.BindUid != "" {
|
||
if v, ok := wsUsers.Load(client.BindUid); ok {
|
||
uidClientIds := v.(*[]string)
|
||
if len(*uidClientIds) > 0 {
|
||
tempClientIds := make([]string, 0, len(*uidClientIds))
|
||
for _, v := range *uidClientIds {
|
||
if v != client.ID {
|
||
tempClientIds = append(tempClientIds, v)
|
||
}
|
||
}
|
||
*uidClientIds = tempClientIds
|
||
}
|
||
}
|
||
}
|
||
|
||
// 客户端断线时自动踢出已加入的组
|
||
if len(client.SubGroup) > 0 {
|
||
for _, groupID := range client.SubGroup {
|
||
v, ok := wsGroup.Load(groupID)
|
||
if !ok {
|
||
continue
|
||
}
|
||
groupClientIds := v.(*[]string)
|
||
if len(*groupClientIds) > 0 {
|
||
tempClientIds := make([]string, 0, len(*groupClientIds))
|
||
for _, v := range *groupClientIds {
|
||
if v != client.ID {
|
||
tempClientIds = append(tempClientIds, v)
|
||
}
|
||
}
|
||
*groupClientIds = tempClientIds
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ClientReadListen 客户端读取消息监听
|
||
// receiveFn 接收函数进行消息处理
|
||
func (s *WS) ClientReadListen(wsClient *model.WSClient, receiveFn func(*model.WSClient, model.WSRequest)) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
logger.Errorf("ws ReadMessage Panic Error: %v", err)
|
||
}
|
||
}()
|
||
for {
|
||
// 读取消息
|
||
messageType, msg, err := wsClient.Conn.ReadMessage()
|
||
if err != nil {
|
||
logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error())
|
||
s.ClientClose(wsClient.ID)
|
||
return
|
||
}
|
||
// fmt.Println(messageType, string(msg))
|
||
|
||
// 文本 只处理文本json
|
||
if messageType == websocket.TextMessage {
|
||
var reqMsg model.WSRequest
|
||
if err := json.Unmarshal(msg, &reqMsg); err != nil {
|
||
msgByte, _ := json.Marshal(result.ErrMsg("message format json error"))
|
||
wsClient.MsgChan <- msgByte
|
||
continue
|
||
}
|
||
// 接收器处理
|
||
go receiveFn(wsClient, reqMsg)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ClientWriteListen 客户端写入消息监听
|
||
// wsClient.MsgChan <- msgByte 写入消息
|
||
func (s *WS) ClientWriteListen(wsClient *model.WSClient) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
logger.Errorf("ws WriteMessage Panic Error: %v", err)
|
||
}
|
||
}()
|
||
// 发客户端id确认是否连接
|
||
msgByte, _ := json.Marshal(result.OkData(map[string]string{
|
||
"clientId": wsClient.ID,
|
||
}))
|
||
wsClient.MsgChan <- msgByte
|
||
// 消息发送监听
|
||
for msg := range wsClient.MsgChan {
|
||
// PONG句柄
|
||
if string(msg) == "ws:pong" {
|
||
wsClient.LastHeartbeat = time.Now().UnixMilli()
|
||
wsClient.Conn.WriteMessage(websocket.PongMessage, []byte{})
|
||
continue
|
||
}
|
||
// 关闭句柄
|
||
if string(msg) == "ws:close" {
|
||
wsClient.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||
return
|
||
}
|
||
|
||
// 发送消息
|
||
err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg)
|
||
if err != nil {
|
||
logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error())
|
||
s.ClientClose(wsClient.ID)
|
||
return
|
||
}
|
||
}
|
||
}
|