215 lines
5.5 KiB
Go
215 lines
5.5 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"ems.agt/src/framework/logger"
|
||
"ems.agt/src/framework/utils/generate"
|
||
"ems.agt/src/framework/vo/result"
|
||
"ems.agt/src/modules/ws/model"
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
var (
|
||
// ws客户端 [clientId: client]
|
||
WsClients = sync.Map{}
|
||
// ws用户对应的多个客户端id [uid:clientIds]
|
||
WsUsers = sync.Map{}
|
||
// ws组对应的多个用户id [groupID:uids]
|
||
WsGroup = sync.Map{}
|
||
)
|
||
|
||
// 实例化服务层 WSImpl 结构体
|
||
var NewWSImpl = &WSImpl{}
|
||
|
||
// WSImpl WebSocket通信 服务层处理
|
||
type WSImpl struct{}
|
||
|
||
// UpgraderWs http升级ws请求
|
||
func (s *WSImpl) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn {
|
||
wsUpgrader := websocket.Upgrader{
|
||
Subprotocols: []string{"omc-ws"},
|
||
// 设置消息发送缓冲区大小(byte),如果这个值设置得太小,可能会导致服务端在发送大型消息时遇到问题
|
||
WriteBufferSize: 1024,
|
||
// 消息包启用压缩
|
||
EnableCompression: true,
|
||
// ws握手超时时间
|
||
HandshakeTimeout: 5 * time.Second,
|
||
// ws握手过程中允许跨域
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true
|
||
},
|
||
}
|
||
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
logger.Errorf("ws Upgrade err: %s", err.Error())
|
||
}
|
||
return conn
|
||
}
|
||
|
||
// NewClient 新建客户端 uid 登录用户ID
|
||
func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn) *model.WSClient {
|
||
// clientID也可以用其他方式生成,只要能保证在所有服务端中都能保证唯一即可
|
||
clientID := generate.Code(16)
|
||
|
||
wsClient := &model.WSClient{
|
||
ID: clientID,
|
||
Conn: conn,
|
||
LastHeartbeat: time.Now().UnixMilli(),
|
||
BindUid: uid,
|
||
SubGroup: groupIDs,
|
||
MsgChan: make(chan []byte, 100),
|
||
StopChan: make(chan struct{}, 1), // 卡死循环标记
|
||
}
|
||
|
||
// 存入客户端
|
||
WsClients.Store(clientID, wsClient)
|
||
|
||
// 存入用户持有客户端
|
||
if uid != "" {
|
||
if v, ok := WsUsers.Load(uid); ok {
|
||
uidClientIds := v.(*[]string)
|
||
*uidClientIds = append(*uidClientIds, clientID)
|
||
} else {
|
||
WsUsers.Store(uid, &[]string{clientID})
|
||
}
|
||
}
|
||
|
||
// 存入用户订阅组
|
||
if uid != "" && len(groupIDs) > 0 {
|
||
for _, groupID := range groupIDs {
|
||
if v, ok := WsGroup.Load(groupID); ok {
|
||
groupUIDs := v.(*[]string)
|
||
*groupUIDs = append(*groupUIDs, uid)
|
||
} else {
|
||
WsGroup.Store(groupID, &[]string{uid})
|
||
}
|
||
}
|
||
}
|
||
|
||
go s.clientRead(wsClient)
|
||
go s.clientWrite(wsClient)
|
||
|
||
// 发客户端id确认是否连接
|
||
msgByte, _ := json.Marshal(result.OkData(map[string]string{
|
||
"clientId": clientID,
|
||
}))
|
||
wsClient.MsgChan <- msgByte
|
||
|
||
return wsClient
|
||
}
|
||
|
||
// clientRead 客户端读取消息
|
||
func (s *WSImpl) clientRead(wsClient *model.WSClient) {
|
||
for {
|
||
// 读取消息
|
||
messageType, msg, err := wsClient.Conn.ReadMessage()
|
||
if err != nil {
|
||
logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error())
|
||
s.CloseClient(wsClient.ID)
|
||
return
|
||
}
|
||
|
||
// 文本和二进制类型,只处理文本json
|
||
if messageType == websocket.TextMessage {
|
||
var reqMsg model.WSRequest
|
||
err := json.Unmarshal(msg, &reqMsg)
|
||
// fmt.Println(messageType, string(msg))
|
||
if err != nil {
|
||
msgByte, _ := json.Marshal(result.ErrMsg("message format not supported"))
|
||
wsClient.MsgChan <- msgByte
|
||
} else {
|
||
err := NewWSReceiveImpl.Receive(wsClient, reqMsg)
|
||
if err != nil {
|
||
logger.Warnf("ws ReceiveMessage UID %s err: %s", wsClient.BindUid, err.Error())
|
||
msgByte, _ := json.Marshal(result.ErrMsg(err.Error()))
|
||
wsClient.MsgChan <- msgByte
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// clientWrite 客户端写入消息
|
||
func (s *WSImpl) clientWrite(wsClient *model.WSClient) {
|
||
ticker := time.NewTicker(time.Second * 5) // 设置心跳间隔为 5 秒钟
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
wsClient.LastHeartbeat = time.Now().UnixMilli()
|
||
// 发送 Ping 消息
|
||
err := wsClient.Conn.WriteMessage(websocket.PingMessage, []byte{})
|
||
if err != nil {
|
||
logger.Warnf("ws PingMessage UID %s err: %s", wsClient.BindUid, err.Error())
|
||
s.CloseClient(wsClient.ID)
|
||
return
|
||
}
|
||
case msg := <-wsClient.MsgChan:
|
||
// 发送消息
|
||
err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg)
|
||
if err != nil {
|
||
logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error())
|
||
s.CloseClient(wsClient.ID)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// CloseClient 客户端关闭
|
||
func (s *WSImpl) CloseClient(clientID string) {
|
||
v, ok := WsClients.Load(clientID)
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
client := v.(*model.WSClient)
|
||
defer func() {
|
||
client.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||
client.Conn.Close()
|
||
client.StopChan <- struct{}{}
|
||
WsClients.Delete(clientID)
|
||
}()
|
||
|
||
// 客户端断线时自动踢出Uid绑定列表
|
||
if client.BindUid != "" {
|
||
if clientIds, ok := WsUsers.Load(client.BindUid); ok {
|
||
uidClientIds := clientIds.(*[]string)
|
||
if len(*uidClientIds) > 0 {
|
||
tempClientIds := make([]string, 0, len(*uidClientIds))
|
||
for _, v := range *uidClientIds {
|
||
if v != client.ID {
|
||
tempClientIds = append(tempClientIds, v)
|
||
}
|
||
}
|
||
*uidClientIds = tempClientIds
|
||
}
|
||
}
|
||
}
|
||
|
||
// 客户端断线时自动踢出已加入的组
|
||
if client.BindUid != "" && len(client.SubGroup) > 0 {
|
||
for _, groupID := range client.SubGroup {
|
||
uids, ok := WsGroup.Load(groupID)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
groupUIDs := uids.(*[]string)
|
||
if len(*groupUIDs) > 0 {
|
||
tempUIDs := make([]string, 0, len(*groupUIDs))
|
||
for _, v := range *groupUIDs {
|
||
if v != client.BindUid {
|
||
tempUIDs = append(tempUIDs, v)
|
||
}
|
||
}
|
||
*groupUIDs = tempUIDs
|
||
}
|
||
}
|
||
}
|
||
}
|