1
0

feat: 删除不需要文件夹

This commit is contained in:
TsMask
2023-10-10 10:56:44 +08:00
parent ce7c3cae68
commit d173205528
154 changed files with 32276 additions and 1 deletions

64
lib/aes/aes.go Normal file
View File

@@ -0,0 +1,64 @@
package aes
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
)
func AesEncrypt(orig string, key string) string {
// 转成字节数组
origData := []byte(orig)
k := []byte(key)
// 分组秘钥
block, _ := aes.NewCipher(k)
// 获取秘钥块的长度
blockSize := block.BlockSize()
// 补全码
origData = PKCS7Padding(origData, blockSize)
// 加密模式
blockMode := cipher.NewCBCEncrypter(block, k[:blockSize])
// 创建数组
cryted := make([]byte, len(origData))
// 加密
blockMode.CryptBlocks(cryted, origData)
return base64.StdEncoding.EncodeToString(cryted)
}
func AesDecrypt(cryted string, key string) string {
// 转成字节数组
crytedByte, _ := base64.StdEncoding.DecodeString(cryted)
k := []byte(key)
// 分组秘钥
block, _ := aes.NewCipher(k)
// 获取秘钥块的长度
blockSize := block.BlockSize()
// 加密模式
blockMode := cipher.NewCBCDecrypter(block, k[:blockSize])
// 创建数组
orig := make([]byte, len(crytedByte))
// 解密
blockMode.CryptBlocks(orig, crytedByte)
// 去补全码
orig = PKCS7UnPadding(orig)
return string(orig)
}
// 补码
func PKCS7Padding(ciphertext []byte, blocksize int) []byte {
padding := blocksize - len(ciphertext)%blocksize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
// 去码
func PKCS7UnPadding(origData []byte) []byte {
length := len(origData)
unpadding := int(origData[length-1])
return origData[:(length - unpadding)]
}

View File

@@ -0,0 +1,54 @@
package account
import (
"fmt"
"strconv"
"time"
sysMenuService "ems.agt/features/sys_menu/service"
sysRoleService "ems.agt/features/sys_role/service"
"ems.agt/lib/core/cache"
"ems.agt/lib/core/conf"
"ems.agt/lib/core/vo"
"ems.agt/lib/dborm"
)
// 登录缓存用户信息
func CacheLoginUser(user *dborm.User) {
// 过期时间
expiresStr, err := dborm.XormGetConfigValue("Security", "sessionExpires")
if err != nil {
expiresStr = "18000"
}
expiresValue, _ := strconv.Atoi(expiresStr)
expireTime := time.Duration(expiresValue) * time.Second
nowTime := time.Now().UnixMilli()
// 登录用户
loginUser := vo.LoginUser{
UserID: fmt.Sprint(user.Id),
UserName: user.Name,
ExpireTime: nowTime + expireTime.Milliseconds(),
LoginTime: nowTime,
User: *user,
}
// 是否管理员
if conf.IsAdmin(loginUser.UserID) {
loginUser.Permissions = []string{"*:*:*"}
} else {
// 获取权限标识
loginUser.Permissions = sysMenuService.NewRepoSysMenu.SelectMenuPermsByUserId(loginUser.UserID)
// 获取角色信息
loginUser.User.Roles = sysRoleService.NewRepoSysRole.SelectRoleListByUserId(loginUser.UserID)
}
// 缓存时间
cache.SetLocalTTL(user.AccountId, loginUser, time.Duration(expireTime))
}
// 清除缓存用户信息
func ClearLoginUser(accountId string) {
cache.DeleteLocalTTL(accountId)
}

56
lib/core/cache/lcoal.go vendored Normal file
View File

@@ -0,0 +1,56 @@
package cache
import (
"strings"
"time"
"github.com/patrickmn/go-cache"
)
// 创建一个全局的不过期缓存对象
var cNoExpiration = cache.New(cache.NoExpiration, cache.NoExpiration)
// 将数据放入缓存,并设置永不过期
func SetLocal(key string, value any) {
cNoExpiration.Set(key, value, cache.NoExpiration)
}
// 从缓存中获取数据
func GetLocal(key string) (any, bool) {
return cNoExpiration.Get(key)
}
// 从缓存中删除数据
func DeleteLocal(key string) {
cNoExpiration.Delete(key)
}
// 获取指定前缀的所有键
func GetLocalKeys(prefix string) []string {
prefix = strings.TrimSuffix(prefix, "*")
var keys []string
for key := range cNoExpiration.Items() {
if strings.HasPrefix(key, prefix) {
keys = append(keys, key)
}
}
return keys
}
// 创建一个全局的过期缓存对象
var cTTL = cache.New(6*time.Hour, 12*time.Hour)
// 设置具有过期时间的缓存项
func SetLocalTTL(key string, value any, expiration time.Duration) {
cTTL.Set(key, value, expiration)
}
// 从缓存中获取数据
func GetLocalTTL(key string) (any, bool) {
return cTTL.Get(key)
}
// 从缓存中删除数据
func DeleteLocalTTL(key string) {
cTTL.Delete(key)
}

201
lib/core/cmd/cmd.go Normal file
View File

@@ -0,0 +1,201 @@
package cmd
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
)
func Exec(cmdStr string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
cmd := exec.Command("bash", "-c", cmdStr)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("errCmdTimeout %v", err)
}
if err != nil {
errMsg := ""
if len(stderr.String()) != 0 {
errMsg = fmt.Sprintf("stderr: %s", stderr.String())
}
if len(stdout.String()) != 0 {
if len(errMsg) != 0 {
errMsg = fmt.Sprintf("%s; stdout: %s", errMsg, stdout.String())
} else {
errMsg = fmt.Sprintf("stdout: %s", stdout.String())
}
}
return errMsg, err
}
return stdout.String(), nil
}
func ExecWithTimeOut(cmdStr string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.Command("bash", "-c", cmdStr)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("errCmdTimeout %v", err)
}
if err != nil {
errMsg := ""
if len(stderr.String()) != 0 {
errMsg = fmt.Sprintf("stderr: %s", stderr.String())
}
if len(stdout.String()) != 0 {
if len(errMsg) != 0 {
errMsg = fmt.Sprintf("%s; stdout: %s", errMsg, stdout.String())
} else {
errMsg = fmt.Sprintf("stdout: %s", stdout.String())
}
}
return errMsg, err
}
return stdout.String(), nil
}
func ExecCronjobWithTimeOut(cmdStr string, workdir string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.Command("bash", "-c", cmdStr)
cmd.Dir = workdir
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("errCmdTimeout %v", err)
}
errMsg := ""
if len(stderr.String()) != 0 {
errMsg = fmt.Sprintf("stderr:\n %s", stderr.String())
}
if len(stdout.String()) != 0 {
if len(errMsg) != 0 {
errMsg = fmt.Sprintf("%s \n\n; stdout:\n %s", errMsg, stdout.String())
} else {
errMsg = fmt.Sprintf("stdout:\n %s", stdout.String())
}
}
return errMsg, err
}
func Execf(cmdStr string, a ...interface{}) (string, error) {
cmd := exec.Command("bash", "-c", fmt.Sprintf(cmdStr, a...))
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
errMsg := ""
if len(stderr.String()) != 0 {
errMsg = fmt.Sprintf("stderr: %s", stderr.String())
}
if len(stdout.String()) != 0 {
if len(errMsg) != 0 {
errMsg = fmt.Sprintf("%s; stdout: %s", errMsg, stdout.String())
} else {
errMsg = fmt.Sprintf("stdout: %s", stdout.String())
}
}
return errMsg, err
}
return stdout.String(), nil
}
func ExecWithCheck(name string, a ...string) (string, error) {
cmd := exec.Command(name, a...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
errMsg := ""
if len(stderr.String()) != 0 {
errMsg = fmt.Sprintf("stderr: %s", stderr.String())
}
if len(stdout.String()) != 0 {
if len(errMsg) != 0 {
errMsg = fmt.Sprintf("%s; stdout: %s", errMsg, stdout.String())
} else {
errMsg = fmt.Sprintf("stdout: %s", stdout.String())
}
}
return errMsg, err
}
return stdout.String(), nil
}
func ExecScript(scriptPath, workDir string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
cmd := exec.Command("bash", scriptPath)
cmd.Dir = workDir
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("errCmdTimeout %v", err)
}
if err != nil {
errMsg := ""
if len(stderr.String()) != 0 {
errMsg = fmt.Sprintf("stderr: %s", stderr.String())
}
if len(stdout.String()) != 0 {
if len(errMsg) != 0 {
errMsg = fmt.Sprintf("%s; stdout: %s", errMsg, stdout.String())
} else {
errMsg = fmt.Sprintf("stdout: %s", stdout.String())
}
}
return errMsg, err
}
return stdout.String(), nil
}
func CheckIllegal(args ...string) bool {
if args == nil {
return false
}
for _, arg := range args {
if strings.Contains(arg, "&") || strings.Contains(arg, "|") || strings.Contains(arg, ";") ||
strings.Contains(arg, "$") || strings.Contains(arg, "'") || strings.Contains(arg, "`") ||
strings.Contains(arg, "(") || strings.Contains(arg, ")") || strings.Contains(arg, "\"") {
return true
}
}
return false
}
func HasNoPasswordSudo() bool {
cmd2 := exec.Command("sudo", "-n", "ls")
err2 := cmd2.Run()
return err2 == nil
}
func SudoHandleCmd() string {
cmd := exec.Command("sudo", "-n", "ls")
if err := cmd.Run(); err == nil {
return "sudo "
}
return ""
}
func Which(name string) bool {
_, err := exec.LookPath(name)
return err == nil
}

52
lib/core/conf/conf.go Normal file
View File

@@ -0,0 +1,52 @@
package conf
import (
"fmt"
"time"
"github.com/spf13/viper"
)
// 配置文件读取
func InitConfig(configFile string) {
// 设置配置文件路径
viper.SetConfigFile(configFile)
// 读取配置文件
err := viper.ReadInConfig()
if err != nil {
fmt.Printf("读取配置文件失败: %v \n", err)
return
}
// 记录程序开始运行的时间点
viper.Set("runTime", time.Now())
}
// RunTime 程序开始运行的时间
func RunTime() time.Time {
return viper.GetTime("runTime")
}
// Get 获取配置信息
//
// Get("framework.name")
func Get(key string) any {
return viper.Get(key)
}
// IsAdmin 用户是否为管理员
func IsAdmin(userID string) bool {
if userID == "" {
return false
}
// 从本地配置获取user信息
// admins := Get("user.adminList").([]any)
admins := []string{"1", "2", "3"}
for _, s := range admins {
if s == userID {
return true
}
}
return false
}

View File

@@ -0,0 +1,24 @@
package cachekey
// 缓存的key常量
// 登录用户
const LOGIN_TOKEN_KEY = "login_tokens:"
// 验证码
const CAPTCHA_CODE_KEY = "captcha_codes:"
// 参数管理
const SYS_CONFIG_KEY = "sys_config:"
// 字典管理
const SYS_DICT_KEY = "sys_dict:"
// 防重提交
const REPEAT_SUBMIT_KEY = "repeat_submit:"
// 限流
const RATE_LIMIT_KEY = "rate_limit:"
// 登录账户密码错误次数
const PWD_ERR_CNT_KEY = "pwd_err_cnt:"

View File

@@ -0,0 +1,49 @@
package datasource
import (
"database/sql"
"regexp"
"ems.agt/lib/dborm"
"xorm.io/xorm"
)
// 获取默认数据源
func DefaultDB() *xorm.Engine {
return dborm.DbClient.XEngine
}
// RawDB 原生查询语句
func RawDB(source string, sql string, parameters []any) ([]map[string]any, error) {
// 数据源
db := DefaultDB()
// 使用正则表达式替换连续的空白字符为单个空格
fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
// log.Infof("sql=> %v", fmtSql)
// log.Infof("parameters=> %v", parameters)
// 查询结果
var rows []map[string]any
err := db.SQL(fmtSql, parameters...).Find(&rows)
if err != nil {
return nil, err
}
return rows, nil
}
// ExecDB 原生执行语句
func ExecDB(source string, sql string, parameters []any) (sql.Result, error) {
// 数据源
db := DefaultDB()
// 使用正则表达式替换连续的空白字符为单个空格
fmtSql := regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
// 执行结果
res, err := db.Exec(append([]any{fmtSql}, parameters...)...)
if err != nil {
return nil, err
}
return res, err
}

135
lib/core/datasource/repo.go Normal file
View File

@@ -0,0 +1,135 @@
package datasource
import (
"fmt"
"reflect"
"strconv"
"strings"
)
// PageNumSize 分页页码记录数
func PageNumSize(pageNum, pageSize any) (int, int) {
// 记录起始索引
pageNumStr := fmt.Sprintf("%v", pageNum)
num := 1
if v, err := strconv.Atoi(pageNumStr); err == nil && v > 0 {
if num > 5000 {
num = 5000
}
num = v
}
// 显示记录数
pageSizeStr := fmt.Sprintf("%v", pageSize)
size := 10
if v, err := strconv.Atoi(pageSizeStr); err == nil && v > 0 {
if size < 0 {
size = 10
} else if size > 1000 {
size = 1000
} else {
size = v
}
}
return num - 1, size
}
// SetFieldValue 判断结构体内是否存在指定字段并设置值
func SetFieldValue(obj any, fieldName string, value any) {
// 获取结构体的反射值
userValue := reflect.ValueOf(obj)
// 获取字段的反射值
fieldValue := userValue.Elem().FieldByName(fieldName)
// 检查字段是否存在
if fieldValue.IsValid() && fieldValue.CanSet() {
// 获取字段的类型
fieldType := fieldValue.Type()
// 转换传入的值类型为字段类型
switch fieldType.Kind() {
case reflect.String:
if value == nil {
fieldValue.SetString("")
} else {
fieldValue.SetString(fmt.Sprintf("%v", value))
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(fmt.Sprintf("%v", value), 10, 64)
if err != nil {
intValue = 0
}
fieldValue.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(fmt.Sprintf("%v", value), 10, 64)
if err != nil {
uintValue = 0
}
fieldValue.SetUint(uintValue)
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(fmt.Sprintf("%v", value), 64)
if err != nil {
floatValue = 0
}
fieldValue.SetFloat(floatValue)
default:
// 设置字段的值
fieldValue.Set(reflect.ValueOf(value).Convert(fieldValue.Type()))
}
}
}
// ConvertIdsSlice 将 []string 转换为 []any
func ConvertIdsSlice(ids []string) []any {
// 将 []string 转换为 []any
arr := make([]any, len(ids))
for i, v := range ids {
arr[i] = v
}
return arr
}
// 查询-参数值的占位符
func KeyPlaceholderByQuery(sum int) string {
placeholders := make([]string, sum)
for i := 0; i < sum; i++ {
placeholders[i] = "?"
}
return strings.Join(placeholders, ",")
}
// 插入-参数映射键值占位符 keys, placeholder, values
func KeyPlaceholderValueByInsert(params map[string]any) ([]string, string, []any) {
// 参数映射的键
keys := make([]string, len(params))
// 参数映射的值
values := make([]any, len(params))
sum := 0
for k, v := range params {
keys[sum] = k
values[sum] = v
sum++
}
// 参数值的占位符
placeholders := make([]string, sum)
for i := 0; i < sum; i++ {
placeholders[i] = "?"
}
return keys, strings.Join(placeholders, ","), values
}
// 更新-参数映射键值占位符 keys, values
func KeyValueByUpdate(params map[string]any) ([]string, []any) {
// 参数映射的键
keys := make([]string, len(params))
// 参数映射的值
values := make([]any, len(params))
sum := 0
for k, v := range params {
keys[sum] = k + "=?"
values[sum] = v
sum++
}
return keys, values
}

88
lib/core/file/csv.go Normal file
View File

@@ -0,0 +1,88 @@
package file
import (
"encoding/csv"
"os"
"path/filepath"
"strings"
"ems.agt/lib/log"
)
// 写入CSV文件需要转换数据
// 例如:
// data := [][]string{}
// data = append(data, []string{"姓名", "年龄", "城市"})
// data = append(data, []string{"1", "2", "3"})
// err := file.WriterCSVFile(data, filePath)
func WriterCSVFile(data [][]string, filePath string) error {
// 获取文件所在的目录路径
dirPath := filepath.Dir(filePath)
// 确保文件夹路径存在
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
log.Errorf("创建文件夹失败 CreateFile %v", err)
}
// 创建或打开文件
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
// 创建CSV编写器
writer := csv.NewWriter(file)
defer writer.Flush()
// 写入数据
for _, row := range data {
writer.Write(row)
}
return nil
}
// 读取CSV文件转换map数据
func ReadCSVFile(filePath string) []map[string]string {
// 创建 map 存储 CSV 数据
arr := make([]map[string]string, 0)
// 打开 CSV 文件
file, err := os.Open(filePath)
if err != nil {
log.Fatal("无法打开 CSV 文件:", err)
return arr
}
defer file.Close()
// 创建 CSV Reader
reader := csv.NewReader(file)
// 读取 CSV 头部行
header, err := reader.Read()
if err != nil {
log.Fatal("无法读取 CSV 头部行:", err)
return arr
}
// 遍历 CSV 数据行
for {
// 读取一行数据
record, err := reader.Read()
if err != nil {
// 到达文件末尾或遇到错误时退出循环
break
}
// 将 CSV 数据插入到 map 中
data := make(map[string]string)
for i, value := range record {
key := strings.ToLower(header[i])
data[key] = value
}
arr = append(arr, data)
}
return arr
}

49
lib/core/file/ssh.go Normal file
View File

@@ -0,0 +1,49 @@
package file
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"ems.agt/lib/core/conf"
"ems.agt/lib/log"
)
// 网元NE 文件复制到远程文件
func FileSCPLocalToNe(neIp, localPath, nePath string) error {
usernameNe := conf.Get("ne.user").(string)
// scp /path/to/local/file.txt user@remote-server:/path/to/remote/directory/
neDir := fmt.Sprintf("%s@%s:%s", usernameNe, neIp, nePath)
cmd := exec.Command("scp", "-r", localPath, neDir)
out, err := cmd.CombinedOutput()
if err != nil {
return err
}
log.Infof("FileSCPLocalToNe %s", string(out))
return nil
}
// 网元NE 远程文件复制到本地文件
func FileSCPNeToLocal(neIp, nePath, localPath string) error {
// 获取文件所在的目录路径
dirPath := filepath.Dir(localPath)
// 确保文件夹路径存在
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
log.Errorf("创建文件夹失败 CreateFile %v", err)
return err
}
usernameNe := conf.Get("ne.user").(string)
// scp user@remote-server:/path/to/remote/directory/ /path/to/local/file.txt
neDir := fmt.Sprintf("%s@%s:%s", usernameNe, neIp, nePath)
cmd := exec.Command("scp", "-r", neDir, localPath)
out, err := cmd.CombinedOutput()
if err != nil {
return err
}
log.Infof("FileSCPNeToLocal %s", string(out))
return nil
}

79
lib/core/file/txt.go Normal file
View File

