package service import ( "encoding/json" "net/http" "sync" "time" "ems.agt/src/framework/logger" "ems.agt/src/framework/utils/generate" "ems.agt/src/framework/vo/result" "ems.agt/src/modules/ws/model" "github.com/gorilla/websocket" ) var ( // ws客户端 [clientId: client] WsClients = sync.Map{} // ws用户对应的多个客户端id [uid:clientIds] WsUsers = sync.Map{} // ws组对应的多个用户id [groupID:uids] WsGroup = sync.Map{} ) // 实例化服务层 WSImpl 结构体 var NewWSImpl = &WSImpl{} // WSImpl WebSocket通信 服务层处理 type WSImpl struct{} // UpgraderWs http升级ws请求 func (s *WSImpl) UpgraderWs(w http.ResponseWriter, r *http.Request) *websocket.Conn { wsUpgrader := websocket.Upgrader{ Subprotocols: []string{"omc-ws"}, // 设置消息发送缓冲区大小(byte),如果这个值设置得太小,可能会导致服务端在发送大型消息时遇到问题 WriteBufferSize: 1024, // 消息包启用压缩 EnableCompression: true, // ws握手超时时间 HandshakeTimeout: 5 * time.Second, // ws握手过程中允许跨域 CheckOrigin: func(r *http.Request) bool { return true }, } conn, err := wsUpgrader.Upgrade(w, r, nil) if err != nil { logger.Errorf("ws Upgrade err: %s", err.Error()) } return conn } // NewClient 新建客户端 uid 登录用户ID func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn) *model.WSClient { // clientID也可以用其他方式生成,只要能保证在所有服务端中都能保证唯一即可 clientID := generate.Code(16) wsClient := &model.WSClient{ ID: clientID, Conn: conn, LastHeartbeat: time.Now().UnixMilli(), BindUid: uid, SubGroup: groupIDs, MsgChan: make(chan []byte, 100), StopChan: make(chan struct{}, 1), // 卡死循环标记 } // 存入客户端 WsClients.Store(clientID, wsClient) // 存入用户持有客户端 if uid != "" { if v, ok := WsUsers.Load(uid); ok { uidClientIds := v.(*[]string) *uidClientIds = append(*uidClientIds, clientID) } else { WsUsers.Store(uid, &[]string{clientID}) } } // 存入用户订阅组 if uid != "" && len(groupIDs) > 0 { for _, groupID := range groupIDs { if v, ok := WsGroup.Load(groupID); ok { groupUIDs := v.(*[]string) *groupUIDs = append(*groupUIDs, uid) } else { WsGroup.Store(groupID, &[]string{uid}) } } } go s.clientRead(wsClient) go s.clientWrite(wsClient) // 发客户端id确认是否连接 msgByte, _ := json.Marshal(result.OkData(map[string]string{ "clientId": clientID, })) wsClient.MsgChan <- msgByte return wsClient } // clientRead 客户端读取消息 func (s *WSImpl) clientRead(wsClient *model.WSClient) { for { // 读取消息 messageType, msg, err := wsClient.Conn.ReadMessage() if err != nil { logger.Warnf("ws ReadMessage UID %s err: %s", wsClient.BindUid, err.Error()) s.CloseClient(wsClient.ID) return } // 文本和二进制类型,只处理文本json if messageType == websocket.TextMessage { var reqMsg model.WSRequest err := json.Unmarshal(msg, &reqMsg) // fmt.Println(messageType, string(msg)) if err != nil { msgByte, _ := json.Marshal(result.ErrMsg("message format not supported")) wsClient.MsgChan <- msgByte } else { err := NewWSReceiveImpl.Receive(wsClient, reqMsg) if err != nil { logger.Warnf("ws ReceiveMessage UID %s err: %s", wsClient.BindUid, err.Error()) msgByte, _ := json.Marshal(result.ErrMsg(err.Error())) wsClient.MsgChan <- msgByte } } } } } // clientWrite 客户端写入消息 func (s *WSImpl) clientWrite(wsClient *model.WSClient) { ticker := time.NewTicker(time.Second * 5) // 设置心跳间隔为 5 秒钟 defer ticker.Stop() for { select { case <-ticker.C: wsClient.LastHeartbeat = time.Now().UnixMilli() // 发送 Ping 消息 err := wsClient.Conn.WriteMessage(websocket.PingMessage, []byte{}) if err != nil { logger.Warnf("ws PingMessage UID %s err: %s", wsClient.BindUid, err.Error()) s.CloseClient(wsClient.ID) return } case msg := <-wsClient.MsgChan: // 发送消息 err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg) if err != nil { logger.Warnf("ws WriteMessage UID %s err: %s", wsClient.BindUid, err.Error()) s.CloseClient(wsClient.ID) return } } } } // CloseClient 客户端关闭 func (s *WSImpl) CloseClient(clientID string) { v, ok := WsClients.Load(clientID) if !ok { return } client := v.(*model.WSClient) defer func() { client.Conn.WriteMessage(websocket.CloseMessage, []byte{}) client.Conn.Close() client.StopChan <- struct{}{} WsClients.Delete(clientID) }() // 客户端断线时自动踢出Uid绑定列表 if client.BindUid != "" { if clientIds, ok := WsUsers.Load(client.BindUid); ok { uidClientIds := clientIds.(*[]string) if len(*uidClientIds) > 0 { tempClientIds := make([]string, 0, len(*uidClientIds)) for _, v := range *uidClientIds { if v != client.ID { tempClientIds = append(tempClientIds, v) } } *uidClientIds = tempClientIds } } } // 客户端断线时自动踢出已加入的组 if client.BindUid != "" && len(client.SubGroup) > 0 { for _, groupID := range client.SubGroup { uids, ok := WsGroup.Load(groupID) if !ok { continue } groupUIDs := uids.(*[]string) if len(*groupUIDs) > 0 { tempUIDs := make([]string, 0, len(*groupUIDs)) for _, v := range *groupUIDs { if v != client.BindUid { tempUIDs = append(tempUIDs, v) } } *groupUIDs = tempUIDs } } } }