Files
be.ems/src/modules/ws/service/ws.impl.go
2024-01-25 10:40:00 +08:00

215 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"encoding/json"
"net/http"
"sync"
"time"
"ems.agt/src/framework/logger"
"ems.agt/src/framework/utils/generate"
"ems.agt/src/framework/vo/result"
"ems.agt/src/modules/ws/model"
"github.com/gorilla/websocket"
)
var (
// ws客户端 [clientId: client]
WsClients = sync.Map{}
// ws用户对应的多个客户端id [uid:clientIds]
WsUsers = sync.Map{}
// ws组对应的多个用户id [groupID:uids]
WsGroup = sync.Map{}
)
// 实例化服务层 WSImpl 结构体
var NewWSImpl = &WSImpl{}
// WSImpl WebSocket通信 服务层处理
type WSImpl struct{}
// UpgraderWs http升级ws请求
func (s *WSImpl) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn {
wsUpgrader := websocket.Upgrader{
Subprotocols: []string{"omc-ws"},
// 设置消息发送缓冲区大小byte如果这个值设置得太小可能会导致服务端在发送大型消息时遇到问题
WriteBufferSize: 1024,
// 消息包启用压缩
EnableCompression: true,
// ws握手超时时间
HandshakeTimeout: 5 * time.Second,
// ws握手过程中允许跨域
CheckOrigin: func(r *http.Request) bool {
return true
},
}
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
logger.Errorf("ws Upgrade err: %s", err.Error())
}
return conn
}
// NewClient 新建客户端 uid 登录用户ID
func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn) *model.WSClient {
// clientID也可以用其他方式生成只要能保证在所有服务端中都能保证唯一即可
clientID := generate.Code(16)
wsClient := &model.WSClient{
ID: clientID,
Conn: conn,
LastHeartbeat: time.Now().UnixMilli(),
BindUid: uid,
SubGroup: groupIDs,
MsgChan: make(chan []byte, 100),
StopChan: make(chan struct{}, 1), // 卡死循环标记
}
// 存入客户端
WsClients.Store(clientID, wsClient)
// 存入用户持有客户端
if uid != "" {
if v, ok := WsUsers.Load(uid); ok {
uidClientIds := v.(*[]string)
*uidClientIds = append(*uidClientIds, clientID)
} else {
WsUsers.Store(uid, &[]string{clientID})
}
}
// 存入用户订阅组
if uid != "" && len(groupIDs) > 0 {
for _, groupID := range groupIDs {
if v, ok := WsGroup.Load(groupID); ok {
groupUIDs := v.(*[]string)
*groupUIDs = append(*groupUIDs, uid)
} else {
WsGroup.Store(groupID, &[]string{uid})
}
}
}
go s.clientRead(wsClient)
go s.clientWrite(wsClient)
// 发客户端id确认是否连接
msgByte, _ := json.Marshal(result.OkData(map[string]string{
"clientId": clientID,
}))
wsClient.MsgChan <- msgByte
return wsClient
}
// clientRead 客户端读取消息
func (s *WSImpl) clientRead(wsClient *model.WSClient) {
for {
// 读取消息
messageType, msg, err := wsClient.Conn.ReadMessage()
if err != nil {
logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.CloseClient(wsClient.ID)
return
}
// 文本和二进制类型只处理文本json
if messageType == websocket.TextMessage {
var reqMsg model.WSRequest
err := json.Unmarshal(msg, &reqMsg)
// fmt.Println(messageType, string(msg))
if err != nil {
msgByte, _ := json.Marshal(result.ErrMsg("message format not supported"))
wsClient.MsgChan <- msgByte
} else {
err := NewWSReceiveImpl.Receive(wsClient, reqMsg)
if err != nil {
logger.Warnf("ws ReceiveMessage UID %s err: %s", wsClient.BindUid, err.Error())
msgByte, _ := json.Marshal(result.ErrMsg(err.Error()))
wsClient.MsgChan <- msgByte
}
}
}
}
}
// clientWrite 客户端写入消息
func (s *WSImpl) clientWrite(wsClient *model.WSClient) {
ticker := time.NewTicker(time.Second * 5) // 设置心跳间隔为 5 秒钟
defer ticker.Stop()
for {
select {
case <-ticker.C:
wsClient.LastHeartbeat = time.Now().UnixMilli()
// 发送 Ping 消息
err := wsClient.Conn.WriteMessage(websocket.PingMessage, []byte{})
if err != nil {
logger.Warnf("ws PingMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.CloseClient(wsClient.ID)
return
}
case msg := <-wsClient.MsgChan:
// 发送消息
err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error())
s.CloseClient(wsClient.ID)
return
}
}
}
}
// CloseClient 客户端关闭
func (s *WSImpl) CloseClient(clientID string) {
v, ok := WsClients.Load(clientID)
if !ok {
return
}
client := v.(*model.WSClient)
defer func() {
client.Conn.WriteMessage(websocket.CloseMessage, []byte{})
client.Conn.Close()
client.StopChan <- struct{}{}
WsClients.Delete(clientID)
}()
// 客户端断线时自动踢出Uid绑定列表
if client.BindUid != "" {
if clientIds, ok := WsUsers.Load(client.BindUid); ok {
uidClientIds := clientIds.(*[]string)
if len(*uidClientIds) > 0 {
tempClientIds := make([]string, 0, len(*uidClientIds))
for _, v := range *uidClientIds {
if v != client.ID {
tempClientIds = append(tempClientIds, v)
}
}
*uidClientIds = tempClientIds
}
}
}
// 客户端断线时自动踢出已加入的组
if client.BindUid != "" && len(client.SubGroup) > 0 {
for _, groupID := range client.SubGroup {
uids, ok := WsGroup.Load(groupID)
if !ok {
continue
}
groupUIDs := uids.(*[]string)
if len(*groupUIDs) > 0 {
tempUIDs := make([]string, 0, len(*groupUIDs))
for _, v := range *groupUIDs {
if v != client.BindUid {
tempUIDs = append(tempUIDs, v)
}
}
*groupUIDs = tempUIDs
}
}
}
}