From 2c1f8c75fa5e6965c1f031525c88baf30c69c501 Mon Sep 17 00:00:00 2001 From: TsMask <340112800@qq.com> Date: Tue, 6 Aug 2024 15:06:57 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20ws=20=E8=BF=9E=E6=8E=A5write=E9=87=8A?= =?UTF-8?q?=E6=94=BEgoroutune?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/modules/ws/service/ws.impl.go | 33 +++++++++++++++----------- src/modules/ws/service/ws_send.impl.go | 6 ++--- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/modules/ws/service/ws.impl.go b/src/modules/ws/service/ws.impl.go index 71c2615d..9a8aa5da 100644 --- a/src/modules/ws/service/ws.impl.go +++ b/src/modules/ws/service/ws.impl.go @@ -15,11 +15,11 @@ import ( var ( // ws客户端 [clientId: client] - WsClients sync.Map + wsClients sync.Map // ws用户对应的多个客户端id [uid:clientIds] - WsUsers sync.Map + wsUsers sync.Map // ws组对应的多个用户id [groupID:uids] - WsGroup sync.Map + wsGroup sync.Map ) // 实例化服务层 WSImpl 结构体 @@ -72,22 +72,22 @@ func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn, } // 存入客户端 - WsClients.Store(clientID, wsClient) + wsClients.Store(clientID, wsClient) // 存入用户持有客户端 if uid != "" { - if v, ok := WsUsers.Load(uid); ok { + if v, ok := wsUsers.Load(uid); ok { uidClientIds := v.(*[]string) *uidClientIds = append(*uidClientIds, clientID) } else { - WsUsers.Store(uid, &[]string{clientID}) + wsUsers.Store(uid, &[]string{clientID}) } } // 存入用户订阅组 if uid != "" && len(groupIDs) > 0 { for _, groupID := range groupIDs { - if v, ok := WsGroup.Load(groupID); ok { + if v, ok := wsGroup.Load(groupID); ok { groupUIDs := v.(*[]string) // 避免同组内相同用户 hasUid := false @@ -101,7 +101,7 @@ func (s *WSImpl) NewClient(uid string, groupIDs []string, conn *websocket.Conn, *groupUIDs = append(*groupUIDs, uid) } } else { - WsGroup.Store(groupID, &[]string{uid}) + wsGroup.Store(groupID, &[]string{uid}) } } } @@ -158,6 +158,11 @@ func (s *WSImpl) clientWrite(wsClient *model.WSClient) { } }() for msg := range wsClient.MsgChan { + // 关闭句柄 + if string(msg) == "ws:close" { + wsClient.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } // 发送消息 err := wsClient.Conn.WriteMessage(websocket.TextMessage, msg) if err != nil { @@ -171,22 +176,22 @@ func (s *WSImpl) clientWrite(wsClient *model.WSClient) { // CloseClient 客户端关闭 func (s *WSImpl) CloseClient(clientID string) { - v, ok := WsClients.Load(clientID) + v, ok := wsClients.Load(clientID) if !ok { return } client := v.(*model.WSClient) defer func() { - client.Conn.WriteMessage(websocket.CloseMessage, []byte{}) - client.Conn.Close() - WsClients.Delete(clientID) + client.MsgChan <- []byte("ws:close") client.StopChan <- struct{}{} + client.Conn.Close() + wsClients.Delete(clientID) }() // 客户端断线时自动踢出Uid绑定列表 if client.BindUid != "" { - if clientIds, ok := WsUsers.Load(client.BindUid); ok { + if clientIds, ok := wsUsers.Load(client.BindUid); ok { uidClientIds := clientIds.(*[]string) if len(*uidClientIds) > 0 { tempClientIds := make([]string, 0, len(*uidClientIds)) @@ -203,7 +208,7 @@ func (s *WSImpl) CloseClient(clientID string) { // 客户端断线时自动踢出已加入的组 if client.BindUid != "" && len(client.SubGroup) > 0 { for _, groupID := range client.SubGroup { - uids, ok := WsGroup.Load(groupID) + uids, ok := wsGroup.Load(groupID) if !ok { continue } diff --git a/src/modules/ws/service/ws_send.impl.go b/src/modules/ws/service/ws_send.impl.go index 2eea685b..da71fe18 100644 --- a/src/modules/ws/service/ws_send.impl.go +++ b/src/modules/ws/service/ws_send.impl.go @@ -34,7 +34,7 @@ type WSSendImpl struct{} // ByClientID 给已知客户端发消息 func (s *WSSendImpl) ByClientID(clientID string, data any) error { - v, ok := WsClients.Load(clientID) + v, ok := wsClients.Load(clientID) if !ok { return fmt.Errorf("no fount client ID: %s", clientID) } @@ -55,7 +55,7 @@ func (s *WSSendImpl) ByClientID(clientID string, data any) error { // ByGroupID 给订阅组的用户发送消息 func (s *WSSendImpl) ByGroupID(groupID string, data any) error { - uids, ok := WsGroup.Load(groupID) + uids, ok := wsGroup.Load(groupID) if !ok { return fmt.Errorf("no fount Group ID: %s", groupID) } @@ -68,7 +68,7 @@ func (s *WSSendImpl) ByGroupID(groupID string, data any) error { // 在群组中找到对应的 uid for _, uid := range *groupUids { - clientIds, ok := WsUsers.Load(uid) + clientIds, ok := wsUsers.Load(uid) if !ok { continue }