@@ -0,0 +1,79 @@
package file
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
"ems.agt/lib/log"
)
// 写入Txt文件用,号分割 需要转换数据
// 例如:
// data := [][]string{}
// data = append(data, []string{"姓名", "年龄", "城市"})
// data = append(data, []string{"1", "2", "3"})
// err := file.WriterCSVFile(data, filePath)
func WriterTxtFile(data [][]string, filePath string) error {
// 获取文件所在的目录路径
dirPath := filepath.Dir(filePath)
// 确保文件夹路径存在
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
log.Errorf("创建文件夹失败 CreateFile %v", err)
}
// 创建或打开文件
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
// 创建一个 Writer 对象,用于将数据写入文件
writer := bufio.NewWriter(file)
for _, row := range data {
line := strings.Join(row, ",")
fmt.Fprintln(writer, line)
}
// 将缓冲区中的数据刷新到文件中
err = writer.Flush()
if err != nil {
log.Errorf("刷新缓冲区时发生错误:", err)
return err
}
return nil
}
// 读取Txt文件用,号分割 转换数组数据
func ReadTxtFile(filePath string) [][]string {
// 创建 map 存储 CSV 数据
arr := make([][]string, 0)
// 打开文本文件
file, err := os.Open(filePath)
if err != nil {
log.Fatal("无法打开文件:", err)
return arr
}
defer file.Close()
// 创建一个 Scanner 对象,用于逐行读取文件内容
scanner := bufio.NewScanner(file)
if scanner.Err() != nil {
log.Fatal("读取文件时出错:", scanner.Err())
return arr
}
for scanner.Scan() {
line := scanner.Text()
fields := strings.Split(line, ",")
arr = append(arr, fields)
}
return arr
}

View File

@@ -0,0 +1,88 @@
package mmlclient
import (
"bufio"
"fmt"
"io"
"net"
"time"
"ems.agt/lib/core/conf"
)
// 定义MMLClient结构体
type MMLClient struct {
awaitTime time.Duration // 等待时间
conn net.Conn
reader *bufio.Reader
size int // 包含字符
}
// 封装NewMMLClient函数用于创建MMLClient实例
// 网元UDM的IP地址 "198.51.100.1"
func NewMMLClient(ip string) (*MMLClient, error) {
// 创建TCP连接
portMML := conf.Get("mml.port").(int)
hostMML := fmt.Sprintf("%s:%d", ip, portMML)
conn, err := net.Dial("tcp", hostMML)
if err != nil {
return nil, err
}
// 进行登录
usernameMML := conf.Get("mml.user").(string)
passwordMML := conf.Get("mml.password").(string)
fmt.Fprintln(conn, usernameMML)
fmt.Fprintln(conn, passwordMML)
// 发送后等待
sleepTime := conf.Get("mml.sleep").(int)
awaitTime := time.Duration(sleepTime) * time.Millisecond
time.Sleep(awaitTime)
// 读取内容
buf := make([]byte, 1024*8)
n, err := conn.Read(buf)
if err != nil {
return nil, err
}
// 创建MMLClient实例
client := &MMLClient{
conn: conn,
reader: bufio.NewReader(conn),
awaitTime: awaitTime,
size: n,
}
return client, nil
}
// 封装Send函数用于向TCP连接发送数据
func (c *MMLClient) Send(msg string) error {
_, err := fmt.Fprintln(c.conn, msg)
if err != nil {
return err
}
time.Sleep(c.awaitTime)
return nil
}
// 封装Receive函数用于从TCP连接中接收数据
func (c *MMLClient) Receive() (string, error) {
buf := make([]byte, 1024*8)
n, err := c.reader.Read(buf)
if err != nil {
if err == io.EOF {
return "", fmt.Errorf("server closed the connection")
}
return "", err
}
return string(buf[0:n]), nil
}
// 封装Close函数用于关闭TCP连接
func (c *MMLClient) Close() error {
return c.conn.Close()
}

104
lib/core/mml_client/send.go Normal file
View File

@@ -0,0 +1,104 @@
package mmlclient
import (
"fmt"
"strings"
)
// 发送MML原始消息
// ip 网元IP地址
// msg 指令
func MMLSendMsg(ip, msg string) (string, error) {
// 创建MMLClient实例
client, err := NewMMLClient(ip)
if err != nil {
return "", fmt.Errorf("创建MMLClient实例失败%v", err)
}
defer client.Close()
// 发送数据
err = client.Send(msg)
if err != nil {
return "", fmt.Errorf("发送数据失败:%v", err)
}
// 接收数据
data, err := client.Receive()
if err != nil {
return "", fmt.Errorf("接收数据失败:%v", err)
}
return data, nil
}
// 发送MML
// ip 网元IP地址
// msg 指令
func MMLSendMsgToString(ip, msg string) (string, error) {
// 发送获取数据
str, err := MMLSendMsg(ip, msg)
if err != nil {
return "", err
}
// 截断
index := strings.Index(str, "\n")
if index != -1 {
str = str[:index]
str = strings.ToLower(str)
}
// 命令成功
if strings.Contains(str, "ok") || strings.Contains(str, "OK") {
return str, nil
}
return "", fmt.Errorf(str)
}
// 发送MML
// ip 网元IP地址
// msg 指令
func MMLSendMsgToMap(ip, msg string) (map[string]string, error) {
// 发送获取数据
str, err := MMLSendMsg(ip, msg)
if err != nil {
return nil, err
}
// 无数据
if strings.HasPrefix(str, "No Auth Data") {
return nil, fmt.Errorf("no auth data")
}
// 初始化一个map用于存储拆分后的键值对
m := make(map[string]string)
var items []string
if strings.Contains(str, "\r\n") {
// 按照分隔符"\r\n"进行拆分
items = strings.Split(str, "\r\n")
} else if strings.Contains(str, "\n") {
// 按照分隔符"\n"进行拆分
items = strings.Split(str, "\n")
}
// 遍历拆分后的结果
for _, item := range items {
var pair []string
if strings.Contains(item, "=") {
// 按照分隔符"="进行拆分键值对
pair = strings.Split(item, "=")
} else if strings.Contains(item, ":") {
// 按照分隔符":"进行拆分键值对
pair = strings.Split(item, ":")
}
if len(pair) == 2 {
// 将键值对存入map中
m[pair[0]] = pair[1]
}
}
return m, err
}

358
lib/core/redis/redis.go Normal file
View File

@@ -0,0 +1,358 @@
package redis
import (
"context"
"fmt"
"strings"
"time"
"ems.agt/lib/core/conf"
"ems.agt/lib/log"
"github.com/redis/go-redis/v9"
)
// Redis连接实例
var rdbMap = make(map[string]*redis.Client)
// 声明定义限流脚本命令
var rateLimitCommand = redis.NewScript(`
local key = KEYS[1]
local time = tonumber(ARGV[1])
local count = tonumber(ARGV[2])
local current = redis.call('get', key);
if current and tonumber(current) >= count then
return tonumber(current);
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current);`)
// 连接Redis实例
func Connect() {
ctx := context.Background()
// 读取数据源配置
datasource := conf.Get("redis.dataSource").(map[string]any)
for k, v := range datasource {
client := v.(map[string]any)
// 创建连接
address := fmt.Sprintf("%s:%d", client["host"], client["port"])
rdb := redis.NewClient(&redis.Options{
Addr: address,
Password: client["password"].(string),
DB: client["db"].(int),
})
// 测试数据库连接
pong, err := rdb.Ping(ctx).Result()
if err != nil {
log.Fatalf("failed error ping redis %s %d is %v", client["host"], client["db"], err)
continue
}
log.Infof("redis %s %d %s connection is successful.", client["host"], client["db"], pong)
rdbMap[k] = rdb
}
}
// 关闭Redis实例
func Close() {
for _, rdb := range rdbMap {
if err := rdb.Close(); err != nil {
log.Errorf("fatal error db close: %s", err)
}
}
}
// 获取默认实例
func DefaultRDB() *redis.Client {
source := conf.Get("redis.defaultDataSourceName").(string)
return rdbMap[source]
}
// 获取实例
func RDB(source string) *redis.Client {
return rdbMap[source]
}
// Info 获取redis服务信息
func Info(source string) map[string]map[string]string {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
info, err := rdb.Info(ctx).Result()
if err != nil {
return map[string]map[string]string{}
}
infoObj := make(map[string]map[string]string)
lines := strings.Split(info, "\r\n")
label := ""
for _, line := range lines {
if strings.Contains(line, "#") {
label = strings.Fields(line)[len(strings.Fields(line))-1]
label = strings.ToLower(label)
infoObj[label] = make(map[string]string)
continue
}
kvArr := strings.Split(line, ":")
if len(kvArr) >= 2 {
key := strings.TrimSpace(kvArr[0])
value := strings.TrimSpace(kvArr[len(kvArr)-1])
infoObj[label][key] = value
}
}
return infoObj
}
// KeySize 获取redis当前连接可用键Key总数信息
func KeySize(source string) int64 {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
size, err := rdb.DBSize(ctx).Result()
if err != nil {
return 0
}
return size
}
// CommandStats 获取redis命令状态信息
func CommandStats(source string) []map[string]string {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
commandstats, err := rdb.Info(ctx, "commandstats").Result()
if err != nil {
return []map[string]string{}
}
statsObjArr := make([]map[string]string, 0)
lines := strings.Split(commandstats, "\r\n")
for _, line := range lines {
if !strings.HasPrefix(line, "cmdstat_") {
continue
}
kvArr := strings.Split(line, ":")
key := kvArr[0]
valueStr := kvArr[len(kvArr)-1]
statsObj := make(map[string]string)
statsObj["name"] = key[8:]
statsObj["value"] = valueStr[6:strings.Index(valueStr, ",usec=")]
statsObjArr = append(statsObjArr, statsObj)
}
return statsObjArr
}
// 获取键的剩余有效时间(秒)
func GetExpire(source string, key string) (float64, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
ttl, err := rdb.TTL(ctx, key).Result()
if err != nil {
return 0, err
}
return ttl.Seconds(), nil
}
// 获得缓存数据的key列表
func GetKeys(source string, pattern string) ([]string, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
// 初始化变量
var keys []string
var cursor uint64 = 0
ctx := context.Background()
// 循环遍历获取匹配的键
for {
// 使用 SCAN 命令获取匹配的键
batchKeys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
log.Errorf("Failed to scan keys: %v", err)
return keys, err
}
cursor = nextCursor
keys = append(keys, batchKeys...)
// 当 cursor 为 0表示遍历完成
if cursor == 0 {
break
}
}
return keys, nil
}
// 批量获得缓存数据
func GetBatch(source string, keys []string) ([]any, error) {
if len(keys) == 0 {
return []any{}, fmt.Errorf("not keys")
}
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
// 获取缓存数据
result, err := rdb.MGet(context.Background(), keys...).Result()
if err != nil {
log.Errorf("Failed to get batch data: %v", err)
return []any{}, err
}
return result, nil
}
// 获得缓存数据
func Get(source, key string) (string, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
value, err := rdb.Get(ctx, key).Result()
if err == redis.Nil || err != nil {
return "", err
}
return value, nil
}
// 获得缓存数据Hash
func GetHash(source, key string) (map[string]string, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
value, err := rdb.HGetAll(ctx, key).Result()
if err == redis.Nil || err != nil {
return map[string]string{}, err
}
return value, nil
}
// 判断是否存在
func Has(source string, keys ...string) (bool, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
exists, err := rdb.Exists(ctx, keys...).Result()
if err != nil {
return false, err
}
return exists >= 1, nil
}
// 设置缓存数据
func Set(source, key string, value any) (bool, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
err := rdb.Set(ctx, key, value, 0).Err()
if err != nil {
log.Errorf("redis lua script err %v", err)
return false, err
}
return true, nil
}
// 设置缓存数据与过期时间
func SetByExpire(source, key string, value any, expiration time.Duration) (bool, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
err := rdb.Set(ctx, key, value, expiration).Err()
if err != nil {
log.Errorf("redis lua script err %v", err)
return false, err
}
return true, nil
}
// 删除单个
func Del(source string, key string) (bool, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
err := rdb.Del(ctx, key).Err()
if err != nil {
log.Errorf("redis lua script err %v", err)
return false, err
}
return true, nil
}
// 删除多个
func DelKeys(source string, keys []string) (bool, error) {
if len(keys) == 0 {
return false, fmt.Errorf("no keys")
}
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
err := rdb.Del(ctx, keys...).Err()
if err != nil {
log.Errorf("redis lua script err %v", err)
return false, err
}
return true, nil
}
// 限流查询并记录
func RateLimit(source, limitKey string, time, count int64) (int64, error) {
// 数据源
rdb := DefaultRDB()
if source != "" {
rdb = RDB(source)
}
ctx := context.Background()
result, err := rateLimitCommand.Run(ctx, rdb, []string{limitKey}, time, count).Result()
if err != nil {
log.Errorf("redis lua script err %v", err)
return 0, err
}
return result.(int64), err
}

View File

@@ -0,0 +1,20 @@
package crypto
import (
"golang.org/x/crypto/bcrypt"
)
// BcryptHash Bcrypt密码加密
func BcryptHash(originStr string) string {
hash, err := bcrypt.GenerateFromPassword([]byte(originStr), bcrypt.DefaultCost)
if err != nil {
return ""
}
return string(hash)
}
// BcryptCompare Bcrypt密码匹配检查
func BcryptCompare(originStr, hashStr string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashStr), []byte(originStr))
return err == nil
}

133
lib/core/utils/ctx/ctx.go Normal file
View File

@@ -0,0 +1,133 @@
package ctx
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"ems.agt/lib/core/vo"
"github.com/gorilla/mux"
)
// Param 地址栏参数{id}
func Param(r *http.Request, key string) string {
vars := mux.Vars(r)
v, ok := vars[key]
if ok {
return v
}
return ""
}
// GetQuery 查询参数
func GetQuery(r *http.Request, key string) string {
return r.URL.Query().Get(key)
}
// QueryMap 查询参数转换Map
func QueryMap(r *http.Request) map[string]any {
queryValues := r.URL.Query()
queryParams := make(map[string]any)
for key, values := range queryValues {
queryParams[key] = values[0]
}
return queryParams
}
// 读取json请求结构团体
func ShouldBindJSON(r *http.Request, args any) error {
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 设置较大的长度,例如 1<<20 (1MB)
if err != nil {
return err
}
err = json.Unmarshal(body, args)
return err
}
// JSON 相应json数据
func JSON(w http.ResponseWriter, code int, data any) {
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
response, err := json.Marshal(data)
if err != nil {
w.WriteHeader(500)
w.Write([]byte(err.Error()))
} else {
w.WriteHeader(code)
w.Write(response)
}
}
// 将文件导出到外部下载
func FileAttachment(w http.ResponseWriter, r *http.Request, filepath, filename string) {
w.Header().Set("Content-Disposition", `attachment; filename=`+url.QueryEscape(filename))
w.Header().Set("Content-Type", "application/octet-stream")
http.ServeFile(w, r, filepath)
}
// 将文件上传保存到指定目录
// file, handler, err := r.FormFile("file")
// SaveUploadedFile uploads the form file to specific dst.
func SaveUploadedFile(r *http.Request, dst string) error {
// 解析请求中的文件
_, handler, err := r.FormFile("file")
if err != nil {
return err
}
src, err := handler.Open()
if err != nil {
return err
}
defer src.Close()
if err = os.MkdirAll(filepath.Dir(dst), 0750); err != nil {
return err
}
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, src)
return err
}
/// ==== 登录用户信息, 通过中间件后预置入
// 定义自定义类型作为键
type ContextKey string
// LoginUser 登录用户信息需要Authorize中间件
func LoginUser(r *http.Request) (vo.LoginUser, error) {
// 上下文
v := r.Context().Value(ContextKey("LoginUser"))
if v != nil {
return v.(vo.LoginUser), nil
}
return vo.LoginUser{}, fmt.Errorf("无用户信息")
}
// LoginUserToUserID 登录用户信息-用户ID
func LoginUserToUserID(r *http.Request) string {
loginUser, err := LoginUser(r)
if err != nil {
return ""
}
return loginUser.UserID
}
// LoginUserToUserName 登录用户信息-用户名称
func LoginUserToUserName(r *http.Request) string {
loginUser, err := LoginUser(r)
if err != nil {
return ""
}
return loginUser.UserName
}

View File

@@ -0,0 +1,70 @@
package date
import (
"fmt"
"time"
"ems.agt/lib/log"
)
const (
// 年 列如2022
YYYY = "2006"
// 年-月 列如2022-12
YYYY_MM = "2006-01"
// 年-月-日 列如2022-12-30
YYYY_MM_DD = "2006-01-02"
// 年月日时分秒 列如20221230010159
YYYYMMDDHHMMSS = "20060102150405"
// 年-月-日 时:分:秒 列如2022-12-30 01:01:59
YYYY_MM_DD_HH_MM_SS = "2006-01-02 15:04:05"
)
// 格式时间字符串
//
// dateStr 时间字符串
//
// formatStr 时间格式 默认YYYY-MM-DD HH:mm:ss
func ParseStrToDate(dateStr, formatStr string) time.Time {
t, err := time.Parse(formatStr, dateStr)
if err != nil {
log.Infof("utils ParseStrToDate err %v", err)
return time.Time{}
}
return t
}
// 格式时间
//
// date 可转的Date对象
//
// formatStr 时间格式 默认YYYY-MM-DD HH:mm:ss
func ParseDateToStr(date any, formatStr string) string {
t, ok := date.(time.Time)
if !ok {
switch v := date.(type) {
case int64:
if v == 0 {
return ""
}
t = time.UnixMilli(v)
case string:
parsedTime, err := time.Parse(formatStr, v)
if err != nil {
fmt.Printf("utils ParseDateToStr err %v \n", err)
return ""
}
t = parsedTime
default:
return ""
}
}
return t.Format(formatStr)
}
// 格式时间成日期路径
//
// 年/月 列如2022/12
func ParseDatePath(date time.Time) string {
return date.Format("2006/01")
}

View File

@@ -0,0 +1,34 @@
package firewall
import (
"errors"
"os"
"ems.agt/lib/core/utils/firewall/client"
)
type FirewallClient interface {
Name() string // ufw firewalld
Start() error
Stop() error
Reload() error
Status() (string, error) // running not running
Version() (string, error)
ListPort() ([]client.FireInfo, error)
ListAddress() ([]client.FireInfo, error)
Port(port client.FireInfo, operation string) error
RichRules(rule client.FireInfo, operation string) error
PortForward(info client.Forward, operation string) error
}
func NewFirewallClient() (FirewallClient, error) {
if _, err := os.Stat("/usr/sbin/firewalld"); err == nil {
return client.NewFirewalld()
}
if _, err := os.Stat("/usr/sbin/ufw"); err == nil {
return client.NewUfw()
}
return nil, errors.New("no such type")
}

View File

