// Package gormcol provides GORM model generation with type-safe column descriptors. // // This file contains the core generation logic including: // - Model generation from database tables // - Configuration management // - Output directory handling and file cleanup package gormcol import ( "context" "fmt" "os" "path/filepath" "regexp" "strings" "gorm.io/driver/mysql" "gorm.io/gen" "gorm.io/gorm" "gorm.io/gorm/logger" ) // GenConfig holds configuration for model generation. type GenConfig struct { // OutputDir is the directory where generated files are written. OutputDir string // PkgName is the Go package name for generated files. PkgName string // TableFilter is a regex pattern to match table names. // Example: "(ps_|b2b_).*" matches tables starting with ps_ or b2b_. TableFilter string // SelectedTables is a list of specific table names to generate. // When set, TableFilter is ignored. SelectedTables []string // Clean determines whether to remove existing model files before generation. // When false, only updates or creates models without deleting existing files. Clean bool } // defaultConfig returns the default configuration values. func defaultConfig() GenConfig { return GenConfig{ OutputDir: "./app/model/dbmodel", PkgName: "dbmodel", TableFilter: "ps_.*", Clean: true, } } // GormGen handles GORM model generation with column descriptors. type GormGen struct { db *gorm.DB cfg GenConfig } // New creates a new GormGen with default configuration. func New(db *gorm.DB) *GormGen { return &GormGen{db: db, cfg: defaultConfig()} } // NewWithConfig creates a new GormGen with custom configuration. func NewWithConfig(db *gorm.DB, cfg GenConfig) *GormGen { d := defaultConfig() if cfg.OutputDir != "" { d.OutputDir = cfg.OutputDir } if cfg.PkgName != "" { d.PkgName = cfg.PkgName } if cfg.TableFilter != "" { d.TableFilter = cfg.TableFilter } if len(cfg.SelectedTables) > 0 { d.SelectedTables = cfg.SelectedTables } d.Clean = cfg.Clean return &GormGen{db: db, cfg: d} } // ConnectDSN opens a MySQL/MariaDB connection from a DSN string. func ConnectDSN(dsn string) (*gorm.DB, error) { db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Error), }) if err != nil { return nil, fmt.Errorf("failed to connect with dsn: %w", err) } return db, nil } // GenModels generates GORM model files and column descriptors for matched tables. // It cleans the output directory (if Clean is true), generates models using gorm.io/gen, // and appends Cols variables with type-safe Field descriptors. func (m *GormGen) GenModels(ctx context.Context) error { // Clean output directory if Clean is true. if m.cfg.Clean { if err := m.cleanOutputDir(); err != nil { return fmt.Errorf("failed to clean output dir: %w", err) } } // When Clean is false, save existing files to restore later. var existingFiles map[string][]byte if !m.cfg.Clean { existingFiles = m.saveExistingFiles() } g := gen.NewGenerator(gen.Config{ OutPath: m.cfg.OutputDir, ModelPkgPath: m.cfg.PkgName, FieldNullable: true, FieldWithIndexTag: true, }) g.UseDB(m.db) tableNames, err := m.db.Migrator().GetTables() if err != nil { return fmt.Errorf("failed to get table list: %w", err) } var matched int var re *regexp.Regexp if len(m.cfg.SelectedTables) > 0 { tableSet := make(map[string]bool) for _, t := range m.cfg.SelectedTables { tableSet[t] = true } for _, tableName := range tableNames { if tableSet[tableName] { g.GenerateModel(tableName) matched++ } } fmt.Printf("Selected %d tables\n", matched) } else { re, err = regexp.Compile("^" + m.cfg.TableFilter + "$") if err != nil { return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err) } for _, tableName := range tableNames { if re.MatchString(tableName) { g.GenerateModel(tableName) matched++ } } fmt.Printf("Matched %d tables with filter %q\n", matched, m.cfg.TableFilter) } g.Execute() if err := m.cleanupGeneratedFiles(); err != nil { return fmt.Errorf("failed to cleanup generated files: %w", err) } // Restore existing files that were not regenerated. if !m.cfg.Clean { m.restoreExistingFiles(existingFiles) } if err := m.generateCols(); err != nil { return fmt.Errorf("failed to generate column descriptors: %w", err) } return nil } // saveExistingFiles saves the content of existing .go files in the output directory. func (m *GormGen) saveExistingFiles() map[string][]byte { dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } files := make(map[string][]byte) entries, err := os.ReadDir(dir) if err != nil { return files } for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".go") { continue } path := filepath.Join(dir, entry.Name()) if data, err := os.ReadFile(path); err == nil { files[entry.Name()] = data } } return files } // restoreExistingFiles restores files that weren't regenerated. func (m *GormGen) restoreExistingFiles(saved map[string][]byte) { dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } restored := 0 for name, data := range saved { path := filepath.Join(dir, name) if _, err := os.Stat(path); os.IsNotExist(err) { if err := os.WriteFile(path, data, 0644); err == nil { fmt.Printf("Restored: %s\n", path) restored++ } } } if restored > 0 { fmt.Printf("Restored %d existing files\n", restored) } } // cleanOutputDir removes existing .go files from the output directory // or creates it if it doesn't exist. func (m *GormGen) cleanOutputDir() error { dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } absDir, err := filepath.Abs(dir) if err != nil { return err } if _, err := os.Stat(absDir); os.IsNotExist(err) { if err := os.MkdirAll(absDir, 0755); err != nil { return fmt.Errorf("failed to create output dir: %w", err) } fmt.Printf("Created: %s\n", absDir) return nil } entries, err := os.ReadDir(absDir) if err != nil { return err } for _, entry := range entries { if entry.IsDir() { continue } if strings.HasSuffix(entry.Name(), ".go") { path := filepath.Join(absDir, entry.Name()) if err := os.Remove(path); err != nil { return fmt.Errorf("failed to remove %s: %w", path, err) } fmt.Printf("Removed: %s\n", path) } } return nil } // cleanupGeneratedFiles removes gorm.io/gen helper files and renames // .gen.go files to .go with cleaned content. func (m *GormGen) cleanupGeneratedFiles() error { filesToRemove := []string{"gen.go", "do.go", "_gen.go"} dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } absDir, err := filepath.Abs(dir) if err != nil { return err } for _, fileName := range filesToRemove { filePath := filepath.Join(absDir, fileName) if _, err := os.Stat(filePath); err == nil { if err := os.Remove(filePath); err != nil { return fmt.Errorf("failed to remove %s: %w", filePath, err) } fmt.Printf("Removed: %s\n", filePath) } } files, err := os.ReadDir(absDir) if err != nil { return err } var re *regexp.Regexp if len(m.cfg.SelectedTables) > 0 { pattern := "^(" + strings.Join(m.cfg.SelectedTables, "|") + ")\\.gen\\.go$" re, err = regexp.Compile(pattern) } else { re, err = regexp.Compile("^(" + m.cfg.TableFilter + ")\\.gen\\.go$") } if err != nil { return err } for _, file := range files { name := file.Name() if re.MatchString(name) { oldPath := filepath.Join(absDir, name) baseName := strings.TrimSuffix(name, ".gen.go") newPath := filepath.Join(absDir, baseName+".go") content, err := os.ReadFile(oldPath) if err != nil { return err } content = m.cleanModelContent(content) if err := os.WriteFile(newPath, content, 0644); err != nil { return err } if err := os.Remove(oldPath); err != nil { return err } } } return nil } // cleanModelContent removes gorm.io/gen-specific imports and type declarations // from the generated model content. func (m *GormGen) cleanModelContent(content []byte) []byte { result := string(content) lines := strings.Split(result, "\n") var newLines []string importStarted := false importEnded := false for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "import (" { importStarted = true newLines = append(newLines, line) continue } if importStarted && trimmed == ")" { importEnded = true importStarted = false newLines = append(newLines, line) continue } if importStarted && !importEnded { if strings.Contains(trimmed, "\"gorm.io/gen\"") || strings.Contains(trimmed, "\"gorm.io/gen/field\"") || strings.Contains(trimmed, "\"gorm.io/plugin/dbresolver\"") || strings.Contains(trimmed, "gen.DO") { continue } } if strings.Contains(trimmed, "type psCategoryDo struct") || strings.HasPrefix(trimmed, "func (p psCategoryDo)") { continue } newLines = append(newLines, line) } result = strings.Join(newLines, "\n") result = strings.ReplaceAll(result, "psCategoryDo", "") return []byte(result) }