package gormcol import ( "context" "fmt" "os" "path/filepath" "regexp" "sort" "strings" "gorm.io/driver/mysql" "gorm.io/gen" "gorm.io/gorm" "gorm.io/gorm/logger" ) type GenConfig struct { OutputDir string PkgName string TableFilter string // regex pattern, e.g. "(ps_|b2b_).*" } func defaultConfig() GenConfig { return GenConfig{ OutputDir: "./app/model/dbmodel", PkgName: "dbmodel", TableFilter: "ps_.*", } } type GormGen struct { db *gorm.DB cfg GenConfig } func New(db *gorm.DB) *GormGen { return &GormGen{db: db, cfg: defaultConfig()} } func NewWithConfig(db *gorm.DB, cfg GenConfig) *GormGen { d := defaultConfig() if cfg.OutputDir != "" { d.OutputDir = cfg.OutputDir } if cfg.PkgName != "" { d.PkgName = cfg.PkgName } if cfg.TableFilter != "" { d.TableFilter = cfg.TableFilter } return &GormGen{db: db, cfg: d} } // ConnectDSN opens a MySQL/MariaDB connection from a DSN string. func ConnectDSN(dsn string) (*gorm.DB, error) { db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Error), }) if err != nil { return nil, fmt.Errorf("failed to connect with dsn: %w", err) } return db, nil } func (m *GormGen) GenModels(ctx context.Context) error { re, err := regexp.Compile("^" + m.cfg.TableFilter + "$") if err != nil { return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err) } if err := m.cleanOutputDir(); err != nil { return fmt.Errorf("failed to clean output dir: %w", err) } g := gen.NewGenerator(gen.Config{ OutPath: m.cfg.OutputDir, ModelPkgPath: m.cfg.PkgName, FieldNullable: true, FieldWithIndexTag: true, }) g.UseDB(m.db) tableNames, err := m.db.Migrator().GetTables() if err != nil { return fmt.Errorf("failed to get table list: %w", err) } matched := 0 for _, tableName := range tableNames { if re.MatchString(tableName) { g.GenerateModel(tableName) matched++ } } fmt.Printf("Matched %d tables with filter %q\n", matched, m.cfg.TableFilter) g.Execute() if err := m.cleanupGeneratedFiles(); err != nil { return fmt.Errorf("failed to cleanup generated files: %w", err) } if err := m.generateCols(re); err != nil { return fmt.Errorf("failed to generate column descriptors: %w", err) } belongsTo, many2many, err := m.fetchRelations() if err != nil { return fmt.Errorf("failed to fetch relations: %w", err) } fmt.Printf("Found %d belongs_to and %d many2many relations\n", len(belongsTo), len(many2many)) if err := m.updateModelsWithAssociations(belongsTo, many2many, re); err != nil { return fmt.Errorf("failed to update models with associations: %w", err) } return nil } func (m *GormGen) cleanOutputDir() error { dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } absDir, err := filepath.Abs(dir) if err != nil { return err } if _, err := os.Stat(absDir); os.IsNotExist(err) { if err := os.MkdirAll(absDir, 0755); err != nil { return fmt.Errorf("failed to create output dir: %w", err) } fmt.Printf("Created: %s\n", absDir) return nil } entries, err := os.ReadDir(absDir) if err != nil { return err } for _, entry := range entries { if entry.IsDir() { continue } if strings.HasSuffix(entry.Name(), ".go") { path := filepath.Join(absDir, entry.Name()) if err := os.Remove(path); err != nil { return fmt.Errorf("failed to remove %s: %w", path, err) } fmt.Printf("Removed: %s\n", path) } } return nil } func (m *GormGen) cleanupGeneratedFiles() error { filesToRemove := []string{"gen.go", "do.go", "_gen.go"} dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } absDir, err := filepath.Abs(dir) if err != nil { return err } for _, fileName := range filesToRemove { filePath := filepath.Join(absDir, fileName) if _, err := os.Stat(filePath); err == nil { if err := os.Remove(filePath); err != nil { return fmt.Errorf("failed to remove %s: %w", filePath, err) } fmt.Printf("Removed: %s\n", filePath) } } files, err := os.ReadDir(absDir) if err != nil { return err } re, err := regexp.Compile("^(" + m.cfg.TableFilter + ")\\.gen\\.go$") if err != nil { return err } for _, file := range files { name := file.Name() if re.MatchString(name) { oldPath := filepath.Join(absDir, name) baseName := strings.TrimSuffix(name, ".gen.go") newPath := filepath.Join(absDir, baseName+".go") content, err := os.ReadFile(oldPath) if err != nil { return err } content = m.cleanModelContent(content) if err := os.WriteFile(newPath, content, 0644); err != nil { return err } if err := os.Remove(oldPath); err != nil { return err } fmt.Printf("Renamed: %s -> %s\n", oldPath, newPath) } } return nil } type associationInfo struct { TableName string FKColumn string RefTable string RefColumn string AssociationName string AssociationType string JoinTableName string } func (m *GormGen) fetchRelations() (map[string][]associationInfo, map[string][]associationInfo, error) { belongsTo := make(map[string][]associationInfo) many2many := make(map[string][]associationInfo) query := ` SELECT kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME, kcu.CONSTRAINT_NAME FROM information_schema.KEY_COLUMN_USAGE kcu JOIN information_schema.TABLE_CONSTRAINTS tc ON kcu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME AND kcu.TABLE_SCHEMA = tc.TABLE_SCHEMA AND kcu.TABLE_NAME = tc.TABLE_NAME WHERE tc.CONSTRAINT_TYPE = 'FOREIGN KEY' AND kcu.TABLE_SCHEMA = DATABASE() ` rows, err := m.db.Raw(query).Rows() if err != nil { return nil, nil, fmt.Errorf("failed to query foreign keys: %w", err) } defer rows.Close() tableFKs := make(map[string][]associationInfo) for rows.Next() { var tableName, columnName, refTable, refCol, constraintName string if err := rows.Scan(&tableName, &columnName, &refTable, &refCol, &constraintName); err != nil { return nil, nil, fmt.Errorf("failed to scan foreign key row: %w", err) } refModelName := tableNameToModelName(refTable) assocName := singularize(refModelName) info := associationInfo{ TableName: tableName, FKColumn: columnName, RefTable: refTable, RefColumn: refCol, AssociationName: assocName, AssociationType: "belongs_to", } tableFKs[tableName] = append(tableFKs[tableName], info) } for tableName, fks := range tableFKs { if isMany2Many(tableName, fks) { for _, fk := range fks { refModelName := tableNameToModelName(fk.RefTable) assocName := pluralize(refModelName) joinTable := fk.TableName + "_" + fk.RefTable if fk.TableName > fk.RefTable { joinTable = fk.RefTable + "_" + fk.TableName } m2mInfo := associationInfo{ TableName: tableName, FKColumn: fk.FKColumn, RefTable: fk.RefTable, RefColumn: fk.RefColumn, AssociationName: assocName, AssociationType: "many2many", JoinTableName: joinTable, } many2many[fk.RefTable] = append(many2many[fk.RefTable], m2mInfo) many2many[tableName] = append(many2many[tableName], m2mInfo) } } else { belongsTo[tableName] = fks } } return belongsTo, many2many, nil } func isMany2Many(tableName string, fks []associationInfo) bool { if len(fks) != 2 { return false } refTables := make(map[string]bool) for _, fk := range fks { refTables[fk.RefTable] = true } if len(refTables) != 2 { return false } tableLower := strings.ToLower(tableName) if strings.HasSuffix(tableLower, "_product") && strings.Contains(tableLower, "_") { return true } if strings.HasSuffix(tableLower, "_category") && strings.Contains(tableLower, "_") { return true } if strings.HasSuffix(tableLower, "_cart") && strings.Contains(tableLower, "_") { return true } joinTableSuffixes := []string{ "_rel", "_link", "_map", "_junction", "_join", "_products", "_categories", "_carts", } for _, suffix := range joinTableSuffixes { if strings.HasSuffix(tableLower, suffix) { return true } } parts := strings.Split(tableLower, "_") if len(parts) >= 3 { hasIdColumn := false for _, fk := range fks { if strings.Contains(strings.ToLower(fk.FKColumn), "id") { hasIdColumn = true break } } if hasIdColumn && len(fks) == 2 { return true } } return false } func pluralize(word string) string { wordLower := strings.ToLower(word) if strings.HasSuffix(wordLower, "y") && len(word) > 1 { return word[:len(word)-1] + "ies" } if strings.HasSuffix(wordLower, "s") || strings.HasSuffix(wordLower, "x") || strings.HasSuffix(wordLower, "z") || strings.HasSuffix(wordLower, "ch") || strings.HasSuffix(wordLower, "sh") { return word + "es" } return word + "s" } func (m *GormGen) updateModelsWithAssociations(belongsTo, many2many map[string][]associationInfo, fileFilter *regexp.Regexp) error { dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } absDir, err := filepath.Abs(dir) if err != nil { return err } entries, err := os.ReadDir(absDir) if err != nil { return err } allAssocs := make(map[string][]associationInfo) for table, assocs := range belongsTo { allAssocs[table] = append(allAssocs[table], assocs...) } for table, assocs := range many2many { allAssocs[table] = append(allAssocs[table], assocs...) } for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".go") { continue } fileBase := strings.TrimSuffix(entry.Name(), ".go") if !fileFilter.MatchString(fileBase) { continue } tableAssocs, ok := allAssocs[fileBase] if !ok || len(tableAssocs) == 0 { continue } path := filepath.Join(absDir, entry.Name()) content, err := os.ReadFile(path) if err != nil { return fmt.Errorf("failed to read %s: %w", entry.Name(), err) } contentStr := string(content) var newFields []string for _, assoc := range tableAssocs { refModelType := tableNameToModelName(assoc.RefTable) fieldName := assoc.AssociationName if assoc.AssociationType == "many2many" { newFields = append(newFields, fmt.Sprintf("\t%s []*%s `gorm:\"many2many:%s\"`", fieldName, refModelType, assoc.JoinTableName)) } else { fkColumn := snakeToPascal(assoc.FKColumn) refColumn := snakeToPascal(assoc.RefColumn) newFields = append(newFields, fmt.Sprintf("\t%s *%s `gorm:\"foreignKey:%s;references:%s\"`", fieldName, refModelType, fkColumn, refColumn)) } } if len(newFields) == 0 { continue } insertPos := findStructEndPos(contentStr) if insertPos < 0 { continue } newContent := contentStr[:insertPos] + strings.Join(newFields, "\n") + "\n" + contentStr[insertPos:] if err := os.WriteFile(path, []byte(newContent), 0644); err != nil { return fmt.Errorf("failed to write %s: %w", entry.Name(), err) } fmt.Printf("Added associations: %s\n", entry.Name()) } return nil } func tableNameToJoinTableName(table1, table2 string) string { tables := []string{table1, table2} sort.Strings(tables) return tables[0] + "_" + tables[1] } func snakeToPascal(s string) string { parts := strings.Split(s, "_") var result []string for _, part := range parts { if part == "" { continue } lower := strings.ToLower(part) if lower == "id" { result = append(result, "ID") } else if lower == "url" || lower == "api" || lower == "http" || lower == "tcp" || lower == "udp" { result = append(result, strings.ToUpper(part)) } else { result = append(result, strings.ToUpper(part[:1])+part[1:]) } } return strings.Join(result, "") } func findStructEndPos(content string) int { lines := strings.Split(content, "\n") structStart := -1 for i, line := range lines { trimmed := strings.TrimSpace(line) if strings.HasPrefix(trimmed, "type ") && strings.Contains(trimmed, " struct {") { parts := strings.Split(trimmed, " ") if len(parts) >= 3 && parts[2] == "struct" { structStart = i break } } } if structStart < 0 { return -1 } depth := 0 for i := structStart; i < len(lines); i++ { line := lines[i] for _, ch := range line { if ch == '{' { depth++ } else if ch == '}' { depth-- if depth == 0 { lineStart := 0 for j := 0; j < i; j++ { lineStart += len(lines[j]) + 1 } return lineStart + strings.Index(line, "}") } } } } return -1 } func tableNameToModelName(tableName string) string { parts := strings.Split(tableName, "_") var result []string for i, part := range parts { if part == "" { continue } if i == len(parts)-1 { part = singularize(part) } result = append(result, strings.ToUpper(part[:1])+part[1:]) } return strings.Join(result, "") } func singularize(word string) string { if strings.HasSuffix(word, "ies") { return word[:len(word)-3] + "Y" } if strings.HasSuffix(word, "es") && len(word) > 2 { if strings.HasSuffix(word, "ses") { return word[:len(word)-2] } return word[:len(word)-1] } if strings.HasSuffix(word, "s") && len(word) > 1 && !strings.HasSuffix(word, "ss") { if strings.HasSuffix(word, "us") { return word } return word[:len(word)-1] } return word } func (m *GormGen) cleanModelContent(content []byte) []byte { result := string(content) lines := strings.Split(result, "\n") var newLines []string importStarted := false importEnded := false for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "import (" { importStarted = true newLines = append(newLines, line) continue } if importStarted && trimmed == ")" { importEnded = true importStarted = false newLines = append(newLines, line) continue } if importStarted && !importEnded { if strings.Contains(trimmed, "\"gorm.io/gen\"") || strings.Contains(trimmed, "\"gorm.io/gen/field\"") || strings.Contains(trimmed, "\"gorm.io/plugin/dbresolver\"") || strings.Contains(trimmed, "gen.DO") { continue } } if strings.Contains(trimmed, "type psCategoryDo struct") || strings.HasPrefix(trimmed, "func (p psCategoryDo)") { continue } newLines = append(newLines, line) } result = strings.Join(newLines, "\n") result = strings.ReplaceAll(result, "psCategoryDo", "") return []byte(result) }