@@ -0,0 +1,209 @@
package client
import (
"fmt"
"strings"
"sync"
"ems.agt/lib/core/cmd"
)
type Firewall struct{}
func NewFirewalld() (*Firewall, error) {
return &Firewall{}, nil
}
func (f *Firewall) Name() string {
return "firewalld"
}
func (f *Firewall) Status() (string, error) {
stdout, _ := cmd.Exec("firewall-cmd --state")
if stdout == "running\n" {
return "running", nil
}
return "not running", nil
}
func (f *Firewall) Version() (string, error) {
stdout, err := cmd.Exec("firewall-cmd --version")
if err != nil {
return "", fmt.Errorf("load the firewall version failed, err: %s", stdout)
}
return strings.ReplaceAll(stdout, "\n ", ""), nil
}
func (f *Firewall) Start() error {
stdout, err := cmd.Exec("systemctl start firewalld")
if err != nil {
return fmt.Errorf("enable the firewall failed, err: %s", stdout)
}
return nil
}
func (f *Firewall) Stop() error {
stdout, err := cmd.Exec("systemctl stop firewalld")
if err != nil {
return fmt.Errorf("stop the firewall failed, err: %s", stdout)
}
return nil
}
func (f *Firewall) Reload() error {
stdout, err := cmd.Exec("firewall-cmd --reload")
if err != nil {
return fmt.Errorf("reload firewall failed, err: %s", stdout)
}
return nil
}
func (f *Firewall) ListPort() ([]FireInfo, error) {
var wg sync.WaitGroup
var datas []FireInfo
wg.Add(2)
go func() {
defer wg.Done()
stdout, err := cmd.Exec("firewall-cmd --zone=public --list-ports")
if err != nil {
return
}
ports := strings.Split(strings.ReplaceAll(stdout, "\n", ""), " ")
for _, port := range ports {
if len(port) == 0 {
continue
}
var itemPort FireInfo
if strings.Contains(port, "/") {
itemPort.Port = strings.Split(port, "/")[0]
itemPort.Protocol = strings.Split(port, "/")[1]
}
itemPort.Strategy = "accept"
datas = append(datas, itemPort)
}
}()
go func() {
defer wg.Done()
stdout1, err := cmd.Exec("firewall-cmd --zone=public --list-rich-rules")
if err != nil {
return
}
rules := strings.Split(stdout1, "\n")
for _, rule := range rules {
if len(rule) == 0 {
continue
}
itemRule := f.loadInfo(rule)
if len(itemRule.Port) != 0 && itemRule.Family == "ipv4" {
datas = append(datas, itemRule)
}
}
}()
wg.Wait()
return datas, nil
}
func (f *Firewall) ListAddress() ([]FireInfo, error) {
stdout, err := cmd.Exec("firewall-cmd --zone=public --list-rich-rules")
if err != nil {
return nil, err
}
var datas []FireInfo
rules := strings.Split(stdout, "\n")
for _, rule := range rules {
if len(rule) == 0 {
continue
}
itemRule := f.loadInfo(rule)
if len(itemRule.Port) == 0 && len(itemRule.Address) != 0 {
datas = append(datas, itemRule)
}
}
return datas, nil
}
func (f *Firewall) Port(port FireInfo, operation string) error {
if cmd.CheckIllegal(operation, port.Protocol, port.Port) {
return fmt.Errorf("errCmdIllegal %v", port)
}
stdout, err := cmd.Execf("firewall-cmd --zone=public --%s-port=%s/%s --permanent", operation, port.Port, port.Protocol)
if err != nil {
return fmt.Errorf("%s port failed, err: %s", operation, stdout)
}
return nil
}
func (f *Firewall) RichRules(rule FireInfo, operation string) error {
if cmd.CheckIllegal(operation, rule.Address, rule.Protocol, rule.Port, rule.Strategy) {
return fmt.Errorf("errCmdIllegal %v", rule)
}
ruleStr := ""
if strings.Contains(rule.Address, "-") {
std, err := cmd.Execf("firewall-cmd --permanent --new-ipset=%s --type=hash:ip", rule.Address)
if err != nil {
return fmt.Errorf("add new ipset failed, err: %s", std)
}
std2, err := cmd.Execf("firewall-cmd --permanent --ipset=%s --add-entry=%s", rule.Address, rule.Address)
if err != nil {
return fmt.Errorf("add entry to ipset failed, err: %s", std2)
}
if err := f.Reload(); err != nil {
return err
}
ruleStr = fmt.Sprintf("rule source ipset=%s %s", rule.Address, rule.Strategy)
} else {
ruleStr = "rule family=ipv4 "
if len(rule.Address) != 0 {
ruleStr += fmt.Sprintf("source address=%s ", rule.Address)
}
if len(rule.Port) != 0 {
ruleStr += fmt.Sprintf("port port=%s ", rule.Port)
}
if len(rule.Protocol) != 0 {
ruleStr += fmt.Sprintf("protocol=%s ", rule.Protocol)
}
ruleStr += rule.Strategy
}
stdout, err := cmd.Execf("firewall-cmd --zone=public --%s-rich-rule '%s' --permanent", operation, ruleStr)
if err != nil {
return fmt.Errorf("%s rich rules failed, err: %s", operation, stdout)
}
return nil
}
func (f *Firewall) PortForward(info Forward, operation string) error {
ruleStr := fmt.Sprintf("firewall-cmd --%s-forward-port=port=%s:proto=%s:toport=%s --permanent", operation, info.Port, info.Protocol, info.Target)
if len(info.Address) != 0 {
ruleStr = fmt.Sprintf("firewall-cmd --%s-forward-port=port=%s:proto=%s:toaddr=%s:toport=%s --permanent", operation, info.Port, info.Protocol, info.Address, info.Target)
}
stdout, err := cmd.Exec(ruleStr)
if err != nil {
return fmt.Errorf("%s port forward failed, err: %s", operation, stdout)
}
return nil
}
func (f *Firewall) loadInfo(line string) FireInfo {
var itemRule FireInfo
ruleInfo := strings.Split(strings.ReplaceAll(line, "\"", ""), " ")
for _, item := range ruleInfo {
switch {
case strings.Contains(item, "family="):
itemRule.Family = strings.ReplaceAll(item, "family=", "")
case strings.Contains(item, "ipset="):
itemRule.Address = strings.ReplaceAll(item, "ipset=", "")
case strings.Contains(item, "address="):
itemRule.Address = strings.ReplaceAll(item, "address=", "")
case strings.Contains(item, "port="):
itemRule.Port = strings.ReplaceAll(item, "port=", "")
case strings.Contains(item, "protocol="):
itemRule.Protocol = strings.ReplaceAll(item, "protocol=", "")
case item == "accept" || item == "drop" || item == "reject":
itemRule.Strategy = item
}
}
return itemRule
}

View File

@@ -0,0 +1,20 @@
package client
type FireInfo struct {
Family string `json:"family"` // ipv4 ipv6
Address string `json:"address"` // Anywhere
Port string `json:"port"`
Protocol string `json:"protocol"` // tcp udp tcp/udp
Strategy string `json:"strategy"` // accept drop
APPName string `json:"appName"`
IsUsed bool `json:"isUsed"`
Description string `json:"description"`
}
type Forward struct {
Protocol string `json:"protocol"`
Address string `json:"address"`
Port string `json:"port"`
Target string `json:"target"`
}

View File

@@ -0,0 +1,238 @@
package client
import (
"fmt"
"strings"
"ems.agt/lib/core/cmd"
)
type Ufw struct {
CmdStr string
}
func NewUfw() (*Ufw, error) {
var ufw Ufw
if cmd.HasNoPasswordSudo() {
ufw.CmdStr = "sudo ufw"
} else {
ufw.CmdStr = "ufw"
}
return &ufw, nil
}
func (f *Ufw) Name() string {
return "ufw"
}
func (f *Ufw) Status() (string, error) {
stdout, _ := cmd.Execf("%s status | grep Status", f.CmdStr)
if stdout == "Status: active\n" {
return "running", nil
}
stdout1, _ := cmd.Execf("%s status | grep 状态", f.CmdStr)
if stdout1 == "状态: 激活\n" {
return "running", nil
}
return "not running", nil
}
func (f *Ufw) Version() (string, error) {
stdout, err := cmd.Execf("%s version | grep ufw", f.CmdStr)
if err != nil {
return "", fmt.Errorf("load the firewall status failed, err: %s", stdout)
}
info := strings.ReplaceAll(stdout, "\n", "")
return strings.ReplaceAll(info, "ufw ", ""), nil
}
func (f *Ufw) Start() error {
stdout, err := cmd.Execf("echo y | %s enable", f.CmdStr)
if err != nil {
return fmt.Errorf("enable the firewall failed, err: %s", stdout)
}
return nil
}
func (f *Ufw) Stop() error {
stdout, err := cmd.Execf("%s disable", f.CmdStr)
if err != nil {
return fmt.Errorf("stop the firewall failed, err: %s", stdout)
}
return nil
}
func (f *Ufw) Reload() error {
return nil
}
func (f *Ufw) ListPort() ([]FireInfo, error) {
stdout, err := cmd.Execf("%s status verbose", f.CmdStr)
if err != nil {
return nil, err
}
portInfos := strings.Split(stdout, "\n")
var datas []FireInfo
isStart := false
for _, line := range portInfos {
if strings.HasPrefix(line, "-") {
isStart = true
continue
}
if !isStart {
continue
}
itemFire := f.loadInfo(line, "port")
if len(itemFire.Port) != 0 && itemFire.Port != "Anywhere" && !strings.Contains(itemFire.Port, ".") {
itemFire.Port = strings.ReplaceAll(itemFire.Port, ":", "-")
datas = append(datas, itemFire)
}
}
return datas, nil
}
func (f *Ufw) ListAddress() ([]FireInfo, error) {
stdout, err := cmd.Execf("%s status verbose", f.CmdStr)
if err != nil {
return nil, err
}
portInfos := strings.Split(stdout, "\n")
var datas []FireInfo
isStart := false
for _, line := range portInfos {
if strings.HasPrefix(line, "-") {
isStart = true
continue
}
if !isStart {
continue
}
if !strings.Contains(line, " IN") {
continue
}
itemFire := f.loadInfo(line, "address")
if strings.Contains(itemFire.Port, ".") {
itemFire.Address += ("-" + itemFire.Port)
itemFire.Port = ""
}
if len(itemFire.Port) == 0 && len(itemFire.Address) != 0 {
datas = append(datas, itemFire)
}
}
return datas, nil
}
func (f *Ufw) Port(port FireInfo, operation string) error {
switch port.Strategy {
case "accept":
port.Strategy = "allow"
case "drop":
port.Strategy = "deny"
default:
return fmt.Errorf("unsupport strategy %s", port.Strategy)
}
if cmd.CheckIllegal(port.Protocol, port.Port) {
return fmt.Errorf("errCmdIllegal %v", port)
}
command := fmt.Sprintf("%s %s %s", f.CmdStr, port.Strategy, port.Port)
if operation == "remove" {
command = fmt.Sprintf("%s delete %s %s", f.CmdStr, port.Strategy, port.Port)
}
if len(port.Protocol) != 0 {
command += fmt.Sprintf("/%s", port.Protocol)
}
stdout, err := cmd.Exec(command)
if err != nil {
return fmt.Errorf("%s port failed, err: %s", operation, stdout)
}
return nil
}
func (f *Ufw) RichRules(rule FireInfo, operation string) error {
switch rule.Strategy {
case "accept":
rule.Strategy = "allow"
case "drop":
rule.Strategy = "deny"
default:
return fmt.Errorf("unsupport strategy %s", rule.Strategy)
}
if cmd.CheckIllegal(operation, rule.Protocol, rule.Address, rule.Port) {
return fmt.Errorf("errCmdIllegal %v", rule)
}
ruleStr := fmt.Sprintf("%s %s ", f.CmdStr, rule.Strategy)
if operation == "remove" {
ruleStr = fmt.Sprintf("%s delete %s ", f.CmdStr, rule.Strategy)
}
if len(rule.Protocol) != 0 {
ruleStr += fmt.Sprintf("proto %s ", rule.Protocol)
}
if strings.Contains(rule.Address, "-") {
ruleStr += fmt.Sprintf("from %s to %s ", strings.Split(rule.Address, "-")[0], strings.Split(rule.Address, "-")[1])
} else {
ruleStr += fmt.Sprintf("from %s ", rule.Address)
}
if len(rule.Port) != 0 {
ruleStr += fmt.Sprintf("to any port %s ", rule.Port)
}
stdout, err := cmd.Exec(ruleStr)
if err != nil {
return fmt.Errorf("%s rich rules failed, err: %s", operation, stdout)
}
return nil
}
func (f *Ufw) PortForward(info Forward, operation string) error {
ruleStr := fmt.Sprintf("firewall-cmd --%s-forward-port=port=%s:proto=%s:toport=%s --permanent", operation, info.Port, info.Protocol, info.Target)
if len(info.Address) != 0 {
ruleStr = fmt.Sprintf("firewall-cmd --%s-forward-port=port=%s:proto=%s:toaddr=%s:toport=%s --permanent", operation, info.Port, info.Protocol, info.Address, info.Target)
}
stdout, err := cmd.Exec(ruleStr)
if err != nil {
return fmt.Errorf("%s port forward failed, err: %s", operation, stdout)
}
if err := f.Reload(); err != nil {
return err
}
return nil
}
func (f *Ufw) loadInfo(line string, fireType string) FireInfo {
fields := strings.Fields(line)
var itemInfo FireInfo
if len(fields) < 4 {
return itemInfo
}
if fields[1] == "(v6)" {
return itemInfo
}
if fields[0] == "Anywhere" && fireType != "port" {
itemInfo.Strategy = "drop"
if fields[1] == "ALLOW" {
itemInfo.Strategy = "accept"
}
itemInfo.Address = fields[3]
return itemInfo
}
if strings.Contains(fields[0], "/") {
itemInfo.Port = strings.Split(fields[0], "/")[0]
itemInfo.Protocol = strings.Split(fields[0], "/")[1]
} else {
itemInfo.Port = fields[0]
itemInfo.Protocol = "tcp/udp"
}
itemInfo.Family = "ipv4"
if fields[1] == "ALLOW" {
itemInfo.Strategy = "accept"
} else {
itemInfo.Strategy = "drop"
}
itemInfo.Address = fields[3]
return itemInfo
}

View File

@@ -0,0 +1,139 @@
package parse
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/robfig/cron/v3"
)
// Number 解析数值型
func Number(str any) int64 {
switch str := str.(type) {
case string:
if str == "" {
return 0
}
num, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return 0
}
return num
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return reflect.ValueOf(str).Int()
case float32, float64:
return int64(reflect.ValueOf(str).Float())
default:
return 0
}
}
// Boolean 解析布尔型
func Boolean(str any) bool {
switch str := str.(type) {
case string:
if str == "" || str == "false" || str == "0" {
return false
}
// 尝试将字符串解析为数字
if num, err := strconv.ParseFloat(str, 64); err == nil {
return num != 0
}
return true
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
num := reflect.ValueOf(str).Int()
return num != 0
case float32, float64:
num := reflect.ValueOf(str).Float()
return num != 0
default:
return false
}
}
// FirstUpper 首字母转大写
//
// 字符串 abc_123!@# 结果 Abc_123
func FirstUpper(str string) string {
if len(str) == 0 {
return str
}
reg := regexp.MustCompile(`[^_\w]+`)
str = reg.ReplaceAllString(str, "")
return strings.ToUpper(str[:1]) + str[1:]
}
// Bit 比特位为单位
func Bit(bit float64) string {
var GB, MB, KB string
if bit > float64(1<<30) {
GB = fmt.Sprintf("%0.2f", bit/(1<<30))
}
if bit > float64(1<<20) && bit < (1<<30) {
MB = fmt.Sprintf("%.2f", bit/(1<<20))
}
if bit > float64(1<<10) && bit < (1<<20) {
KB = fmt.Sprintf("%.2f", bit/(1<<10))
}
if GB != "" {
return GB + "GB"
} else if MB != "" {
return MB + "MB"
} else if KB != "" {
return KB + "KB"
} else {
return fmt.Sprintf("%vB", bit)
}
}
// CronExpression 解析 Cron 表达式,返回下一次执行的时间戳(毫秒)
//
// 【*/5 * * * * ?】 6个参数
func CronExpression(expression string) int64 {
specParser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
schedule, err := specParser.Parse(expression)
if err != nil {
fmt.Println(err)
return 0
}
return schedule.Next(time.Now()).UnixMilli()
}
// SafeContent 内容值进行安全掩码
func SafeContent(value string) string {
if len(value) < 3 {
return strings.Repeat("*", len(value))
} else if len(value) < 6 {
return string(value[0]) + strings.Repeat("*", len(value)-1)
} else if len(value) < 10 {
return string(value[0]) + strings.Repeat("*", len(value)-2) + string(value[len(value)-1])
} else if len(value) < 15 {
return value[:2] + strings.Repeat("*", len(value)-4) + value[len(value)-2:]
} else {
return value[:3] + strings.Repeat("*", len(value)-6) + value[len(value)-3:]
}
}
// RemoveDuplicates 数组内字符串去重
func RemoveDuplicates(ids []string) []string {
uniqueIDs := make(map[string]bool)
uniqueIDSlice := make([]string, 0)
for _, id := range ids {
_, ok := uniqueIDs[id]
if !ok && id != "" {
uniqueIDs[id] = true
uniqueIDSlice = append(uniqueIDSlice, id)
}
}
return uniqueIDSlice
}

View File

@@ -0,0 +1,54 @@
package regular
import (
"regexp"
)
// Replace 正则替换
func Replace(originStr, pattern, repStr string) string {
regex := regexp.MustCompile(pattern)
return regex.ReplaceAllString(originStr, repStr)
}
// 判断是否为有效用户名格式
//
// 用户名不能以数字开头可包含大写小写字母数字且不少于5位
func ValidUsername(username string) bool {
if username == "" {
return false
}
pattern := `^[a-zA-Z][a-z0-9A-Z]{5,}`
match, err := regexp.MatchString(pattern, username)
if err != nil {
return false
}
return match
}
// 判断是否为有效手机号格式1开头的11位手机号
func ValidMobile(mobile string) bool {
if mobile == "" {
return false
}
pattern := `^1[3|4|5|6|7|8|9][0-9]\d{8}$`
match, err := regexp.MatchString(pattern, mobile)
if err != nil {
return false
}
return match
}
// 判断是否为http(s)://开头
//
// link 网络链接
func ValidHttp(link string) bool {
if link == "" {
return false
}
pattern := `^http(s)?:\/\/+`
match, err := regexp.MatchString(pattern, link)
if err != nil {
return false
}
return match
}

View File

@@ -0,0 +1,31 @@
package scan
import (
"net"
"strconv"
)
func ScanPort(port int) bool {
ln, err := net.Listen("tcp", ":"+strconv.Itoa(port))
if err != nil {
return true
}
defer ln.Close()
return false
}
func ScanUDPPort(port int) bool {
ln, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
if err != nil {
return true
}
defer ln.Close()
return false
}
func ScanPortWithProto(port int, proto string) bool {
if proto == "udp" {
return ScanUDPPort(port)
}
return ScanPort(port)
}

28
lib/core/vo/login_user.go Normal file
View File

@@ -0,0 +1,28 @@
package vo
import (
"ems.agt/lib/dborm"
)
// LoginUser 登录用户身份权限信息对象
type LoginUser struct {
// UserID 用户ID
UserID string `json:"userId"`
// UserName 用户名
UserName string `json:"userName"`
// LoginTime 登录时间时间戳
LoginTime int64 `json:"loginTime"`
// ExpireTime 过期时间时间戳
ExpireTime int64 `json:"expireTime"`
// Permissions 权限列表
Permissions []string `json:"permissions"`
// User 用户信息
User dborm.User `json:"user"`
Session dborm.Session `json:"-"`
}

View File

@@ -0,0 +1,72 @@
package result
const CODE_ERROR = 0
const MSG_ERROR = "error"
const CODE_SUCCESS = 1
const MSG_SUCCESS = "success"
// CodeMsg 响应结果
func CodeMsg(code int, msg string) map[string]any {
args := make(map[string]any)
args["code"] = code
args["msg"] = msg
return args
}
// 响应成功结果 map[string]any{}
func Ok(v map[string]any) map[string]any {
args := make(map[string]any)
args["code"] = CODE_SUCCESS
args["msg"] = MSG_SUCCESS
// v合并到args
for key, value := range v {
args[key] = value
}
return args
}
// 响应成功结果信息
func OkMsg(msg string) map[string]any {
args := make(map[string]any)
args["code"] = CODE_SUCCESS
args["msg"] = msg
return args
}
// 响应成功结果数据
func OkData(data any) map[string]any {
args := make(map[string]any)
args["code"] = CODE_SUCCESS
args["msg"] = MSG_SUCCESS
args["data"] = data
return args
}
// 响应失败结果 map[string]any{}
func Err(v map[string]any) map[string]any {
args := make(map[string]any)
args["code"] = CODE_ERROR
args["msg"] = MSG_ERROR
// v合并到args
for key, value := range v {
args[key] = value
}
return args
}
// 响应失败结果信息
func ErrMsg(msg string) map[string]any {
args := make(map[string]any)
args["code"] = CODE_ERROR
args["msg"] = msg
return args
}
// 响应失败结果数据
func ErrData(data any) map[string]any {
args := make(map[string]any)
args["code"] = CODE_ERROR
args["msg"] = MSG_ERROR
args["data"] = data
return args
}

