diff --git a/src/modules/trace/controller/trace_task.go b/src/modules/trace/controller/trace_task.go new file mode 100644 index 00000000..7598c567 --- /dev/null +++ b/src/modules/trace/controller/trace_task.go @@ -0,0 +1,131 @@ +package controller + +import ( + "strings" + + "be.ems/src/framework/i18n" + "be.ems/src/framework/utils/ctx" + "be.ems/src/framework/utils/parse" + "be.ems/src/framework/vo/result" + "be.ems/src/modules/trace/model" + traceService "be.ems/src/modules/trace/service" + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" +) + +// 实例化控制层 TraceTaskController 结构体 +var NewTraceTask = &TraceTaskController{ + traceTaskService: traceService.NewTraceTask, +} + +// 跟踪任务 +// +// PATH /task +type TraceTaskController struct { + // 跟踪_任务信息服务 + traceTaskService *traceService.TraceTask +} + +// 跟踪任务列表 +// +// GET /list +func (s *TraceTaskController) List(c *gin.Context) { + query := ctx.QueryMap(c) + + // 查询数据 + data := s.traceTaskService.SelectPage(query) + c.JSON(200, result.Ok(data)) +} + +// 跟踪任务信息 +// +// GET /:id +func (s *TraceTaskController) Info(c *gin.Context) { + language := ctx.AcceptLanguage(c) + id := c.Param("id") + if id == "" { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + + data := s.traceTaskService.SelectById(id) + if data.ID == id { + c.JSON(200, result.OkData(data)) + return + } + c.JSON(200, result.Err(nil)) +} + +// 跟踪任务新增 +// +// POST / +func (s *TraceTaskController) Add(c *gin.Context) { + language := ctx.AcceptLanguage(c) + var body model.TraceTask + err := c.ShouldBindBodyWith(&body, binding.JSON) + if err != nil || body.ID != "" { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + + body.CreateBy = ctx.LoginUserToUserName(c) + if err = s.traceTaskService.Insert(body); err != nil { + c.JSON(200, result.ErrMsg(i18n.TKey(language, err.Error()))) + return + } + c.JSON(200, result.Ok(nil)) +} + +// 跟踪任务修改 +// +// PUT / +func (s *TraceTaskController) Edit(c *gin.Context) { + language := ctx.AcceptLanguage(c) + var body model.TraceTask + err := c.ShouldBindBodyWith(&body, binding.JSON) + if err != nil || body.ID == "" { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + + // 检查是否存在 + taskInfo := s.traceTaskService.SelectById(body.ID) + if taskInfo.ID != body.ID { + // 没有可访问任务信息数据! + c.JSON(200, result.ErrMsg(i18n.TKey(language, "task.noData"))) + return + } + + body.UpdateBy = ctx.LoginUserToUserName(c) + if err = s.traceTaskService.Update(body); err != nil { + c.JSON(200, result.ErrMsg(i18n.TKey(language, err.Error()))) + return + } + c.JSON(200, result.Ok(nil)) +} + +// 跟踪任务删除 +// +// DELETE /:ids +func (s *TraceTaskController) Remove(c *gin.Context) { + language := ctx.AcceptLanguage(c) + rowIds := c.Param("ids") + if rowIds == "" { + c.JSON(400, result.CodeMsg(400, i18n.TKey(language, "app.common.err400"))) + return + } + // 处理字符转id数组后去重 + ids := strings.Split(rowIds, ",") + uniqueIDs := parse.RemoveDuplicates(ids) + if len(uniqueIDs) <= 0 { + c.JSON(200, result.Err(nil)) + return + } + rows, err := s.traceTaskService.DeleteByIds(uniqueIDs) + if err != nil { + c.JSON(200, result.ErrMsg(i18n.TKey(language, err.Error()))) + return + } + msg := i18n.TTemplate(language, "app.common.deleteSuccess", map[string]any{"num": rows}) + c.JSON(200, result.OkMsg(msg)) +} diff --git a/src/modules/trace/model/trace_task.go b/src/modules/trace/model/trace_task.go new file mode 100644 index 00000000..752d4f70 --- /dev/null +++ b/src/modules/trace/model/trace_task.go @@ -0,0 +1,31 @@ +package model + +// TraceTask 跟踪_任务 +type TraceTask struct { + ID string `json:"id" gorm:"column:id;primaryKey;autoIncrement"` // 跟踪任务ID + TraceId string `json:"traceId" gorm:"trace_id"` // 任务编号 + TraceType string `json:"traceType" gorm:"trace_type"` // 1-Interface,2-Device,3-User + StartTime int64 `json:"startTime" gorm:"start_time"` // 开始时间 毫秒 + EndTime int64 `json:"endTime" gorm:"end_time"` // 结束时间 毫秒 + Interfaces string `json:"interfaces" gorm:"interfaces"` // 接口跟踪必须 例如 N8,N10 + IMSI string `json:"imsi" gorm:"imsi"` // 用户跟踪必须 + MSISDN string `json:"msisdn" gorm:"msisdn"` // 用户跟踪可选 + UeIp string `json:"ueIp" gorm:"ue_ip"` // 设备跟踪必须 IP + SrcIp string `json:"srcIp" gorm:"src_ip"` // 源地址IP + DstIp string `json:"dstIp" gorm:"dst_ip"` // 目标地址IP + SignalPort int64 `json:"signalPort" gorm:"signal_port"` // 地址IP端口 + CreateBy string `json:"createBy" gorm:"create_by"` // 创建者 + CreateTime int64 `json:"createTime" gorm:"create_time"` // 创建时间 + UpdateBy string `json:"updateBy" gorm:"update_by"` // 更新者 + UpdateTime int64 `json:"updateTime" gorm:"update_time"` // 更新时间 + Remark string `json:"remark" gorm:"remark"` // 备注 + NeType string `json:"neType" gorm:"ne_type"` // 网元类型 + NeId string `json:"neId" gorm:"ne_id"` // 网元ID + NotifyUrl string `json:"notifyUrl" gorm:"notify_url"` // 信息数据通知回调地址UDP 例如udp:192.168.5.58:29500 + FetchMsg string `json:"fetchMsg" gorm:"fetch_msg"` // 任务下发请求响应消息 +} + +// TableName 表名称 +func (*TraceTask) TableName() string { + return "trace_task" +} diff --git a/src/modules/trace/repository/trace_task.go b/src/modules/trace/repository/trace_task.go new file mode 100644 index 00000000..39f3aa30 --- /dev/null +++ b/src/modules/trace/repository/trace_task.go @@ -0,0 +1,358 @@ +package repository + +import ( + "fmt" + "strings" + "time" + + "be.ems/src/framework/datasource" + "be.ems/src/framework/logger" + "be.ems/src/framework/utils/parse" + "be.ems/src/framework/utils/repo" + "be.ems/src/modules/trace/model" +) + +// 实例化数据层 TraceTask 结构体 +var NewTraceTask = &TraceTask{ + selectSql: `select id, trace_id, trace_type, start_time, end_time, + interfaces, imsi, msisdn, + ue_ip, src_ip, dst_ip, signal_port, + create_by, create_time, update_by, update_time, remark, + ne_type, ne_id, notify_url, fetch_msg + from trace_task`, + + resultMap: map[string]string{ + "id": "ID", + "trace_id": "TraceId", + "trace_type": "TraceType", + "start_time": "StartTime", + "end_time": "EndTime", + "interfaces": "Interfaces", + "imsi": "IMSI", + "msisdn": "MSISDN", + "ue_ip": "UeIp", + "src_ip": "SrcIp", + "dst_ip": "DstIp", + "signal_port": "SignalPort", + "create_by": "CreateBy", + "create_time": "CreateTime", + "update_by": "UpdateBy", + "update_time": "UpdateTime", + "remark": "Remark", + "ne_type": "NeType", + "ne_id": "NeId", + "notify_url": "NotifyUrl", + "fetch_msg": "FetchMsg", + }, +} + +// TraceTask 跟踪_任务 数据层处理 +type TraceTask struct { + // 查询视图对象SQL + selectSql string + // 结果字段与实体映射 + resultMap map[string]string +} + +// convertResultRows 将结果记录转实体结果组 +func (r *TraceTask) convertResultRows(rows []map[string]any) []model.TraceTask { + arr := make([]model.TraceTask, 0) + for _, row := range rows { + item := model.TraceTask{} + for key, value := range row { + if keyMapper, ok := r.resultMap[key]; ok { + repo.SetFieldValue(&item, keyMapper, value) + } + } + arr = append(arr, item) + } + return arr +} + +// SelectPage 根据条件分页查询 +func (r *TraceTask) SelectPage(query map[string]any) map[string]any { + // 查询条件拼接 + var conditions []string + var params []any + if v, ok := query["neType"]; ok && v != "" { + conditions = append(conditions, "ne_type = ?") + params = append(params, v) + } + if v, ok := query["imsi"]; ok && v != "" { + conditions = append(conditions, "imsi like concat(?, '%')") + params = append(params, v) + } + if v, ok := query["msisdn"]; ok && v != "" { + conditions = append(conditions, "msisdn like concat(?, '%')") + params = append(params, v) + } + if v, ok := query["startTime"]; ok && v != "" { + conditions = append(conditions, "start_time >= ?") + params = append(params, v) + } + if v, ok := query["endTime"]; ok && v != "" { + conditions = append(conditions, "end_time <= ?") + params = append(params, v) + } + + // 构建查询条件语句 + whereSql := "" + if len(conditions) > 0 { + whereSql += " where " + strings.Join(conditions, " and ") + } + + result := map[string]any{ + "total": 0, + "rows": []model.TraceTask{}, + } + + // 查询数量 长度为0直接返回 + totalSql := "select count(1) as 'total' from trace_task" + totalRows, err := datasource.RawDB("", totalSql+whereSql, params) + if err != nil { + logger.Errorf("total err => %v", err) + return result + } + total := parse.Number(totalRows[0]["total"]) + if total == 0 { + return result + } else { + result["total"] = total + } + + // 分页 + pageNum, pageSize := repo.PageNumSize(query["pageNum"], query["pageSize"]) + pageSql := " limit ?,? " + params = append(params, pageNum*pageSize) + params = append(params, pageSize) + + // 排序 + orderSql := "" + if v, ok := query["sortField"]; ok && v != "" { + sortSql := v.(string) + if v, ok := query["sortOrder"]; ok && v != "" { + if v.(string) == "desc" { + sortSql += " desc " + } else { + sortSql += " asc " + } + } + orderSql = fmt.Sprintf(" order by %s ", sortSql) + } + + // 查询数据 + querySql := r.selectSql + whereSql + orderSql + pageSql + results, err := datasource.RawDB("", querySql, params) + if err != nil { + logger.Errorf("query err => %v", err) + } + + // 转换实体 + result["rows"] = r.convertResultRows(results) + return result +} + +// SelectList 根据实体查询 +func (r *TraceTask) SelectList(task model.TraceTask) []model.TraceTask { + // 查询条件拼接 + var conditions []string + var params []any + if task.IMSI != "" { + conditions = append(conditions, "imsi = ?") + params = append(params, task.IMSI) + } + if task.SrcIp != "" { + conditions = append(conditions, "src_ip = ?") + params = append(params, task.SrcIp) + } + if task.DstIp != "" { + conditions = append(conditions, "dst_ip = ?") + params = append(params, task.DstIp) + } + + // 构建查询条件语句 + whereSql := "" + if len(conditions) > 0 { + whereSql += " where " + strings.Join(conditions, " and ") + } + + // 查询数据 + querySql := r.selectSql + whereSql + " order by id desc " + results, err := datasource.RawDB("", querySql, params) + if err != nil { + logger.Errorf("query err => %v", err) + } + + // 转换实体 + return r.convertResultRows(results) +} + +// SelectByIds 通过ID查询 +func (r *TraceTask) SelectByIds(ids []string) []model.TraceTask { + placeholder := repo.KeyPlaceholderByQuery(len(ids)) + querySql := r.selectSql + " where id in (" + placeholder + ")" + parameters := repo.ConvertIdsSlice(ids) + results, err := datasource.RawDB("", querySql, parameters) + if err != nil { + logger.Errorf("query err => %v", err) + return []model.TraceTask{} + } + // 转换实体 + return r.convertResultRows(results) +} + +// Insert 新增信息 +func (r *TraceTask) Insert(task model.TraceTask) string { + // 参数拼接 + params := make(map[string]any) + if task.TraceId != "" { + params["trace_id"] = task.TraceId + } + if task.TraceType != "" { + params["trace_type"] = task.TraceType + } + if task.StartTime > 0 { + params["start_time"] = task.StartTime + } + if task.EndTime > 0 { + params["end_time"] = task.EndTime + } + if task.Interfaces != "" { + params["interfaces"] = task.Interfaces + } + if task.IMSI != "" { + params["imsi"] = task.IMSI + } + if task.MSISDN != "" { + params["msisdn"] = task.MSISDN + } + if task.UeIp != "" { + params["ue_ip"] = task.UeIp + } + if task.SrcIp != "" { + params["src_ip"] = task.SrcIp + } + if task.DstIp != "" { + params["dst_ip"] = task.DstIp + } + if task.SignalPort != 0 { + params["signal_port"] = task.SignalPort + } + if task.NeType != "" { + params["ne_type"] = task.NeType + } + if task.NeId != "" { + params["ne_id"] = task.NeId + } + if task.NotifyUrl != "" { + params["notify_url"] = task.NotifyUrl + } + if task.FetchMsg != "" { + params["fetch_msg"] = task.FetchMsg + } + if task.Remark != "" { + params["remark"] = task.Remark + } + if task.CreateBy != "" { + params["create_by"] = task.CreateBy + params["create_time"] = time.Now().UnixMilli() + } + + // 构建执行语句 + keys, placeholder, values := repo.KeyPlaceholderValueByInsert(params) + sql := "insert into trace_task (" + strings.Join(keys, ",") + ")values(" + placeholder + ")" + + db := datasource.DefaultDB() + // 开启事务 + tx := db.Begin() + // 执行插入 + err := tx.Exec(sql, values...).Error + if err != nil { + logger.Errorf("insert row : %v", err.Error()) + tx.Rollback() + return "" + } + // 获取生成的自增 ID + var insertedID string + err = tx.Raw("select last_insert_id()").Row().Scan(&insertedID) + if err != nil { + logger.Errorf("insert last id : %v", err.Error()) + tx.Rollback() + return "" + } + // 提交事务 + tx.Commit() + return insertedID +} + +// Update 修改信息 +func (r *TraceTask) Update(task model.TraceTask) int64 { + // 参数拼接 + params := make(map[string]any) + params["trace_id"] = task.TraceId + params["trace_type"] = task.TraceType + params["ne_type"] = task.NeType + params["ne_id"] = task.NeId + params["notify_url"] = task.NotifyUrl + + params["start_time"] = task.StartTime + params["end_time"] = task.EndTime + params["fetch_msg"] = task.FetchMsg + params["remark"] = task.Remark + + params["interfaces"] = task.Interfaces + + params["imsi"] = task.IMSI + params["msisdn"] = task.MSISDN + + params["ue_ip"] = task.UeIp + params["src_ip"] = task.SrcIp + params["dst_ip"] = task.DstIp + params["signal_port"] = task.SignalPort + + if task.UpdateBy != "" { + params["update_by"] = task.UpdateBy + params["update_time"] = time.Now().UnixMilli() + } + + // 构建执行语句 + keys, values := repo.KeyValueByUpdate(params) + sql := "update trace_task set " + strings.Join(keys, ",") + " where id = ?" + + // 执行更新 + values = append(values, task.ID) + rows, err := datasource.ExecDB("", sql, values) + if err != nil { + logger.Errorf("update row : %v", err.Error()) + return 0 + } + return rows +} + +// DeleteByIds 批量删除信息 +func (r *TraceTask) DeleteByIds(ids []string) int64 { + placeholder := repo.KeyPlaceholderByQuery(len(ids)) + sql := "delete from trace_task where id in (" + placeholder + ")" + parameters := repo.ConvertIdsSlice(ids) + results, err := datasource.ExecDB("", sql, parameters) + if err != nil { + logger.Errorf("delete err => %v", err) + return 0 + } + return results +} + +// LastID 最后一条ID +func (r *TraceTask) LastID() int64 { + // 查询数据 + querySql := "SELECT id as 'str' FROM trace_task ORDER BY id DESC LIMIT 1" + results, err := datasource.RawDB("", querySql, nil) + if err != nil { + logger.Errorf("query err %v", err) + return 0 + } + if len(results) > 0 { + return parse.Number(results[0]["str"]) + } + return 0 +} diff --git a/src/modules/trace/service/trace_task.go b/src/modules/trace/service/trace_task.go new file mode 100644 index 00000000..e9891b28 --- /dev/null +++ b/src/modules/trace/service/trace_task.go @@ -0,0 +1,294 @@ +package service + +import ( + "encoding/json" + "fmt" + "strings" + + "be.ems/src/framework/config" + "be.ems/src/framework/logger" + "be.ems/src/framework/socket" + "be.ems/src/framework/utils/date" + "be.ems/src/framework/utils/parse" + neFetchlink "be.ems/src/modules/network_element/fetch_link" + neService "be.ems/src/modules/network_element/service" + "be.ems/src/modules/trace/model" + "be.ems/src/modules/trace/repository" +) + +// 实例化数据层 TraceTask 结构体 +var NewTraceTask = &TraceTask{ + udpService: socket.SocketUDP{}, + traceTaskRepository: repository.NewTraceTask, + traceDataRepository: repository.NewTraceData, +} + +// TraceTask 跟踪任务 服务层处理 +type TraceTask struct { + // UDP服务对象 + udpService socket.SocketUDP + // 跟踪_任务数据信息 + traceTaskRepository *repository.TraceTask + // 跟踪_数据信息 + traceDataRepository *repository.TraceData +} + +// CreateUDP 创建UDP数据通道 +func (r *TraceTask) CreateUDP() error { + // 跟踪配置是否开启 + if v := config.Get("trace.enabled"); v != nil { + if !v.(bool) { + return nil + } + } + host := "127.0.0.1" + if v := config.Get("trace.host"); v != nil { + host = v.(string) + } + var port int64 = 33033 + if v := config.Get("trace.port"); v != nil { + port = parse.Number(v) + } + + // 初始化UDP服务 + r.udpService = socket.SocketUDP{Addr: host, Port: port} + if _, err := r.udpService.New(); err != nil { + return err + } + + // 接收处理UDP数据 + go r.udpService.Resolve(2048, func(data []byte, n int) { + logger.Infof("socket UDP: %s", string(data)) + mData, err := UDPDataHandler(data, n) + if err != nil { + logger.Errorf("udp resolve data fail: %s", err.Error()) + return + } + // 插入数据库做记录 + r.traceDataRepository.Insert(model.TraceData{ + TaskId: parse.Number(mData["taskId"]), + IMSI: mData["imsi"].(string), + SrcAddr: mData["srcAddr"].(string), + DstAddr: mData["dstAddr"].(string), + IfType: parse.Number(mData["ifType"]), + MsgType: parse.Number(mData["msgType"]), + MsgDirect: parse.Number(mData["msgDirect"]), + Length: parse.Number(mData["dataLen"]), + RawMsg: mData["dataInfo"].(string), + Timestamp: parse.Number(mData["timestamp"]), + DecMsg: mData["decMsg"].(string), + }) + + // 推送文件 + if v, ok := mData["pcapFile"]; ok && v != "" { + logger.Infof("pcapFile: %s", v) + } + }) + + // ============ 测试接收网元UDP发过来的数据 + // 初始化TCP服务 后续调整TODO + tcpService := socket.SocketTCP{Addr: host, Port: port + 1} + if _, err := tcpService.New(); err != nil { + return err + } + // 接收处理TCP数据 + go tcpService.Resolve(1024, func(data []byte, n int) { + logger.Infof("socket TCP: %s", string(data)) + mData, err := UDPDataHandler(data, n) + if err != nil { + logger.Errorf("tcp resolve data fail: %s", err.Error()) + return + } + // 插入数据库做记录 + r.traceDataRepository.Insert(model.TraceData{ + TaskId: parse.Number(mData["taskId"]), + IMSI: mData["imsi"].(string), + SrcAddr: mData["srcAddr"].(string), + DstAddr: mData["dstAddr"].(string), + IfType: parse.Number(mData["ifType"]), + MsgType: parse.Number(mData["msgType"]), + MsgDirect: parse.Number(mData["msgDirect"]), + Length: parse.Number(mData["dataLen"]), + RawMsg: mData["dataInfo"].(string), + Timestamp: parse.Number(mData["timestamp"]), + DecMsg: mData["decMsg"].(string), + }) + + // 推送文件 + if v, ok := mData["pcapFile"]; ok && v != "" { + logger.Infof("pcapFile: %s", v) + } + }) + return nil +} + +// CloseUDP 关闭UDP数据通道 +func (r *TraceTask) CloseUDP() { + r.udpService.Close() +} + +// SelectPage 根据条件分页查询 +func (r *TraceTask) SelectPage(query map[string]any) map[string]any { + return r.traceTaskRepository.SelectPage(query) +} + +// SelectById 通过ID查询 +func (r *TraceTask) SelectById(id string) model.TraceTask { + tasks := r.traceTaskRepository.SelectByIds([]string{id}) + if len(tasks) > 0 { + return tasks[0] + } + return model.TraceTask{} +} + +// Insert 新增信息 +func (r *TraceTask) Insert(task model.TraceTask) error { + // 跟踪配置是否开启 + if v := config.Get("trace.enabled"); v != nil { + if !v.(bool) { + return fmt.Errorf("tracking is not enabled") + } + } + host := "127.0.0.1" + if v := config.Get("trace.host"); v != nil { + host = v.(string) + } + var port int64 = 33033 + if v := config.Get("trace.port"); v != nil { + port = parse.Number(v) + } + task.NotifyUrl = fmt.Sprintf("udp:%s:%d", host, port) + + // 查询网元获取IP + neInfo := neService.NewNeInfoImpl.SelectNeInfoByNeTypeAndNeID(task.NeType, task.NeId) + if neInfo.NeId != task.NeId || neInfo.IP == "" { + return fmt.Errorf("app.common.noNEInfo") + } + traceId := r.traceTaskRepository.LastID() + 1 // 生成任务ID < 65535 + task.TraceId = fmt.Sprint(traceId) + + // 发送任务给网元 + data := map[string]any{ + "neType": neInfo.NeType, + "neId": neInfo.NeId, + "notifyUrl": task.NotifyUrl, + "id": traceId, + "startTime": date.ParseDateToStr(task.StartTime, date.YYYY_MM_DD_HH_MM_SS), + "endTime": date.ParseDateToStr(task.EndTime, date.YYYY_MM_DD_HH_MM_SS), + } + switch task.TraceType { + case "1": // Interface + data["traceType"] = "Interface" + data["interfaces"] = strings.Split(task.Interfaces, ",") + case "2": // Device + data["traceType"] = "Device" + data["ueIp"] = task.UeIp + data["srcIp"] = task.SrcIp + data["dstIp"] = task.DstIp + data["signalPort"] = task.SignalPort + task.UeIp = neInfo.IP + case "3": // UE + data["traceType"] = "UE" + data["imsi"] = task.IMSI + data["msisdn"] = task.MSISDN + default: + return fmt.Errorf("trace type is not disabled") + } + msg, err := neFetchlink.NeTraceAdd(neInfo, data) + if err != nil { + return err + } + s, _ := json.Marshal(msg) + task.FetchMsg = string(s) + + // 插入数据库 + r.traceTaskRepository.Insert(task) + return nil +} + +// Update 修改信息 +func (r *TraceTask) Update(task model.TraceTask) error { + // 跟踪配置是否开启 + if v := config.Get("trace.enabled"); v != nil { + if !v.(bool) { + return fmt.Errorf("tracking is not enabled") + } + } + host := "127.0.0.1" + if v := config.Get("trace.host"); v != nil { + host = v.(string) + } + var port int64 = 33033 + if v := config.Get("trace.port"); v != nil { + port = parse.Number(v) + } + task.NotifyUrl = fmt.Sprintf("udp:%s:%d", host, port) + + // 查询网元获取IP + neInfo := neService.NewNeInfoImpl.SelectNeInfoByNeTypeAndNeID(task.NeType, task.NeId) + if neInfo.NeId != task.NeId || neInfo.IP == "" { + return fmt.Errorf("app.common.noNEInfo") + } + + // 查询网元任务信息 + if msg, err := neFetchlink.NeTraceInfo(neInfo, task.TraceId); err == nil { + s, _ := json.Marshal(msg) + task.FetchMsg = string(s) + // 修改任务信息 + data := map[string]any{ + "neType": neInfo.NeType, + "neId": neInfo.NeId, + "notifyUrl": task.NotifyUrl, + "id": parse.Number(task.TraceId), + "startTime": date.ParseDateToStr(task.StartTime, date.YYYY_MM_DD_HH_MM_SS), + "endTime": date.ParseDateToStr(task.EndTime, date.YYYY_MM_DD_HH_MM_SS), + } + switch task.TraceType { + case "1": // Interface + data["traceType"] = "Interface" + data["interfaces"] = strings.Split(task.Interfaces, ",") + case "2": // Device + task.UeIp = neInfo.IP + data["traceType"] = "Device" + data["ueIp"] = task.UeIp + data["srcIp"] = task.SrcIp + data["dstIp"] = task.DstIp + data["signalPort"] = task.SignalPort + case "3": // UE + data["traceType"] = "UE" + data["imsi"] = task.IMSI + data["msisdn"] = task.MSISDN + default: + return fmt.Errorf("trace type is not disabled") + } + neFetchlink.NeTraceEdit(neInfo, data) + } + + // 更新数据库 + r.traceTaskRepository.Update(task) + return nil +} + +// DeleteByIds 批量删除信息 +func (r *TraceTask) DeleteByIds(ids []string) (int64, error) { + // 检查是否存在 + rows := r.traceTaskRepository.SelectByIds(ids) + if len(rows) <= 0 { + return 0, fmt.Errorf("not data") + } + + if len(rows) == len(ids) { + // 停止任务 + for _, v := range rows { + neInfo := neService.NewNeInfoImpl.SelectNeInfoByNeTypeAndNeID(v.NeType, v.NeId) + if neInfo.NeId != v.NeId || neInfo.IP == "" { + continue + } + neFetchlink.NeTraceDelete(neInfo, v.TraceId) + } + num := r.traceTaskRepository.DeleteByIds(ids) + return num, nil + } + // 删除信息失败! + return 0, fmt.Errorf("delete fail") +} diff --git a/src/modules/trace/service/trace_task_udp_data.go b/src/modules/trace/service/trace_task_udp_data.go new file mode 100644 index 00000000..057ddf41 --- /dev/null +++ b/src/modules/trace/service/trace_task_udp_data.go @@ -0,0 +1,330 @@ +package service + +import ( + "encoding/base64" + "encoding/binary" + "fmt" + "os" + "runtime" + "strings" + "time" + + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" +) + +const ( + GTPU_V1_VERSION = 1 << 5 + GTPU_VER_MASK = 7 << 5 + GTPU_PT_GTP = 1 << 4 + GTPU_HEADER_LEN = 12 + GTPU_E_S_PB_BIT = 7 + GTPU_E_BI = 1 << 2 +) + +const ( + GTPU_HEADER_VERSION_INDEX = 0 + GTPU_HEADER_MSG_TYPE_INDEX = 1 + GTPU_HEADER_LENGTH_INDEX = 2 + GTPU_HEADER_TEID_INDEX = 4 +) + +type ExtHeader struct { + TaskId uint32 + IMSI string + IfType int + MsgType int + MsgDirect int // 0-recv,1-send + TimeStamp int64 + SrcIP string + DstIP string + SrcPort uint16 + DstPort uint16 + Proto int // Protocol + PPI int // only for SCTP + DataLen uint16 + DataInfo []byte +} + +// parseUDPData 解析UDP数据 +func parseUDPData(rvMsg []byte, rvLen int) (ExtHeader, error) { + var extHdr ExtHeader + // var tr dborm.TraceData + var off int + msg := rvMsg + + verFlags := msg[GTPU_HEADER_VERSION_INDEX] + + gtpuHdrLen := GTPU_HEADER_LEN + + localTeid := binary.BigEndian.Uint32(msg[GTPU_HEADER_TEID_INDEX:]) + + extHdr.TaskId = localTeid + + if (verFlags & GTPU_E_S_PB_BIT) != 0 { + if (verFlags & GTPU_E_BI) != 0 { + extTypeIndex := GTPU_HEADER_LEN - 1 + + extType := msg[extTypeIndex] + + if extType == 0xFE { + extHdr.IMSI = string(msg[extTypeIndex+2 : extTypeIndex+17]) + extHdr.IfType = int(msg[extTypeIndex+17]) + extHdr.MsgType = int(msg[extTypeIndex+18]) + extHdr.MsgDirect = int(msg[extTypeIndex+19]) + + extHdr.TimeStamp = time.Now().UTC().UnixMilli() + // extHdr.TimeStamp = int64(binary.BigEndian.Uint64(msg[extTypeIndex+19:])) + // fmt.Printf("ext info %v %s %d %d %d \n", msg[(extTypeIndex+2):(extTypeIndex+20)], extHdr.IMSI, extHdr.IfType, extHdr.MsgType, extHdr.MsgDirect) + // set offset of IP Packet + off = 40 + 4 + //src ip: msg+40+12 + extHdr.SrcIP = fmt.Sprintf("%d.%d.%d.%d", msg[off+12], msg[off+13], msg[off+14], msg[off+15]) + //dst ip: msg+40+12+4 + extHdr.DstIP = fmt.Sprintf("%d.%d.%d.%d", msg[off+16], msg[off+17], msg[off+18], msg[off+19]) + extHdr.SrcPort = uint16(binary.BigEndian.Uint16(msg[off+20:])) + extHdr.DstPort = uint16(binary.BigEndian.Uint16(msg[off+22:])) + // fmt.Printf("info %s:%d %s:%d \n", extHdr.SrcIP, extHdr.SrcPort, extHdr.DstIP, extHdr.DstPort) + // ip header start msg+40 + extHdr.DataLen = uint16(rvLen - off) + extHdr.DataInfo = make([]byte, int(rvLen-off)) + copy(extHdr.DataInfo, []byte(msg[off:])) + + // 132 SCTP + // 6 TCP + // 17 UDP + extHdr.Proto = int(msg[off+9]) + if extHdr.Proto == 132 { + extHdr.PPI = int(msg[off+47]) + extHdr.DataLen = uint16(binary.BigEndian.Uint16(msg[(off+34):]) - 16) + // fmt.Printf("dat len %d %d \n", extHdr.DataLen, extHdr.PPI) + } + } + + for extType != 0 && extTypeIndex < rvLen { + extLen := msg[extTypeIndex+1] << 2 + if extLen == 0 { + return extHdr, fmt.Errorf("error, extLen is zero") + } + + gtpuHdrLen += int(extLen) + extTypeIndex += int(extLen) + extType = msg[extTypeIndex] + } + } + } else { + gtpuHdrLen -= 4 + } + return extHdr, nil +} + +// UDPDataHandler UDP数据处理 +func UDPDataHandler(data []byte, n int) (map[string]any, error) { + extHdr, err := parseUDPData(data, n) + if err != nil { + return nil, err + } + if extHdr.TaskId == 0 || extHdr.DataLen < 1 { + return nil, fmt.Errorf("data error") + } + + m := map[string]any{ + "taskId": extHdr.TaskId, + "imsi": extHdr.IMSI, + "ifType": extHdr.IfType, + "srcAddr": fmt.Sprintf("%s:%d", extHdr.SrcIP, extHdr.SrcPort), + "dstAddr": fmt.Sprintf("%s:%d", extHdr.DstIP, extHdr.DstPort), + "msgType": extHdr.MsgType, + "msgDirect": extHdr.MsgDirect, + "timestamp": extHdr.TimeStamp, + "dataLen": extHdr.DataLen, + // "dataInfo": extHdr.DataInfo, + "decMsg": "", + } + // Base64 编码 + m["dataInfo"] = base64.StdEncoding.EncodeToString(extHdr.DataInfo) + + if extHdr.Proto == 6 { // TCP + // 取响应数据 + iplen := uint16(binary.BigEndian.Uint16(extHdr.DataInfo[2:])) + tcplen := uint16(iplen - 32 - 20) + hdrlen := uint16(binary.BigEndian.Uint16(extHdr.DataInfo[20+32+1:])) + offset := uint16(52) + // fmt.Printf("HTTP %d %d %d \n", iplen, tcplen, hdrlen) + if tcplen > (hdrlen + 9) { // has data + doffset := uint16(offset + hdrlen + 9) + datlen := uint16(binary.BigEndian.Uint16(extHdr.DataInfo[doffset+1:])) + // fmt.Printf("HTTP datlen %d \n", datlen) + m["decMsg"], _ = httpDataMsg(extHdr.DataInfo[offset+9:offset+9+hdrlen], extHdr.DataInfo[doffset+9:doffset+datlen+9]) + } else { + m["decMsg"], _ = httpDataMsg(extHdr.DataInfo[offset+9:hdrlen], nil) + } + } + + // pcap文件 + m["pcapFile"] = writePcap(extHdr) + return m, nil +} + +// =========== TCP协议Body =========== + +// httpDataMsg Http数据信息处理 +func httpDataMsg(header []byte, data []byte) (string, error) { + var remainSize = uint32(16 << 20) + var sawRegular bool + var invalid bool // pseudo header field errors + var Fields []hpack.HeaderField + + invalid = false + hdec := hpack.NewDecoder(4096, nil) + hdec.SetEmitEnabled(true) + hdec.SetMaxStringLength(int(16 << 20)) + hdec.SetEmitFunc(func(hf hpack.HeaderField) { + if !httpguts.ValidHeaderFieldValue(hf.Value) { + // Don't include the value in the error, because it may be sensitive. + invalid = true + } + isPseudo := strings.HasPrefix(hf.Name, ":") + if isPseudo { + if sawRegular { + invalid = true + } + } else { + sawRegular = true + if !validWireHeaderFieldName(hf.Name) { + invalid = true + } + } + + if invalid { + hdec.SetEmitEnabled(false) + return + } + + size := hf.Size() + if size > remainSize { + hdec.SetEmitEnabled(false) + //mh.Truncated = true + return + } + remainSize -= size + + Fields = append(Fields, hf) + }) + + // defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + frag := header + if _, err := hdec.Write(frag); err != nil { + return "", err + } + + if err := hdec.Close(); err != nil { + return "", err + } + + // hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + var headers []byte + var line string + for i := range Fields { + line = fmt.Sprintf("\"%s\":\"%s\",", Fields[i].Name, Fields[i].Value) + headers = append(headers, []byte(line)...) + } + + if len(data) > 0 { + return fmt.Sprintf("{ %s \"content\":%s }", string(headers), string(data)), nil + } else { + return fmt.Sprintf("{ %s }", string(headers)), nil + } +} + +// validWireHeaderFieldName 校验报文头字段名称 +func validWireHeaderFieldName(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + if !httpguts.IsTokenRune(r) { + return false + } + if 'A' <= r && r <= 'Z' { + return false + } + } + return true +} + +// =========== writePcap 写Pcap文件 =========== + +const magicMicroseconds = 0xA1B2C3D4 +const versionMajor = 2 +const versionMinor = 4 + +func writeEmptyPcap(filename string, timeStamp int64, length int, data []byte) error { + var err error + var file *os.File + if _, err = os.Stat(filename); os.IsNotExist(err) { + file, err = os.Create(filename) + // File Header + var fileHeaderBuf [24]byte + binary.LittleEndian.PutUint32(fileHeaderBuf[0:4], magicMicroseconds) + binary.LittleEndian.PutUint16(fileHeaderBuf[4:6], versionMajor) + binary.LittleEndian.PutUint16(fileHeaderBuf[6:8], versionMinor) + // bytes 8:12 stay 0 (timezone = UTC) + // bytes 12:16 stay 0 (sigfigs is always set to zero, according to + // http://wiki.wireshark.org/Development/LibpcapFileFormat + binary.LittleEndian.PutUint32(fileHeaderBuf[16:20], 0x00040000) + binary.LittleEndian.PutUint32(fileHeaderBuf[20:24], 0x00000071) + if _, err := file.Write(fileHeaderBuf[:]); err != nil { + return err + } + } else { + file, err = os.OpenFile(filename, os.O_WRONLY|os.O_APPEND, 0666) + } + if err != nil { + return err + } + defer file.Close() + + // Packet Header + var packetHeaderBuf [24]byte + t := time.UnixMilli(timeStamp) + if t.IsZero() { + t = time.Now() + } + secs := t.Unix() + usecs := t.Nanosecond() / 1000 + binary.LittleEndian.PutUint32(packetHeaderBuf[0:4], uint32(secs)) + binary.LittleEndian.PutUint32(packetHeaderBuf[4:8], uint32(usecs)) + binary.LittleEndian.PutUint32(packetHeaderBuf[8:12], uint32(length+16)) + binary.LittleEndian.PutUint32(packetHeaderBuf[12:16], uint32(length+16)) + if _, err := file.Write(packetHeaderBuf[:]); err != nil { + return err + } + + // 数据包内容的定义 + cooked := [...]byte{0x00, 0x00, 0x03, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00} + if _, err := file.Write(cooked[:]); err != nil { + return err + } + + // Packet Data + if _, err := file.Write(data); err != nil { + return err + } + return nil +} + +// writePcap 写Pcap文件并返回文件路径 +func writePcap(extHdr ExtHeader) string { + filePath := fmt.Sprintf("/tmp/trace_%d .pcap", extHdr.TaskId) + if runtime.GOOS == "windows" { + filePath = fmt.Sprintf("C:%s", filePath) + } + err := writeEmptyPcap(filePath, extHdr.TimeStamp, int(extHdr.DataLen), extHdr.DataInfo) + if err != nil { + return "" + } + return filePath +}