312 lines
7.8 KiB
Go
312 lines
7.8 KiB
Go
// Package gormcol provides GORM model generation with type-safe column descriptors.
|
|
//
|
|
// This file handles the generation of <Model>Cols variables that provide
|
|
// type-safe column references for use in GORM queries.
|
|
package gormcol
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"unicode"
|
|
)
|
|
|
|
// fieldInfo holds information about a struct field.
|
|
type fieldInfo struct {
|
|
GoName string // Go struct field name
|
|
ColName string // database column name from gorm tag
|
|
}
|
|
|
|
// structInfo holds information about a parsed Go struct.
|
|
type structInfo struct {
|
|
Name string // struct name
|
|
Table string // table name (derived from struct name using GORM convention)
|
|
TableConst string // constant name (e.g., "TableNamePsAccess")
|
|
Fields []fieldInfo // list of fields
|
|
FilePath string // source file path
|
|
}
|
|
|
|
// toSnakeCase converts a CamelCase string to snake_case.
|
|
// E.g., "PsProductShop" -> "ps_product_shop"
|
|
func toSnakeCase(s string) string {
|
|
var result strings.Builder
|
|
for i, r := range s {
|
|
if unicode.IsUpper(r) {
|
|
if i > 0 {
|
|
result.WriteRune('_')
|
|
}
|
|
result.WriteRune(unicode.ToLower(r))
|
|
} else {
|
|
result.WriteRune(r)
|
|
}
|
|
}
|
|
return result.String()
|
|
}
|
|
|
|
// parseGormColumn extracts the column name value from a gorm tag string.
|
|
// The tag is expected to be in the format "column:name;otherTag" or similar.
|
|
// Returns the column name if found, empty string otherwise.
|
|
func parseGormColumn(tag string) string {
|
|
for _, part := range strings.Split(tag, ";") {
|
|
part = strings.TrimSpace(part)
|
|
if strings.HasPrefix(part, "column:") {
|
|
return strings.TrimPrefix(part, "column:")
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// parseModelFile parses a Go model source file and extracts struct information.
|
|
// It uses go/ast to parse the file and extract:
|
|
// - Struct name from the type declaration
|
|
// - Table name from the TableName constant
|
|
// - Field names and their database column names from gorm tags
|
|
//
|
|
// Returns nil if no valid struct is found in the file.
|
|
func parseModelFile(path string) (*structInfo, error) {
|
|
fset := token.NewFileSet()
|
|
f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var si structInfo
|
|
si.FilePath = path
|
|
|
|
for _, decl := range f.Decls {
|
|
gd, ok := decl.(*ast.GenDecl)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if gd.Tok == token.CONST {
|
|
for _, spec := range gd.Specs {
|
|
vs, ok := spec.(*ast.ValueSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
for i, name := range vs.Names {
|
|
if strings.HasPrefix(name.Name, "TableName") {
|
|
si.TableConst = name.Name
|
|
if i < len(vs.Values) {
|
|
if bl, ok := vs.Values[i].(*ast.BasicLit); ok {
|
|
si.Table = strings.Trim(bl.Value, "\"")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if gd.Tok != token.TYPE {
|
|
continue
|
|
}
|
|
for _, spec := range gd.Specs {
|
|
ts, ok := spec.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
st, ok := ts.Type.(*ast.StructType)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
si.Name = ts.Name.Name
|
|
|
|
for _, field := range st.Fields.List {
|
|
if len(field.Names) == 0 || field.Tag == nil {
|
|
continue
|
|
}
|
|
goName := field.Names[0].Name
|
|
tag := strings.Trim(field.Tag.Value, "`")
|
|
|
|
var gormTag string
|
|
for _, t := range strings.Split(tag, " ") {
|
|
t = strings.TrimSpace(t)
|
|
if strings.HasPrefix(t, "gorm:") {
|
|
gormTag = strings.TrimPrefix(t, "gorm:")
|
|
gormTag = strings.Trim(gormTag, "\"")
|
|
break
|
|
}
|
|
}
|
|
|
|
colName := parseGormColumn(gormTag)
|
|
if colName == "" {
|
|
continue
|
|
}
|
|
|
|
si.Fields = append(si.Fields, fieldInfo{
|
|
GoName: goName,
|
|
ColName: colName,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
if si.Name == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
// If no TableName constant was found, derive table name from struct name using GORM convention.
|
|
if si.Table == "" {
|
|
si.Table = toSnakeCase(si.Name)
|
|
}
|
|
|
|
return &si, nil
|
|
}
|
|
|
|
// generateColsVarBlock generates the Go source code for a <Model>Cols variable.
|
|
// The generated code defines a struct with Field typed members for each
|
|
// database column, providing type-safe column references.
|
|
func generateColsVarBlock(si *structInfo) string {
|
|
if len(si.Fields) == 0 {
|
|
return ""
|
|
}
|
|
|
|
var b strings.Builder
|
|
b.WriteString(fmt.Sprintf("\nvar %sCols = struct {\n", si.Name))
|
|
for _, f := range si.Fields {
|
|
b.WriteString(fmt.Sprintf("\t%s gormcol.Field\n", f.GoName))
|
|
}
|
|
b.WriteString("}{\n")
|
|
for _, f := range si.Fields {
|
|
b.WriteString(fmt.Sprintf("\t%s: gormcol.Field{Table: (%s{}).TableName(), Column: %q},\n", f.GoName, si.Name, f.ColName))
|
|
}
|
|
b.WriteString("}\n")
|
|
return b.String()
|
|
}
|
|
|
|
// findGoMod searches upward from startDir for a go.mod file.
|
|
func findGoMod(startDir string) (string, error) {
|
|
dir := startDir
|
|
for {
|
|
path := filepath.Join(dir, "go.mod")
|
|
if _, err := os.Stat(path); err == nil {
|
|
return path, nil
|
|
}
|
|
parent := filepath.Dir(dir)
|
|
if parent == dir {
|
|
return "", fmt.Errorf("go.mod not found from %s", startDir)
|
|
}
|
|
dir = parent
|
|
}
|
|
}
|
|
|
|
// readModulePath extracts the module path from a go.mod file.
|
|
func readModulePath(goModPath string) (string, error) {
|
|
content, err := os.ReadFile(goModPath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
for _, line := range strings.Split(string(content), "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if strings.HasPrefix(line, "module ") {
|
|
return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("module directive not found in %s", goModPath)
|
|
}
|
|
|
|
// generateCols appends <Model>Cols variables to generated model files.
|
|
// It parses each .go file in the output directory, extracts struct fields
|
|
// with gorm column tags, and generates type-safe Field descriptors.
|
|
func (m *GormGen) generateCols() error {
|
|
dir := m.cfg.OutputDir
|
|
if !strings.HasPrefix(dir, "./") {
|
|
dir = "./" + dir
|
|
}
|
|
|
|
absDir, err := filepath.Abs(dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
entries, err := os.ReadDir(absDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
goModPath, err := findGoMod(absDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
modulePath, err := readModulePath(goModPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
gormcolImport := fmt.Sprintf("%q", modulePath+"")
|
|
|
|
var fileFilter *regexp.Regexp
|
|
if len(m.cfg.SelectedTables) > 0 {
|
|
pattern := "^(" + strings.Join(m.cfg.SelectedTables, "|") + ")$"
|
|
fileFilter, err = regexp.Compile(pattern)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid selected tables pattern: %w", err)
|
|
}
|
|
} else {
|
|
fileFilter, err = regexp.Compile("^" + m.cfg.TableFilter + "$")
|
|
if err != nil {
|
|
return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err)
|
|
}
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".go") {
|
|
continue
|
|
}
|
|
|
|
// Match files by the table filter regex (strip .go suffix for matching)
|
|
fileBase := strings.TrimSuffix(entry.Name(), ".go")
|
|
if !fileFilter.MatchString(fileBase) {
|
|
continue
|
|
}
|
|
|
|
path := filepath.Join(absDir, entry.Name())
|
|
|
|
si, err := parseModelFile(path)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "warning: skipping cols for %s: %v\n", entry.Name(), err)
|
|
continue
|
|
}
|
|
if si == nil || len(si.Fields) == 0 {
|
|
continue
|
|
}
|
|
|
|
content, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fileContent := string(content)
|
|
|
|
if strings.Contains(fileContent, "gormcol.Field{") {
|
|
continue
|
|
}
|
|
|
|
if !strings.Contains(fileContent, gormcolImport) {
|
|
if strings.Contains(fileContent, "import (") {
|
|
fileContent = strings.Replace(fileContent, "import (", "import (\n\t"+gormcolImport, 1)
|
|
} else if strings.Contains(fileContent, "package dbmodel") {
|
|
fileContent = strings.Replace(fileContent, "package dbmodel",
|
|
"package dbmodel\n\nimport "+gormcolImport, 1)
|
|
}
|
|
}
|
|
|
|
colsBlock := generateColsVarBlock(si)
|
|
fileContent += colsBlock
|
|
|
|
if err := os.WriteFile(path, []byte(fileContent), 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
fmt.Printf("Cols: %s\n", entry.Name())
|
|
}
|
|
|
|
return nil
|
|
}
|