17
lib/core/vo/router.go Normal file
View File

@@ -0,0 +1,17 @@
package vo
// Router 路由信息对象
type Router struct {
// 路由名字 英文首字母大写
Name string `json:"name"`
// 路由地址
Path string `json:"path"`
// 其他元素
Meta RouterMeta `json:"meta"`
// 组件地址
Component string `json:"component"`
// 重定向地址
Redirect string `json:"redirect"`
// 子路由
Children []Router `json:"children,omitempty"`
}

View File

@@ -0,0 +1,17 @@
package vo
// RouterMeta 路由元信息对象
type RouterMeta struct {
// 设置该菜单在侧边栏和面包屑中展示的名字
Title string `json:"title"`
// 设置该菜单的图标
Icon string `json:"icon"`
// 设置为true则不会被 <keep-alive>缓存
Cache bool `json:"cache"`
// 内链地址http(s)://开头), 打开目标位置 '_blank' | '_self' | ''
Target string `json:"target"`
// 在菜单中隐藏子节点
HideChildInMenu bool `json:"hideChildInMenu"`
// 在菜单中隐藏自己和子节点
HideInMenu bool `json:"hideInMenu"`
}

36
lib/core/vo/treeselect.go Normal file
View File

@@ -0,0 +1,36 @@
package vo
// import sysmenu "ems.agt/features/sys_menu"
// TreeSelect 树结构实体类
type TreeSelect struct {
// ID 节点ID
ID string `json:"id"`
// Label 节点名称
Label string `json:"label"`
// Title 节点名称旧版本layui
Title string `json:"title"`
// Children 子节点
Children []TreeSelect `json:"children"`
}
// // SysMenuTreeSelect 使用给定的 SysMenu 对象解析为 TreeSelect 对象
// func SysMenuTreeSelect(sysMenu sysmenu.SysMenu) TreeSelect {
// t := TreeSelect{}
// t.ID = sysMenu.MenuID
// t.Label = sysMenu.MenuName
// if len(sysMenu.Children) > 0 {
// for _, menu := range sysMenu.Children {
// child := SysMenuTreeSelect(menu)
// t.Children = append(t.Children, child)
// }
// } else {
// t.Children = []TreeSelect{}
// }
// return t
// }

1809
lib/dborm/dborm.go Normal file

File diff suppressed because it is too large Load Diff

161
lib/file/file.go Normal file
View File

@@ -0,0 +1,161 @@
package file
import (
"fmt"
"net/http"
"os"
)
// const (
// //经过测试linux下延时需要大于100ms
// TIME_DELAY_AFTER_WRITE = 200
// )
// type Response struct {
// Data []string `json:"data"`
// }
// type MMLRequest struct {
// MML []string `json:"mml"`
// }
// func GetFile(w http.ResponseWriter, r *http.Request) {
// log.Debug("PostMMLToNF processing... ")
// vars := mux.Vars(r)
// neType := vars["elementTypeValue"]
// params := r.URL.Query()
// neId := params["ne_id"]
// log.Debug("neType:", neType, "neId", neId)
// neInfo := new(dborm.NeInfo)
// var err error
// if len(neId) == 0 {
// log.Error("ne_id NOT FOUND")
// services.ResponseBadRequest400WrongParamValue(w)
// return
// }
// neInfo, err = dborm.XormGetNeInfo(neType, neId[0])
// if err != nil {
// log.Error("dborm.XormGetNeInfo is failed:", err)
// services.ResponseInternalServerError500DatabaseOperationFailed(w)
// return
// }
// var buf [8192]byte
// var n int
// var mmlResult []string
// if neInfo != nil {
// hostMML := fmt.Sprintf("%s:%d", neInfo.Ip, config.GetYamlConfig().MML.Port)
// conn, err := net.Dial("tcp", hostMML)
// if err != nil {
// errMsg := fmt.Sprintf("Failed to dial %s: %v", hostMML, err)
// log.Error(errMsg)
// mmlResult = append(mmlResult, errMsg)
// response := Response{mmlResult}
// services.ResponseWithJson(w, http.StatusOK, response)
// return
// }
// loginStr := fmt.Sprintf("%s\n%s\n", config.GetYamlConfig().MML.User, config.GetYamlConfig().MML.Password)
// n, err = conn.Write([]byte(loginStr))
// if err != nil {
// log.Errorf("Error: %s", err.Error())
// return
// }
// time.Sleep(time.Millisecond * TIME_DELAY_AFTER_WRITE)
// n, err = conn.Read(buf[0:])
// if err != nil {
// log.Errorf("Error: %s", err.Error())
// return
// }
// log.Debug(string(buf[0:n]))
// body, err := io.ReadAll(io.LimitReader(r.Body, global.RequestBodyMaxLen))
// if err != nil {
// log.Error("io.ReadAll is failed:", err)
// services.ResponseNotFound404UriNotExist(w, r)
// return
// }
// log.Debug("Body:", string(body))
// mmlRequest := new(MMLRequest)
// _ = json.Unmarshal(body, mmlRequest)
// for _, mml := range mmlRequest.MML {
// mmlCommand := fmt.Sprintf("%s\n", mml)
// log.Debug("mml command:", mmlCommand)
// n, err = conn.Write([]byte(mmlCommand))
// if err != nil {
// log.Errorf("Error: %s", err.Error())
// return
// }
// time.Sleep(time.Millisecond * TIME_DELAY_AFTER_WRITE)
// n, err = conn.Read(buf[0:])
// if err != nil {
// log.Errorf("Error: %s", err.Error())
// return
// }
// log.Debug(string(buf[0 : n-len(neType)-2]))
// mmlResult = append(mmlResult, string(buf[0:n-len(neType)-2]))
// }
// }
// response := Response{mmlResult}
// services.ResponseWithJson(w, http.StatusOK, response)
// }
// 格式文件大小单位
func FormatFileSize(fileSize float64) (size string) {
if fileSize < 1024 {
return fmt.Sprintf("%.2fB", fileSize/float64(1))
} else if fileSize < (1024 * 1024) {
return fmt.Sprintf("%.2fKB", fileSize/float64(1024))
} else if fileSize < (1024 * 1024 * 1024) {
return fmt.Sprintf("%.2fMB", fileSize/float64(1024*1024))
} else if fileSize < (1024 * 1024 * 1024 * 1024) {
return fmt.Sprintf("%.2fGB", fileSize/float64(1024*1024*1024))
} else if fileSize < (1024 * 1024 * 1024 * 1024 * 1024) {
return fmt.Sprintf("%.2fTB", fileSize/float64(1024*1024*1024*1024))
} else {
return fmt.Sprintf("%.2fEB", fileSize/float64(1024*1024*1024*1024*1024))
}
}
func IsSymlink(mode os.FileMode) bool {
return mode&os.ModeSymlink != 0
}
const dotCharacter = 46
func IsHidden(path string) bool {
return path[0] == dotCharacter
}
func GetMimeType(path string) string {
file, err := os.Open(path)
if err != nil {
return ""
}
defer file.Close()
buffer := make([]byte, 512)
_, err = file.Read(buffer)
if err != nil {
return ""
}
mimeType := http.DetectContentType(buffer)
return mimeType
}
func GetSymlink(path string) string {
linkPath, err := os.Readlink(path)
if err != nil {
return ""
}
return linkPath
}

65
lib/global/global.go Normal file
View File

@@ -0,0 +1,65 @@
package global
import "errors"
// 跨package引用的首字母大写
const (
RequestBodyMaxLen = 2000000
ApiVersionV1 = "v1"
ApiVersionV2 = "v2"
LineBreak = "\n"
)
const (
DateTime = "2006-01-02 15:04:05"
DateData = "20060102150405"
DateHour = "2006010215"
DateZone = "2006-01-02 15:04:05 +0000 UTC"
)
const (
MaxInt32Number = 2147483647
)
const (
MaxLimitData = 1000
)
var (
Version string
BuildTime string
GoVer string
)
var (
DefaultUriPrefix = "/api/rest"
)
var (
ErrParamsNotAdapted = errors.New("the number of params is not adapted")
// PM module error message
ErrPMNotFoundData = errors.New("not found PM data")
// CM module error message
ErrCMNotFoundTargetNE = errors.New("not found target NE")
ErrCMCannotDeleteActiveNE = errors.New("can not delete an active NE")
ErrCMInvalidBackupFile = errors.New("invalid backup file")
ErrCMNotMatchMD5File = errors.New("md5 not match between file and url")
ErrCMNotMatchSignFile = errors.New("digests signatures not match in the file")
ErrCMExistSoftwareFile = errors.New("exist the same software package file")
ErrCMNotFoundTargetSoftware = errors.New("not found the target software package")
ErrCMNotFoundTargetNeVersion = errors.New("not found the target NE version")
ErrCMNotFoundRollbackNeVersion = errors.New("not found the rollback NE version")
ErrCMUnknownServiceAction = errors.New("unknown service action")
ErrCMUnknownInstanceAction = errors.New("unknown instance action")
ErrCMNotFoundTargetBackupFile = errors.New("not found the target NE backup")
ErrCMUnknownSoftwareFormat = errors.New("unknown software package format") // 未知软件包格式
// TRACE module error message
ErrTraceFailedDistributeToNEs = errors.New("failed to distribute trace task to target NEs")
ErrTraceNotCarriedTaskID = errors.New("not carried task id in request url")
// MML module error define
ErrMmlInvalidCommandFormat = errors.New("invalid mml command format")
)

677
lib/global/kits.go Normal file
View File

