package genmodels import ( "context" "fmt" "os" "path/filepath" "strings" "gorm.io/gen" "gorm.io/gen/field" "gorm.io/gorm" ) const GENERATED_FILES_FOLDER = "./app/model/prestadb" const GENERATED_MODEL_PKG = "prestadb" type GormGenModels struct { db *gorm.DB } func New(db *gorm.DB) *GormGenModels { return &GormGenModels{ db: db, } } func (m *GormGenModels) GormGenModels(ctx context.Context) error { // Use gorm gen to generate models g := gen.NewGenerator(gen.Config{ OutPath: GENERATED_FILES_FOLDER, ModelPkgPath: GENERATED_MODEL_PKG, Mode: gen.WithoutContext, }) g.UseDB(m.db) // Get all table names from the database and filter for 'ps_' prefix tableNames, err := m.db.Migrator().GetTables() if err != nil { return fmt.Errorf("failed to get table list: %w", err) } // Generate models only for tables with 'ps_' prefix // Use gen.FieldRelateModel to add foreign key relations for _, tableName := range tableNames { if strings.HasPrefix(tableName, "ps_") { g.GenerateModel(tableName, gen.FieldRelateModel(field.BelongsTo, "*", nil, &field.RelateConfig{})) } } g.Execute() // Post-process: remove query/DO files and keep only model files if err := m.cleanupGeneratedFiles(); err != nil { return fmt.Errorf("failed to cleanup generated files: %w", err) } return nil } func (m *GormGenModels) cleanupGeneratedFiles() error { // Files to remove (query/DO code) filesToRemove := []string{ "gen.go", "do.go", "_gen.go", } // Directory to clean dir := GENERATED_FILES_FOLDER if !strings.HasPrefix(dir, "./") { dir = "./" + dir } // Get absolute path if needed absDir, err := filepath.Abs(dir) if err != nil { return err } // Remove query files for _, fileName := range filesToRemove { filePath := filepath.Join(absDir, fileName) if _, err := os.Stat(filePath); err == nil { if err := os.Remove(filePath); err != nil { return fmt.Errorf("failed to remove %s: %w", filePath, err) } fmt.Printf("Removed: %s\n", filePath) } } // Rename ps_*.gen.go to *.go (e.g., ps_category.gen.go -> category.go) files, err := os.ReadDir(absDir) if err != nil { return err } for _, file := range files { name := file.Name() // Check if it's a generated model file like ps_category.gen.go if strings.HasSuffix(name, ".gen.go") && strings.HasPrefix(name, "ps_") { oldPath := filepath.Join(absDir, name) // Extract name: ps_category.gen.go -> category baseName := strings.TrimSuffix(name, ".gen.go") baseName, _ = strings.CutPrefix(baseName, "ps_") newPath := filepath.Join(absDir, baseName+".go") // Read file content content, err := os.ReadFile(oldPath) if err != nil { return err } // Remove package prestadb imports that reference generated DO // and clean up the code content = m.cleanModelContent(content) // Write to new file if err := os.WriteFile(newPath, content, 0644); err != nil { return err } // Remove old file if err := os.Remove(oldPath); err != nil { return err } fmt.Printf("Renamed: %s -> %s\n", oldPath, newPath) } } return nil } func (m *GormGenModels) cleanModelContent(content []byte) []byte { result := string(content) // Remove imports that are only needed for query code // Keep "gorm.io/gorm/schema" for TableName method lines := strings.Split(result, "\n") var newLines []string importStarted := false importEnded := false for _, line := range lines { trimmed := strings.TrimSpace(line) // Track import block if trimmed == "import (" { importStarted = true newLines = append(newLines, line) continue } if importStarted && trimmed == ")" { importEnded = true importStarted = false newLines = append(newLines, line) continue } // Inside import block, remove query-related imports if importStarted && !importEnded { // Skip these imports as they're only for query code if strings.Contains(trimmed, "\"gorm.io/gen\"") || strings.Contains(trimmed, "\"gorm.io/gen/field\"") || strings.Contains(trimmed, "\"gorm.io/plugin/dbresolver\"") || strings.Contains(trimmed, "gen.DO") { continue } } // Remove DO (Data Object) type definitions if strings.Contains(trimmed, "type psCategoryDo struct") || strings.HasPrefix(trimmed, "func (p psCategoryDo)") { continue } newLines = append(newLines, line) } result = strings.Join(newLines, "\n") // Also remove the psCategoryDo references in methods result = strings.ReplaceAll(result, "psCategoryDo", "") return []byte(result) }