Files
replica/pkg/replica/sqlbuilder.go
2026-02-12 04:55:32 +01:00

169 lines
4.3 KiB
Go

package replica
import (
"fmt"
"strings"
"github.com/go-mysql-org/go-mysql/replication"
)
// SQLBuilder handles SQL statement building
type SQLBuilder struct{}
// NewSQLBuilder creates a new SQL builder
func NewSQLBuilder() *SQLBuilder {
return &SQLBuilder{}
}
// BuildInsert builds an INSERT statement
func (sb *SQLBuilder) BuildInsert(schema, table string, tableMap *replication.TableMapEvent, row []interface{}) string {
if len(row) == 0 {
return fmt.Sprintf("INSERT INTO `%s`.`%s` VALUES ()", schema, table)
}
var columns []string
var values []string
for i, col := range row {
colName := sb.getColumnName(tableMap, i)
columns = append(columns, colName)
values = append(values, formatValue(col))
}
return fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES (%s)",
schema, table, strings.Join(columns, ", "), strings.Join(values, ", "))
}
// BuildUpdate builds an UPDATE statement
func (sb *SQLBuilder) BuildUpdate(schema, table string, tableMap *replication.TableMapEvent, before, after []interface{}) string {
if len(before) == 0 || len(after) == 0 {
return fmt.Sprintf("UPDATE `%s`.`%s` SET id = id", schema, table)
}
var setClauses []string
var whereClauses []string
for i := range before {
colName := sb.getColumnName(tableMap, i)
if !valuesEqual(before[i], after[i]) {
setClauses = append(setClauses, fmt.Sprintf("%s = %s", colName, formatValue(after[i])))
}
if i == 0 {
whereClauses = append(whereClauses, fmt.Sprintf("%s = %s", colName, formatValue(before[i])))
}
}
return fmt.Sprintf("UPDATE `%s`.`%s` SET %s WHERE %s",
schema, table, strings.Join(setClauses, ", "), strings.Join(whereClauses, " AND "))
}
// BuildDelete builds a DELETE statement
func (sb *SQLBuilder) BuildDelete(schema, table string, tableMap *replication.TableMapEvent, row []interface{}) string {
if len(row) == 0 {
return fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE 1=0", schema, table)
}
colName := sb.getColumnName(tableMap, 0)
whereClause := fmt.Sprintf("%s = %s", colName, formatValue(row[0]))
return fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE %s", schema, table, whereClause)
}
// getColumnName returns the column name at the given index
func (sb *SQLBuilder) getColumnName(tableMap *replication.TableMapEvent, index int) string {
if tableMap == nil || index >= len(tableMap.ColumnName) {
return fmt.Sprintf("`col_%d`", index)
}
return fmt.Sprintf("`%s`", tableMap.ColumnName[index])
}
// formatValue formats a value for SQL
func formatValue(col interface{}) string {
switch v := col.(type) {
case []byte:
if str := string(v); validUTF8(v) {
return fmt.Sprintf("'%s'", strings.ReplaceAll(str, "'", "''"))
}
return fmt.Sprintf("X'%s'", hexEncode(v))
case string:
return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
case nil:
return "NULL"
default:
return fmt.Sprintf("'%v'", v)
}
}
func validUTF8(b []byte) bool {
for i := 0; i < len(b); {
if b[i] < 0x80 {
i++
} else if (b[i] & 0xE0) == 0xC0 {
if i+1 >= len(b) || (b[i+1]&0xC0) != 0x80 {
return false
}
i += 2
} else if (b[i] & 0xF0) == 0xE0 {
if i+2 >= len(b) || (b[i+1]&0xC0) != 0x80 || (b[i+2]&0xC0) != 0x80 {
return false
}
i += 3
} else if (b[i] & 0xF8) == 0xF0 {
if i+3 >= len(b) || (b[i+1]&0xC0) != 0x80 || (b[i+2]&0xC0) != 0x80 || (b[i+3]&0xC0) != 0x80 {
return false
}
i += 4
} else {
return false
}
}
return true
}
func hexEncode(b []byte) string {
hexChars := "0123456789ABCDEF"
result := make([]byte, len(b)*2)
for i, byteVal := range b {
result[i*2] = hexChars[byteVal>>4]
result[i*2+1] = hexChars[byteVal&0x0F]
}
return string(result)
}
// valuesEqual compares two values, handling slices properly
func valuesEqual(a, b interface{}) bool {
// Handle nil cases
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
// Handle byte slices specially
if aBytes, ok := a.([]byte); ok {
if bBytes, ok := b.([]byte); ok {
return bytesEqual(aBytes, bBytes)
}
return false
}
if _, ok := b.([]byte); ok {
return false
}
// For other types, use fmt.Sprintf comparison
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
}
// bytesEqual compares two byte slices
func bytesEqual(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}