// Package gormcol provides GORM model generation with type-safe column descriptors. // // This file handles the generation of Cols variables that provide // type-safe column references for use in GORM queries. package gormcol import ( "fmt" "go/ast" "go/parser" "go/token" "os" "path/filepath" "regexp" "strings" "unicode" ) // fieldInfo holds information about a struct field. type fieldInfo struct { GoName string // Go struct field name ColName string // database column name from gorm tag } // structInfo holds information about a parsed Go struct. type structInfo struct { Name string // struct name Table string // table name (derived from struct name using GORM convention) TableConst string // constant name (e.g., "TableNamePsAccess") Fields []fieldInfo // list of fields FilePath string // source file path } // toSnakeCase converts a CamelCase string to snake_case. // E.g., "PsProductShop" -> "ps_product_shop" func toSnakeCase(s string) string { var result strings.Builder for i, r := range s { if unicode.IsUpper(r) { if i > 0 { result.WriteRune('_') } result.WriteRune(unicode.ToLower(r)) } else { result.WriteRune(r) } } return result.String() } // parseGormColumn extracts the column name value from a gorm tag string. // The tag is expected to be in the format "column:name;otherTag" or similar. // Returns the column name if found, empty string otherwise. func parseGormColumn(tag string) string { for _, part := range strings.Split(tag, ";") { part = strings.TrimSpace(part) if strings.HasPrefix(part, "column:") { return strings.TrimPrefix(part, "column:") } } return "" } // parseModelFile parses a Go model source file and extracts struct information. // It uses go/ast to parse the file and extract: // - Struct name from the type declaration // - Table name from the TableName constant // - Field names and their database column names from gorm tags // // Returns nil if no valid struct is found in the file. func parseModelFile(path string) (*structInfo, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { return nil, err } var si structInfo si.FilePath = path for _, decl := range f.Decls { gd, ok := decl.(*ast.GenDecl) if !ok { continue } if gd.Tok == token.CONST { for _, spec := range gd.Specs { vs, ok := spec.(*ast.ValueSpec) if !ok { continue } for i, name := range vs.Names { if strings.HasPrefix(name.Name, "TableName") { si.TableConst = name.Name if i < len(vs.Values) { if bl, ok := vs.Values[i].(*ast.BasicLit); ok { si.Table = strings.Trim(bl.Value, "\"") } } } } } } if gd.Tok != token.TYPE { continue } for _, spec := range gd.Specs { ts, ok := spec.(*ast.TypeSpec) if !ok { continue } st, ok := ts.Type.(*ast.StructType) if !ok { continue } si.Name = ts.Name.Name for _, field := range st.Fields.List { if len(field.Names) == 0 || field.Tag == nil { continue } goName := field.Names[0].Name tag := strings.Trim(field.Tag.Value, "`") var gormTag string for _, t := range strings.Split(tag, " ") { t = strings.TrimSpace(t) if strings.HasPrefix(t, "gorm:") { gormTag = strings.TrimPrefix(t, "gorm:") gormTag = strings.Trim(gormTag, "\"") break } } colName := parseGormColumn(gormTag) if colName == "" { continue } si.Fields = append(si.Fields, fieldInfo{ GoName: goName, ColName: colName, }) } } } if si.Name == "" { return nil, nil } // If no TableName constant was found, derive table name from struct name using GORM convention. if si.Table == "" { si.Table = toSnakeCase(si.Name) } return &si, nil } // generateColsVarBlock generates the Go source code for a Cols variable. // The generated code defines a struct with Field typed members for each // database column, providing type-safe column references. func generateColsVarBlock(si *structInfo) string { if len(si.Fields) == 0 { return "" } var b strings.Builder b.WriteString(fmt.Sprintf("\nvar %sCols = struct {\n", si.Name)) for _, f := range si.Fields { b.WriteString(fmt.Sprintf("\t%s gormcol.Field\n", f.GoName)) } b.WriteString("}{\n") for _, f := range si.Fields { b.WriteString(fmt.Sprintf("\t%s: gormcol.Field{Column: %q},\n", f.GoName, f.ColName)) } b.WriteString("}\n") return b.String() } // findGoMod searches upward from startDir for a go.mod file. func findGoMod(startDir string) (string, error) { dir := startDir for { path := filepath.Join(dir, "go.mod") if _, err := os.Stat(path); err == nil { return path, nil } parent := filepath.Dir(dir) if parent == dir { return "", fmt.Errorf("go.mod not found from %s", startDir) } dir = parent } } // readModulePath extracts the module path from a go.mod file. func readModulePath(goModPath string) (string, error) { content, err := os.ReadFile(goModPath) if err != nil { return "", err } for _, line := range strings.Split(string(content), "\n") { line = strings.TrimSpace(line) if strings.HasPrefix(line, "module ") { return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil } } return "", fmt.Errorf("module directive not found in %s", goModPath) } // generateCols appends Cols variables to generated model files. // It parses each .go file in the output directory, extracts struct fields // with gorm column tags, and generates type-safe Field descriptors. func (m *GormGen) generateCols() error { dir := m.cfg.OutputDir if !strings.HasPrefix(dir, "./") { dir = "./" + dir } absDir, err := filepath.Abs(dir) if err != nil { return err } entries, err := os.ReadDir(absDir) if err != nil { return err } goModPath, err := findGoMod(absDir) if err != nil { return err } modulePath, err := readModulePath(goModPath) if err != nil { return err } gormcolImport := fmt.Sprintf("%q", modulePath+"") var fileFilter *regexp.Regexp if len(m.cfg.SelectedTables) > 0 { pattern := "^(" + strings.Join(m.cfg.SelectedTables, "|") + ")$" fileFilter, err = regexp.Compile(pattern) if err != nil { return fmt.Errorf("invalid selected tables pattern: %w", err) } } else { fileFilter, err = regexp.Compile("^" + m.cfg.TableFilter + "$") if err != nil { return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err) } } for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".go") { continue } // Match files by the table filter regex (strip .go suffix for matching) fileBase := strings.TrimSuffix(entry.Name(), ".go") if !fileFilter.MatchString(fileBase) { continue } path := filepath.Join(absDir, entry.Name()) si, err := parseModelFile(path) if err != nil { fmt.Fprintf(os.Stderr, "warning: skipping cols for %s: %v\n", entry.Name(), err) continue } if si == nil || len(si.Fields) == 0 { continue } content, err := os.ReadFile(path) if err != nil { return err } fileContent := string(content) if strings.Contains(fileContent, "gormcol.Field{") { continue } if !strings.Contains(fileContent, gormcolImport) { if strings.Contains(fileContent, "import (") { fileContent = strings.Replace(fileContent, "import (", "import (\n\t"+gormcolImport, 1) } else if strings.Contains(fileContent, "package dbmodel") { fileContent = strings.Replace(fileContent, "package dbmodel", "package dbmodel\n\nimport "+gormcolImport, 1) } } colsBlock := generateColsVarBlock(si) fileContent += colsBlock if err := os.WriteFile(path, []byte(fileContent), 0644); err != nil { return err } fmt.Printf("Cols: %s\n", entry.Name()) } return nil }