initial
This commit is contained in:
@@ -0,0 +1,604 @@
|
||||
package gormcol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type GenConfig struct {
|
||||
OutputDir string
|
||||
PkgName string
|
||||
TableFilter string // regex pattern, e.g. "(ps_|b2b_).*"
|
||||
}
|
||||
|
||||
func defaultConfig() GenConfig {
|
||||
return GenConfig{
|
||||
OutputDir: "./app/model/dbmodel",
|
||||
PkgName: "dbmodel",
|
||||
TableFilter: "ps_.*",
|
||||
}
|
||||
}
|
||||
|
||||
type GormGen struct {
|
||||
db *gorm.DB
|
||||
cfg GenConfig
|
||||
}
|
||||
|
||||
func New(db *gorm.DB) *GormGen {
|
||||
return &GormGen{db: db, cfg: defaultConfig()}
|
||||
}
|
||||
|
||||
func NewWithConfig(db *gorm.DB, cfg GenConfig) *GormGen {
|
||||
d := defaultConfig()
|
||||
if cfg.OutputDir != "" {
|
||||
d.OutputDir = cfg.OutputDir
|
||||
}
|
||||
if cfg.PkgName != "" {
|
||||
d.PkgName = cfg.PkgName
|
||||
}
|
||||
if cfg.TableFilter != "" {
|
||||
d.TableFilter = cfg.TableFilter
|
||||
}
|
||||
return &GormGen{db: db, cfg: d}
|
||||
}
|
||||
|
||||
// ConnectDSN opens a MySQL/MariaDB connection from a DSN string.
|
||||
func ConnectDSN(dsn string) (*gorm.DB, error) {
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Error),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect with dsn: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (m *GormGen) GenModels(ctx context.Context) error {
|
||||
re, err := regexp.Compile("^" + m.cfg.TableFilter + "$")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid table filter regex %q: %w", m.cfg.TableFilter, err)
|
||||
}
|
||||
|
||||
if err := m.cleanOutputDir(); err != nil {
|
||||
return fmt.Errorf("failed to clean output dir: %w", err)
|
||||
}
|
||||
|
||||
g := gen.NewGenerator(gen.Config{
|
||||
OutPath: m.cfg.OutputDir,
|
||||
ModelPkgPath: m.cfg.PkgName,
|
||||
FieldNullable: true,
|
||||
FieldWithIndexTag: true,
|
||||
})
|
||||
|
||||
g.UseDB(m.db)
|
||||
|
||||
tableNames, err := m.db.Migrator().GetTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table list: %w", err)
|
||||
}
|
||||
|
||||
matched := 0
|
||||
for _, tableName := range tableNames {
|
||||
if re.MatchString(tableName) {
|
||||
g.GenerateModel(tableName)
|
||||
matched++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Matched %d tables with filter %q\n", matched, m.cfg.TableFilter)
|
||||
|
||||
g.Execute()
|
||||
|
||||
if err := m.cleanupGeneratedFiles(); err != nil {
|
||||
return fmt.Errorf("failed to cleanup generated files: %w", err)
|
||||
}
|
||||
|
||||
if err := m.generateCols(re); err != nil {
|
||||
return fmt.Errorf("failed to generate column descriptors: %w", err)
|
||||
}
|
||||
|
||||
belongsTo, many2many, err := m.fetchRelations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch relations: %w", err)
|
||||
}
|
||||
fmt.Printf("Found %d belongs_to and %d many2many relations\n", len(belongsTo), len(many2many))
|
||||
|
||||
if err := m.updateModelsWithAssociations(belongsTo, many2many, re); err != nil {
|
||||
return fmt.Errorf("failed to update models with associations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GormGen) cleanOutputDir() error {
|
||||
dir := m.cfg.OutputDir
|
||||
if !strings.HasPrefix(dir, "./") {
|
||||
dir = "./" + dir
|
||||
}
|
||||
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(absDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(absDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create output dir: %w", err)
|
||||
}
|
||||
fmt.Printf("Created: %s\n", absDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(absDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(entry.Name(), ".go") {
|
||||
path := filepath.Join(absDir, entry.Name())
|
||||
if err := os.Remove(path); err != nil {
|
||||
return fmt.Errorf("failed to remove %s: %w", path, err)
|
||||
}
|
||||
fmt.Printf("Removed: %s\n", path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GormGen) cleanupGeneratedFiles() error {
|
||||
filesToRemove := []string{"gen.go", "do.go", "_gen.go"}
|
||||
|
||||
dir := m.cfg.OutputDir
|
||||
if !strings.HasPrefix(dir, "./") {
|
||||
dir = "./" + dir
|
||||
}
|
||||
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, fileName := range filesToRemove {
|
||||
filePath := filepath.Join(absDir, fileName)
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
return fmt.Errorf("failed to remove %s: %w", filePath, err)
|
||||
}
|
||||
fmt.Printf("Removed: %s\n", filePath)
|
||||
}
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(absDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
re, err := regexp.Compile("^(" + m.cfg.TableFilter + ")\\.gen\\.go$")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
name := file.Name()
|
||||
if re.MatchString(name) {
|
||||
oldPath := filepath.Join(absDir, name)
|
||||
baseName := strings.TrimSuffix(name, ".gen.go")
|
||||
newPath := filepath.Join(absDir, baseName+".go")
|
||||
|
||||
content, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content = m.cleanModelContent(content)
|
||||
|
||||
if err := os.WriteFile(newPath, content, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Remove(oldPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Renamed: %s -> %s\n", oldPath, newPath)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type associationInfo struct {
|
||||
TableName string
|
||||
FKColumn string
|
||||
RefTable string
|
||||
RefColumn string
|
||||
AssociationName string
|
||||
AssociationType string
|
||||
JoinTableName string
|
||||
}
|
||||
|
||||
func (m *GormGen) fetchRelations() (map[string][]associationInfo, map[string][]associationInfo, error) {
|
||||
belongsTo := make(map[string][]associationInfo)
|
||||
many2many := make(map[string][]associationInfo)
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
kcu.TABLE_NAME,
|
||||
kcu.COLUMN_NAME,
|
||||
kcu.REFERENCED_TABLE_NAME,
|
||||
kcu.REFERENCED_COLUMN_NAME,
|
||||
kcu.CONSTRAINT_NAME
|
||||
FROM information_schema.KEY_COLUMN_USAGE kcu
|
||||
JOIN information_schema.TABLE_CONSTRAINTS tc
|
||||
ON kcu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
|
||||
AND kcu.TABLE_SCHEMA = tc.TABLE_SCHEMA
|
||||
AND kcu.TABLE_NAME = tc.TABLE_NAME
|
||||
WHERE tc.CONSTRAINT_TYPE = 'FOREIGN KEY'
|
||||
AND kcu.TABLE_SCHEMA = DATABASE()
|
||||
`
|
||||
|
||||
rows, err := m.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query foreign keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tableFKs := make(map[string][]associationInfo)
|
||||
for rows.Next() {
|
||||
var tableName, columnName, refTable, refCol, constraintName string
|
||||
if err := rows.Scan(&tableName, &columnName, &refTable, &refCol, &constraintName); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to scan foreign key row: %w", err)
|
||||
}
|
||||
|
||||
refModelName := tableNameToModelName(refTable)
|
||||
assocName := singularize(refModelName)
|
||||
|
||||
info := associationInfo{
|
||||
TableName: tableName,
|
||||
FKColumn: columnName,
|
||||
RefTable: refTable,
|
||||
RefColumn: refCol,
|
||||
AssociationName: assocName,
|
||||
AssociationType: "belongs_to",
|
||||
}
|
||||
tableFKs[tableName] = append(tableFKs[tableName], info)
|
||||
}
|
||||
|
||||
for tableName, fks := range tableFKs {
|
||||
if isMany2Many(tableName, fks) {
|
||||
for _, fk := range fks {
|
||||
refModelName := tableNameToModelName(fk.RefTable)
|
||||
assocName := pluralize(refModelName)
|
||||
joinTable := fk.TableName + "_" + fk.RefTable
|
||||
if fk.TableName > fk.RefTable {
|
||||
joinTable = fk.RefTable + "_" + fk.TableName
|
||||
}
|
||||
m2mInfo := associationInfo{
|
||||
TableName: tableName,
|
||||
FKColumn: fk.FKColumn,
|
||||
RefTable: fk.RefTable,
|
||||
RefColumn: fk.RefColumn,
|
||||
AssociationName: assocName,
|
||||
AssociationType: "many2many",
|
||||
JoinTableName: joinTable,
|
||||
}
|
||||
many2many[fk.RefTable] = append(many2many[fk.RefTable], m2mInfo)
|
||||
many2many[tableName] = append(many2many[tableName], m2mInfo)
|
||||
}
|
||||
} else {
|
||||
belongsTo[tableName] = fks
|
||||
}
|
||||
}
|
||||
|
||||
return belongsTo, many2many, nil
|
||||
}
|
||||
|
||||
func isMany2Many(tableName string, fks []associationInfo) bool {
|
||||
if len(fks) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
refTables := make(map[string]bool)
|
||||
for _, fk := range fks {
|
||||
refTables[fk.RefTable] = true
|
||||
}
|
||||
|
||||
if len(refTables) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
tableLower := strings.ToLower(tableName)
|
||||
|
||||
if strings.HasSuffix(tableLower, "_product") && strings.Contains(tableLower, "_") {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasSuffix(tableLower, "_category") && strings.Contains(tableLower, "_") {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasSuffix(tableLower, "_cart") && strings.Contains(tableLower, "_") {
|
||||
return true
|
||||
}
|
||||
|
||||
joinTableSuffixes := []string{
|
||||
"_rel", "_link", "_map", "_junction", "_join",
|
||||
"_products", "_categories", "_carts",
|
||||
}
|
||||
for _, suffix := range joinTableSuffixes {
|
||||
if strings.HasSuffix(tableLower, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.Split(tableLower, "_")
|
||||
if len(parts) >= 3 {
|
||||
hasIdColumn := false
|
||||
for _, fk := range fks {
|
||||
if strings.Contains(strings.ToLower(fk.FKColumn), "id") {
|
||||
hasIdColumn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasIdColumn && len(fks) == 2 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func pluralize(word string) string {
|
||||
wordLower := strings.ToLower(word)
|
||||
if strings.HasSuffix(wordLower, "y") && len(word) > 1 {
|
||||
return word[:len(word)-1] + "ies"
|
||||
}
|
||||
if strings.HasSuffix(wordLower, "s") || strings.HasSuffix(wordLower, "x") ||
|
||||
strings.HasSuffix(wordLower, "z") || strings.HasSuffix(wordLower, "ch") ||
|
||||
strings.HasSuffix(wordLower, "sh") {
|
||||
return word + "es"
|
||||
}
|
||||
return word + "s"
|
||||
}
|
||||
|
||||
func (m *GormGen) updateModelsWithAssociations(belongsTo, many2many map[string][]associationInfo, fileFilter *regexp.Regexp) error {
|
||||
dir := m.cfg.OutputDir
|
||||
if !strings.HasPrefix(dir, "./") {
|
||||
dir = "./" + dir
|
||||
}
|
||||
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(absDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allAssocs := make(map[string][]associationInfo)
|
||||
for table, assocs := range belongsTo {
|
||||
allAssocs[table] = append(allAssocs[table], assocs...)
|
||||
}
|
||||
for table, assocs := range many2many {
|
||||
allAssocs[table] = append(allAssocs[table], assocs...)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
fileBase := strings.TrimSuffix(entry.Name(), ".go")
|
||||
if !fileFilter.MatchString(fileBase) {
|
||||
continue
|
||||
}
|
||||
|
||||
tableAssocs, ok := allAssocs[fileBase]
|
||||
if !ok || len(tableAssocs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(absDir, entry.Name())
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", entry.Name(), err)
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
var newFields []string
|
||||
for _, assoc := range tableAssocs {
|
||||
refModelType := tableNameToModelName(assoc.RefTable)
|
||||
fieldName := assoc.AssociationName
|
||||
|
||||
if assoc.AssociationType == "many2many" {
|
||||
newFields = append(newFields, fmt.Sprintf("\t%s []*%s `gorm:\"many2many:%s\"`", fieldName, refModelType, assoc.JoinTableName))
|
||||
} else {
|
||||
fkColumn := snakeToPascal(assoc.FKColumn)
|
||||
refColumn := snakeToPascal(assoc.RefColumn)
|
||||
newFields = append(newFields, fmt.Sprintf("\t%s *%s `gorm:\"foreignKey:%s;references:%s\"`", fieldName, refModelType, fkColumn, refColumn))
|
||||
}
|
||||
}
|
||||
|
||||
if len(newFields) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
insertPos := findStructEndPos(contentStr)
|
||||
if insertPos < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := contentStr[:insertPos] + strings.Join(newFields, "\n") + "\n" + contentStr[insertPos:]
|
||||
|
||||
if err := os.WriteFile(path, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write %s: %w", entry.Name(), err)
|
||||
}
|
||||
fmt.Printf("Added associations: %s\n", entry.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableNameToJoinTableName(table1, table2 string) string {
|
||||
tables := []string{table1, table2}
|
||||
sort.Strings(tables)
|
||||
return tables[0] + "_" + tables[1]
|
||||
}
|
||||
|
||||
func snakeToPascal(s string) string {
|
||||
parts := strings.Split(s, "_")
|
||||
var result []string
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(part)
|
||||
if lower == "id" {
|
||||
result = append(result, "ID")
|
||||
} else if lower == "url" || lower == "api" || lower == "http" || lower == "tcp" || lower == "udp" {
|
||||
result = append(result, strings.ToUpper(part))
|
||||
} else {
|
||||
result = append(result, strings.ToUpper(part[:1])+part[1:])
|
||||
}
|
||||
}
|
||||
return strings.Join(result, "")
|
||||
}
|
||||
|
||||
func findStructEndPos(content string) int {
|
||||
lines := strings.Split(content, "\n")
|
||||
structStart := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "type ") && strings.Contains(trimmed, " struct {") {
|
||||
parts := strings.Split(trimmed, " ")
|
||||
if len(parts) >= 3 && parts[2] == "struct" {
|
||||
structStart = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if structStart < 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
depth := 0
|
||||
for i := structStart; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
for _, ch := range line {
|
||||
if ch == '{' {
|
||||
depth++
|
||||
} else if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
lineStart := 0
|
||||
for j := 0; j < i; j++ {
|
||||
lineStart += len(lines[j]) + 1
|
||||
}
|
||||
return lineStart + strings.Index(line, "}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
func tableNameToModelName(tableName string) string {
|
||||
parts := strings.Split(tableName, "_")
|
||||
var result []string
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if i == len(parts)-1 {
|
||||
part = singularize(part)
|
||||
}
|
||||
result = append(result, strings.ToUpper(part[:1])+part[1:])
|
||||
}
|
||||
return strings.Join(result, "")
|
||||
}
|
||||
|
||||
func singularize(word string) string {
|
||||
if strings.HasSuffix(word, "ies") {
|
||||
return word[:len(word)-3] + "Y"
|
||||
}
|
||||
if strings.HasSuffix(word, "es") && len(word) > 2 {
|
||||
if strings.HasSuffix(word, "ses") {
|
||||
return word[:len(word)-2]
|
||||
}
|
||||
return word[:len(word)-1]
|
||||
}
|
||||
if strings.HasSuffix(word, "s") && len(word) > 1 && !strings.HasSuffix(word, "ss") {
|
||||
if strings.HasSuffix(word, "us") {
|
||||
return word
|
||||
}
|
||||
return word[:len(word)-1]
|
||||
}
|
||||
return word
|
||||
}
|
||||
|
||||
func (m *GormGen) cleanModelContent(content []byte) []byte {
|
||||
result := string(content)
|
||||
|
||||
lines := strings.Split(result, "\n")
|
||||
var newLines []string
|
||||
importStarted := false
|
||||
importEnded := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
if trimmed == "import (" {
|
||||
importStarted = true
|
||||
newLines = append(newLines, line)
|
||||
continue
|
||||
}
|
||||
if importStarted && trimmed == ")" {
|
||||
importEnded = true
|
||||
importStarted = false
|
||||
newLines = append(newLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
if importStarted && !importEnded {
|
||||
if strings.Contains(trimmed, "\"gorm.io/gen\"") ||
|
||||
strings.Contains(trimmed, "\"gorm.io/gen/field\"") ||
|
||||
strings.Contains(trimmed, "\"gorm.io/plugin/dbresolver\"") ||
|
||||
strings.Contains(trimmed, "gen.DO") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(trimmed, "type psCategoryDo struct") ||
|
||||
strings.HasPrefix(trimmed, "func (p psCategoryDo)") {
|
||||
continue
|
||||
}
|
||||
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
result = strings.Join(newLines, "\n")
|
||||
result = strings.ReplaceAll(result, "psCategoryDo", "")
|
||||
|
||||
return []byte(result)
|
||||
}
|
||||
Reference in New Issue
Block a user