@@ -0,0 +1,677 @@
package global
import (
"archive/zip"
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"reflect"
"regexp"
"sort"
"strings"
"time"
)
const (
IsIPv4 = "IPv4"
IsIPv6 = "IPv6"
NonIP = "NonIp"
)
type em struct{}
func GetPkgName() string {
return reflect.TypeOf(em{}).PkgPath()
}
// interface{} change to map[string]interface{}
// interface{} data is []interface{}
func ListToMap(list interface{}, key string) map[string]interface{} {
res := make(map[string]interface{})
arr := ToSlice(list)
for _, row := range arr {
immutable := reflect.ValueOf(row)
val := immutable.FieldByName(key).String()
res[val] = row
}
return res
}
// interface{} change to []interface{}
func ToSlice(arr interface{}) []interface{} {
ret := make([]interface{}, 0)
v := reflect.ValueOf(arr)
if v.Kind() != reflect.Slice {
ret = append(ret, arr)
return ret
}
l := v.Len()
for i := 0; i < l; i++ {
ret = append(ret, v.Index(i).Interface())
}
return ret
}
var TodoList []Todo
type Todo struct {
Id int64
Item string
}
// JSON序列化方式
func jsonStructToMap(TodoList Todo) (map[string]interface{}, error) {
// 结构体转json
strRet, err := json.Marshal(TodoList)
if err != nil {
return nil, err
}
// json转map
var mRet map[string]interface{}
err1 := json.Unmarshal(strRet, &mRet)
if err1 != nil {
return nil, err1
}
return mRet, nil
}
func IsContain(item string, items []string) bool {
for _, e := range items {
if e == item {
return true
}
}
return false
}
func IsContainP(item string, items *[]string, size int) bool {
for i := 0; i < size; i++ {
if (*items)[i] == item {
return true
}
}
return false
}
// 将字符串 分割成 字符串数组
// @s分割符
func SplitString(str string, s string) []string {
sa := strings.Split(str, s)
return sa
}
//  合并字符串数组
func MergeStringArr(a, b []string) []string {
var arr []string
for _, i := range a {
arr = append(arr, i)
}
for _, j := range b {
arr = append(arr, j)
}
return arr
}
// 数组去重
func UniqueStringArr(m []string) []string {
d := make([]string, 0)
tempMap := make(map[string]bool, len(m))
for _, v := range m { // 以值作为键名
if tempMap[v] == false {
tempMap[v] = true
d = append(d, v)
}
}
return d
}
//  合并整型数组
func MergeArr(a, b []int) []int {
var arr []int
for _, i := range a {
arr = append(arr, i)
}
for _, j := range b {
arr = append(arr, j)
}
return arr
}
// 数组去重
func UniqueArr(m []int) []int {
d := make([]int, 0)
tempMap := make(map[int]bool, len(m))
for _, v := range m { // 以值作为键名
if tempMap[v] == false {
tempMap[v] = true
d = append(d, v)
}
}
return d
}
// 升序
func AscArr(e []int) []int {
sort.Ints(e[:])
return e
}
// 降序
func DescArr(e []int) []int {
sort.Sort(sort.Reverse(sort.IntSlice(e)))
return e
}
func MatchRmUID(p string, s string) bool {
match, _ := regexp.MatchString(p, s)
return match
}
type OrderedMap struct {
Order []string
Map map[string]interface{}
}
func (om *OrderedMap) UnmarshalJson(b []byte) error {
json.Unmarshal(b, &om.Map)
index := make(map[string]int)
for key := range om.Map {
om.Order = append(om.Order, key)
esc, _ := json.Marshal(key) //Escape the key
index[key] = bytes.Index(b, esc)
}
sort.Slice(om.Order, func(i, j int) bool { return index[om.Order[i]] < index[om.Order[j]] })
return nil
}
func (om OrderedMap) MarshalJson() ([]byte, error) {
var b []byte
buf := bytes.NewBuffer(b)
buf.WriteRune('{')
l := len(om.Order)
for i, key := range om.Order {
km, err := json.Marshal(key)
if err != nil {
return nil, err
}
buf.Write(km)
buf.WriteRune(':')
vm, err := json.Marshal(om.Map[key])
if err != nil {
return nil, err
}
buf.Write(vm)
if i != l-1 {
buf.WriteRune(',')
}
fmt.Println(buf.String())
}
buf.WriteRune('}')
fmt.Println(buf.String())
return buf.Bytes(), nil
}
func GetBodyCopy(r *http.Request) (*bytes.Buffer, error) {
// If r.bodyBuf present, return the copy
// if r.bodyBuf != nil {
// return bytes.NewBuffer(r.bodyBuf.Bytes()), nil
// }
// Maybe body is `io.Reader`.
// Note: Resty user have to watchout for large body size of `io.Reader`
if r.Body != nil {
b, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
// Restore the Body
// close(r.Body)
r.Body = io.NopCloser(bytes.NewReader(b))
// Return the Body bytes
return bytes.NewBuffer(b), nil
}
return nil, nil
}
func UnmarshalBody(r *http.Request, v *interface{}, maxLen int64) error {
body, err := io.ReadAll(io.LimitReader(r.Body, maxLen))
if err != nil {
return err
}
return json.Unmarshal(body, v)
}
func SetNotifyUrl(ip string, port uint16, uri string) string {
return fmt.Sprintf("http://%s:%d%s", ip, port, uri)
}
func GetIps() (ips []string, err error) {
interfaceAddr, err := net.InterfaceAddrs()
if err != nil {
return ips, err
}
for _, address := range interfaceAddr {
ipNet, isVailIpNet := address.(*net.IPNet)
// 检查ip地址判断是否回环地址
if isVailIpNet && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ips = append(ips, ipNet.IP.String())
}
}
}
return ips, nil
}
func GetCurrentTimeSliceIndexByPeriod(t time.Time, period int) int {
index := int((t.Hour()*60+t.Minute())/period) - 1
if index < 0 {
return int(24*60/period) - 1
}
return index
}
var (
cst *time.Location
)
// RFC3339ToDateTime convert rfc3339 value to china standard time layout
func RFC3339ToDateTime(value string) (string, error) {
ts, err := time.Parse(time.RFC3339, value)
if err != nil {
return "", err
}
return ts.In(cst).Format("2006-01-02 15:04:05"), nil
}
// CreateTimeDir 根据当前时间格式来创建文件夹
func CreateTimeDir(fmt string, path string) string {
folderName := time.Now().Format(fmt)
folderPath := filepath.Join(path, folderName)
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
// 必须分成两步:先创建文件夹、再修改权限
os.Mkdir(folderPath, 0664) //0644也可以os.ModePerm
os.Chmod(folderPath, 0664)
}
return folderPath
}
// CreateDir 根据传入的目录名和路径来创建文件夹
func CreateDir(folderName string, path string) string {
folderPath := filepath.Join(path, folderName)
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
// 必须分成两步:先创建文件夹、再修改权限
os.MkdirAll(folderPath, 0664) //0644也可以os.ModePerm
os.Chmod(folderPath, 0664)
}
return folderPath
}
func GetFmtTimeString(srcFmt string, timeString string, dstFmt string) string {
t, _ := time.ParseInLocation(srcFmt, timeString, time.Local)
return t.Format(dstFmt)
}
func GetFileMD5Sum(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
md5 := md5.New()
_, err = io.Copy(md5, file)
if err != nil {
return "", err
}
md5str := hex.EncodeToString(md5.Sum(nil))
return md5str, nil
}
// PathExists check path is exist or no
func PathExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil { //文件或者目录存在
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
// PathExists check path is exist or no
func FilePathExists(filePath string) (bool, error) {
_, err := os.Stat(filePath)
if err == nil { //文件或者目录存在
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
func GetDayDuration(d1, d2 string) int64 {
a, _ := time.Parse("2006-01-02", d1)
b, _ := time.Parse("2006-01-02", d2)
d := a.Sub(b)
return (int64)(d.Hours() / 24)
}
func GetSecondsSinceDatetime(datetimeStr string) (int64, error) {
loc1, _ := time.LoadLocation("Local")
// 解析日期时间字符串为时间对象
datetime, err := time.ParseInLocation(time.DateTime, datetimeStr, loc1)
if err != nil {
return 0, err
}
// 计算时间差
duration := time.Since(datetime)
// 获取时间差的秒数
seconds := int64(duration.Seconds())
return seconds, nil
}
// 0: invalid ip
// 4: IPv4
// 6: IPv6
func ParseIP(s string) (net.IP, int) {
ip := net.ParseIP(s)
if ip == nil {
return nil, 0
}
for i := 0; i < len(s); i++ {
switch s[i] {
case '.':
return ip, 4
case ':':
return ip, 6
}
}
return nil, 0
}
func BytesCombine1(pBytes ...[]byte) []byte {
return bytes.Join(pBytes, []byte(""))
}
func BytesCombine(pBytes ...[]byte) []byte {
length := len(pBytes)
s := make([][]byte, length)
for index := 0; index < length; index++ {
s[index] = pBytes[index]
}
sep := []byte("")
return bytes.Join(s, sep)
}
func ParseIPAddr(ip string) string {
ipAddr := net.ParseIP(ip)
if ipAddr != nil {
if ipAddr.To4() != nil {
return IsIPv4
} else {
return IsIPv6
}
}
return NonIP
}
func CombineHostUri(ip string, port string) string {
var hostUri string = ""
ipType := ParseIPAddr(ip)
if ipType == IsIPv4 {
hostUri = fmt.Sprintf("http://%s:%v", ip, port)
} else {
hostUri = fmt.Sprintf("http://[%s]:%v", ip, port)
}
return hostUri
}
func StructToMap(obj interface{}) map[string]interface{} {
objValue := reflect.ValueOf(obj)
objType := objValue.Type()
m := make(map[string]interface{})
for i := 0; i < objValue.NumField(); i++ {
field := objValue.Field(i)
fieldName := objType.Field(i).Name
m[fieldName] = field.Interface()
}
return m
}
// ToMap 结构体转为Map[string]interface{}
func ToMap(in interface{}, tagName string) (map[string]interface{}, error) {
out := make(map[string]interface{})
v := reflect.ValueOf(in)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct { // 非结构体返回错误提示
return nil, fmt.Errorf("ToMap only accepts struct or struct pointer; got %T", v)
}
t := v.Type()
// 遍历结构体字段
// 指定tagName值为map中key;字段值为map中value
for i := 0; i < v.NumField(); i++ {
fi := t.Field(i)
if tagValue := fi.Tag.Get(tagName); tagValue != "" {
out[tagValue] = v.Field(i).Interface()
}
}
return out, nil
}
func ZipOneFile(srcFile, dstZip string, pathFlag bool) error {
zipFile, err := os.Create(dstZip)
if err != nil {
return err
}
defer zipFile.Close()
zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close()
fileToCompress, err := os.Open(srcFile)
if err != nil {
return err
}
defer fileToCompress.Close()
var fileInZip io.Writer
if pathFlag {
fileInZip, err = zipWriter.Create(srcFile)
if err != nil {
return err
}
} else {
// 获取文件的基本名称
fileName := filepath.Base(fileToCompress.Name())
fileInZip, err = zipWriter.Create(fileName)
if err != nil {
return err
}
}
_, err = io.Copy(fileInZip, fileToCompress)
if err != nil {
return err
}
return nil
}
func ZipDirectoryFile(srcDir, dstZip string) error {
// Create a new zip file
zipfileWriter, err := os.Create(dstZip)
if err != nil {
return err
}
defer zipfileWriter.Close()
// Create a new zip archive
zipWriter := zip.NewWriter(zipfileWriter)
defer zipWriter.Close()
// Walk through the directory and add files to the zip archive
err = filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Create a new file header for the current file
header, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
// Set the name of the file within the zip archive
header.Name = filepath.Join(filepath.Base(srcDir), path[len(srcDir):])
// If the current file is a directory, skip it
if info.IsDir() {
return nil
}
// Create a new file in the zip archive
fileWriter, err := zipWriter.CreateHeader(header)
if err != nil {
return err
}
// Open the current file
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
// Copy the contents of the current file to the zip archive
_, err = io.Copy(fileWriter, file)
if err != nil {
return err
}
return nil
})
return err
}
// 判断软件包是rpm或者deb, 1:rpm, 2:deb, 0:unknown format
func JudgeRpmOrDebPackage(filePath string) (int, error) {
var fileType int = 0
file, err := os.Open(filePath)
if err != nil {
return fileType, err
}
defer file.Close()
// Read the first 6 bytes of the file
header := make([]byte, 6)
_, err = file.Read(header)
if err != nil {
return fileType, err
}
// Check the magic numbers to determine the package format
if string(header) == "!<arch>" {
fileType = 1
} else if string(header) == "!<arch\n" || string(header) == "!<arch\r" {
fileType = 2
} else {
fileType = 0
}
return fileType, nil
}
func isRpmPackage(file *os.File) bool {
// RPM packages start with the magic number "EDABEEDB"
magic := []byte{0xED, 0xAB, 0xEE, 0xDB}
buffer := make([]byte, len(magic))
_, err := file.Read(buffer)
if err != nil && err != io.EOF {
return false
}
return string(buffer) == string(magic)
}
func isDebPackage(file *os.File) bool {
// DEB packages start with the magic number "!<arch>\n"
magic := []byte("!<arch>\n")
buffer := make([]byte, len(magic))
_, err := file.Read(buffer)
if err != nil && err != io.EOF {
return false
}
return string(buffer) == string(magic)
}
func CheckRpmOrDebPackage(filePath string) (int, error) {
var fileType int = 0
file, err := os.Open(filePath)
if err != nil {
return fileType, err
}
defer file.Close()
isRpm := isRpmPackage(file)
isDeb := isDebPackage(file)
if isRpm {
fileType = 1
} else if isDeb {
fileType = 2
} else {
fileType = 0
}
return fileType, nil
}
func IsRpmOrDebPackage(filePath string) int {
var fileType int = 0
if strings.Contains(filePath, ".rpm") {
fileType = 1
} else if strings.Contains(filePath, ".deb") {
fileType = 2
} else {
fileType = 0
}
return fileType
}

337
lib/log/logger.go Normal file
View File

@@ -0,0 +1,337 @@
// logger for omc/ems
package log
import (
"fmt"
"io"
"log"
)
// LogLevel defines a log level
type LogLevel int
// enum all LogLevels
const (
// following level also match syslog.Priority value
LOG_TRACE LogLevel = iota
LOG_DEBUG
LOG_INFO
LOG_WARN
LOG_ERROR
LOG_FATAL
LOG_OFF
LOG_NODEF
)
// default log options
const (
DEFAULT_LOG_PREFIX = "omc"
DEFAULT_LOG_FLAG = log.Lshortfile | log.Ldate | log.Lmicroseconds
DEFAULT_LOG_LEVEL = LOG_DEBUG
DEFAULT_CALL_DEPTH = 3
)
// Logger is a logger interface
type Logger interface {
Fatal(v ...interface{})
Fatalf(format string, v ...interface{})
Error(v ...interface{})
Errorf(format string, v ...interface{})
Warn(v ...interface{})
Warnf(format string, v ...interface{})
Info(v ...interface{})
Infof(format string, v ...interface{})
Debug(v ...interface{})
Debugf(format string, v ...interface{})
Trace(v ...interface{})
Tracef(format string, v ...interface{})
Level() LogLevel
LevelString() string
SetLevel(l LogLevel)
}
var _ Logger = DiscardLogger{}
// DiscardLogger don't log implementation for ILogger
type DiscardLogger struct{}
// Trace empty implementation
func (DiscardLogger) Trace(v ...interface{}) {}
// Tracef empty implementation
func (DiscardLogger) Tracef(format string, v ...interface{}) {}
// Debug empty implementation
func (DiscardLogger) Debug(v ...interface{}) {}
// Debugf empty implementation
func (DiscardLogger) Debugf(format string, v ...interface{}) {}
// Info empty implementation
func (DiscardLogger) Info(v ...interface{}) {}
// Infof empty implementation
func (DiscardLogger) Infof(format string, v ...interface{}) {}
// Warn empty implementation
func (DiscardLogger) Warn(v ...interface{}) {}
// Warnf empty implementation
func (DiscardLogger) Warnf(format string, v ...interface{}) {}
// Error empty implementation
func (DiscardLogger) Error(v ...interface{}) {}
// Errorf empty implementation
func (DiscardLogger) Errorf(format string, v ...interface{}) {}
// Fatal empty implementation
func (DiscardLogger) Fatal(v ...interface{}) {}
// Fatalf empty implementation
func (DiscardLogger) Fatalf(format string, v ...interface{}) {}
// Level empty implementation
func (DiscardLogger) Level() LogLevel {
return LOG_NODEF
}
// Level empty implementation
func (DiscardLogger) LevelString() string {
return ""
}
// SetLevel empty implementation
func (DiscardLogger) SetLevel(l LogLevel) {}
// EmsLogger is the default implment of ILogger
type EmsLogger struct {
TRACE *log.Logger
DEBUG *log.Logger
INFO *log.Logger
WARN *log.Logger
ERROR *log.Logger
FATAL *log.Logger
level LogLevel
levelString []string
depth int
}
var _ Logger = &EmsLogger{}
// NewEmsLogger use a special io.Writer as logger output
func NewEmsLogger(out io.Writer) *EmsLogger {
return NewEmsLogger2(out, DEFAULT_LOG_PREFIX, DEFAULT_LOG_FLAG)
}
// NewEmsLogger2 let you customrize your logger prefix and flag
func NewEmsLogger2(out io.Writer, prefix string, flag int) *EmsLogger {
return NewEmsLogger3(out, prefix, flag, DEFAULT_LOG_LEVEL)
}
// NewEmsLogger3 let you customrize your logger prefix and flag and logLevel
func NewEmsLogger3(out io.Writer, prefix string, flag int, l LogLevel) *EmsLogger {
return &EmsLogger{
TRACE: log.New(out, fmt.Sprintf("[%s] [trace] ", prefix), flag),
DEBUG: log.New(out, fmt.Sprintf("[%s] [debug] ", prefix), flag),
INFO: log.New(out, fmt.Sprintf("[%s] [info ] ", prefix), flag),
WARN: log.New(out, fmt.Sprintf("[%s] [warn ] ", prefix), flag),
ERROR: log.New(out, fmt.Sprintf("[%s] [error] ", prefix), flag),
FATAL: log.New(out, fmt.Sprintf("[%s] [fatal] ", prefix), flag),
level: l,
levelString: []string{"trace", "debug", "info", "warn", "error", "fatal"},
depth: DEFAULT_CALL_DEPTH,
}
}
// Trace implement ILogger
func (s *EmsLogger) Trace(v ...interface{}) {
if s.level <= LOG_TRACE {
_ = s.TRACE.Output(s.depth, fmt.Sprintln(v...))
}
}
// Tracef implement ILogger
func (s *EmsLogger) Tracef(format string, v ...interface{}) {
if s.level <= LOG_TRACE {
_ = s.TRACE.Output(s.depth, fmt.Sprintf(format, v...))
}
}
// Debug implement ILogger
func (s *EmsLogger) Debug(v ...interface{}) {
if s.level <= LOG_DEBUG {
_ = s.DEBUG.Output(s.depth, fmt.Sprintln(v...))
}
}
// Debugf implement ILogger
func (s *EmsLogger) Debugf(format string, v ...interface{}) {
if s.level <= LOG_DEBUG {
_ = s.DEBUG.Output(s.depth, fmt.Sprintf(format, v...))
}
}
// Info implement ILogger
func (s *EmsLogger) Info(v ...interface{}) {
if s.level <= LOG_INFO {
_ = s.INFO.Output(s.depth, fmt.Sprintln(v...))
}
}
// Infof implement ILogger
func (s *EmsLogger) Infof(format string, v ...interface{}) {
if s.level <= LOG_INFO {
_ = s.INFO.Output(s.depth, fmt.Sprintf(format, v...))
}
}
// Warn implement ILogger
func (s *EmsLogger) Warn(v ...interface{}) {
if s.level <= LOG_WARN {
_ = s.WARN.Output(s.depth, fmt.Sprintln(v...))
}
}
// Warnf implement ILogger
func (s *EmsLogger) Warnf(format string, v ...interface{}) {
if s.level <= LOG_WARN {
_ = s.WARN.Output(s.depth, fmt.Sprintf(format, v...))
}
}
// Error implement ILogger
func (s *EmsLogger) Error(v ...interface{}) {
if s.level <= LOG_ERROR {
_ = s.ERROR.Output(s.depth, fmt.Sprintln(v...))
}
}
// Errorf implement ILogger
func (s *EmsLogger) Errorf(format string, v ...interface{}) {
if s.level <= LOG_ERROR {
_ = s.ERROR.Output(s.depth, fmt.Sprintf(format, v...))
}
}
// Warn implement ILogger
func (s *EmsLogger) Fatal(v ...interface{}) {
if s.level <= LOG_FATAL {
_ = s.FATAL.Output(s.depth, fmt.Sprintln(v...))
}
}
// Warnf implement ILogger
func (s *EmsLogger) Fatalf(format string, v ...interface{}) {
if s.level <= LOG_FATAL {
_ = s.FATAL.Output(s.depth, fmt.Sprintf(format, v...))
}
}
// Level implement ILogger
func (s *EmsLogger) Level() LogLevel {
return s.level
}
// Level implement ILogger
func (s *EmsLogger) LevelString() string {
return s.levelString[s.level]
}
// SetLevel implement ILogger
func (s *EmsLogger) SetLevel(l LogLevel) {
s.level = l
}
var Elogger Logger
func InitLogger(logFile string, period int, count int, prefix string, logLevel LogLevel) {
/*
logFile, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766)
if err != nil {
panic(err)
}
*/
logWriter := getLogWriter(logFile, period, count)
Elogger = NewEmsLogger3(logWriter, prefix, DEFAULT_LOG_FLAG, logLevel)
fmt.Printf("logFile=%s, period=%d, count=%d, prefix=%s, logLevel=%s\n", logFile, period, count, prefix, GetLevelString())
}
// Trace implement ILogger
func Trace(v ...interface{}) {
Elogger.Trace(v...)
}
// Tracef implement ILogger
func Tracef(format string, v ...interface{}) {
Elogger.Tracef(format, v...)
}
// Debug implement ILogger
func Debug(v ...interface{}) {
Elogger.Debug(v...)
}
// Debugf implement ILogger
func Debugf(format string, v ...interface{}) {
Elogger.Debugf(format, v...)
}
// Info implement ILogger
func Info(v ...interface{}) {
Elogger.Info(v...)
}
// Infof implement ILogger
func Infof(format string, v ...interface{}) {
Elogger.Infof(format, v...)
}
// Warn implement ILogger
func Warn(v ...interface{}) {
Elogger.Warn(v...)
}
// Warnf implement ILogger
func Warnf(format string, v ...interface{}) {
Elogger.Warnf(format, v...)
}
// Error implement ILogger
func Error(v ...interface{}) {
Elogger.Error(v...)
}
// Errorf implement ILogger
func Errorf(format string, v ...interface{}) {
Elogger.Errorf(format, v...)
}
// Warn implement ILogger
func Fatal(v ...interface{}) {
Elogger.Fatal(v...)
}
// Warnf implement ILogger
func Fatalf(format string, v ...interface{}) {
Elogger.Fatalf(format, v...)
}
// Level implement ILogger
func GetLevel() LogLevel {
return Elogger.Level()
}
// Level implement ILogger
func GetLevelString() string {
return Elogger.LevelString()
}
// SetLevel implement ILogger
func SetLevel(l LogLevel) {
Elogger.SetLevel(l)
}

71
lib/log/partition.go Normal file
View File

@@ -0,0 +1,71 @@
package log
import (
"io"
"time"
rotatelogs "github.com/lestrrat/go-file-rotatelogs"
)
type WriteSyncer interface {
io.Writer
Sync() error
}
// 得到LogWriter
func getLogWriter(filePath string, period, count int) WriteSyncer {
warnIoWriter := getWriter(filePath, period, count)
return addSync(warnIoWriter)
}
// 日志文件切割
func getWriter(filename string, period, count int) io.Writer {
// 保存日志count天每period小时分割一次日志
duration := time.Hour * time.Duration(period)
var logfile string
if period >= 24 {
logfile = filename + "-%Y%m%d"
} else {
logfile = filename + "-%Y%m%d%H"
}
hook, err := rotatelogs.New(
logfile,
rotatelogs.WithLinkName(filename),
// rotatelogs.WithMaxAge(duration),
rotatelogs.WithRotationCount(count),
rotatelogs.WithRotationTime(duration),
rotatelogs.WithLocation(time.Local),
)
//保存日志30天每1分钟分割一次日志
/*
hook, err := rotatelogs.New(
filename+"_%Y%m%d%H%M.log",
rotatelogs.WithLinkName(filename),
rotatelogs.WithMaxAge(time.Hour*24*30),
rotatelogs.WithRotationTime(time.Minute*1),
)
*/
if err != nil {
panic(err)
}
return hook
}
func addSync(w io.Writer) WriteSyncer {
switch w := w.(type) {
case WriteSyncer:
return w
default:
return writerWrapper{w}
}
}
type writerWrapper struct {
io.Writer
}
func (w writerWrapper) Sync() error {
return nil
}

View File

@@ -0,0 +1,77 @@
package midware
import (
"encoding/json"
"net/http"
"strings"
"time"
"ems.agt/lib/dborm"
"ems.agt/lib/services"
)
// 登录策略限制登录时间和访问ip范围
func ArrowIPAddr(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ipAddr := strings.Split(r.RemoteAddr, ":")[0]
// 读取配置信息 登录策略设置
result, err := dborm.XormGetConfig("Security", "loginSecurity")
if err != nil {
next.ServeHTTP(w, r)
return
}
data := make(map[string]any)
err = json.Unmarshal([]byte(result["value_json"].(string)), &data)
if err != nil {
next.ServeHTTP(w, r)
return
}
// 开关
switchStr := data["switch"].(string)
if switchStr == "0" {
next.ServeHTTP(w, r)
return
}
ipRange := data["ipRange"].(string)
logintimeRange := data["logintime_range"].(string)
// 检查ip
ips := strings.Split(ipRange, "/")
hasIP := false
for _, ip := range ips {
if ipAddr == ip {
hasIP = true
}
}
if !hasIP {
services.ResponseErrorWithJson(w, 502, "网关登录策略-IP限制: "+ipAddr)
return
}
// 检查开放时间
logintimeRangeArr := strings.Split(logintimeRange, " - ")
// 加载中国时区
loc, _ := time.LoadLocation("Asia/Shanghai")
// 获取当前时间
currentTime := time.Now().In(loc)
// 获取当前日期
currentDate := time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 0, 0, 0, 0, currentTime.Location())
ymd := currentDate.Format("2006-01-02")
// 定义开始时间和结束时间
startTime, _ := time.ParseInLocation("2006-01-02 15:04:05", ymd+" "+logintimeRangeArr[0], loc)
endTime, _ := time.ParseInLocation("2006-01-02 15:04:05", ymd+" "+logintimeRangeArr[1], loc)
// 判断当前时间是否在范围内
if currentTime.After(startTime) && currentTime.Before(endTime) {
next.ServeHTTP(w, r)
} else {
services.ResponseErrorWithJson(w, 502, "网关登录策略-不在开放时间范围内")
}
})
}

182
lib/midware/authorize.go Normal file
View File

@@ -0,0 +1,182 @@
package midware
import (
"context"
"fmt"
"net/http"
"ems.agt/lib/core/cache"
"ems.agt/lib/core/utils/ctx"
"ems.agt/lib/core/vo"
"ems.agt/lib/core/vo/result"
"ems.agt/lib/dborm"
)
// Authorize 用户身份授权认证校验
//
// 只需含有其中角色 "hasRoles": {"xxx"},
//
// 只需含有其中权限 "hasPerms": {"xxx"},
//
// 同时匹配其中角色 "matchRoles": {"xxx"},
//
// 同时匹配其中权限 "matchPerms": {"xxx"},
func Authorize(options map[string][]string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取请求头标识信息
accessToken := r.Header.Get("AccessToken")
if accessToken == "" {
ctx.JSON(w, 401, result.CodeMsg(401, "token error 无效身份授权"))
return
}
// 验证令牌 == 这里直接查数据库session
if !dborm.XormExistValidToken(accessToken, 0) {
ctx.JSON(w, 401, result.CodeMsg(401, "valid error 无效身份授权"))
return
}
se, err := dborm.XormUpdateSessionShakeTime(accessToken)
if err != nil {
ctx.JSON(w, 401, result.CodeMsg(401, "shake error 无效身份授权"))
return
}
// 获取缓存的用户信息
data, ok := cache.GetLocalTTL(se.AccountId)
if data == nil || !ok {
ctx.JSON(w, 401, result.CodeMsg(401, "info error 无效身份授权"))
return
}
loginUser := data.(vo.LoginUser)
// 登录用户角色权限校验
if options != nil {
var roles []string
for _, item := range loginUser.User.Roles {
roles = append(roles, item.RoleKey)
}
perms := loginUser.Permissions
verifyOk := verifyRolePermission(roles, perms, options)
if !verifyOk {
msg := fmt.Sprintf("无权访问 %s %s", r.Method, r.RequestURI)
ctx.JSON(w, 403, result.CodeMsg(403, msg))
return
}
}
// 在请求的 Context 中存储数据
rContext := r.Context()
rContext = context.WithValue(rContext, ctx.ContextKey("LoginUser"), loginUser)
// 继续处理请求
next.ServeHTTP(w, r.WithContext(rContext))
})
}
}
// verifyRolePermission 校验角色权限是否满足
//
// roles 角色字符数组
//
// perms 权限字符数组
//
// options 参数
func verifyRolePermission(roles, perms []string, options map[string][]string) bool {
// 直接放行 管理员角色或任意权限
if contains(roles, "admin") || contains(perms, "*:*:*") {
return true
}
opts := make([]bool, 4)
// 只需含有其中角色
hasRole := false
if arr, ok := options["hasRoles"]; ok && len(arr) > 0 {
hasRole = some(roles, arr)
opts[0] = true
}
// 只需含有其中权限
hasPerms := false
if arr, ok := options["hasPerms"]; ok && len(arr) > 0 {
hasPerms = some(perms, arr)
opts[1] = true
}
// 同时匹配其中角色
matchRoles := false
if arr, ok := options["matchRoles"]; ok && len(arr) > 0 {
matchRoles = every(roles, arr)
opts[2] = true
}
// 同时匹配其中权限
matchPerms := false
if arr, ok := options["matchPerms"]; ok && len(arr) > 0 {
matchPerms = every(perms, arr)
opts[3] = true
}
// 同时判断 含有其中
if opts[0] && opts[1] {
return hasRole || hasPerms
}
// 同时判断 匹配其中
if opts[2] && opts[3] {
return matchRoles && matchPerms
}
// 同时判断 含有其中且匹配其中
if opts[0] && opts[3] {
return hasRole && matchPerms
}
if opts[1] && opts[2] {
return hasPerms && matchRoles
}
return hasRole || hasPerms || matchRoles || matchPerms
}
// contains 检查字符串数组中是否包含指定的字符串
func contains(arr []string, target string) bool {
for _, str := range arr {
if str == target {
return true
}
}
return false
}
// some 检查字符串数组中含有其中一项
func some(origin []string, target []string) bool {
has := false
for _, t := range target {
for _, o := range origin {
if t == o {
has = true
break
}
}
if has {
break
}
}
return has
}
// every 检查字符串数组中同时包含所有项
func every(origin []string, target []string) bool {
match := true
for _, t := range target {
found := false
for _, o := range origin {
if t == o {
found = true
break
}
}
if !found {
match = false
break
}
}
return match
}

