package config import ( "bytes" "embed" "fmt" "log" "os" "time" "github.com/spf13/pflag" "github.com/spf13/viper" ) var ( Version string = "-" BuildTime string = "-" GoVer string = "-" ) // 程序配置 var conf *viper.Viper = viper.New() // 初始化程序配置 func InitConfig(configDir *embed.FS) { initFlag() initViper(configDir) } // 指定参数绑定 func initFlag() { // --env prod pflag.String("env", "prod", "Specify Run Environment Configuration local or prod") // --c /usr/local/etc/omc/omc.yaml // -c /usr/local/etc/omc/omc.yaml pflag.StringP("config", "c", "/usr/local/etc/omc/omc.yaml", "Specify Configuration File") // --sqlPath ./sql/20250228.sql pflag.String("sqlPath", "", "Execution SQL File Path") // --sqlSource default pflag.String("sqlSource", "default", "Execution SQL Database Source") // --version // -V pVersion := pflag.BoolP("version", "V", false, "Output program version") // --help pHelp := pflag.Bool("help", false, "Viewing Help Commands") pflag.Parse() // 参数固定输出 if *pVersion { buildInfo := fmt.Sprintf("OMC \nBuildVer: %s\nBuildTime: %s\nBuildEnv: %s\n", Version, BuildTime, GoVer) fmt.Println(buildInfo) os.Exit(0) } if *pHelp { pflag.Usage() os.Exit(0) } conf.BindPFlags(pflag.CommandLine) } // 配置文件读取 func initViper(configDir *embed.FS) { // 如果配置文件名中没有扩展名,则需要设置Type conf.SetConfigType("yaml") // 读取默认配置文件 configDefaultByte, err := configDir.ReadFile("config/config.default.yaml") if err != nil { log.Fatalf("config default file read error: %s", err) return } if err = conf.ReadConfig(bytes.NewReader(configDefaultByte)); err != nil { log.Fatalf("config default file read error: %s", err) return } // 当期服务环境运行配置 => local env := conf.GetString("env") // log.Printf("current service environment configuration => %s \n", env) // 加载运行配置文件合并相同配置 envConfigPath := fmt.Sprintf("config/config.%s.yaml", env) configEnvByte, err := configDir.ReadFile(envConfigPath) if err != nil { log.Fatalf("config env %s file read error: %s", env, err) return } if err = conf.MergeConfig(bytes.NewReader(configEnvByte)); err != nil { log.Fatalf("config env %s file read error: %s", env, err) return } // 外部文件配置 externalConfig := conf.GetString("config") if externalConfig != "" { readExternalConfig(externalConfig) } // 记录程序开始运行的时间点 conf.Set("runTime", time.Now()) } // readExternalConfig 读取外部文件配置 func readExternalConfig(configPaht string) { f, err := os.Open(configPaht) if err != nil { log.Fatalf("config external file read error: %s", err) return } defer f.Close() if err = conf.MergeConfig(f); err != nil { log.Fatalf("config external file read error: %s", err) return } } // Env 获取运行服务环境 // local prod func Env() string { return conf.GetString("env") } // RunTime 程序开始运行的时间 func RunTime() time.Time { return conf.GetTime("runTime") } // Get 获取配置信息 // // Get("redis.defaultDataSourceName") func Get(key string) any { return conf.Get(key) } // Set 设置配置信息 // // Set("redis.defaultDataSourceName", "std") func Set(key string, value any) { conf.Set(key, value) } // GetAssetsDirFS 访问程序内全局资源访问 func GetAssetsDirFS() *embed.FS { return conf.Get("AssetsDir").(*embed.FS) } // SetAssetsDirFS 设置程序内全局资源访问 func SetAssetsDirFS(assetsDir *embed.FS) { conf.Set("AssetsDir", assetsDir) } // IsSystemUser 用户是否为系统管理员 func IsSystemUser(userId int64) bool { if userId <= 0 { return false } // 从配置中获取系统管理员ID列表 arr := Get("systemUser").([]any) for _, v := range arr { if fmt.Sprint(v) == fmt.Sprint(userId) { return true } } return false }