From 0e53b5c61d0eb1e06dfa2132c629de45f04cf26f Mon Sep 17 00:00:00 2001 From: TsMask <340112800@qq.com> Date: Thu, 6 Mar 2025 18:03:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E6=9B=B4=E6=96=B0=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=80=89=E9=A1=B9=E6=94=AF=E6=8C=81=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?SQL=E5=AF=BC=E5=85=A5=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/configuration.go | 1 + src/framework/config/config.go | 34 +++++---- src/framework/database/db/expand.go | 104 ++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 15 deletions(-) create mode 100644 src/framework/database/db/expand.go diff --git a/src/configuration.go b/src/configuration.go index 9cd57054..9654d441 100644 --- a/src/configuration.go +++ b/src/configuration.go @@ -25,6 +25,7 @@ func ConfigurationInit() { logger.InitLogger() // 连接数据库实例 db.Connect() + db.ImportSQL() // 连接Redis实例 redis.Connect() // 启动调度任务实例 diff --git a/src/framework/config/config.go b/src/framework/config/config.go index adad1415..c161a6f5 100644 --- a/src/framework/config/config.go +++ b/src/framework/config/config.go @@ -31,6 +31,10 @@ func initFlag() { // --c /etc/restconf.yaml // -c /etc/restconf.yaml pflag.StringP("config", "c", "./etc/restconf.yaml", "Specify Configuration File") + // --sqlPath ./sql/20250228.sql + pflag.String("sqlPath", "", "Execution SQL File Path") + // --sqlSource default + pflag.String("sqlSource", "default", "Execution SQL Database Source") // --version // -V pVersion := pflag.BoolP("version", "V", false, "Output program version") @@ -96,6 +100,21 @@ func initViper(configDir *embed.FS) { conf.Set("runTime", time.Now()) } +// readExternalConfig 读取外部文件配置(放弃旧的配置序列化时候才用) +func readExternalConfig(configPaht string) { + f, err := os.Open(configPaht) + if err != nil { + log.Fatalf("config external file read error: %s", err) + return + } + defer f.Close() + + if err = conf.MergeConfig(f); err != nil { + log.Fatalf("config external file read error: %s", err) + return + } +} + // 配置文件读取进行内部参数合并 func configInMerge(configFile string) { // 指定配置文件读取序列化 @@ -156,21 +175,6 @@ func SetAssetsDirFS(assetsDir *embed.FS) { conf.Set("AssetsDir", assetsDir) } -// readExternalConfig 读取外部文件配置 -func readExternalConfig(configPaht string) { - f, err := os.Open(configPaht) - if err != nil { - log.Fatalf("config external file read error: %s", err) - return - } - defer f.Close() - - if err = conf.MergeConfig(f); err != nil { - log.Fatalf("config external file read error: %s", err) - return - } -} - // IsSystemUser 用户是否为系统管理员 func IsSystemUser(userId int64) bool { if userId <= 0 { diff --git a/src/framework/database/db/expand.go b/src/framework/database/db/expand.go new file mode 100644 index 00000000..7aaa51b7 --- /dev/null +++ b/src/framework/database/db/expand.go @@ -0,0 +1,104 @@ +package db + +import ( + "bufio" + "log" + "os" + "path/filepath" + "strings" + + "gorm.io/gorm" + + "be.ems/src/framework/config" +) + +// ImportSQL 导入SQL +func ImportSQL() { + sqlPath := config.Get("sqlPath").(string) + if sqlPath == "" { + return + } + sqlSource := config.Get("sqlSource").(string) + if sqlSource == "" { + sqlSource = config.Get("database.defaultDataSourceName").(string) + } + + // 数据源 + db := DB(sqlSource) + if db == nil { + log.Fatalln("not database source") + return + } + + // 获取路径信息 + fileInfo, err := os.Stat(sqlPath) + if err != nil { + log.Fatalln(err.Error()) + return + } + + // 处理目录或文件 + if fileInfo.IsDir() { + // 处理目录 + files, err := os.ReadDir(sqlPath) + if err != nil { + log.Fatalln(err.Error()) + return + } + + for _, file := range files { + if file.IsDir() { + continue + } + if !strings.HasSuffix(file.Name(), ".sql") { + continue + } + processSQLFile(db, filepath.Join(sqlPath, file.Name())) + } + } else { + // 处理单个文件 + processSQLFile(db, sqlPath) + } + + log.Println("Import SQL End") + os.Exit(1) +} + +// 处理单个SQL文件的通用函数 +func processSQLFile(db *gorm.DB, filePath string) { + file, err := os.Open(filePath) + if err != nil { + log.Fatalln(err.Error()) + return + } + defer file.Close() + + // 逐行读取 SQL 文件 + scanner := bufio.NewScanner(file) + var sqlBuilder strings.Builder + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + // 跳过注释和空行 + if strings.HasPrefix(line, "--") || strings.TrimSpace(line) == "" { + continue + } + // 跳过配置语句 + if strings.HasPrefix(line, "/*!") { + continue + } + + sqlBuilder.WriteString(line + "\n") + + // 当遇到分号时,执行 SQL 语句 + if strings.HasSuffix(line, ";") { + // 执行 SQL 语句 + if err := db.Exec(sqlBuilder.String()).Error; err != nil { + log.Fatalln(err.Error()) + return + } + + sqlBuilder.Reset() + continue + } + } +}