414 lines
10 KiB
Go
414 lines
10 KiB
Go
// Package gormcol provides GORM model generation with type-safe column descriptors.
|
|
//
|
|
// This file contains the core generation logic including:
|
|
// - Model generation from database tables
|
|
// - Configuration management
|
|
// - Output directory handling and file cleanup
|
|
package gormcol
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gen"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// GenConfig holds configuration for model generation.
|
|
type GenConfig struct {
|
|
// OutputDir is the directory where generated files are written.
|
|
OutputDir string
|
|
// PkgName is the Go package name for generated files.
|
|
PkgName string
|
|
// TableFilter is a regex pattern to match table names.
|
|
// Example: "(ps_|b2b_).*" matches tables starting with ps_ or b2b_.
|
|
TableFilter string
|
|
// SelectedTables is a list of specific table names to generate.
|
|
// When set, TableFilter is ignored.
|
|
SelectedTables []string
|
|
// Clean determines whether to remove existing model files before generation.
|
|
// When false, only updates or creates models without deleting existing files.
|
|
Clean bool
|
|
}
|
|
|
|
// defaultConfig returns the default configuration values.
|
|
func defaultConfig() GenConfig {
|
|
return GenConfig{
|
|
OutputDir: "./app/model/dbmodel",
|
|
PkgName: "dbmodel",
|
|
TableFilter: "ps_.*",
|
|
Clean: true,
|
|
}
|
|
}
|
|
|
|
// GormGen handles GORM model generation with column descriptors.
|
|
type GormGen struct {
|
|
db *gorm.DB
|
|
cfg GenConfig
|
|
}
|
|
|
|
// New creates a new GormGen with default configuration.
|
|
func New(db *gorm.DB) *GormGen {
|
|
return &GormGen{db: db, cfg: defaultConfig()}
|
|
}
|
|
|
|
// NewWithConfig creates a new GormGen with custom configuration.
|
|
func NewWithConfig(db *gorm.DB, cfg GenConfig) *GormGen {
|
|
d := defaultConfig()
|
|
if cfg.OutputDir != "" {
|
|
d.OutputDir = cfg.OutputDir
|
|
}
|
|
if cfg.PkgName != "" {
|
|
d.PkgName = cfg.PkgName
|
|
}
|
|
if cfg.TableFilter != "" {
|
|
d.TableFilter = cfg.TableFilter
|
|
}
|
|
if len(cfg.SelectedTables) > 0 {
|
|
d.SelectedTables = cfg.SelectedTables
|
|
}
|
|
d.Clean = cfg.Clean
|
|
return &GormGen{db: db, cfg: d}
|
|
}
|
|
|
|
// ConnectDSN opens a MySQL/MariaDB connection from a DSN string.
|
|
func ConnectDSN(dsn string) (*gorm.DB, error) {
|
|
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Error),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect with dsn: %w", err)
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
// GenModels generates GORM model files and column descriptors for matched tables.
|
|
// It cleans the output directory (if Clean is true), generates models using gorm.io/gen,
|
|
// and appends <Model>Cols variables with type-safe Field descriptors.
|
|
func (m *GormGen) GenModels(ctx context.Context) error {
|
|
dir := m.cfg.OutputDir
|
|
if !strings.HasPrefix(dir, "./") {
|
|
dir = "./" + dir
|
|
}
|
|
absDir, err := filepath.Abs(dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// When Clean is false, backup existing files first.
|
|
var backupDir string
|
|
if !m.cfg.Clean {
|
|
backupDir, err = m.backupDir(absDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to backup existing files: %w", err)
|
|
}
|
|
defer os.RemoveAll(backupDir)
|
|
}
|
|
|
|
if err := m.cleanOutputDir(); err != nil {
|
|
return fmt.Errorf("failed to clean output dir: %w", err)
|
|
}
|
|
|
|
g := gen.NewGenerator(gen.Config{
|
|
OutPath: m.cfg.OutputDir,
|
|
ModelPkgPath: m.cfg.PkgName,
|
|
FieldNullable: true,
|
|
FieldWithIndexTag: true,
|
|
})
|
|
|
|
g.UseDB(m.db)
|
|
|
|
tableNames, err := m.db.Migrator().GetTables()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get table list: %w", err)
|
|
}
|
|
|
|
var matched int
|
|
var re *regexp.Regexp
|
|
|
|
if len(m.cfg.SelectedTables) > 0 {
|
|
tableSet := make(map[string]bool)
|
|
for _, t := range m.cfg.SelectedTables {
|
|
tableSet[t] = true
|
|
}
|
|
for _, tableName := range tableNames {
|
|
if tableSet[tableName] {
|
|
g.GenerateModel(tableName)
|
|
matched++
|
|
}
|
|
}
|
|
fmt.Printf("Selected %d tables\n", matched)
|
|
} else {
|
|
re, err = regexp.Compile("^" + m.cfg.TableFilter + "$")
|
|
if err != nil {
|
|
return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err)
|
|
}
|
|
for _, tableName := range tableNames {
|
|
if re.MatchString(tableName) {
|
|
g.GenerateModel(tableName)
|
|
matched++
|
|
}
|
|
}
|
|
fmt.Printf("Matched %d tables with filter %q\n", matched, m.cfg.TableFilter)
|
|
}
|
|
|
|
g.Execute()
|
|
|
|
if err := m.cleanupGeneratedFiles(); err != nil {
|
|
return fmt.Errorf("failed to cleanup generated files: %w", err)
|
|
}
|
|
|
|
// Restore backup if Clean was false.
|
|
if backupDir != "" {
|
|
if err := m.restoreBackup(absDir, backupDir); err != nil {
|
|
return fmt.Errorf("failed to restore backup: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := m.generateCols(); err != nil {
|
|
return fmt.Errorf("failed to generate column descriptors: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// backupDir creates a backup copy of the directory and returns the backup path.
|
|
func (m *GormGen) backupDir(dir string) (string, error) {
|
|
backupPath := dir + ".backup"
|
|
if _, err := os.Stat(backupPath); err == nil {
|
|
os.RemoveAll(backupPath)
|
|
}
|
|
|
|
if err := copyDir(dir, backupPath); err != nil {
|
|
return "", err
|
|
}
|
|
fmt.Printf("Backed up: %s -> %s\n", dir, backupPath)
|
|
return backupPath, nil
|
|
}
|
|
|
|
// restoreBackup restores files from backup that don't exist in the target directory.
|
|
func (m *GormGen) restoreBackup(dir, backupDir string) error {
|
|
entries, err := os.ReadDir(backupDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
restored := 0
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
targetPath := filepath.Join(dir, entry.Name())
|
|
if _, err := os.Stat(targetPath); os.IsNotExist(err) {
|
|
srcPath := filepath.Join(backupDir, entry.Name())
|
|
data, err := os.ReadFile(srcPath)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if err := os.WriteFile(targetPath, data, 0644); err != nil {
|
|
return err
|
|
}
|
|
fmt.Printf("Restored: %s\n", targetPath)
|
|
restored++
|
|
}
|
|
}
|
|
if restored > 0 {
|
|
fmt.Printf("Restored %d files from backup\n", restored)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// copyDir copies a directory recursively.
|
|
func copyDir(src, dst string) error {
|
|
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rel, _ := filepath.Rel(src, path)
|
|
dstPath := filepath.Join(dst, rel)
|
|
|
|
if info.IsDir() {
|
|
return os.MkdirAll(dstPath, info.Mode())
|
|
}
|
|
return copyFile(path, dstPath)
|
|
})
|
|
}
|
|
|
|
// copyFile copies a single file.
|
|
func copyFile(src, dst string) error {
|
|
data, err := os.ReadFile(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(dst, data, 0644)
|
|
}
|
|
|
|
// cleanOutputDir removes existing .go files from the output directory
|
|
// or creates it if it doesn't exist.
|
|
func (m *GormGen) cleanOutputDir() error {
|
|
dir := m.cfg.OutputDir
|
|
if !strings.HasPrefix(dir, "./") {
|
|
dir = "./" + dir
|
|
}
|
|
|
|
absDir, err := filepath.Abs(dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := os.Stat(absDir); os.IsNotExist(err) {
|
|
if err := os.MkdirAll(absDir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create output dir: %w", err)
|
|
}
|
|
fmt.Printf("Created: %s\n", absDir)
|
|
return nil
|
|
}
|
|
|
|
// Skip cleaning if Clean is false.
|
|
if !m.cfg.Clean {
|
|
return nil
|
|
}
|
|
|
|
entries, err := os.ReadDir(absDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
if strings.HasSuffix(entry.Name(), ".go") {
|
|
path := filepath.Join(absDir, entry.Name())
|
|
if err := os.Remove(path); err != nil {
|
|
return fmt.Errorf("failed to remove %s: %w", path, err)
|
|
}
|
|
fmt.Printf("Removed: %s\n", path)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanupGeneratedFiles removes gorm.io/gen helper files and renames
|
|
// .gen.go files to .go with cleaned content.
|
|
func (m *GormGen) cleanupGeneratedFiles() error {
|
|
filesToRemove := []string{"gen.go", "do.go", "_gen.go"}
|
|
|
|
dir := m.cfg.OutputDir
|
|
if !strings.HasPrefix(dir, "./") {
|
|
dir = "./" + dir
|
|
}
|
|
|
|
absDir, err := filepath.Abs(dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, fileName := range filesToRemove {
|
|
filePath := filepath.Join(absDir, fileName)
|
|
if _, err := os.Stat(filePath); err == nil {
|
|
if err := os.Remove(filePath); err != nil {
|
|
return fmt.Errorf("failed to remove %s: %w", filePath, err)
|
|
}
|
|
fmt.Printf("Removed: %s\n", filePath)
|
|
}
|
|
}
|
|
|
|
files, err := os.ReadDir(absDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var re *regexp.Regexp
|
|
if len(m.cfg.SelectedTables) > 0 {
|
|
pattern := "^(" + strings.Join(m.cfg.SelectedTables, "|") + ")\\.gen\\.go$"
|
|
re, err = regexp.Compile(pattern)
|
|
} else {
|
|
re, err = regexp.Compile("^(" + m.cfg.TableFilter + ")\\.gen\\.go$")
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, file := range files {
|
|
name := file.Name()
|
|
if re.MatchString(name) {
|
|
oldPath := filepath.Join(absDir, name)
|
|
baseName := strings.TrimSuffix(name, ".gen.go")
|
|
newPath := filepath.Join(absDir, baseName+".go")
|
|
|
|
content, err := os.ReadFile(oldPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
content = m.cleanModelContent(content)
|
|
|
|
if err := os.WriteFile(newPath, content, 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := os.Remove(oldPath); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanModelContent removes gorm.io/gen-specific imports and type declarations
|
|
// from the generated model content.
|
|
func (m *GormGen) cleanModelContent(content []byte) []byte {
|
|
result := string(content)
|
|
|
|
lines := strings.Split(result, "\n")
|
|
var newLines []string
|
|
importStarted := false
|
|
importEnded := false
|
|
|
|
for _, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
|
|
if trimmed == "import (" {
|
|
importStarted = true
|
|
newLines = append(newLines, line)
|
|
continue
|
|
}
|
|
if importStarted && trimmed == ")" {
|
|
importEnded = true
|
|
importStarted = false
|
|
newLines = append(newLines, line)
|
|
continue
|
|
}
|
|
|
|
if importStarted && !importEnded {
|
|
if strings.Contains(trimmed, "\"gorm.io/gen\"") ||
|
|
strings.Contains(trimmed, "\"gorm.io/gen/field\"") ||
|
|
strings.Contains(trimmed, "\"gorm.io/plugin/dbresolver\"") ||
|
|
strings.Contains(trimmed, "gen.DO") {
|
|
continue
|
|
}
|
|
}
|
|
|
|
if strings.Contains(trimmed, "type psCategoryDo struct") ||
|
|
strings.HasPrefix(trimmed, "func (p psCategoryDo)") {
|
|
continue
|
|
}
|
|
|
|
newLines = append(newLines, line)
|
|
}
|
|
|
|
result = strings.Join(newLines, "\n")
|
|
result = strings.ReplaceAll(result, "psCategoryDo", "")
|
|
|
|
return []byte(result)
|
|
}
|