refactor: 删除冗余常量文件并整合常量定义

This commit is contained in:
TsMask
2025-02-20 09:50:29 +08:00
parent 1b435074cb
commit a1296b6fe6
63 changed files with 1823 additions and 1748 deletions

View File

@@ -1,254 +0,0 @@
package ctx
import (
"fmt"
"strings"
"be.ems/src/framework/config"
"be.ems/src/framework/constants/common"
"be.ems/src/framework/constants/roledatascope"
"be.ems/src/framework/constants/token"
"be.ems/src/framework/utils/ip2region"
"be.ems/src/framework/utils/ua"
"be.ems/src/framework/vo"
"golang.org/x/text/language"
"github.com/gin-gonic/gin"
)
// QueryMapString 查询参数转换Map
func QueryMapString(c *gin.Context) map[string]string {
queryValues := c.Request.URL.Query()
queryParams := make(map[string]string)
for key, values := range queryValues {
queryParams[key] = values[0]
}
return queryParams
}
// QueryMap 查询参数转换Map
func QueryMap(c *gin.Context) map[string]any {
queryValues := c.Request.URL.Query()
queryParams := make(map[string]any)
for key, values := range queryValues {
queryParams[key] = values[0]
}
return queryParams
}
// BodyJSONMap JSON参数转换Map
func BodyJSONMap(c *gin.Context) map[string]any {
params := make(map[string]any)
c.ShouldBindBodyWithJSON(&params)
return params
}
// RequestParamsMap 请求参数转换Map
func RequestParamsMap(c *gin.Context) map[string]any {
params := make(map[string]any)
// json
if strings.HasPrefix(c.ContentType(), "application/json") {
c.ShouldBindBodyWithJSON(&params)
}
// 表单
bodyParams := c.Request.PostForm
for key, value := range bodyParams {
params[key] = value[0]
}
// 查询
queryParams := c.Request.URL.Query()
for key, value := range queryParams {
params[key] = value[0]
}
return params
}
// IPAddrLocation 解析ip地址
func IPAddrLocation(c *gin.Context) (string, string) {
ip := ip2region.ClientIP(c.ClientIP())
location := ip2region.RealAddressByIp(ip)
return ip, location
}
// Authorization 解析请求头
func Authorization(c *gin.Context) string {
// Query请求查询
if authQuery, ok := c.GetQuery(token.ACCESS_TOKEN); ok && authQuery != "" {
return authQuery
}
// Header请求头
if authHeader := c.GetHeader(token.ACCESS_TOKEN); authHeader != "" {
return authHeader
}
// Query请求查询
if authQuery, ok := c.GetQuery(token.RESPONSE_FIELD); ok && authQuery != "" {
return authQuery
}
// Header请求头
authHeader := c.GetHeader(token.HEADER_KEY)
if authHeader == "" {
return ""
}
// 拆分 Authorization 请求头,提取 JWT 令牌部分
arr := strings.SplitN(authHeader, token.HEADER_PREFIX, 2)
if len(arr) < 2 {
return ""
}
return arr[1]
}
// UaOsBrowser 解析请求用户代理信息
func UaOsBrowser(c *gin.Context) (string, string) {
userAgent := c.GetHeader("user-agent")
uaInfo := ua.Info(userAgent)
browser := "app.common.noUaOsBrowser"
bName, bVersion := uaInfo.Browser()
if bName != "" && bVersion != "" {
browser = bName + " " + bVersion
}
os := "app.common.noUaOsBrowser"
bos := uaInfo.OS()
if bos != "" {
os = bos
}
return os, browser
}
// AcceptLanguage 解析客户端接收语言 zh中文 en: 英文
func AcceptLanguage(c *gin.Context) string {
preferredLanguage := language.English
// Query请求查询
if v, ok := c.GetQuery("language"); ok && v != "" {
tags, _, _ := language.ParseAcceptLanguage(v)
if len(tags) > 0 {
preferredLanguage = tags[0]
}
}
// Header请求头
if v := c.GetHeader("Accept-Language"); v != "" {
tags, _, _ := language.ParseAcceptLanguage(v)
if len(tags) > 0 {
preferredLanguage = tags[0]
}
}
// 只取前缀
lang := preferredLanguage.String()
arr := strings.Split(lang, "-")
return arr[0]
}
// LoginUser 登录用户信息
func LoginUser(c *gin.Context) (vo.LoginUser, error) {
value, exists := c.Get(common.CTX_LOGIN_USER)
if exists {
return value.(vo.LoginUser), nil
}
// 登录用户信息无效
return vo.LoginUser{}, fmt.Errorf("app.common.noLoginUser")
}
// LoginUserToUserID 登录用户信息-用户ID
func LoginUserToUserID(c *gin.Context) string {
value, exists := c.Get(common.CTX_LOGIN_USER)
if exists {
loginUser := value.(vo.LoginUser)
return loginUser.UserID
}
return ""
}
// LoginUserToUserName 登录用户信息-用户名称
func LoginUserToUserName(c *gin.Context) string {
value, exists := c.Get(common.CTX_LOGIN_USER)
if exists {
loginUser := value.(vo.LoginUser)
return loginUser.User.UserName
}
return ""
}
// LoginUserToDataScopeSQL 登录用户信息-角色数据范围过滤SQL字符串
func LoginUserToDataScopeSQL(c *gin.Context, deptAlias string, userAlias string) string {
dataScopeSQL := ""
// 登录用户信息
loginUser, err := LoginUser(c)
if err != nil {
return dataScopeSQL
}
userInfo := loginUser.User
// 如果是管理员,则不过滤数据
if config.IsAdmin(userInfo.UserID) {
return dataScopeSQL
}
// 无用户角色
if len(userInfo.Roles) <= 0 {
return dataScopeSQL
}
// 记录角色权限范围定义添加过, 非自定数据权限不需要重复拼接SQL
var scopeKeys []string
var conditions []string
for _, role := range userInfo.Roles {
dataScope := role.DataScope
if roledatascope.ALL == dataScope {
break
}
if roledatascope.CUSTOM != dataScope {
hasKey := false
for _, key := range scopeKeys {
if key == dataScope {
hasKey = true
break
}
}
if hasKey {
continue
}
}
if roledatascope.CUSTOM == dataScope {
sql := fmt.Sprintf(`%s.dept_id IN ( SELECT dept_id FROM sys_role_dept WHERE role_id = '%s' )`, deptAlias, role.RoleID)
conditions = append(conditions, sql)
}
if roledatascope.DEPT == dataScope {
sql := fmt.Sprintf(`%s.dept_id = '%s'`, deptAlias, userInfo.DeptID)
conditions = append(conditions, sql)
}
if roledatascope.DEPT_AND_CHILD == dataScope {
sql := fmt.Sprintf(`%s.dept_id IN ( SELECT dept_id FROM sys_dept WHERE dept_id = '%s' or find_in_set('%s' , ancestors ) )`, deptAlias, userInfo.DeptID, userInfo.DeptID)
conditions = append(conditions, sql)
}
if roledatascope.SELF == dataScope {
// 数据权限为仅本人且没有userAlias别名不查询任何数据
if userAlias == "" {
sql := fmt.Sprintf(`%s.parent_id = '0'`, deptAlias)
conditions = append(conditions, sql)
} else {
sql := fmt.Sprintf(`%s.user_id = '%s'`, userAlias, userInfo.UserID)
conditions = append(conditions, sql)
}
}
// 记录角色范围
scopeKeys = append(scopeKeys, dataScope)
}
// 构建查询条件语句
if len(conditions) > 0 {
dataScopeSQL = fmt.Sprintf(" AND ( %s ) ", strings.Join(conditions, " OR "))
}
return dataScopeSQL
}