66
lib/midware/cors.go Normal file
View File

@@ -0,0 +1,66 @@
package midware
import (
"net/http"
"strings"
)
// Cors 跨域
func Cors(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 设置Vary头部
w.Header().Set("Vary", "Origin")
w.Header().Set("Keep-Alive", "timeout=5")
requestOrigin := r.Header.Get("Origin")
if requestOrigin == "" {
next.ServeHTTP(w, r)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
// OPTIONS
if r.Method == "OPTIONS" {
requestMethod := r.Header.Get("Access-Control-Request-Method")
if requestMethod == "" {
next.ServeHTTP(w, r)
return
}
// 响应最大时间值
w.Header().Set("Access-Control-Max-Age", "31536000")
// 允许方法
allowMethods := []string{
"OPTIONS",
"HEAD",
"GET",
"POST",
"PUT",
"DELETE",
"PATCH",
}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowMethods, ","))
// 允许请求头
allowHeaders := []string{
"Accesstoken",
"Content-Type",
"operationtype",
}
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ","))
w.WriteHeader(204)
return
}
// 暴露请求头
exposeHeaders := []string{"X-RepeatSubmit-Rest", "AccessToken"}
w.Header().Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
next.ServeHTTP(w, r)
})
}

77
lib/midware/midhandle.go Normal file
View File

@@ -0,0 +1,77 @@
package midware
import (
"net/http"
"strings"
"ems.agt/lib/log"
"ems.agt/lib/services"
"github.com/gorilla/mux"
)
func LoggerTrace(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Trace("Http Trace Info:")
log.Trace(" From Host:", r.RemoteAddr)
log.Trace(" To Host:", r.Host)
log.Debug(" RequestUri:", r.RequestURI)
log.Trace(" Method:", r.Method)
log.Trace(" Proto:", r.Proto)
log.Trace(" ContentLength:", r.ContentLength)
log.Trace(" User-Agent:", r.Header.Get("User-Agent"))
log.Trace(" Content-Type:", r.Header.Get("Content-Type"))
log.Trace(" AccessToken:", r.Header.Get("AccessToken"))
log.Trace("Trace End=====")
//body, _ := io.ReadAll(io.LimitReader(r.Body, global.RequestBodyMaxLen))
// nop-close to ready r.Body !!!
//r.Body = ioutil.NopCloser(bytes.NewReader(body))
//log.Trace("Body:", string(body))
// Call the next handler, which can be another middleware in the chain, or the final handler.
// if r.Method == "OPTIONS" {
// services.ResponseStatusOK201Accepted(w)
// return
// }
next.ServeHTTP(w, r)
})
}
// 已禁用
func OptionProcess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
services.ResponseStatusOK201Accepted(w)
return
}
next.ServeHTTP(w, r)
})
}
// 已禁用
func CheckPermission(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("AccessToken")
vars := mux.Vars(r)
management := vars["managedType"]
element := vars["elementTypeValue"]
object := vars["objectTypeValue"]
pack := "*"
if token != "" && element != "oauth" {
log.Debugf("token:%s, method:%s, management:%s, element:%s, object:%s, pack:%s", token, r.Method, management, element, object, pack)
exist, err := services.CheckUserPermission(token, strings.ToLower(r.Method), management, element, object, pack)
if err != nil {
log.Error("Failed to get permission:", err)
services.ResponseForbidden403NotPermission(w)
return
}
if !exist {
log.Error("Not permission!")
services.ResponseForbidden403NotPermission(w)
return
}
}
next.ServeHTTP(w, r)
})
}

1043
lib/mmlp/parse.go Normal file

File diff suppressed because it is too large Load Diff

184
lib/oauth/oauth.go Normal file
View File

@@ -0,0 +1,184 @@
package oauth
import (
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"fmt"
"math/rand"
"net/http"
"strings"
"time"
"ems.agt/lib/log"
"github.com/dgrijalva/jwt-go"
"golang.org/x/crypto/bcrypt"
)
// GenToken 生成Token值
func GenToken(mapClaims jwt.MapClaims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
var nowDate = time.Now()
var secret = fmt.Sprintf("%v%v", nowDate, "xxxx")
return token.SignedString([]byte(secret))
}
// GenerateToken 生成Token值
func GenerateToken(mapClaims jwt.MapClaims, key string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
return token.SignedString([]byte(key))
}
// ParseToken: "解析token"
func ParseToken(token string, secret string) (string, error) {
claim, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil {
return "", err
}
return claim.Claims.(jwt.MapClaims)["cmd"].(string), nil
}
func RandAccessToken(n int) (ret string) {
allString := "52661fbd-6b84-4fc2-aa1e-17879a5c6c9b"
ret = ""
for i := 0; i < n; i++ {
r := rand.Intn(len(allString))
ret = ret + allString[r:r+1]
}
return ret
}
const letterBytes = "abcdef0123456789"
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
var src = rand.NewSource(time.Now().UnixNano())
func RandStringBytes(n int) string {
b := make([]byte, n)
// A src.Int63() generates 63 random bits, enough for letterIdxMax characters!
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return string(b)
}
func GenRandToken(prefix string) string {
if prefix == "" {
return RandStringBytes(8) + "-" + RandStringBytes(4) + "-" +
RandStringBytes(4) + "-" + RandStringBytes(4) + "-" + RandStringBytes(12)
} else {
return prefix + "-" + RandStringBytes(8) + "-" + RandStringBytes(4) + "-" +
RandStringBytes(4) + "-" + RandStringBytes(4) + "-" + RandStringBytes(12)
}
}
type OAuthBody struct {
GrantType string
UserName string
Value string
}
/*
func IsValidOAuthInfo(oAuthBody OAuthBody) bool {
log.Debug("IsValidOAuthInfo processing... ")
conf := config.GetYamlConfig()
for _, o := range conf.Auth {
if oAuthBody.GrantType == o.Type && oAuthBody.UserName == o.User && oAuthBody.Value == o.Password {
return true
}
}
return false
}
*/
func IsWrongOAuthInfo(oAuthBody OAuthBody) bool {
log.Debug("IsWrongOAuthInfo processing... ")
if oAuthBody.GrantType == "" || strings.ToLower(oAuthBody.GrantType) != "password" ||
oAuthBody.UserName == "" || oAuthBody.Value == "" {
return true
}
return false
}
func GetTokenFromHttpRequest(r *http.Request) string {
for k, v := range r.Header {
log.Tracef("k:%s, v:%s", k, v)
if strings.ToLower(k) == "accesstoken" && len(v) != 0 {
log.Trace("AccessToken:", v[0])
return v[0]
}
}
return ""
}
// IsCarriedToken check token is carried
func IsCarriedToken(r *http.Request) (string, bool) {
token := GetTokenFromHttpRequest(r)
if token == "" {
return "", false
}
return token, true
}
// Bcrypt Encrypt 加密明文密码
func BcryptEncrypt(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(hashedBytes), err
}
// Bcrypt Compare 密文校验
func BcryptCompare(hashedPassword, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}
// sha256 crypt
func GetSHA256HashCode(stringMessage string) string {
message := []byte(stringMessage) //字符串转化字节数组
//创建一个基于SHA256算法的hash.Hash接口的对象
hash := sha256.New() //sha-256加密
//输入数据
hash.Write(message)
//计算哈希值
bytes := hash.Sum(nil)
//将字符串编码为16进制格式,返回字符串
hashCode := hex.EncodeToString(bytes)
//返回哈希值
return hashCode
}
// sha512 crypt
func GetSHA512HashCode(stringMessage string) string {
message := []byte(stringMessage) //字符串转化字节数组
//创建一个基于SHA256算法的hash.Hash接口的对象
hash := sha512.New() //SHA-512加密
//输入数据
hash.Write(message)
//计算哈希值
bytes := hash.Sum(nil)
//将字符串编码为16进制格式,返回字符串
hashCode := hex.EncodeToString(bytes)
//返回哈希值
return hashCode
}

37
lib/pair/pair.go Normal file
View File

