This commit is contained in:
2026-03-28 17:45:22 +01:00
commit 73605da712
8 changed files with 1177 additions and 0 deletions

306
gen.go Normal file
View File

@@ -0,0 +1,306 @@
// Package gormcol provides GORM model generation with type-safe column descriptors.
//
// This file contains the core generation logic including:
// - Model generation from database tables
// - Configuration management
// - Output directory handling and file cleanup
package gormcol
import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"gorm.io/driver/mysql"
"gorm.io/gen"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// GenConfig holds configuration for model generation.
type GenConfig struct {
// OutputDir is the directory where generated files are written.
OutputDir string
// PkgName is the Go package name for generated files.
PkgName string
// TableFilter is a regex pattern to match table names.
// Example: "(ps_|b2b_).*" matches tables starting with ps_ or b2b_.
TableFilter string
// SelectedTables is a list of specific table names to generate.
// When set, TableFilter is ignored.
SelectedTables []string
}
// defaultConfig returns the default configuration values.
func defaultConfig() GenConfig {
return GenConfig{
OutputDir: "./app/model/dbmodel",
PkgName: "dbmodel",
TableFilter: "ps_.*",
}
}
// GormGen handles GORM model generation with column descriptors.
type GormGen struct {
db *gorm.DB
cfg GenConfig
}
// New creates a new GormGen with default configuration.
func New(db *gorm.DB) *GormGen {
return &GormGen{db: db, cfg: defaultConfig()}
}
// NewWithConfig creates a new GormGen with custom configuration.
func NewWithConfig(db *gorm.DB, cfg GenConfig) *GormGen {
d := defaultConfig()
if cfg.OutputDir != "" {
d.OutputDir = cfg.OutputDir
}
if cfg.PkgName != "" {
d.PkgName = cfg.PkgName
}
if cfg.TableFilter != "" {
d.TableFilter = cfg.TableFilter
}
if len(cfg.SelectedTables) > 0 {
d.SelectedTables = cfg.SelectedTables
}
return &GormGen{db: db, cfg: d}
}
// ConnectDSN opens a MySQL/MariaDB connection from a DSN string.
func ConnectDSN(dsn string) (*gorm.DB, error) {
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Error),
})
if err != nil {
return nil, fmt.Errorf("failed to connect with dsn: %w", err)
}
return db, nil
}
// GenModels generates GORM model files and column descriptors for matched tables.
// It cleans the output directory, generates models using gorm.io/gen,
// and appends <Model>Cols variables with type-safe Field descriptors.
func (m *GormGen) GenModels(ctx context.Context) error {
if err := m.cleanOutputDir(); err != nil {
return fmt.Errorf("failed to clean output dir: %w", err)
}
g := gen.NewGenerator(gen.Config{
OutPath: m.cfg.OutputDir,
ModelPkgPath: m.cfg.PkgName,
FieldNullable: true,
FieldWithIndexTag: true,
})
g.UseDB(m.db)
tableNames, err := m.db.Migrator().GetTables()
if err != nil {
return fmt.Errorf("failed to get table list: %w", err)
}
var matched int
var re *regexp.Regexp
if len(m.cfg.SelectedTables) > 0 {
tableSet := make(map[string]bool)
for _, t := range m.cfg.SelectedTables {
tableSet[t] = true
}
for _, tableName := range tableNames {
if tableSet[tableName] {
g.GenerateModel(tableName)
matched++
}
}
fmt.Printf("Selected %d tables\n", matched)
} else {
re, err = regexp.Compile("^" + m.cfg.TableFilter + "$")
if err != nil {
return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err)
}
for _, tableName := range tableNames {
if re.MatchString(tableName) {
g.GenerateModel(tableName)
matched++
}
}
fmt.Printf("Matched %d tables with filter %q\n", matched, m.cfg.TableFilter)
}
g.Execute()
if err := m.cleanupGeneratedFiles(); err != nil {
return fmt.Errorf("failed to cleanup generated files: %w", err)
}
if err := m.generateCols(); err != nil {
return fmt.Errorf("failed to generate column descriptors: %w", err)
}
return nil
}
// cleanOutputDir removes existing .go files from the output directory
// or creates it if it doesn't exist.
func (m *GormGen) cleanOutputDir() error {
dir := m.cfg.OutputDir
if !strings.HasPrefix(dir, "./") {
dir = "./" + dir
}
absDir, err := filepath.Abs(dir)
if err != nil {
return err
}
if _, err := os.Stat(absDir); os.IsNotExist(err) {
if err := os.MkdirAll(absDir, 0755); err != nil {
return fmt.Errorf("failed to create output dir: %w", err)
}
fmt.Printf("Created: %s\n", absDir)
return nil
}
entries, err := os.ReadDir(absDir)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
if strings.HasSuffix(entry.Name(), ".go") {
path := filepath.Join(absDir, entry.Name())
if err := os.Remove(path); err != nil {
return fmt.Errorf("failed to remove %s: %w", path, err)
}
fmt.Printf("Removed: %s\n", path)
}
}
return nil
}
// cleanupGeneratedFiles removes gorm.io/gen helper files and renames
// .gen.go files to .go with cleaned content.
func (m *GormGen) cleanupGeneratedFiles() error {
filesToRemove := []string{"gen.go", "do.go", "_gen.go"}
dir := m.cfg.OutputDir
if !strings.HasPrefix(dir, "./") {
dir = "./" + dir
}
absDir, err := filepath.Abs(dir)
if err != nil {
return err
}
for _, fileName := range filesToRemove {
filePath := filepath.Join(absDir, fileName)
if _, err := os.Stat(filePath); err == nil {
if err := os.Remove(filePath); err != nil {
return fmt.Errorf("failed to remove %s: %w", filePath, err)
}
fmt.Printf("Removed: %s\n", filePath)
}
}
files, err := os.ReadDir(absDir)
if err != nil {
return err
}
var re *regexp.Regexp
if len(m.cfg.SelectedTables) > 0 {
pattern := "^(" + strings.Join(m.cfg.SelectedTables, "|") + ")\\.gen\\.go$"
re, err = regexp.Compile(pattern)
} else {
re, err = regexp.Compile("^(" + m.cfg.TableFilter + ")\\.gen\\.go$")
}
if err != nil {
return err
}
for _, file := range files {
name := file.Name()
if re.MatchString(name) {
oldPath := filepath.Join(absDir, name)
baseName := strings.TrimSuffix(name, ".gen.go")
newPath := filepath.Join(absDir, baseName+".go")
content, err := os.ReadFile(oldPath)
if err != nil {
return err
}
content = m.cleanModelContent(content)
if err := os.WriteFile(newPath, content, 0644); err != nil {
return err
}
if err := os.Remove(oldPath); err != nil {
return err
}
}
}
return nil
}
// cleanModelContent removes gorm.io/gen-specific imports and type declarations
// from the generated model content.
func (m *GormGen) cleanModelContent(content []byte) []byte {
result := string(content)
lines := strings.Split(result, "\n")
var newLines []string
importStarted := false
importEnded := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "import (" {
importStarted = true
newLines = append(newLines, line)
continue
}
if importStarted && trimmed == ")" {
importEnded = true
importStarted = false
newLines = append(newLines, line)
continue
}
if importStarted && !importEnded {
if strings.Contains(trimmed, "\"gorm.io/gen\"") ||
strings.Contains(trimmed, "\"gorm.io/gen/field\"") ||
strings.Contains(trimmed, "\"gorm.io/plugin/dbresolver\"") ||
strings.Contains(trimmed, "gen.DO") {
continue
}
}
if strings.Contains(trimmed, "type psCategoryDo struct") ||
strings.HasPrefix(trimmed, "func (p psCategoryDo)") {
continue
}
newLines = append(newLines, line)
}
result = strings.Join(newLines, "\n")
result = strings.ReplaceAll(result, "psCategoryDo", "")
return []byte(result)
}