168 lines
3.8 KiB
Go
168 lines
3.8 KiB
Go
package config
|
||
|
||
import (
|
||
"bytes"
|
||
"embed"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"time"
|
||
|
||
"github.com/spf13/pflag"
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
var (
|
||
Version string = "-"
|
||
BuildTime string = "-"
|
||
GoVer string = "-"
|
||
)
|
||
|
||
// 程序配置
|
||
var conf *viper.Viper = viper.New()
|
||
|
||
// 初始化程序配置
|
||
func InitConfig(configDir *embed.FS) {
|
||
initFlag()
|
||
initViper(configDir)
|
||
}
|
||
|
||
// 指定参数绑定
|
||
func initFlag() {
|
||
// --env prod
|
||
pflag.String("env", "prod", "Specify Run Environment Configuration local or prod")
|
||
// --c /usr/local/etc/omc/omc.yaml
|
||
// -c /usr/local/etc/omc/omc.yaml
|
||
pflag.StringP("config", "c", "/usr/local/etc/omc/omc.yaml", "Specify Configuration File")
|
||
// --sqlPath ./sql/20250228.sql
|
||
pflag.String("sqlPath", "", "Execution SQL File Path")
|
||
// --sqlSource default
|
||
pflag.String("sqlSource", "default", "Execution SQL Database Source")
|
||
// --version
|
||
// -V
|
||
pVersion := pflag.BoolP("version", "V", false, "Output program version")
|
||
// --help
|
||
pHelp := pflag.Bool("help", false, "Viewing Help Commands")
|
||
|
||
pflag.Parse()
|
||
|
||
// 参数固定输出
|
||
if *pVersion {
|
||
buildInfo := fmt.Sprintf("OMC \nBuildVer: %s\nBuildTime: %s\nBuildEnv: %s\n", Version, BuildTime, GoVer)
|
||
fmt.Println(buildInfo)
|
||
os.Exit(0)
|
||
}
|
||
if *pHelp {
|
||
pflag.Usage()
|
||
os.Exit(0)
|
||
}
|
||
|
||
conf.BindPFlags(pflag.CommandLine)
|
||
}
|
||
|
||
// 配置文件读取
|
||
func initViper(configDir *embed.FS) {
|
||
// 如果配置文件名中没有扩展名,则需要设置Type
|
||
conf.SetConfigType("yaml")
|
||
// 读取默认配置文件
|
||
configDefaultByte, err := configDir.ReadFile("config/config.default.yaml")
|
||
if err != nil {
|
||
log.Fatalf("config default file read error: %s", err)
|
||
return
|
||
}
|
||
if err = conf.ReadConfig(bytes.NewReader(configDefaultByte)); err != nil {
|
||
log.Fatalf("config default file read error: %s", err)
|
||
return
|
||
}
|
||
|
||
// 当期服务环境运行配置 => local
|
||
env := conf.GetString("env")
|
||
// log.Printf("current service environment configuration => %s \n", env)
|
||
|
||
// 加载运行配置文件合并相同配置
|
||
envConfigPath := fmt.Sprintf("config/config.%s.yaml", env)
|
||
configEnvByte, err := configDir.ReadFile(envConfigPath)
|
||
if err != nil {
|
||
log.Fatalf("config env %s file read error: %s", env, err)
|
||
return
|
||
}
|
||
if err = conf.MergeConfig(bytes.NewReader(configEnvByte)); err != nil {
|
||
log.Fatalf("config env %s file read error: %s", env, err)
|
||
return
|
||
}
|
||
|
||
// 外部文件配置
|
||
externalConfig := conf.GetString("config")
|
||
if externalConfig != "" {
|
||
readExternalConfig(externalConfig)
|
||
}
|
||
|
||
// 记录程序开始运行的时间点
|
||
conf.Set("runTime", time.Now())
|
||
}
|
||
|
||
// readExternalConfig 读取外部文件配置
|
||
func readExternalConfig(configPaht string) {
|
||
f, err := os.Open(configPaht)
|
||
if err != nil {
|
||
log.Fatalf("config external file read error: %s", err)
|
||
return
|
||
}
|
||
defer f.Close()
|
||
|
||
if err = conf.MergeConfig(f); err != nil {
|
||
log.Fatalf("config external file read error: %s", err)
|
||
return
|
||
}
|
||
}
|
||
|
||
// Env 获取运行服务环境
|
||
// local prod
|
||
func Env() string {
|
||
return conf.GetString("env")
|
||
}
|
||
|
||
// RunTime 程序开始运行的时间
|
||
func RunTime() time.Time {
|
||
return conf.GetTime("runTime")
|
||
}
|
||
|
||
// Get 获取配置信息
|
||
//
|
||
// Get("redis.defaultDataSourceName")
|
||
func Get(key string) any {
|
||
return conf.Get(key)
|
||
}
|
||
|
||
// Set 设置配置信息
|
||
//
|
||
// Set("redis.defaultDataSourceName", "std")
|
||
func Set(key string, value any) {
|
||
conf.Set(key, value)
|
||
}
|
||
|
||
// GetAssetsDirFS 访问程序内全局资源访问
|
||
func GetAssetsDirFS() *embed.FS {
|
||
return conf.Get("AssetsDir").(*embed.FS)
|
||
}
|
||
|
||
// SetAssetsDirFS 设置程序内全局资源访问
|
||
func SetAssetsDirFS(assetsDir *embed.FS) {
|
||
conf.Set("AssetsDir", assetsDir)
|
||
}
|
||
|
||
// IsSystemUser 用户是否为系统管理员
|
||
func IsSystemUser(userId int64) bool {
|
||
if userId <= 0 {
|
||
return false
|
||
}
|
||
// 从配置中获取系统管理员ID列表
|
||
arr := Get("systemUser").([]any)
|
||
for _, v := range arr {
|
||
if fmt.Sprint(v) == fmt.Sprint(userId) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|