View File

@@ -27,6 +27,12 @@ const (
//
// formatStr 时间格式 默认YYYY-MM-DD HH:mm:ss
func ParseStrToDate(dateStr, formatStr string) time.Time {
if dateStr == "" || dateStr == "<nil>" {
return time.Time{}
}
if formatStr == "" {
formatStr = YYYY_MM_DD_HH_MM_SS
}
t, err := time.Parse(formatStr, dateStr)
if err != nil {
logger.Infof("utils ParseStrToDate err %v", err)

View File

@@ -7,7 +7,7 @@ import (
"path/filepath"
"time"
"be.ems/src/framework/constants/uploadsubpath"
"be.ems/src/framework/constants"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/date"
@@ -26,7 +26,7 @@ func TransferExeclUploadFile(file *multipart.FileHeader) (string, error) {
// 上传资源路径
_, dir := resourceUpload()
// 新文件名称并组装文件地址
filePath := filepath.Join(uploadsubpath.IMPORT, date.ParseDatePath(time.Now()))
filePath := filepath.Join(constants.UPLOAD_IMPORT, date.ParseDatePath(time.Now()))
fileName := generateFileName(file.Filename)
writePathFile := filepath.Join(dir, filePath, fileName)
// 存入新文件路径
@@ -138,7 +138,7 @@ func WriteSheet(headerCells map[string]string, dataCells []map[string]any, fileN
// 上传资源路径
_, dir := resourceUpload()
filePath := filepath.Join(uploadsubpath.EXPORT, date.ParseDatePath(time.Now()))
filePath := filepath.Join(constants.UPLOAD_EXPORT, date.ParseDatePath(time.Now()))
saveFilePath := filepath.Join(dir, filePath, fileName)
// 创建文件目录

View File

@@ -12,7 +12,7 @@ import (
"time"
"be.ems/src/framework/config"
"be.ems/src/framework/constants/uploadsubpath"
"be.ems/src/framework/constants"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/date"
"be.ems/src/framework/utils/generate"
@@ -237,7 +237,7 @@ func TransferChunkUploadFile(file *multipart.FileHeader, index, identifier strin
// 上传资源路径
prefix, dir := resourceUpload()
// 新文件名称并组装文件地址
filePath := filepath.Join(uploadsubpath.CHUNK, date.ParseDatePath(time.Now()), identifier)
filePath := filepath.Join(constants.UPLOAD_CHUNK, date.ParseDatePath(time.Now()), identifier)
writePathFile := filepath.Join(dir, filePath, index)
// 存入新文件路径
err = transferToNewFile(file, writePathFile)
@@ -261,7 +261,7 @@ func ChunkCheckFile(identifier, originalFileName string) ([]string, error) {
}
// 上传资源路径
_, dir := resourceUpload()
dirPath := path.Join(uploadsubpath.CHUNK, date.ParseDatePath(time.Now()), identifier)
dirPath := path.Join(constants.UPLOAD_CHUNK, date.ParseDatePath(time.Now()), identifier)
readPath := path.Join(dir, dirPath)
fileList, err := getDirFileNameList(readPath)
if err != nil {
@@ -286,7 +286,7 @@ func ChunkMergeFile(identifier, originalFileName, subPath string) (string, error
// 上传资源路径
prefix, dir := resourceUpload()
// 切片存放目录
dirPath := path.Join(uploadsubpath.CHUNK, date.ParseDatePath(time.Now()), identifier)
dirPath := path.Join(constants.UPLOAD_CHUNK, date.ParseDatePath(time.Now()), identifier)
readPath := path.Join(dir, dirPath)
// 组合存放文件路径
fileName := generateFileName(originalFileName)
@@ -305,7 +305,7 @@ func ChunkMergeFile(identifier, originalFileName, subPath string) (string, error
// filePath 上传得到的文件路径 /upload....
// dst 新文件路径 /a/xx.pdf
func CopyUploadFile(filePath, dst string) error {
srcPath := ParseUploadFilePath(filePath)
srcPath := ParseUploadFileAbsPath(filePath)
src, err := os.Open(srcPath)
if err != nil {
return err
@@ -346,10 +346,10 @@ func ParseUploadFileDir(subPath string) string {
return filepath.Join(dir, filePath)
}
// ParseUploadFilePath 上传资源本地绝对资源路径
// ParseUploadFileAbsPath 上传资源本地绝对资源路径
//
// filePath 上传文件路径
func ParseUploadFilePath(filePath string) string {
func ParseUploadFileAbsPath(filePath string) string {
prefix, dir := resourceUpload()
absPath := strings.Replace(filePath, prefix, dir, 1)
return filepath.ToSlash(absPath)

View File

@@ -1,238 +0,0 @@
package ip2region
import (
"encoding/binary"
"fmt"
"os"
)
const (
HeaderInfoLength = 256
VectorIndexRows = 256
VectorIndexCols = 256
VectorIndexSize = 8
SegmentIndexBlockSize = 14
)
// --- Index policy define
type IndexPolicy int
const (
VectorIndexPolicy IndexPolicy = 1
BTreeIndexPolicy IndexPolicy = 2
)
func (i IndexPolicy) String() string {
switch i {
case VectorIndexPolicy:
return "VectorIndex"
case BTreeIndexPolicy:
return "BtreeIndex"
default:
return "unknown"
}
}
// --- Header define
type Header struct {
// data []byte
Version uint16
IndexPolicy IndexPolicy
CreatedAt uint32
StartIndexPtr uint32
EndIndexPtr uint32
}
func NewHeader(input []byte) (*Header, error) {
if len(input) < 16 {
return nil, fmt.Errorf("invalid input buffer")
}
return &Header{
Version: binary.LittleEndian.Uint16(input),
IndexPolicy: IndexPolicy(binary.LittleEndian.Uint16(input[2:])),
CreatedAt: binary.LittleEndian.Uint32(input[4:]),
StartIndexPtr: binary.LittleEndian.Uint32(input[8:]),
EndIndexPtr: binary.LittleEndian.Uint32(input[12:]),
}, nil
}
// --- searcher implementation
type Searcher struct {
handle *os.File
ioCount int
// use it only when this feature enabled.
// Preload the vector index will reduce the number of IO operations
// thus speedup the search process
vectorIndex []byte
// content buffer.
// running with the whole xdb file cached
contentBuff []byte
}
func baseNew(dbFile string, vIndex []byte, cBuff []byte) (*Searcher, error) {
var err error
// content buff first
if cBuff != nil {
return &Searcher{
vectorIndex: nil,
contentBuff: cBuff,
}, nil
}
// open the xdb binary file
handle, err := os.OpenFile(dbFile, os.O_RDONLY, 0600)
if err != nil {
return nil, err
}
return &Searcher{
handle: handle,
vectorIndex: vIndex,
}, nil
}
func NewWithFileOnly(dbFile string) (*Searcher, error) {
return baseNew(dbFile, nil, nil)
}
func NewWithVectorIndex(dbFile string, vIndex []byte) (*Searcher, error) {
return baseNew(dbFile, vIndex, nil)
}
func NewWithBuffer(cBuff []byte) (*Searcher, error) {
return baseNew("", nil, cBuff)
}
func (s *Searcher) Close() {
if s.handle != nil {
err := s.handle.Close()
if err != nil {
return
}
}
}
// GetIOCount return the global io count for the last search
func (s *Searcher) GetIOCount() int {
return s.ioCount
}
// SearchByStr find the region for the specified ip string
func (s *Searcher) SearchByStr(str string) (string, error) {
ip, err := CheckIP(str)
if err != nil {
return "", err
}
return s.Search(ip)
}
// Search find the region for the specified long ip
func (s *Searcher) Search(ip uint32) (string, error) {
// reset the global ioCount
s.ioCount = 0
// locate the segment index block based on the vector index
var il0 = (ip >> 24) & 0xFF
var il1 = (ip >> 16) & 0xFF
var idx = il0*VectorIndexCols*VectorIndexSize + il1*VectorIndexSize
var sPtr, ePtr = uint32(0), uint32(0)
if s.vectorIndex != nil {
sPtr = binary.LittleEndian.Uint32(s.vectorIndex[idx:])
ePtr = binary.LittleEndian.Uint32(s.vectorIndex[idx+4:])
} else if s.contentBuff != nil {
sPtr = binary.LittleEndian.Uint32(s.contentBuff[HeaderInfoLength+idx:])
ePtr = binary.LittleEndian.Uint32(s.contentBuff[HeaderInfoLength+idx+4:])
} else {
// read the vector index block
var buff = make([]byte, VectorIndexSize)
err := s.read(int64(HeaderInfoLength+idx), buff)
if err != nil {
return "", fmt.Errorf("read vector index block at %d: %w", HeaderInfoLength+idx, err)
}
sPtr = binary.LittleEndian.Uint32(buff)
ePtr = binary.LittleEndian.Uint32(buff[4:])
}
// fmt.Printf("sPtr=%d, ePtr=%d", sPtr, ePtr)
// binary search the segment index to get the region
var dataLen, dataPtr = 0, uint32(0)
var buff = make([]byte, SegmentIndexBlockSize)
var l, h = 0, int((ePtr - sPtr) / SegmentIndexBlockSize)
for l <= h {
m := (l + h) >> 1
p := sPtr + uint32(m*SegmentIndexBlockSize)
err := s.read(int64(p), buff)
if err != nil {
return "", fmt.Errorf("read segment index at %d: %w", p, err)
}
// decode the data step by step to reduce the unnecessary operations
sip := binary.LittleEndian.Uint32(buff)
if ip < sip {
h = m - 1
} else {
eip := binary.LittleEndian.Uint32(buff[4:])
if ip > eip {
l = m + 1
} else {
dataLen = int(binary.LittleEndian.Uint16(buff[8:]))
dataPtr = binary.LittleEndian.Uint32(buff[10:])
break
}
}
}
//fmt.Printf("dataLen: %d, dataPtr: %d", dataLen, dataPtr)
if dataLen == 0 {
return "", nil
}
// load and return the region data
var regionBuff = make([]byte, dataLen)
err := s.read(int64(dataPtr), regionBuff)
if err != nil {
return "", fmt.Errorf("read region at %d: %w", dataPtr, err)
}
return string(regionBuff), nil
}
// do the data read operation based on the setting.
// content buffer first or will read from the file.
// this operation will invoke the Seek for file based read.
func (s *Searcher) read(offset int64, buff []byte) error {
if s.contentBuff != nil {
cLen := copy(buff, s.contentBuff[offset:])
if cLen != len(buff) {
return fmt.Errorf("incomplete read: readed bytes should be %d", len(buff))
}
} else {
_, err := s.handle.Seek(offset, 0)
if err != nil {
return fmt.Errorf("seek to %d: %w", offset, err)
}
s.ioCount++
rLen, err := s.handle.Read(buff)
if err != nil {
return fmt.Errorf("handle read: %w", err)
}
if rLen != len(buff) {
return fmt.Errorf("incomplete read: readed bytes should be %d", len(buff))
}
}
return nil
}

View File

@@ -1,93 +0,0 @@
package ip2region
import (
"embed"
"strings"
"time"
"be.ems/src/framework/logger"
)
// 网络地址(内网)
const LOCAT_HOST = "127.0.0.1"
// 全局查询对象
var searcher *Searcher
//go:embed ip2region.xdb
var ip2regionDB embed.FS
func init() {
// 从 dbPath 加载整个 xdb 到内存
buf, err := ip2regionDB.ReadFile("ip2region.xdb")
if err != nil {
logger.Fatalf("failed error load xdb from : %s\n", err)
return
}
// 用全局的 cBuff 创建完全基于内存的查询对象。
base, err := NewWithBuffer(buf)
if err != nil {
logger.Errorf("failed error create searcher with content: %s\n", err)
return
}
// 赋值到全局查询对象
searcher = base
}
// RegionSearchByIp 查询IP所在地
//
// 国家|区域|省份|城市|ISP
func RegionSearchByIp(ip string) (string, int, int64) {
ip = ClientIP(ip)
if ip == LOCAT_HOST {
// "0|0|0|内网IP|内网IP"
return "0|0|0|app.common.noIPregion|app.common.noIPregion", 0, 0
}
tStart := time.Now()
region, err := searcher.SearchByStr(ip)
if err != nil {
logger.Errorf("failed to SearchIP(%s): %s\n", ip, err)
return "0|0|0|0|0", 0, 0
}
return region, 0, time.Since(tStart).Milliseconds()
}
// RealAddressByIp 地址IP所在地
//
// 218.4.167.70 江苏省 苏州市
func RealAddressByIp(ip string) string {
ip = ClientIP(ip)
if ip == LOCAT_HOST {
return "app.common.noIPregion" // 内网IP
}
region, err := searcher.SearchByStr(ip)
if err != nil {
logger.Errorf("failed to SearchIP(%s): %s\n", ip, err)
return "app.common.unknown" // 未知
}
parts := strings.Split(region, "|")
province := parts[2]
city := parts[3]
if province == "0" && city != "0" {
if city == "内网IP" {
return "app.common.noIPregion" // 内网IP
}
return city
}
return province + " " + city
}
// ClientIP 处理客户端IP地址显示iPv4
//
// 转换 ip2region.ClientIP(c.ClientIP())
func ClientIP(ip string) string {
if strings.HasPrefix(ip, "::ffff:") {
ip = strings.Replace(ip, "::ffff:", "", 1)
}
if ip == LOCAT_HOST || ip == "::1" {
return LOCAT_HOST
}
return ip
}

View File

@@ -1,175 +0,0 @@
package ip2region
import (
"fmt"
"os"
"strconv"
"strings"
)
var shiftIndex = []int{24, 16, 8, 0}
func CheckIP(ip string) (uint32, error) {
var ps = strings.Split(strings.TrimSpace(ip), ".")
if len(ps) != 4 {
return 0, fmt.Errorf("invalid ip address `%s`", ip)
}
var val = uint32(0)
for i, s := range ps {
d, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("the %dth part `%s` is not an integer", i, s)
}
if d < 0 || d > 255 {
return 0, fmt.Errorf("the %dth part `%s` should be an integer bettween 0 and 255", i, s)
}
val |= uint32(d) << shiftIndex[i]
}
// convert the ip to integer
return val, nil
}
func Long2IP(ip uint32) string {
return fmt.Sprintf("%d.%d.%d.%d", (ip>>24)&0xFF, (ip>>16)&0xFF, (ip>>8)&0xFF, ip&0xFF)
}
func MidIP(sip uint32, eip uint32) uint32 {
return uint32((uint64(sip) + uint64(eip)) >> 1)
}
// LoadHeader load the header info from the specified handle
func LoadHeader(handle *os.File) (*Header, error) {
_, err := handle.Seek(0, 0)
if err != nil {
return nil, fmt.Errorf("seek to the header: %w", err)
}
var buff = make([]byte, HeaderInfoLength)
rLen, err := handle.Read(buff)
if err != nil {
return nil, err
}
if rLen != len(buff) {
return nil, fmt.Errorf("incomplete read: readed bytes should be %d", len(buff))
}
return NewHeader(buff)
}
// LoadHeaderFromFile load header info from the specified db file path
func LoadHeaderFromFile(dbFile string) (*Header, error) {
handle, err := os.OpenFile(dbFile, os.O_RDONLY, 0600)
if err != nil {
return nil, fmt.Errorf("open xdb file `%s`: %w", dbFile, err)
}
defer func(handle *os.File) {
_ = handle.Close()
}(handle)
header, err := LoadHeader(handle)
if err != nil {
return nil, err
}
return header, nil
}
// LoadHeaderFromBuff wrap the header info from the content buffer
func LoadHeaderFromBuff(cBuff []byte) (*Header, error) {
return NewHeader(cBuff[0:256])
}
// LoadVectorIndex util function to load the vector index from the specified file handle
func LoadVectorIndex(handle *os.File) ([]byte, error) {
// load all the vector index block
_, err := handle.Seek(HeaderInfoLength, 0)
if err != nil {
return nil, fmt.Errorf("seek to vector index: %w", err)
}
var buff = make([]byte, VectorIndexRows*VectorIndexCols*VectorIndexSize)
rLen, err := handle.Read(buff)
if err != nil {
return nil, err
}
if rLen != len(buff) {
return nil, fmt.Errorf("incomplete read: readed bytes should be %d", len(buff))
}
return buff, nil
}
// LoadVectorIndexFromFile load vector index from a specified file path
func LoadVectorIndexFromFile(dbFile string) ([]byte, error) {
handle, err := os.OpenFile(dbFile, os.O_RDONLY, 0600)
if err != nil {
return nil, fmt.Errorf("open xdb file `%s`: %w", dbFile, err)
}
defer func() {
_ = handle.Close()
}()
vIndex, err := LoadVectorIndex(handle)
if err != nil {
return nil, err
}
return vIndex, nil
}
// LoadContent load the whole xdb content from the specified file handle
func LoadContent(handle *os.File) ([]byte, error) {
// get file size
fi, err := handle.Stat()
if err != nil {
return nil, fmt.Errorf("stat: %w", err)
}
size := fi.Size()
// seek to the head of the file
_, err = handle.Seek(0, 0)
if err != nil {
return nil, fmt.Errorf("seek to get xdb file length: %w", err)
}
var buff = make([]byte, size)
rLen, err := handle.Read(buff)
if err != nil {
return nil, err
}
if rLen != len(buff) {
return nil, fmt.Errorf("incomplete read: readed bytes should be %d", len(buff))
}
return buff, nil
}
// LoadContentFromFile load the whole xdb content from the specified db file path
func LoadContentFromFile(dbFile string) ([]byte, error) {
handle, err := os.OpenFile(dbFile, os.O_RDONLY, 0600)
if err != nil {
return nil, fmt.Errorf("open xdb file `%s`: %w", dbFile, err)
}
defer func() {
_ = handle.Close()
}()
cBuff, err := LoadContent(handle)
if err != nil {
return nil, err
}
return cBuff, nil
}

View File

@@ -9,7 +9,7 @@ import (
"time"
"be.ems/src/framework/config"
"be.ems/src/framework/constants/common"
"be.ems/src/framework/constants"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/cmd"
"be.ems/src/framework/utils/crypto"
@@ -106,8 +106,8 @@ func Launch() {
"code": Code, // 机器码
"useTime": time.Now().UnixMilli(), // 首次使用时间
common.LAUNCH_BOOTLOADER: true, // 启动引导
common.LAUNCH_BOOTLOADER + "Time": 0, // 引导完成时间
constants.LAUNCH_BOOTLOADER: true, // 启动引导
constants.LAUNCH_BOOTLOADER + "Time": 0, // 引导完成时间
}
codeFileWrite(LaunchInfo)
} else {
@@ -151,8 +151,8 @@ func SetLaunchInfo(info map[string]any) error {
// Bootloader 启动引导标记
func Bootloader(flag bool) error {
return SetLaunchInfo(map[string]any{
common.LAUNCH_BOOTLOADER: flag, // 启动引导 true开 false关
common.LAUNCH_BOOTLOADER + "Time": time.Now().UnixMilli(), // 引导完成时间
constants.LAUNCH_BOOTLOADER: flag, // 启动引导 true开 false关
constants.LAUNCH_BOOTLOADER + "Time": time.Now().UnixMilli(), // 引导完成时间
})
}

View File

@@ -91,31 +91,16 @@ func ConvertToCamelCase(str string) string {
return strings.Join(words, "")
}
// Bit 比特位为单位
// Bit 比特位为单位 1023.00 B --> 1.00 KB
func Bit(bit float64) string {
var GB, MB, KB string
if bit > float64(1<<30) {
GB = fmt.Sprintf("%0.2f", bit/(1<<30))
}
if bit > float64(1<<20) && bit < (1<<30) {
MB = fmt.Sprintf("%.2f", bit/(1<<20))
}
if bit > float64(1<<10) && bit < (1<<20) {
KB = fmt.Sprintf("%.2f", bit/(1<<10))
}
if GB != "" {
return GB + "GB"
} else if MB != "" {
return MB + "MB"
} else if KB != "" {
return KB + "KB"
} else {
return fmt.Sprintf("%vB", bit)
units := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
for i := 0; i < len(units); i++ {
if bit < 1024 || i == len(units)-1 {
return fmt.Sprintf("%.2f %s", bit, units[i])
}
bit /= 1024
}
return ""
}
// CronExpression 解析 Cron 表达式,返回下一次执行的时间戳(毫秒)
@@ -146,11 +131,11 @@ func SafeContent(value string) string {
}
// RemoveDuplicates 数组内字符串去重
func RemoveDuplicates(ids []string) []string {
func RemoveDuplicates(arr []string) []string {
uniqueIDs := make(map[string]bool)
uniqueIDSlice := make([]string, 0)
for _, id := range ids {
for _, id := range arr {
_, ok := uniqueIDs[id]
if !ok && id != "" {
uniqueIDs[id] = true
@@ -161,6 +146,29 @@ func RemoveDuplicates(ids []string) []string {
return uniqueIDSlice
}
// RemoveDuplicatesToArray 数组内字符串分隔去重转为字符数组
func RemoveDuplicatesToArray(keyStr, sep string) []string {
arr := make([]string, 0)
if keyStr == "" {
return arr
}
if strings.Contains(keyStr, sep) {
// 处理字符转数组后去重
strArr := strings.Split(keyStr, sep)
uniqueKeys := make(map[string]bool)
for _, str := range strArr {
_, ok := uniqueKeys[str]
if !ok && str != "" {
uniqueKeys[str] = true
arr = append(arr, str)
}
}
} else {
arr = append(arr, keyStr)
}
return arr
}
// Color 解析颜色 #fafafa
func Color(colorStr string) *color.RGBA {
// 去除 # 号

View File

@@ -1,139 +0,0 @@
package repo
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
"be.ems/src/framework/utils/parse"
)
// PageNumSize 分页页码记录数
func PageNumSize(pageNum, pageSize any) (int64, int64) {
// 记录起始索引
num := parse.Number(pageNum)
if num < 1 {
num = 1
}
// 显示记录数
size := parse.Number(pageSize)
if size < 1 {
size = 10
}
return num - 1, size
}
// SetFieldValue 判断结构体内是否存在指定字段并设置值
func SetFieldValue(obj any, fieldName string, value any) {
// 获取结构体的反射值
userValue := reflect.ValueOf(obj)
// 获取字段的反射值
fieldValue := userValue.Elem().FieldByName(fieldName)
// 检查字段是否存在
if fieldValue.IsValid() && fieldValue.CanSet() {
// 获取字段的类型
fieldType := fieldValue.Type()
// 转换传入的值类型为字段类型
switch fieldType.Kind() {
case reflect.String:
if value == nil {
fieldValue.SetString("")
} else {
fieldValue.SetString(fmt.Sprintf("%v", value))
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(fmt.Sprintf("%v", value), 10, 64)
if err != nil {
intValue = 0
}
fieldValue.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(fmt.Sprintf("%v", value), 10, 64)
if err != nil {
uintValue = 0
}
fieldValue.SetUint(uintValue)
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(fmt.Sprintf("%v", value), 64)
if err != nil {
floatValue = 0
}
fieldValue.SetFloat(floatValue)
case reflect.Struct:
fmt.Printf("%s time resolution %s %v \n", fieldName, fieldValue.Type(), value)
if fieldValue.Type() == reflect.TypeOf(time.Time{}) && value != nil {
// 解析 value 并转换为 time.Time 类型
parsedTime, err := time.Parse("2006-01-02 15:04:05 +0800 CST", fmt.Sprintf("%v", value))
if err != nil {
fmt.Println("Time resolution error:", err)
} else {
// 设置字段的值
fieldValue.Set(reflect.ValueOf(parsedTime))
}
}
default:
// 设置字段的值
fieldValue.Set(reflect.ValueOf(value).Convert(fieldValue.Type()))
}
}
}
// ConvertIdsSlice 将 []string 转换为 []any
func ConvertIdsSlice(ids []string) []any {
// 将 []string 转换为 []any
arr := make([]any, len(ids))
for i, v := range ids {
arr[i] = v
}
return arr
}
// 查询-参数值的占位符
func KeyPlaceholderByQuery(sum int) string {
placeholders := make([]string, sum)
for i := 0; i < sum; i++ {
placeholders[i] = "?"
}
return strings.Join(placeholders, ",")
}
// 插入-参数映射键值占位符 keys, placeholder, values
func KeyPlaceholderValueByInsert(params map[string]any) ([]string, string, []any) {
// 参数映射的键
keys := make([]string, len(params))
// 参数映射的值
values := make([]any, len(params))
sum := 0
for k, v := range params {
keys[sum] = k
values[sum] = v
sum++
}
// 参数值的占位符
placeholders := make([]string, sum)
for i := 0; i < sum; i++ {
placeholders[i] = "?"
}
return keys, strings.Join(placeholders, ","), values
}
// 更新-参数映射键值占位符 keys, values
func KeyValueByUpdate(params map[string]any) ([]string, []any) {
// 参数映射的键
keys := make([]string, len(params))
// 参数映射的值
values := make([]any, len(params))
sum := 0
for k, v := range params {
keys[sum] = k + "=?"
values[sum] = v
sum++
}
return keys, values
}

View File

@@ -1,103 +0,0 @@
package ssh
import (
"fmt"
"strings"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/cmd"
"be.ems/src/framework/utils/parse"
)
// FileListRow 文件列表行数据
type FileListRow struct {
FileType string `json:"fileType"` // 文件类型 dir, file, symlink
FileMode string `json:"fileMode"` // 文件的权限
LinkCount int64 `json:"linkCount"` // 硬链接数目
Owner string `json:"owner"` // 所属用户
Group string `json:"group"` // 所属组
Size string `json:"size"` // 文件的大小
ModifiedTime int64 `json:"modifiedTime"` // 最后修改时间,单位为秒
FileName string `json:"fileName"` // 文件的名称
}
// 文件列表
// search 文件名后模糊*
//
// return 行记录,异常
func FileList(sshClient *ConnSSH, path, search string) ([]FileListRow, error) {
var rows []FileListRow
rowStr := ""
// 发送命令
searchStr := "*"
if search != "" {
searchStr = search + searchStr
}
// cd /var/log && find. -maxdepth 1 -name'mme*' -exec ls -lthd --time-style=+%s {} +
cmdStr := fmt.Sprintf("cd %s && find . -maxdepth 1 -name '%s' -exec ls -lthd --time-style=+%%s {} +", path, searchStr)
// cd /var/log && ls -lthd --time-style=+%s mme*
// cmdStr := fmt.Sprintf("cd %s && ls -lthd --time-style=+%%s %s", path, searchStr)
// 是否远程客户端读取
if sshClient == nil {
resultStr, err := cmd.Exec(cmdStr)
if err != nil {
logger.Errorf("Ne FileList Path: %s, Search: %s, Error:%s", path, search, err.Error())
return rows, err
}
rowStr = resultStr
} else {
resultStr, err := sshClient.RunCMD(cmdStr)
if err != nil {
logger.Errorf("Ne FileList Path: %s, Search: %s, Error:%s", path, search, err.Error())
return rows, err
}
rowStr = resultStr
}
// 遍历组装
rowStrList := strings.Split(rowStr, "\n")
for _, rowStr := range rowStrList {
if rowStr == "" {
continue
}
// 使用空格对字符串进行切割
fields := strings.Fields(rowStr)
// 拆分不足7位跳过
if len(fields) != 7 {
continue
}
// 文件类型
fileMode := fields[0]
fileType := "file"
if fileMode[0] == 'd' {
fileType = "dir"
} else if fileMode[0] == 'l' {
fileType = "symlink"
}
// 文件名
fileName := fields[6]
if fileName == "." {
continue
} else if strings.HasPrefix(fileName, "./") {
fileName = strings.TrimPrefix(fileName, "./")
}
// 提取各个字段的值
rows = append(rows, FileListRow{
FileMode: fileMode,
FileType: fileType,
LinkCount: parse.Number(fields[1]),
Owner: fields[2],
Group: fields[3],
Size: fields[4],
ModifiedTime: parse.Number(fields[5]),
FileName: fileName,
})
}
return rows, nil
}

View File

@@ -1,157 +0,0 @@
package ssh
import (
"io"
"os"
"path/filepath"
"be.ems/src/framework/logger"
gosftp "github.com/pkg/sftp"
)
// SSHClientSFTP SSH客户端SFTP对象
type SSHClientSFTP struct {
Client *gosftp.Client
}
// Close 关闭会话
func (s *SSHClientSFTP) Close() {
if s.Client != nil {
s.Client.Close()
}
}
// CopyDirRemoteToLocal 复制目录-远程到本地
func (s *SSHClientSFTP) CopyDirRemoteToLocal(remoteDir, localDir string) error {
// 创建本地目录
err := os.MkdirAll(localDir, 0775)
if err != nil {
logger.Errorf("CopyDirRemoteToLocal failed to creating local directory %s: => %s", localDir, err.Error())
return err
}
// 列出远程目录中的文件和子目录
remoteFiles, err := s.Client.ReadDir(remoteDir)
if err != nil {
logger.Errorf("CopyDirRemoteToLocal failed to reading remote directory %s: => %s", remoteDir, err.Error())
return err
}
// 遍历远程文件和子目录并复制到本地
for _, remoteFile := range remoteFiles {
remotePath := filepath.ToSlash(filepath.Join(remoteDir, remoteFile.Name()))
localPath := filepath.ToSlash(filepath.Join(localDir, remoteFile.Name()))
if remoteFile.IsDir() {
// 如果是子目录,则递归复制子目录
err = s.CopyDirRemoteToLocal(remotePath, localPath)
if err != nil {
logger.Errorf("CopyDirRemoteToLocal failed to copying remote directory %s: => %s", remotePath, err.Error())
continue
}
} else {
// 如果是文件,则复制文件内容
if err := s.CopyFileRemoteToLocal(remotePath, localPath); err != nil {
logger.Errorf("CopyDirRemoteToLocal failed to opening remote file %s: => %s", remotePath, err.Error())
continue
}
}
}
return nil
}
// CopyDirLocalToRemote 复制目录-本地到远程
func (s *SSHClientSFTP) CopyDirLocalToRemote(localDir, remoteDir string) error {
// 遍历本地目录中的文件和子目录并复制到远程
err := filepath.Walk(localDir, func(localPath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 生成远程路径
remotePath := filepath.ToSlash(filepath.Join(remoteDir, localPath[len(localDir):]))
if info.IsDir() {
// 如果是子目录,则创建远程目录
if err := s.Client.MkdirAll(remotePath); err != nil {
logger.Errorf("CopyDirLocalToRemote failed to creating remote directory %s: => %s", remotePath, err.Error())
return err
}
} else {
// 如果是文件,则复制文件内容
if err := s.CopyFileLocalToRemote(localPath, remotePath); err != nil {
logger.Errorf("CopyDirLocalToRemote failed to copying remote file %s: => %s", localPath, err.Error())
return err
}
}
return nil
})
if err != nil {
logger.Errorf("CopyDirLocalToRemote failed to walking remote directory: => %s", err.Error())
return err
}
return nil
}
// CopyFileRemoteToLocal 复制文件-远程到本地
func (s *SSHClientSFTP) CopyFileRemoteToLocal(remotePath, localPath string) error {
if err := os.MkdirAll(filepath.Dir(localPath), 0775); err != nil {
return err
}
// 打开远程文件
remoteFile, err := s.Client.Open(remotePath)
if err != nil {
logger.Errorf("CopyFileRemoteToLocal failed to opening remote file: => %s", err.Error())
return err
}
defer remoteFile.Close()
// 创建本地文件
localFile, err := os.Create(localPath)
if err != nil {
logger.Errorf("CopyFileRemoteToLocal failed to creating local file: => %s", err.Error())
return err
}
defer localFile.Close()
// 复制文件内容
if _, err = io.Copy(localFile, remoteFile); err != nil {
logger.Errorf("CopyFileRemoteToLocal failed to copying contents: => %s", err.Error())
return err
}
return nil
}
// CopyFileLocalToRemote 复制文件-本地到远程
func (s *SSHClientSFTP) CopyFileLocalToRemote(localPath, remotePath string) error {
// 打开本地文件
localFile, err := os.Open(localPath)
if err != nil {
logger.Errorf("CopyFileLocalToRemote failed to opening local file: => %s", err.Error())
return err
}
defer localFile.Close()
// 创建远程目录
if err := s.Client.MkdirAll(filepath.Dir(remotePath)); err != nil {
logger.Errorf("CopyFileLocalToRemote failed to creating remote directory %s: => %s", remotePath, err.Error())
return err
}
// 创建远程文件
remoteFile, err := s.Client.Create(remotePath)
if err != nil {
logger.Errorf("CopyFileLocalToRemote failed to creating remote file: => %s", err.Error())
return err
}
defer remoteFile.Close()
// 复制文件内容
if _, err = io.Copy(remoteFile, localFile); err != nil {
logger.Errorf("CopyFileLocalToRemote failed to copying contents: => %s", err.Error())
return err
}
return nil
}

View File

@@ -1,229 +0,0 @@
package ssh
import (
"fmt"
"os"
"os/user"
"strings"
"time"
"be.ems/src/framework/logger"
"be.ems/src/framework/utils/cmd"
gosftp "github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
)
// ConnSSH 连接SSH对象
type ConnSSH struct {
User string `json:"user"` // 主机用户名
Addr string `json:"addr"` // 主机地址
Port int64 `json:"port"` // SSH端口
AuthMode string `json:"authMode"` // 认证模式0密码 1主机私钥
Password string `json:"password"` // 认证密码
PrivateKey string `json:"privateKey"` // 认证私钥
PassPhrase string `json:"passPhrase"` // 认证私钥密码
DialTimeOut time.Duration `json:"dialTimeOut"` // 连接超时断开
Client *gossh.Client `json:"client"`
LastResult string `json:"lastResult"` // 记最后一次执行命令的结果
}
// NewClient 创建SSH客户端
func (c *ConnSSH) NewClient() (*ConnSSH, error) {
// IPV6地址协议
proto := "tcp"
if strings.Contains(c.Addr, ":") {
proto = "tcp6"
c.Addr = fmt.Sprintf("[%s]", c.Addr)
}
addr := fmt.Sprintf("%s:%d", c.Addr, c.Port)
// ssh客户端配置
config := &gossh.ClientConfig{}
config.SetDefaults()
config.HostKeyCallback = gossh.InsecureIgnoreHostKey()
config.User = c.User
// 默认等待5s
if c.DialTimeOut == 0 {
c.DialTimeOut = 5 * time.Second
}
config.Timeout = c.DialTimeOut
// 认证模式-0密码 1私钥
if c.AuthMode == "1" {
var signer gossh.Signer
var err error
if len(c.PassPhrase) != 0 {
signer, err = gossh.ParsePrivateKeyWithPassphrase([]byte(c.PrivateKey), []byte(c.PassPhrase))
} else {
signer, err = gossh.ParsePrivateKey([]byte(c.PrivateKey))
}
if err != nil {
logger.Errorf("NewClient parse private key => %s", err.Error())
return nil, err
}
config.Auth = []gossh.AuthMethod{gossh.PublicKeys(signer)}
} else {
config.Auth = []gossh.AuthMethod{gossh.Password(c.Password)}
}
client, err := gossh.Dial(proto, addr, config)
if nil != err {
logger.Errorf("NewClient dial => %s", err.Error())
return c, err
}
c.Client = client
return c, nil
}
// Close 关闭当前SSH客户端
func (c *ConnSSH) Close() {
if c.Client != nil {
c.Client.Close()
}
}
// RunCMD 执行单次命令
func (c *ConnSSH) RunCMD(cmd string) (string, error) {
if c.Client == nil {
if _, err := c.NewClient(); err != nil {
return "", err
}
}
session, err := c.Client.NewSession()
if err != nil {
logger.Errorf("RunCMD failed to create session: => %s", err.Error())
return "", err
}
defer session.Close()
buf, err := session.CombinedOutput(cmd)
if err != nil {
logger.Infof("RunCMD failed run command: => %s", cmd)
logger.Errorf("RunCMD failed run command: => %s", err.Error())
}
c.LastResult = string(buf)
return c.LastResult, err
}
// NewClientSession 创建SSH客户端会话对象
func (c *ConnSSH) NewClientSession(cols, rows int) (*SSHClientSession, error) {
sshSession, err := c.Client.NewSession()
if err != nil {
return nil, err
}
stdin, err := sshSession.StdinPipe()
if err != nil {
return nil, err
}
comboWriter := new(singleWriter)
sshSession.Stdout = comboWriter
sshSession.Stderr = comboWriter
modes := gossh.TerminalModes{
gossh.ECHO: 1,
gossh.TTY_OP_ISPEED: 14400,
gossh.TTY_OP_OSPEED: 14400,
}
if err := sshSession.RequestPty("xterm", rows, cols, modes); err != nil {
return nil, err
}
if err := sshSession.Shell(); err != nil {
return nil, err
}
return &SSHClientSession{
Stdin: stdin,
Stdout: comboWriter,
Session: sshSession,
}, nil
}
// NewClientSFTP 创建SSH客户端SFTP对象
func (c *ConnSSH) NewClientSFTP() (*SSHClientSFTP, error) {
sftpClient, err := gosftp.NewClient(c.Client)
if err != nil {
logger.Errorf("NewClientSFTP failed to create sftp: => %s", err.Error())
return nil, err
}
return &SSHClientSFTP{
Client: sftpClient,
}, nil
}
// NewClientByLocalPrivate 创建SSH客户端-本地私钥(~/.ssh/id_rsa)直连
//
// ssh.ConnSSH{
// User: "user",
// Addr: "192.168.x.x",
// Port: body.Port,
// }
func (c *ConnSSH) NewClientByLocalPrivate() (*ConnSSH, error) {
c.AuthMode = "1"
privateKey, err := c.CurrentUserRsaKey(false)
if err != nil {
return nil, err
}
c.PrivateKey = privateKey
return c.NewClient()
}
// CurrentUserRsaKey 当前用户OMC使用的RSA私钥
// 默认读取私钥
// ssh-keygen -t rsa -P "" -f ~/.ssh/id_rsa
// ssh-keygen -y -f ~/.ssh/id_rsa > ~/.ssh/id_rsa.pub
func (c *ConnSSH) CurrentUserRsaKey(publicKey bool) (string, error) {
usr, err := user.Current()
if err != nil {
logger.Errorf("CurrentUserRsaKey get => %s", err.Error())
return "", err
}
// 是否存在私钥并创建
keyPath := fmt.Sprintf("%s/.ssh/id_rsa", usr.HomeDir)
if _, err := os.Stat(keyPath); err != nil {
if _, err := cmd.ExecWithCheck("ssh-keygen", "-t", "rsa", "-P", "", "-f", keyPath); err != nil {
logger.Errorf("CurrentUserPrivateKey ssh-keygen [%s] rsa => %s", usr.Username, err.Error())
}
}
// 读取用户默认的文件
if publicKey {
keyPath = keyPath + ".pub"
}
key, err := os.ReadFile(keyPath)
if err != nil {
logger.Errorf("CurrentUserRsaKey [%s] read => %s", usr.Username, err.Error())
return "", fmt.Errorf("read file %s fail", keyPath)
}
return string(key), nil
}
// SendToAuthorizedKeys 发送当前用户私钥到远程服务器进行授权密钥
func (c *ConnSSH) SendToAuthorizedKeys() error {
publicKey, err := c.CurrentUserRsaKey(true)
if err != nil {
return err
}
// "sudo mkdir -p ~/.ssh && sudo chown omcuser:omcuser ~/.ssh && sudo chmod 700 ~/.ssh"
// "sudo touch ~/.ssh/authorized_keys && sudo chown omcuser:omcuser ~/.ssh/authorized_keys && sudo chmod 600 ~/.ssh/authorized_keys"
// "echo 'ssh-rsa AAAAB3= pc-host\n' | sudo tee -a ~/.ssh/authorized_keys"
authorizedKeysEntry := fmt.Sprintln(strings.TrimSpace(publicKey))
cmdStrArr := []string{
fmt.Sprintf("sudo mkdir -p ~/.ssh && sudo chown %s:%s ~/.ssh && sudo chmod 700 ~/.ssh", c.User, c.User),
fmt.Sprintf("sudo touch ~/.ssh/authorized_keys && sudo chown %s:%s ~/.ssh/authorized_keys && sudo chmod 600 ~/.ssh/authorized_keys", c.User, c.User),
fmt.Sprintf("echo '%s' | sudo tee -a ~/.ssh/authorized_keys", authorizedKeysEntry),
}
_, err = c.RunCMD(strings.Join(cmdStrArr, " && "))
if err != nil {
logger.Errorf("SendAuthorizedKeys echo err %s", err.Error())
return err
}
return nil
}

View File

@@ -1,73 +0,0 @@
package ssh
import (
"bytes"
"fmt"
"io"
"sync"
gossh "golang.org/x/crypto/ssh"
)
// SSHClientSession SSH客户端会话对象
type SSHClientSession struct {
Stdin io.WriteCloser
Stdout *singleWriter
Session *gossh.Session
}
// Close 关闭会话
func (s *SSHClientSession) Close() {
if s.Stdin != nil {
s.Stdin.Close()
}
if s.Stdout != nil {
s.Stdout = nil
}
if s.Session != nil {
s.Session.Close()
}
}
// Write 写入命令 回车(\n)才会执行
func (s *SSHClientSession) Write(cmd string) (int, error) {
if s.Stdin == nil {
return 0, fmt.Errorf("ssh client session is nil to content write failed")
}
return s.Stdin.Write([]byte(cmd))
}
// Read 读取结果
func (s *SSHClientSession) Read() []byte {
if s.Stdout == nil {
return []byte{}
}
bs := s.Stdout.Bytes()
if len(bs) > 0 {
s.Stdout.Reset()
return bs
}
return []byte{}
}
// singleWriter SSH客户端会话消息
type singleWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *singleWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
func (w *singleWriter) Bytes() []byte {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Bytes()
}
func (w *singleWriter) Reset() {
w.mu.Lock()
defer w.mu.Unlock()
w.b.Reset()
}

View File

@@ -1,157 +0,0 @@
package token
import (
"encoding/json"
"fmt"
"time"
"be.ems/src/framework/config"
cachekeyConstants "be.ems/src/framework/constants/cachekey"
tokenConstants "be.ems/src/framework/constants/token"
"be.ems/src/framework/logger"
redisCahe "be.ems/src/framework/redis"
"be.ems/src/framework/utils/generate"
"be.ems/src/framework/utils/machine"
"be.ems/src/framework/vo"
jwt "github.com/golang-jwt/jwt/v5"
)
// Remove 清除登录用户信息UUID
func Remove(tokenStr string) string {
claims, err := Verify(tokenStr)
if err != nil {
logger.Errorf("token verify err %v", err)
return ""
}
// 清除缓存KEY
uuid := claims[tokenConstants.JWT_UUID].(string)
tokenKey := cachekeyConstants.LOGIN_TOKEN_KEY + uuid
hasKey, _ := redisCahe.Has("", tokenKey)
if hasKey {
redisCahe.Del("", tokenKey)
}
return claims[tokenConstants.JWT_NAME].(string)
}
// Create 令牌生成
func Create(loginUser *vo.LoginUser, ilobArgs ...string) string {
// 生成用户唯一tokne16位
loginUser.UUID = generate.Code(16)
// 设置请求用户登录客户端
loginUser.IPAddr = ilobArgs[0]
loginUser.LoginLocation = ilobArgs[1]
loginUser.OS = ilobArgs[2]
loginUser.Browser = ilobArgs[3]
// 设置用户令牌有效期并存入缓存
Cache(loginUser)
// 设置登录IP和登录时间
loginUser.User.LoginIP = loginUser.IPAddr
loginUser.User.LoginDate = loginUser.LoginTime
// 令牌算法 HS256 HS384 HS512
algorithm := config.Get("jwt.algorithm").(string)
var method *jwt.SigningMethodHMAC
switch algorithm {
case "HS512":
method = jwt.SigningMethodHS512
case "HS384":
method = jwt.SigningMethodHS384
case "HS256":
default:
method = jwt.SigningMethodHS256
}
// 生成令牌负荷绑定uuid标识
jwtToken := jwt.NewWithClaims(method, jwt.MapClaims{
tokenConstants.JWT_UUID: loginUser.UUID,
tokenConstants.JWT_KEY: loginUser.UserID,
tokenConstants.JWT_NAME: loginUser.User.UserName,
"exp": loginUser.ExpireTime,
"ait": loginUser.LoginTime,
})
// 生成令牌设置密钥
secret := config.Get("jwt.secret").(string)
tokenStr, err := jwtToken.SignedString([]byte(machine.Code + "@" + secret))
if err != nil {
logger.Infof("jwt sign err : %v", err)
return ""
}
return tokenStr
}
// Cache 缓存登录用户信息
func Cache(loginUser *vo.LoginUser) {
// 计算配置的有效期
expTime := config.Get("jwt.expiresIn").(int)
expTimestamp := time.Duration(expTime) * time.Minute
iatTimestamp := time.Now().UnixMilli()
loginUser.LoginTime = iatTimestamp
loginUser.ExpireTime = iatTimestamp + expTimestamp.Milliseconds()
// 根据登录标识将loginUser缓存
tokenKey := cachekeyConstants.LOGIN_TOKEN_KEY + loginUser.UUID
jsonBytes, err := json.Marshal(loginUser)
if err != nil {
return
}
redisCahe.SetByExpire("", tokenKey, string(jsonBytes), expTimestamp)
}
// RefreshIn 验证令牌有效期相差不足xx分钟自动刷新缓存
func RefreshIn(loginUser *vo.LoginUser) {
// 相差不足xx分钟自动刷新缓存
refreshTime := config.Get("jwt.refreshIn").(int)
refreshTimestamp := time.Duration(refreshTime) * time.Minute
// 过期时间
expireTimestamp := loginUser.ExpireTime
currentTimestamp := time.Now().UnixMilli()
if expireTimestamp-currentTimestamp <= refreshTimestamp.Milliseconds() {
Cache(loginUser)
}
}
// Verify 校验令牌是否有效
func Verify(tokenString string) (jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
// 判断加密算法是预期的加密算法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok {
secret := config.Get("jwt.secret").(string)
return []byte(machine.Code + "@" + secret), nil
}
return nil, jwt.ErrSignatureInvalid
})
if err != nil {
logger.Errorf("token String Verify : %v", err)
// 无效身份授权
return nil, fmt.Errorf("invalid identity authorization")
}
// 如果解析负荷成功并通过签名校验
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
return claims, nil
}
return nil, fmt.Errorf("token valid error")
}
// LoginUser 缓存的登录用户信息
func LoginUser(claims jwt.MapClaims) vo.LoginUser {
uuid := claims[tokenConstants.JWT_UUID].(string)
tokenKey := cachekeyConstants.LOGIN_TOKEN_KEY + uuid
hasKey, _ := redisCahe.Has("", tokenKey)
var loginUser vo.LoginUser
if hasKey {
loginUserStr, _ := redisCahe.Get("", tokenKey)
if loginUserStr == "" {
return loginUser
}
err := json.Unmarshal([]byte(loginUserStr), &loginUser)
if err != nil {
logger.Errorf("loginuser info json err : %v", err)
return loginUser
}
return loginUser
}
return loginUser
}