174 lines
4.1 KiB
Go
174 lines
4.1 KiB
Go
package config
|
||
|
||
import (
|
||
"bytes"
|
||
"embed"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"time"
|
||
|
||
libConfig "be.ems/src/lib_features/config"
|
||
"github.com/spf13/pflag"
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
//go:embed config/*.yaml
|
||
var configFiles embed.FS
|
||
|
||
// 初始化程序配置
|
||
func InitConfig() {
|
||
initFlag()
|
||
initViper()
|
||
}
|
||
|
||
// 指定参数绑定
|
||
func initFlag() {
|
||
// --env prod
|
||
pflag.String("env", "prod", "Specify Run Environment Configuration local or prod")
|
||
// --c /etc/restconf.yaml
|
||
// -c /etc/restconf.yaml
|
||
pConfig := pflag.StringP("config", "c", "./etc/restconf.yaml", "Specify Configuration File")
|
||
// --version
|
||
// -V
|
||
pVersion := pflag.BoolP("version", "V", false, "Output program version")
|
||
// --help
|
||
pHelp := pflag.Bool("help", false, "Viewing Help Commands")
|
||
|
||
pflag.Parse()
|
||
|
||
// 参数固定输出
|
||
if *pVersion {
|
||
buildInfo := libConfig.BuildInfo()
|
||
fmt.Println(buildInfo)
|
||
os.Exit(1)
|
||
}
|
||
if *pHelp {
|
||
pflag.Usage()
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 外层lib和features使用的配置
|
||
libConfig.ConfigRead(*pConfig)
|
||
|
||
viper.BindPFlags(pflag.CommandLine)
|
||
}
|
||
|
||
// 配置文件读取
|
||
func initViper() {
|
||
// 在当前工作目录中寻找配置
|
||
// viper.AddConfigPath("config")
|
||
// viper.AddConfigPath("src/config")
|
||
// 如果配置文件名中没有扩展名,则需要设置Type
|
||
viper.SetConfigType("yaml")
|
||
|
||
// 从 embed.FS 中读取默认配置文件内容
|
||
configDefault, err := configFiles.ReadFile("config/config.default.yaml")
|
||
if err != nil {
|
||
log.Fatalf("ReadFile config default file: %s", err)
|
||
return
|
||
}
|
||
// 设置默认配置文件内容到 viper
|
||
err = viper.ReadConfig(bytes.NewReader(configDefault))
|
||
if err != nil {
|
||
log.Fatalf("NewReader config default file: %s", err)
|
||
return
|
||
}
|
||
|
||
// // 配置文件的名称(无扩展名)
|
||
// viper.SetConfigName("config.default")
|
||
// // 读取默认配置文件
|
||
// if err := viper.ReadInConfig(); err != nil {
|
||
// log.Fatalf("fatal error config default file: %s", err)
|
||
// }
|
||
|
||
env := viper.GetString("env")
|
||
if env != "local" && env != "prod" {
|
||
log.Fatalf("fatal error config env for local or prod : %s", env)
|
||
}
|
||
log.Printf("Current service environment operation configuration => %s \n", env)
|
||
|
||
// 加载运行配置文件合并相同配置
|
||
if env == "prod" {
|
||
// viper.SetConfigName("config.prod")
|
||
// 从 embed.FS 中读取默认配置文件内容
|
||
configProd, err := configFiles.ReadFile("config/config.prod.yaml")
|
||
if err != nil {
|
||
log.Fatalf("ReadFile config prod file: %s", err)
|
||
return
|
||
}
|
||
// 设置默认配置文件内容到 viper
|
||
err = viper.MergeConfig(bytes.NewReader(configProd))
|
||
if err != nil {
|
||
log.Fatalf("NewReader config prod file: %s", err)
|
||
return
|
||
}
|
||
} else {
|
||
// viper.SetConfigName("config.local")
|
||
// 从 embed.FS 中读取默认配置文件内容
|
||
configLocal, err := configFiles.ReadFile("config/config.local.yaml")
|
||
if err != nil {
|
||
log.Fatalf("ReadFile config local file: %s", err)
|
||
return
|
||
}
|
||
// 设置默认配置文件内容到 viper
|
||
err = viper.MergeConfig(bytes.NewReader(configLocal))
|
||
if err != nil {
|
||
log.Fatalf("NewReader config local file: %s", err)
|
||
return
|
||
}
|
||
}
|
||
// if err := viper.MergeInConfig(); err != nil {
|
||
// log.Fatalf("fatal error config MergeInConfig: %s", err)
|
||
// }
|
||
|
||
// 合并外层lib和features使用配置
|
||
libConfig.ConfigInMerge()
|
||
|
||
// 记录程序开始运行的时间点
|
||
viper.Set("runTime", time.Now())
|
||
}
|
||
|
||
// Env 获取运行服务环境
|
||
// local prod
|
||
func Env() string {
|
||
return viper.GetString("env")
|
||
}
|
||
|
||
// RunTime 程序开始运行的时间
|
||
func RunTime() time.Time {
|
||
return viper.GetTime("runTime")
|
||
}
|
||
|
||
// Get 获取配置信息
|
||
//
|
||
// Get("framework.name")
|
||
func Get(key string) any {
|
||
return viper.Get(key)
|
||
}
|
||
|
||
// GetAssetsDirFS 访问程序内全局资源访问
|
||
func GetAssetsDirFS() embed.FS {
|
||
return viper.Get("AssetsDir").(embed.FS)
|
||
}
|
||
|
||
// SetAssetsDirFS 设置程序内全局资源访问
|
||
func SetAssetsDirFS(assetsDir embed.FS) {
|
||
viper.Set("AssetsDir", assetsDir)
|
||
}
|
||
|
||
// IsAdmin 用户是否为管理员
|
||
func IsAdmin(userID string) bool {
|
||
if userID == "" {
|
||
return false
|
||
}
|
||
// 从本地配置获取user信息
|
||
admins := Get("user.adminList").([]any)
|
||
for _, s := range admins {
|
||
if s.(string) == userID {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|