@@ -0,0 +1,37 @@
package pair
type Pair struct {
Key int
Value int
}
type PairList []Pair
type Interface interface {
// Len is the number of elements in the collection.
Len() int
// Less reports whether the element with index i
// must sort before the element with index j.
//
// If both Less(i, j) and Less(j, i) are false,
// then the elements at index i and j are considered equal.
// Sort may place equal elements in any order in the final result,
// while Stable preserves the original input order of equal elements.
//
// Less must describe a transitive ordering:
// - if both Less(i, j) and Less(j, k) are true, then Less(i, k) must be true as well.
// - if both Less(i, j) and Less(j, k) are false, then Less(i, k) must be false as well.
//
// Note that floating-point comparison (the < operator on float32 or float64 values)
// is not a transitive ordering when not-a-number (NaN) values are involved.
// See Float64Slice.Less for a correct implementation for floating-point values.
Less(i, j int) bool
// Swap swaps the elements with indexes i and j.
Swap(i, j int)
}
func (p PairList) Len() int { return len(p) }
func (p PairList) Less(i, j int) bool { return p[i].Value < p[j].Value }
func (p PairList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

395
lib/routes/routes.go Normal file
View File

@@ -0,0 +1,395 @@
package routes
import (
"net/http"
// "log"
"ems.agt/features/aaaa"
"ems.agt/features/cm"
"ems.agt/features/dbrest"
"ems.agt/features/file"
"ems.agt/features/fm"
"ems.agt/features/lm"
"ems.agt/features/mml"
"ems.agt/features/monitor/monitor"
"ems.agt/features/monitor/psnet"
"ems.agt/features/nbi"
"ems.agt/features/pm"
"ems.agt/features/security"
"ems.agt/features/state"
sysconfig "ems.agt/features/sys_config"
sysdictdata "ems.agt/features/sys_dict_data"
sysdicttype "ems.agt/features/sys_dict_type"
sysmenu "ems.agt/features/sys_menu"
sysrole "ems.agt/features/sys_role"
sysuser "ems.agt/features/sys_user"
"ems.agt/features/trace"
udmuser "ems.agt/features/udm_user"
"ems.agt/features/ue"
"ems.agt/lib/midware"
"ems.agt/lib/services"
"github.com/gorilla/mux"
)
type Router struct {
Method string
Pattern string
Handler http.HandlerFunc
Middleware mux.MiddlewareFunc
}
var routers []Router
func init() {
Register("POST", security.UriOauthToken, security.LoginFromOMC, nil)
Register("POST", security.UriOauthHandshake, security.HandshakeFromOMC, nil)
Register("DELETE", security.UriOauthToken, security.LogoutFromOMC, nil)
Register("POST", security.CustomUriOauthToken, security.LoginFromOMC, nil)
Register("DELETE", security.CustomUriOauthToken, security.LogoutFromOMC, nil)
Register("POST", security.CustomUriOauthHandshake, security.HandshakeFromOMC, nil)
// System state
Register("GET", state.UriSysState, state.GetStateFromNF, nil)
Register("GET", state.UriSysState2, state.GetStateFromNF, nil)
Register("GET", state.UriSysInfoAll, state.GetAllSysinfoFromNF, nil)
Register("GET", state.UriSysInfoOne, state.GetOneSysinfoFromNF, nil)
Register("GET", state.UriLicenseInfoAll, state.GetAllLicenseInfoFromNF, nil)
Register("GET", state.UriLicenseInfoOne, state.GetOneLicenseInfoFromNF, nil)
Register("GET", state.CustomUriSysState, state.GetStateFromNF, nil)
Register("GET", state.CustomUriSysState2, state.GetStateFromNF, nil)
Register("GET", state.CustomUriSysInfoAll, state.GetAllSysinfoFromNF, nil)
Register("GET", state.CustomUriSysInfoOne, state.GetOneSysinfoFromNF, nil)
Register("GET", state.CustomUriLicenseInfoAll, state.GetAllLicenseInfoFromNF, nil)
Register("GET", state.CustomUriLicenseInfoOne, state.GetOneLicenseInfoFromNF, nil)
// 数据库直连操作权限
selectPermission := midware.Authorize(map[string][]string{
"hasRoles": {"dba"},
"hasPerms": {"db:select"},
})
updatePermission := midware.Authorize(map[string][]string{
"hasRoles": {"dba"},
"hasPerms": {"db:update"},
})
insertPermission := midware.Authorize(map[string][]string{
"hasRoles": {"dba"},
"hasPerms": {"db:insert"},
})
deletePermission := midware.Authorize(map[string][]string{
"hasRoles": {"dba"},
"hasPerms": {"db:delete"},
})
// database management
Register("GET", dbrest.XormGetDataUri, dbrest.DatabaseGetData, selectPermission)
Register("GET", dbrest.XormSelectDataUri, dbrest.DatabaseGetData, selectPermission)
Register("POST", dbrest.XormInsertDataUri, dbrest.DatabaseInsertData, insertPermission)
Register("PUT", dbrest.XormUpdateDataUri, dbrest.DatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.XormDeleteDataUri, dbrest.DatabaseDeleteData, deletePermission)
Register("GET", dbrest.CustomXormGetDataUri, dbrest.DatabaseGetData, selectPermission)
Register("GET", dbrest.CustomXormSelectDataUri, dbrest.DatabaseGetData, selectPermission)
Register("POST", dbrest.CustomXormInsertDataUri, dbrest.DatabaseInsertData, insertPermission)
Register("PUT", dbrest.CustomXormUpdateDataUri, dbrest.DatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.CustomXormDeleteDataUri, dbrest.DatabaseDeleteData, deletePermission)
Register("GET", dbrest.XormCommonUri, dbrest.DatabaseGetData, selectPermission)
Register("POST", dbrest.XormCommonUri, dbrest.DatabaseInsertData, insertPermission)
Register("PUT", dbrest.XormCommonUri, dbrest.DatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.XormCommonUri, dbrest.DatabaseDeleteData, deletePermission)
Register("GET", dbrest.XormDatabaseUri, dbrest.TaskDatabaseGetData, selectPermission)
Register("POST", dbrest.XormDatabaseUri, dbrest.TaskDatabaseInsertData, insertPermission)
Register("PUT", dbrest.XormDatabaseUri, dbrest.TaskDatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.XormDatabaseUri, dbrest.TaskDatabaseDeleteData, deletePermission)
Register("GET", dbrest.CustomXormCommonUri, dbrest.DatabaseGetData, selectPermission)
Register("POST", dbrest.CustomXormCommonUri, dbrest.DatabaseInsertData, insertPermission)
Register("PUT", dbrest.CustomXormCommonUri, dbrest.DatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.CustomXormCommonUri, dbrest.DatabaseDeleteData, deletePermission)
Register("GET", dbrest.XormExtDataUri, dbrest.ExtDatabaseGetData, selectPermission)
Register("POST", dbrest.XormExtDataUri, dbrest.ExtDatabaseInsertData, insertPermission)
Register("PUT", dbrest.XormExtDataUri, dbrest.ExtDatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.XormExtDataUri, dbrest.ExtDatabaseDeleteData, deletePermission)
Register("GET", dbrest.CustomXormExtDataUri, dbrest.ExtDatabaseGetData, selectPermission)
Register("POST", dbrest.CustomXormExtDataUri, dbrest.ExtDatabaseInsertData, insertPermission)
Register("PUT", dbrest.CustomXormExtDataUri, dbrest.ExtDatabaseUpdateData, updatePermission)
Register("DELETE", dbrest.CustomXormExtDataUri, dbrest.ExtDatabaseDeleteData, deletePermission)
// alarm restful Register
Register("POST", fm.UriAlarms, fm.PostAlarmFromNF, nil)
Register("Get", fm.UriAlarms, fm.GetAlarmFromNF, nil)
Register("POST", fm.CustomUriAlarms, fm.PostAlarmFromNF, nil)
Register("Get", fm.CustomUriAlarms, fm.GetAlarmFromNF, nil)
// performance restful Register
Register("POST", pm.PerformanceUri, pm.PostKPIReportFromNF, nil)
Register("POST", pm.MeasureTaskUri, pm.PostMeasureTaskToNF, nil)
Register("PUT", pm.MeasureTaskUri, pm.PutMeasureTaskToNF, nil)
Register("DELETE", pm.MeasureTaskUri, pm.DeleteMeasureTaskToNF, nil)
Register("PATCH", pm.MeasureTaskUri, pm.PatchMeasureTaskToNF, nil)
Register("POST", pm.MeasureReportUri, pm.PostMeasureReportFromNF, nil)
Register("POST", pm.MeasurementUri, pm.PostMeasurementFromNF, nil)
Register("GET", pm.MeasurementUri, pm.GetMeasurementFromNF, nil)
Register("POST", pm.CustomPerformanceUri, pm.PostKPIReportFromNF, nil)
Register("POST", pm.CustomMeasureTaskUri, pm.PostMeasureTaskToNF, nil)
Register("PUT", pm.CustomMeasureTaskUri, pm.PutMeasureTaskToNF, nil)
Register("DELETE", pm.CustomMeasureTaskUri, pm.DeleteMeasureTaskToNF, nil)
Register("PATCH", pm.CustomMeasureTaskUri, pm.PatchMeasureTaskToNF, nil)
Register("POST", pm.CustomMeasureReportUri, pm.PostMeasureReportFromNF, nil)
Register("POST", pm.CustomMeasurementUri, pm.PostMeasurementFromNF, nil)
Register("GET", pm.CustomMeasurementUri, pm.GetMeasurementFromNF, nil)
// parameter config management
Register("GET", cm.ParamConfigUri, cm.GetParamConfigFromNF, nil)
Register("POST", cm.ParamConfigUri, cm.PostParamConfigToNF, nil)
Register("PUT", cm.ParamConfigUri, cm.PutParamConfigToNF, nil)
Register("DELETE", cm.ParamConfigUri, cm.DeleteParamConfigToNF, nil)
Register("GET", cm.CustomParamConfigUri, cm.GetParamConfigFromNF, nil)
Register("POST", cm.CustomParamConfigUri, cm.PostParamConfigToNF, nil)
Register("PUT", cm.CustomParamConfigUri, cm.PutParamConfigToNF, nil)
Register("DELETE", cm.CustomParamConfigUri, cm.DeleteParamConfigToNF, nil)
// Get/Create/Modify/Delete NE info
Register("GET", cm.UriNeInfo, cm.GetNeInfo, nil)
Register("POST", cm.UriNeInfo, cm.PostNeInfo, nil)
Register("PUT", cm.UriNeInfo, cm.PutNeInfo, nil)
Register("DELETE", cm.UriNeInfo, cm.DeleteNeInfo, nil)
// Get/Create/Modify/Delete NE info
Register("GET", cm.CustomUriNeInfo, cm.GetNeInfo, nil)
Register("POST", cm.CustomUriNeInfo, cm.PostNeInfo, nil)
Register("PUT", cm.CustomUriNeInfo, cm.PutNeInfo, nil)
Register("DELETE", cm.CustomUriNeInfo, cm.DeleteNeInfo, nil)
//ne service action handle
Register("POST", cm.UriNeService, cm.PostNeServiceAction, nil)
//ne service action handle
Register("POST", cm.UriNeInstance, cm.PostNeInstanceAction, nil)
// Post MML command to NF
Register("POST", mml.UriMML, mml.PostMMLToNF, nil)
Register("POST", mml.UriMMLDiscard, mml.PostMMLToNF, nil)
Register("POST", mml.UriOmMmlExt, mml.PostMMLToOMC, nil)
Register("POST", mml.CustomUriMML, mml.PostMMLToNF, nil)
Register("POST", mml.CustomUriOmMmlExt, mml.PostMMLToOMC, nil)
// Northbound Get NRM
Register("GET", nbi.GetNRMUri, nbi.NBIGetNRMFromNF, nil)
Register("GET", nbi.CustomGetNRMUri, nbi.NBIGetNRMFromNF, nil)
// Import/Export NF CM
Register("GET", cm.NeCmUri, cm.ExportCmFromNF, nil)
Register("POST", cm.NeCmUri, cm.ImportCmToNF, nil)
Register("GET", cm.UriNeCmFile, cm.DownloadNeBackupFile, nil)
Register("DELETE", cm.UriNeCmFile, cm.DeleteNeBackupFile, nil)
Register("GET", cm.CustomNeCmUri, cm.ExportCmFromNF, nil)
Register("POST", cm.CustomNeCmUri, cm.ImportCmToNF, nil)
Register("GET", cm.CustomUriNeCmFile, cm.DownloadNeBackupFile, nil)
Register("DELETE", cm.CustomUriNeCmFile, cm.DeleteNeBackupFile, nil)
// Software management
Register("GET", cm.UriSoftware, cm.DownloadSoftwareFile, nil)
//Register("POST", cm.UriSoftware, cm.UploadSoftwareFile, nil)
Register("POST", cm.UriSoftware, cm.UploadSoftwareMultiFile, nil)
Register("DELETE", cm.UriSoftware, cm.DeleteSoftwareFile, nil)
Register("POST", cm.UriSoftwareNE, cm.DistributeSoftwareToNF, nil)
Register("PUT", cm.UriSoftwareNE, cm.ActiveSoftwareToNF, nil)
Register("PATCH", cm.UriSoftwareNE, cm.RollBackSoftwareToNF, nil)
Register("GET", cm.CustomUriSoftware, cm.DownloadSoftwareFile, nil)
Register("POST", cm.CustomUriSoftware, cm.UploadSoftwareFile, nil)
Register("DELETE", cm.CustomUriSoftware, cm.DeleteSoftwareFile, nil)
Register("POST", cm.CustomUriSoftwareNE, cm.DistributeSoftwareToNF, nil)
Register("PUT", cm.CustomUriSoftwareNE, cm.ActiveSoftwareToNF, nil)
Register("PATCH", cm.CustomUriSoftwareNE, cm.RollBackSoftwareToNF, nil)
// License management
Register("GET", cm.LicenseUri, cm.ExportCmFromNF, nil)
Register("POST", cm.LicenseUri, cm.ImportCmToNF, nil)
Register("DELETE", cm.LicenseUri, cm.ImportCmToNF, nil)
Register("POST", cm.NeLicenseUri, cm.ExportCmFromNF, nil)
Register("PUT", cm.NeLicenseUri, cm.ImportCmToNF, nil)
Register("PATCH", cm.NeLicenseUri, cm.ImportCmToNF, nil)
Register("POST", cm.CustomNeLicenseUri, cm.ExportCmFromNF, nil)
Register("PUT", cm.CustomNeLicenseUri, cm.ImportCmToNF, nil)
Register("PATCH", cm.CustomNeLicenseUri, cm.ImportCmToNF, nil)
// Trace management
Register("POST", trace.UriTraceTask, trace.PostTraceTaskToNF, nil)
Register("PUT", trace.UriTraceTask, trace.PutTraceTaskToNF, nil)
Register("DELETE", trace.UriTraceTask, trace.DeleteTraceTaskToNF, nil)
Register("GET", trace.UriTraceDecMsg, trace.ParseRawMsg2Html, nil)
Register("POST", trace.CustomUriTraceTask, trace.PostTraceTaskToNF, nil)
Register("PUT", trace.CustomUriTraceTask, trace.PutTraceTaskToNF, nil)
Register("DELETE", trace.CustomUriTraceTask, trace.DeleteTraceTaskToNF, nil)
// 网元发送执行 pcap抓包
Register("POST", trace.UriTcpdumpTask, trace.TcpdumpNeTask, midware.Authorize(nil))
Register("POST", trace.CustomUriTcpdumpTask, trace.TcpdumpNeTask, midware.Authorize(nil))
// 网元发送执行 抓包下载pcap文件
Register("POST", trace.UriTcpdumpPcapDownload, trace.TcpdumpPcapDownload, midware.Authorize(nil))
Register("POST", trace.CustomUriTcpdumpPcapDownload, trace.TcpdumpPcapDownload, midware.Authorize(nil))
// 网元发送执行UPF pcap抓包
Register("POST", trace.UriTcpdumpNeUPFTask, trace.TcpdumpNeUPFTask, nil)
Register("POST", trace.CustomUriTcpdumpNeUPFTask, trace.TcpdumpNeUPFTask, nil)
// file management
Register("POST", file.UriFile, file.UploadFile, nil)
Register("GET", file.UriFile, file.DownloadFile, nil)
Register("DELETE", file.UriFile, file.DeleteFile, nil)
Register("POST", file.CustomUriFile, file.UploadFile, nil)
Register("GET", file.CustomUriFile, file.DownloadFile, nil)
Register("DELETE", file.CustomUriFile, file.DeleteFile, nil)
// AAAA
Register("GET", aaaa.UriAAAASSO, aaaa.GetSSOFromAAAA, nil)
// AAAA
Register("GET", aaaa.CustomUriAAAASSO, aaaa.GetSSOFromAAAA, nil)
// UEInfo
Register("GET", ue.UriUEInfo, ue.GetUEInfoFromNF, nil)
Register("GET", ue.CustomUriUEInfo, ue.GetUEInfoFromNF, nil)
// UEInfo
Register("GET", ue.UriUENum, ue.GetUENumFromNF, nil)
Register("GET", ue.CustomUriUENum, ue.GetUENumFromNF, nil)
// NBInfo
Register("GET", ue.UriNBInfo, ue.GetNBInfoFromNF, nil)
Register("GET", ue.CustomUriNBInfo, ue.GetNBInfoFromNF, nil)
// 进程网络
Register("GET", psnet.UriWs, psnet.ProcessWs, nil)
Register("POST", psnet.UriStop, psnet.StopProcess, nil)
Register("POST", psnet.UriPing, psnet.Ping, nil)
// 主机CPU内存监控
Register("POST", monitor.UriLoad, monitor.LoadMonitor, nil)
Register("GET", monitor.UriNetOpt, monitor.Netoptions, nil)
Register("GET", monitor.UriIPAddr, monitor.IPAddr, nil)
Register("GET", monitor.UriIoOpt, monitor.Iooptions, nil)
// 文件资源
Register("GET", file.UriDiskList, file.DiskList, nil)
Register("POST", file.UriListFiles, file.ListFiles, nil)
// 数据库连接情况
Register("GET", dbrest.UriDbConnection, dbrest.DbConnection, nil)
Register("GET", dbrest.CustomUriDbConnection, dbrest.DbConnection, nil)
Register("POST", dbrest.UriDbStop, dbrest.DbStop, nil)
Register("POST", dbrest.CustomUriDbStop, dbrest.DbStop, nil)
// 系统备份
Register("POST", dbrest.UriDbBackup, dbrest.DbBackup, nil)
Register("POST", dbrest.CustomUriDbBackup, dbrest.DbBackup, nil)
Register("POST", dbrest.UriConfBackup, dbrest.ConfBackup, nil)
Register("POST", dbrest.CustomUriConfBackup, dbrest.ConfBackup, nil)
// 日志表备份
Register("POST", lm.ExtBackupDataUri, lm.ExtDatabaseBackupData, nil)
Register("POST", lm.CustomExtBackupDataUri, lm.ExtDatabaseBackupData, nil)
// 系统登录
Register("POST", security.UriLogin, security.LoginOMC, nil)
Register("POST", security.CustomUriLogin, security.LoginOMC, nil)
// 获取验证码
Register("GET", security.UriCaptchaImage, security.CaptchaImage, nil)
Register("GET", security.CustomUriCaptchaImage, security.CaptchaImage, nil)
// 登录用户信息
Register("GET", security.UriUserInfo, security.UserInfo, midware.Authorize(nil))
Register("GET", security.CustomUriUserInfo, security.UserInfo, midware.Authorize(nil))
// 登录用户路由信息
Register("GET", security.UriRouters, security.Routers, midware.Authorize(nil))
Register("GET", security.CustomUriRouters, security.Routers, midware.Authorize(nil))
// 参数配置信息接口添加到路由
for _, v := range sysconfig.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
// 字典类型信息接口添加到路由
for _, v := range sysdicttype.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
// 字典类型对应的字典数据信息接口添加到路由
for _, v := range sysdictdata.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
// 菜单接口添加到路由
for _, v := range sysmenu.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
// 角色接口添加到路由
for _, v := range sysrole.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
// 用户接口添加到路由
for _, v := range sysuser.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
// UDM 用户信息接口添加到路由
for _, v := range udmuser.Routers() {
Register(v.Method, v.Pattern, v.Handler, v.Middleware)
}
}
// To resolv rest POST/PUT/DELETE/PATCH cross domain
func OptionsProc(w http.ResponseWriter, r *http.Request) {
services.ResponseStatusOK204NoContent(w)
}
func NewRouter() *mux.Router {
r := mux.NewRouter()
// set custom handle for status 404/405
r.NotFoundHandler = services.CustomResponseNotFound404Handler()
r.MethodNotAllowedHandler = services.CustomResponseMethodNotAllowed405Handler()
r.Use(midware.LoggerTrace)
r.Use(midware.Cors)
//r.Use(midware.OptionProcess)
// r.Use(midware.ArrowIPAddr)
for _, router := range routers {
rt := r.Methods(router.Method).Subrouter()
rt.HandleFunc(router.Pattern, router.Handler)
if router.Middleware != nil {
rt.Use(router.Middleware)
}
}
return r
}
func Register(method, pattern string, handler http.HandlerFunc, middleware mux.MiddlewareFunc) {
routers = append(routers, Router{method, pattern, handler, middleware})
}

56
lib/run/exec_linux.go Normal file
View File

@@ -0,0 +1,56 @@
//go:build linux
// +build linux
package run
import (
"bytes"
"os/exec"
"ems.agt/lib/log"
)
func ExecCmd(command, path string) ([]byte, error) {
log.Debug("Exec command:", command)
cmd := exec.Command("/bin/bash", "-c", command)
cmd.Dir = path
out, err := cmd.CombinedOutput()
if err != nil {
return out, err
}
return out, nil
}
func ExecShell(command string) error {
in := bytes.NewBuffer(nil)
cmd := exec.Command("sh")
cmd.Stdin = in
in.WriteString(command)
in.WriteString("exit\n")
if err := cmd.Start(); err != nil {
return err
}
return nil
}
func ExecOsCmd(command, os string) error {
log.Debugf("Exec %s command:%s", os, command)
var cmd *exec.Cmd
switch os {
case "Linux":
cmd = exec.Command(command)
case "Windows":
cmd = exec.Command("cmd", "/C", command)
}
out, err := cmd.CombinedOutput()
log.Tracef("Exec output: %v", string(out))
if err != nil {
log.Error("exe cmd error: ", err)
return err
}
return nil
}

45
lib/run/exec_wasm.go Normal file
View File

@@ -0,0 +1,45 @@
//go:build wasm
// +build wasm
package run
import (
"os/exec"
"ems.agt/lib/log"
)
func ExecCmd(command, path string) ([]byte, error) {
log.Debug("Exec command:", command)
cmd := exec.Command("cmd", "/C", command)
cmd.Dir = path
out, err := cmd.CombinedOutput()
log.Tracef("Exec output: %v", string(out))
if err != nil {
log.Error("exe cmd error: ", err)
return out, err
}
return out, nil
}
func ExecOsCmd(command, os string) error {
log.Debugf("Exec %s command:%s", os, command)
var cmd *exec.Cmd
switch os {
case "Linux":
cmd = exec.Command(command)
case "Windows":
cmd = exec.Command("cmd", "/C", command)
}
out, err := cmd.CombinedOutput()
log.Tracef("Exec output: %v", string(out))
if err != nil {
log.Error("exe cmd error: ", err)
return err
}
return nil
}

45
lib/run/exec_windows.go Normal file
View File

@@ -0,0 +1,45 @@
//go:build windows
// +build windows
package run
import (
"os/exec"
"ems.agt/lib/log"
)
func ExecCmd(command, path string) ([]byte, error) {
log.Debug("Exec command:", command)
cmd := exec.Command("cmd", "/C", command)
cmd.Dir = path
out, err := cmd.CombinedOutput()
log.Tracef("Exec output: %v", string(out))
if err != nil {
log.Error("exe cmd error: ", err)
return out, err
}
return out, nil
}
func ExecOsCmd(command, os string) error {
log.Debugf("Exec %s command:%s", os, command)
var cmd *exec.Cmd
switch os {
case "Linux":
cmd = exec.Command(command)
case "Windows":
cmd = exec.Command("cmd", "/C", command)
}
out, err := cmd.CombinedOutput()
log.Tracef("Exec output: %v", string(out))
if err != nil {
log.Error("exe cmd error: ", err)
return err
}
return nil
}

423
lib/services/file.go Normal file
View File

@@ -0,0 +1,423 @@
package services
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"strconv"
"strings"
"ems.agt/lib/log"
)
const (
RootPath = "uploads/"
ChunkRootPath = "uploads_tmp/"
)
var (
// FilesMax 限制上传文件的大小为7 MB
FilesMax int64 = 32 << 20
// ValuesMax 限制POST字段内容的大小
ValuesMax int64 = 512
)
func GetPostFile(w http.ResponseWriter, r *http.Request) {
//获取文件流,第三个返回值是错误对象
file, header, _ := r.FormFile("file")
//读取文件流为[]byte
b, err := io.ReadAll(file)
if err != nil {
log.Error("Failed to ReadAll:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
//把文件保存到指定位置
err = os.WriteFile("./upload/test.zip", b, 0644)
if err != nil {
log.Error("Failed to WriteFile:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
//输出上传时文件名
log.Debug("filename:", header.Filename)
}
func GetUploadFile(w http.ResponseWriter, r *http.Request) {
log.Debug("GetUploadFile processing...")
file, err := os.Create("./test.zip")
if err != nil {
log.Error("Failed to Create:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
_, err = io.Copy(file, r.Body)
if err != nil {
log.Error("Failed to Copy:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
}
func GetUploadFormFile(w http.ResponseWriter, r *http.Request) {
// 设置最大的内存限制为32MB
r.ParseMultipartForm(32 << 20)
file, handler, err := r.FormFile("file")
if err != nil {
log.Error("Failed to FormFile:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
defer file.Close()
log.Debug("Header:%v", handler.Header)
f, err := os.OpenFile("./"+handler.Filename, os.O_WRONLY|os.O_CREATE, 0666)
if err != nil {
log.Error("Failed to OpenFile:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
defer f.Close()
_, err = io.Copy(f, file)
if err != nil {
log.Error("Failed to Copy:", err)
ResponseInternalServerError500ProcessError(w, err)
return
}
log.Debug("File uploaded successfully:", handler.Filename)
}
func HandleUploadFile(r *http.Request, path, newFileName string) (string, error) {
var filePath, fileName string
reader, err := r.MultipartReader()
if err != nil {
log.Error("Failed to MultipartReader:", err)
return "", err
}
for {
part, err := reader.NextPart()
if err == io.EOF {
break
} else if err != nil {
log.Error("Failed to NextPart:", err)
return "", err
}
log.Debugf("FileName=[%s], FormName=[%s]", part.FileName(), part.FormName())
if part.FileName() == "" { // this is FormData
data, _ := io.ReadAll(part)
log.Debugf("FormData=[%s]", string(data))
} else { // This is FileData
if newFileName != "" {
fileName = newFileName
} else {
fileName = part.FileName()
}
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
log.Error("Failed to Mkdir:", err)
return "", err
}
filePath = path + "/" + fileName
file, err := os.Create(filePath)
if err != nil {
log.Error("Failed to Create:", err)
return "", err
}
defer file.Close()
_, err = io.Copy(file, part)
if err != nil {
log.Error("Failed to Copy:", err)
return "", err
}
}
}
return fileName, nil
}
type UploadMultiFileData struct {
SoftwareFileName string `json:"softwareFileName"`
CmsFileName string `json:"cmsFileName"`
Datas map[string][]string `json:"datas"`
}
func HandleUploadMultiFile(r *http.Request, path, newFileName string) (*UploadMultiFileData, error) {
fileData := new(UploadMultiFileData)
// 解析multipart/form-data请求
err := r.ParseMultipartForm(100 << 20) // 100MB
if err != nil {
return fileData, err
}
// 获取文件和数据
softwareFile := r.MultipartForm.File["file"]
cmsFile := r.MultipartForm.File["cms"]
fileData.Datas = r.MultipartForm.Value
// 处理文件
if len(softwareFile) > 0 {
file := softwareFile[0]
// 打开文件
f, err := file.Open()
if err != nil {
return fileData, err
}
defer f.Close()
// 创建本地文件
dst, err := os.Create(path + "/" + file.Filename)
if err != nil {
return fileData, err
}
defer dst.Close()
fileData.SoftwareFileName = file.Filename
// 将文件内容拷贝到本地文件
_, err = io.Copy(dst, f)
if err != nil {
return fileData, err
}
}
// 处理文件
if len(cmsFile) > 0 {
file := cmsFile[0]
// 打开文件
f, err := file.Open()
if err != nil {
return fileData, err
}
defer f.Close()
// 创建本地文件
dst, err := os.Create(path + "/" + file.Filename)
if err != nil {
return fileData, err
}
defer dst.Close()
fileData.CmsFileName = file.Filename
// 将文件内容拷贝到本地文件
_, err = io.Copy(dst, f)
if err != nil {
return fileData, err
}
}
return fileData, nil
}
func HandleUploadFormFile(w http.ResponseWriter, r *http.Request) {
r.ParseMultipartForm(32 << 20)
//mForm := r.MultipartForm
for k, _ := range r.MultipartForm.File {
// k is the key of file part
file, fileHeader, err := r.FormFile(k)
if err != nil {
fmt.Println("inovke FormFile error:", err)
return
}
defer file.Close()
fmt.Printf("the uploaded file: name[%s], size[%d], header[%#v]\n",
fileHeader.Filename, fileHeader.Size, fileHeader.Header)
// store uploaded file into local path
localFileName := "./upload/" + fileHeader.Filename
out, err := os.Create(localFileName)
if err != nil {
fmt.Printf("failed to open the file %s for writing", localFileName)
return
}
defer out.Close()
_, err = io.Copy(out, file)
if err != nil {
fmt.Printf("copy file err:%s\n", err)
return
}
fmt.Printf("file %s uploaded ok\n", fileHeader.Filename)
}
}
func PostFileHandler(w http.ResponseWriter, r *http.Request) {
fmt.Println("PostFileHandler processing... ")
if !strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
// 不支持的 Content-Type 类型
fmt.Println("Invalid Content-Type: ", r.Header.Get("Content-Type"))
http.Error(w, " 不支持的 Content-Type 类型", http.StatusBadRequest)
return
}
// 整个请求的主体大小设置为7.5Mb
r.Body = http.MaxBytesReader(w, r.Body, FilesMax+ValuesMax)
reader, err := r.MultipartReader()
if err != nil {
fmt.Println(err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
for {
// A Part represents a single part in a multipart body.
part, err := reader.NextPart()
if err != nil {
if err == io.EOF {
break
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
fileName := part.FileName()
formName := part.FormName()
var buf = &bytes.Buffer{}
// 非文件字段部分大小限制验证非文件字段go中filename会是空
if fileName == "" {
var limitError = "请求主体中非文件字段" + formName + "超出大小限制"
err = uploadSizeLimit(buf, part, ValuesMax, limitError)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
continue
}
// 文件字段部分大小限制验证
var limitError = "请求主体中文件字段" + fileName + "超出大小限制"
err = uploadSizeLimit(buf, part, FilesMax, limitError)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// 文件创建部分
if err := uploadFileHandle(r.Header, fileName, buf); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// 非逻辑内容,仅为测试使用
var chunkNumber = r.Header.Get("chunk-number")
if chunkNumber == "" {
http.Error(w, "文件"+fileName+"上传成功", http.StatusOK)
} else {
http.Error(w, "分片文件"+fileName+chunkNumber+"上传成功", http.StatusOK)
}
}
}
// 上传内容大小限制
func uploadSizeLimit(buf *bytes.Buffer, part *multipart.Part, maxLimit int64, limitError string) error {
n, err := io.CopyN(buf, part, maxLimit+1)
if err != nil && err != io.EOF {
fmt.Println("PostFileHandler:", err)
return err
}
maxLimit -= n
if maxLimit < 0 {
return errors.New(limitError)
}
return nil
}
// uploadFileHandle handle upload file
func uploadFileHandle(header http.Header, fileName string, buf *bytes.Buffer) error {
var chunkNumberStr = header.Get("chunk-number")
// 1.普通文件上传处理
if chunkNumberStr == "" {
//创建文件并写入文件内容
return createFile(RootPath+fileName, buf.Bytes())
}
// 2.分片文件上传处理
//2.1读取分片编号
chunkNumber, err := strconv.Atoi(chunkNumberStr)
if err != nil {
return err
}
//2.2创建分片文件并写入分片内容
if err := createFile(fmt.Sprintf(ChunkRootPath+fileName+"%d.chunk", chunkNumber), buf.Bytes()); err != nil {
return err
}
//2.3确认是否上传完毕
if header.Get("chunk-final") == "true" {
//2.4合并文件
if err := mergeChunkFiles(fileName); err != nil {
return err
}
//2.5删除分片
for i := 0; ; i++ {
chunFileName := fmt.Sprintf(ChunkRootPath+fileName+"%d.chunk", i)
err := os.Remove(chunFileName)
if err != nil {
if os.IsNotExist(err) {
break
}
return err
}
}
}
return nil
}
// 创建文件并写入内容
func createFile(fileName string, res []byte) error {
newFile, err := os.Create(fileName)
if err != nil {
return err
}
defer func() {
_ = newFile.Close()
}()
bufferedWriter := bufio.NewWriter(newFile)
_, err = bufferedWriter.Write(res)
if err != nil && err != io.EOF {
return err
}
return bufferedWriter.Flush()
}
// 合并分片文件
func mergeChunkFiles(fileName string) error {
var (
n int64
err error
)
finalFile, err := os.Create(RootPath + fileName)
if err != nil {
return err
}
defer finalFile.Close()
// 将分片内容写入最终文件
for i := 0; ; i++ {
chunFile, err := os.Open(fmt.Sprintf(ChunkRootPath+fileName+"%d.chunk", i))
if err != nil {
if os.IsNotExist(err) {
break
}
return err
}
n, err = io.Copy(finalFile, chunFile)
if err != nil {
return err
}
err = chunFile.Close()
if err != nil {
return err
}
if n < 1 {
break
}
}
return nil
}

1000
lib/services/services.go Normal file

File diff suppressed because it is too large Load Diff

169
lib/session/session.go Normal file
View File

@@ -0,0 +1,169 @@
package session
import (
"crypto/rand"
"encoding/base64"
"errors"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"ems.agt/lib/log"
"ems.agt/lib/oauth"
"ems.agt/restagent/config"
)
// SessionMgr session manager
type SessManager struct {
name string
expires int64
lock sync.RWMutex
sessions map[string]*Session
}
// Session
type Session struct {
token string
time time.Time
permission []bool
values map[interface{}]interface{}
}
// NewSessionMgr create session manager
func NewSessManager(name string) *SessManager {
smgr := &SessManager{name: name, expires: (int64)(config.GetExpiresFromConfig()), sessions: make(map[string]*Session)}
go smgr.SessionGC()
return smgr
}
// NewSession create session
func (smgr *SessManager) NewSession(w http.ResponseWriter, r *http.Request, plist []bool) string {
smgr.lock.Lock()
defer smgr.lock.Unlock()
token := oauth.GenRandToken("omc") // Generate new token to session ID
session := &Session{token: token, time: time.Now(), permission: plist, values: make(map[interface{}]interface{})}
smgr.sessions[token] = session
return token
}
// EndSession
func (smgr *SessManager) EndSession(w http.ResponseWriter, r *http.Request) {
token := smgr.GetTokenFromHttpRequest(r)
smgr.lock.Lock()
defer smgr.lock.Unlock()
delete(smgr.sessions, token)
}
// Handshake session, restart session
func (smgr *SessManager) ShakeSession(token string) bool {
smgr.lock.Lock()
defer smgr.lock.Unlock()
for _, s := range smgr.sessions {
if token == s.token {
log.Debug("session time:", s.time)
s.time = time.Now()
return true
}
}
return false
}
// EndSessionByID end the session by session ID
func (smgr *SessManager) DeleteSession(token string) {
smgr.lock.Lock()
defer smgr.lock.Unlock()
delete(smgr.sessions, token)
}
// SetSessionValue set value fo session
func (smgr *SessManager) SetSessionValue(token string, key interface{}, value interface{}) error {
smgr.lock.Lock()
defer smgr.lock.Unlock()
if session, ok := smgr.sessions[token]; ok {
session.values[key] = value
return nil
}
return errors.New("invalid session ID")
}
// GetSessionValue get value fo session
func (smgr *SessManager) GetSessionValue(token string, key interface{}) (interface{}, error) {
smgr.lock.RLock()
defer smgr.lock.RUnlock()
if session, ok := smgr.sessions[token]; ok {
if val, ok := session.values[key]; ok {
return val, nil
}
}
return nil, errors.New("invalid session ID")
}
func (smgr *SessManager) GetTokenFromHttpRequest(r *http.Request) string {
for k, v := range r.Header {
if strings.ToLower(k) == "accesstoken" && len(v) != 0 {
log.Debug("AccessToken:", v[0])
return v[0]
}
}
return ""
}
// IsValidToken check token is valid or not
func (smgr *SessManager) IsValidToken(token string) bool {
smgr.lock.Lock()
defer smgr.lock.Unlock()
if _, ok := smgr.sessions[token]; ok {
return true
}
return false
}
// IsCarriedToken check token is carried
func (smgr *SessManager) IsCarriedToken(r *http.Request) (string, bool) {
token := smgr.GetTokenFromHttpRequest(r)
if token == "" {
return "", false
}
return token, true
}
// GetPermissionFromSession get permission from session by token
func (smgr *SessManager) GetPermissionFromSession(token string) []bool {
if s, ok := smgr.sessions[token]; ok {
return s.permission
}
return nil
}
// SessionGC maintain session
func (smgr *SessManager) SessionGC() {
smgr.lock.Lock()
defer smgr.lock.Unlock()
for token, session := range smgr.sessions {
if session.time.Unix()+smgr.expires < time.Now().Unix() {
delete(smgr.sessions, token)
}
}
time.AfterFunc(time.Duration(smgr.expires)*time.Second, func() { smgr.SessionGC() })
}
// NewSessionID generate unique ID
func (smgr *SessManager) NewSessionID() string {
b := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
nano := time.Now().UnixNano()
return strconv.FormatInt(nano, 10)
}
return base64.URLEncoding.EncodeToString(b)
}

45
lib/wsinfo/client.go Normal file
View File

@@ -0,0 +1,45 @@
package wsinfo
import (
"github.com/gorilla/websocket"
)
type Client struct {
ID string
Socket *websocket.Conn
Msg chan []byte
}
func NewWsClient(ID string, socket *websocket.Conn) *Client {
return &Client{
ID: ID,
Socket: socket,
Msg: make(chan []byte, 100),
}
}
func (c *Client) Read() {
defer func() {
close(c.Msg)
}()
for {
_, message, err := c.Socket.ReadMessage()
if err != nil {
return
}
ProcessData(c, message)
}
}
func (c *Client) Write() {
defer func() {
c.Socket.Close()
}()
for {
message, ok := <-c.Msg
if !ok {
return
}
_ = c.Socket.WriteMessage(websocket.TextMessage, message)
}
}

382
lib/wsinfo/process_data.go Normal file
View File

@@ -0,0 +1,382 @@
package wsinfo
import (
"encoding/json"
"fmt"
"sort"
"strings"
"sync"
"time"
"ems.agt/lib/log"
"github.com/shirou/gopsutil/v3/host"
"github.com/shirou/gopsutil/v3/net"
"github.com/shirou/gopsutil/v3/process"
)
type WsInput struct {
Type string `json:"type"`
DownloadProgress
PsProcessConfig
SSHSessionConfig
NetConfig
}
type DownloadProgress struct {
Keys []string `json:"keys"`
}
type PsProcessConfig struct {
Pid int32 `json:"pid"`
Name string `json:"name"`
Username string `json:"username"`
}
type SSHSessionConfig struct {
LoginUser string `json:"loginUser"`
LoginIP string `json:"loginIP"`
}
type NetConfig struct {
Port uint32 `json:"port"`
ProcessName string `json:"processName"`
ProcessID int32 `json:"processID"`
}
type PsProcessData struct {
PID int32 `json:"PID"`
Name string `json:"name"`
PPID int32 `json:"PPID"`
Username string `json:"username"`
Status string `json:"status"`
StartTime string `json:"startTime"`
NumThreads int32 `json:"numThreads"`
NumConnections int `json:"numConnections"`
CpuPercent string `json:"cpuPercent"`
DiskRead string `json:"diskRead"`
DiskWrite string `json:"diskWrite"`
CmdLine string `json:"cmdLine"`
Rss string `json:"rss"`
VMS string `json:"vms"`
HWM string `json:"hwm"`
Data string `json:"data"`
Stack string `json:"stack"`
Locked string `json:"locked"`
Swap string `json:"swap"`
CpuValue float64 `json:"cpuValue"`
RssValue uint64 `json:"rssValue"`
Envs []string `json:"envs"`
OpenFiles []process.OpenFilesStat `json:"openFiles"`
Connects []processConnect `json:"connects"`
}
type processConnect struct {
Type string `json:"type"`
Status string `json:"status"`
Laddr net.Addr `json:"localaddr"`
Raddr net.Addr `json:"remoteaddr"`
PID int32 `json:"PID"`
Name string `json:"name"`
}
type ProcessConnects []processConnect
func (p ProcessConnects) Len() int {
return len(p)
}
func (p ProcessConnects) Less(i, j int) bool {
return p[i].PID < p[j].PID
}
func (p ProcessConnects) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
type sshSession struct {
Username string `json:"username"`
PID int32 `json:"PID"`
Terminal string `json:"terminal"`
Host string `json:"host"`
LoginTime string `json:"loginTime"`
}
func ProcessData(c *Client, inputMsg []byte) {
wsInput := &WsInput{}
err := json.Unmarshal(inputMsg, wsInput)
if err != nil {
log.Errorf("unmarshal wsInput error,err %s", err.Error())
return
}
switch wsInput.Type {
case "ps":
res, err := getProcessData(wsInput.PsProcessConfig)
if err != nil {
return
}
c.Msg <- res
case "ssh":
res, err := getSSHSessions(wsInput.SSHSessionConfig)
if err != nil {
return
}
c.Msg <- res
case "net":
res, err := getNetConnections(wsInput.NetConfig)
if err != nil {
return
}
c.Msg <- res
}
}
type Process struct {
Total uint64 `json:"total"`
Written uint64 `json:"written"`
Percent float64 `json:"percent"`
Name string `json:"name"`
}
const (
b = uint64(1)
kb = 1024 * b
mb = 1024 * kb
gb = 1024 * mb
)
func formatBytes(bytes uint64) string {
switch {
case bytes < kb:
return fmt.Sprintf("%dB", bytes)
case bytes < mb:
return fmt.Sprintf("%.2fKB", float64(bytes)/float64(kb))
case bytes < gb:
return fmt.Sprintf("%.2fMB", float64(bytes)/float64(mb))
default:
return fmt.Sprintf("%.2fGB", float64(bytes)/float64(gb))
}
}
func getProcessData(processConfig PsProcessConfig) (res []byte, err error) {
var processes []*process.Process
processes, err = process.Processes()
if err != nil {
return
}
var (
result []PsProcessData
resultMutex sync.Mutex
wg sync.WaitGroup
numWorkers = 4
)
handleData := func(proc *process.Process) {
procData := PsProcessData{
PID: proc.Pid,
}
if processConfig.Pid > 0 && processConfig.Pid != proc.Pid {
return
}
if procName, err := proc.Name(); err == nil {
procData.Name = procName
} else {
procData.Name = "<UNKNOWN>"
}
if processConfig.Name != "" && !strings.Contains(procData.Name, processConfig.Name) {
return
}
if username, err := proc.Username(); err == nil {
procData.Username = username
}
if processConfig.Username != "" && !strings.Contains(procData.Username, processConfig.Username) {
return
}
procData.PPID, _ = proc.Ppid()
statusArray, _ := proc.Status()
if len(statusArray) > 0 {
procData.Status = strings.Join(statusArray, ",")
}
createTime, procErr := proc.CreateTime()
if procErr == nil {
t := time.Unix(createTime/1000, 0)
procData.StartTime = t.Format("2006-1-2 15:04:05")
}
procData.NumThreads, _ = proc.NumThreads()
connections, procErr := proc.Connections()
if procErr == nil {
procData.NumConnections = len(connections)
for _, conn := range connections {
if conn.Laddr.IP != "" || conn.Raddr.IP != "" {
procData.Connects = append(procData.Connects, processConnect{
Status: conn.Status,
Laddr: conn.Laddr,
Raddr: conn.Raddr,
})
}
}
}
procData.CpuValue, _ = proc.CPUPercent()
procData.CpuPercent = fmt.Sprintf("%.2f", procData.CpuValue) + "%"
menInfo, procErr := proc.MemoryInfo()
if procErr == nil {
procData.Rss = formatBytes(menInfo.RSS)
procData.RssValue = menInfo.RSS
procData.Data = formatBytes(menInfo.Data)
procData.VMS = formatBytes(menInfo.VMS)
procData.HWM = formatBytes(menInfo.HWM)
procData.Stack = formatBytes(menInfo.Stack)
procData.Locked = formatBytes(menInfo.Locked)
procData.Swap = formatBytes(menInfo.Swap)
} else {
procData.Rss = "--"
procData.Data = "--"
procData.VMS = "--"
procData.HWM = "--"
procData.Stack = "--"
procData.Locked = "--"
procData.Swap = "--"
procData.RssValue = 0
}
ioStat, procErr := proc.IOCounters()
if procErr == nil {
procData.DiskWrite = formatBytes(ioStat.WriteBytes)
procData.DiskRead = formatBytes(ioStat.ReadBytes)
} else {
procData.DiskWrite = "--"
procData.DiskRead = "--"
}
procData.CmdLine, _ = proc.Cmdline()
procData.OpenFiles, _ = proc.OpenFiles()
procData.Envs, _ = proc.Environ()
resultMutex.Lock()
result = append(result, procData)
resultMutex.Unlock()
}
chunkSize := (len(processes) + numWorkers - 1) / numWorkers
for i := 0; i < numWorkers; i++ {
wg.Add(1)
start := i * chunkSize
end := (i + 1) * chunkSize
if end > len(processes) {
end = len(processes)
}
go func(start, end int) {
defer wg.Done()
for j := start; j < end; j++ {
handleData(processes[j])
}
}(start, end)
}
wg.Wait()
sort.Slice(result, func(i, j int) bool {
return result[i].PID < result[j].PID
})
res, err = json.Marshal(result)
return
}
func getSSHSessions(config SSHSessionConfig) (res []byte, err error) {
var (
result []sshSession
users []host.UserStat
processes []*process.Process
)
processes, err = process.Processes()
if err != nil {
return
}
users, err = host.Users()
if err != nil {
return
}
for _, proc := range processes {
name, _ := proc.Name()
if name != "sshd" || proc.Pid == 0 {
continue
}
connections, _ := proc.Connections()
for _, conn := range connections {
for _, user := range users {
if user.Host == "" {
continue
}
if conn.Raddr.IP == user.Host {
if config.LoginUser != "" && !strings.Contains(user.User, config.LoginUser) {
continue
}
if config.LoginIP != "" && !strings.Contains(user.Host, config.LoginIP) {
continue
}
if terminal, err := proc.Cmdline(); err == nil {
if strings.Contains(terminal, user.Terminal) {
session := sshSession{
Username: user.User,
Host: user.Host,
Terminal: user.Terminal,
PID: proc.Pid,
}
t := time.Unix(int64(user.Started), 0)
session.LoginTime = t.Format("2006-1-2 15:04:05")
result = append(result, session)
}
}
}
}
}
}
res, err = json.Marshal(result)
return
}
var netTypes = [...]string{"tcp", "udp"}
func getNetConnections(config NetConfig) (res []byte, err error) {
var (
result []processConnect
proc *process.Process
)
for _, netType := range netTypes {
connections, _ := net.Connections(netType)
if err == nil {
for _, conn := range connections {
if config.ProcessID > 0 && config.ProcessID != conn.Pid {
continue
}
proc, err = process.NewProcess(conn.Pid)
if err == nil {
name, _ := proc.Name()
if name != "" && config.ProcessName != "" && !strings.Contains(name, config.ProcessName) {
continue
}
if config.Port > 0 && config.Port != conn.Laddr.Port && config.Port != conn.Raddr.Port {
continue
}
result = append(result, processConnect{
Type: netType,
Status: conn.Status,
Laddr: conn.Laddr,
Raddr: conn.Raddr,
PID: conn.Pid,
Name: name,
})
}
}
}
}
res, err = json.Marshal(result)
return
}