169 lines
4.3 KiB
Go
169 lines
4.3 KiB
Go
package replica
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/go-mysql-org/go-mysql/replication"
|
|
)
|
|
|
|
// SQLBuilder handles SQL statement building
|
|
type SQLBuilder struct{}
|
|
|
|
// NewSQLBuilder creates a new SQL builder
|
|
func NewSQLBuilder() *SQLBuilder {
|
|
return &SQLBuilder{}
|
|
}
|
|
|
|
// BuildInsert builds an INSERT statement
|
|
func (sb *SQLBuilder) BuildInsert(schema, table string, tableMap *replication.TableMapEvent, row []interface{}) string {
|
|
if len(row) == 0 {
|
|
return fmt.Sprintf("INSERT INTO `%s`.`%s` VALUES ()", schema, table)
|
|
}
|
|
|
|
var columns []string
|
|
var values []string
|
|
|
|
for i, col := range row {
|
|
colName := sb.getColumnName(tableMap, i)
|
|
columns = append(columns, colName)
|
|
values = append(values, formatValue(col))
|
|
}
|
|
|
|
return fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES (%s)",
|
|
schema, table, strings.Join(columns, ", "), strings.Join(values, ", "))
|
|
}
|
|
|
|
// BuildUpdate builds an UPDATE statement
|
|
func (sb *SQLBuilder) BuildUpdate(schema, table string, tableMap *replication.TableMapEvent, before, after []interface{}) string {
|
|
if len(before) == 0 || len(after) == 0 {
|
|
return fmt.Sprintf("UPDATE `%s`.`%s` SET id = id", schema, table)
|
|
}
|
|
|
|
var setClauses []string
|
|
var whereClauses []string
|
|
|
|
for i := range before {
|
|
colName := sb.getColumnName(tableMap, i)
|
|
if !valuesEqual(before[i], after[i]) {
|
|
setClauses = append(setClauses, fmt.Sprintf("%s = %s", colName, formatValue(after[i])))
|
|
}
|
|
if i == 0 {
|
|
whereClauses = append(whereClauses, fmt.Sprintf("%s = %s", colName, formatValue(before[i])))
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("UPDATE `%s`.`%s` SET %s WHERE %s",
|
|
schema, table, strings.Join(setClauses, ", "), strings.Join(whereClauses, " AND "))
|
|
}
|
|
|
|
// BuildDelete builds a DELETE statement
|
|
func (sb *SQLBuilder) BuildDelete(schema, table string, tableMap *replication.TableMapEvent, row []interface{}) string {
|
|
if len(row) == 0 {
|
|
return fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE 1=0", schema, table)
|
|
}
|
|
|
|
colName := sb.getColumnName(tableMap, 0)
|
|
whereClause := fmt.Sprintf("%s = %s", colName, formatValue(row[0]))
|
|
return fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE %s", schema, table, whereClause)
|
|
}
|
|
|
|
// getColumnName returns the column name at the given index
|
|
func (sb *SQLBuilder) getColumnName(tableMap *replication.TableMapEvent, index int) string {
|
|
if tableMap == nil || index >= len(tableMap.ColumnName) {
|
|
return fmt.Sprintf("`col_%d`", index)
|
|
}
|
|
return fmt.Sprintf("`%s`", tableMap.ColumnName[index])
|
|
}
|
|
|
|
// formatValue formats a value for SQL
|
|
func formatValue(col interface{}) string {
|
|
switch v := col.(type) {
|
|
case []byte:
|
|
if str := string(v); validUTF8(v) {
|
|
return fmt.Sprintf("'%s'", strings.ReplaceAll(str, "'", "''"))
|
|
}
|
|
return fmt.Sprintf("X'%s'", hexEncode(v))
|
|
case string:
|
|
return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
|
case nil:
|
|
return "NULL"
|
|
default:
|
|
return fmt.Sprintf("'%v'", v)
|
|
}
|
|
}
|
|
|
|
func validUTF8(b []byte) bool {
|
|
for i := 0; i < len(b); {
|
|
if b[i] < 0x80 {
|
|
i++
|
|
} else if (b[i] & 0xE0) == 0xC0 {
|
|
if i+1 >= len(b) || (b[i+1]&0xC0) != 0x80 {
|
|
return false
|
|
}
|
|
i += 2
|
|
} else if (b[i] & 0xF0) == 0xE0 {
|
|
if i+2 >= len(b) || (b[i+1]&0xC0) != 0x80 || (b[i+2]&0xC0) != 0x80 {
|
|
return false
|
|
}
|
|
i += 3
|
|
} else if (b[i] & 0xF8) == 0xF0 {
|
|
if i+3 >= len(b) || (b[i+1]&0xC0) != 0x80 || (b[i+2]&0xC0) != 0x80 || (b[i+3]&0xC0) != 0x80 {
|
|
return false
|
|
}
|
|
i += 4
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func hexEncode(b []byte) string {
|
|
hexChars := "0123456789ABCDEF"
|
|
result := make([]byte, len(b)*2)
|
|
for i, byteVal := range b {
|
|
result[i*2] = hexChars[byteVal>>4]
|
|
result[i*2+1] = hexChars[byteVal&0x0F]
|
|
}
|
|
return string(result)
|
|
}
|
|
|
|
// valuesEqual compares two values, handling slices properly
|
|
func valuesEqual(a, b interface{}) bool {
|
|
// Handle nil cases
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
|
|
// Handle byte slices specially
|
|
if aBytes, ok := a.([]byte); ok {
|
|
if bBytes, ok := b.([]byte); ok {
|
|
return bytesEqual(aBytes, bBytes)
|
|
}
|
|
return false
|
|
}
|
|
if _, ok := b.([]byte); ok {
|
|
return false
|
|
}
|
|
|
|
// For other types, use fmt.Sprintf comparison
|
|
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
|
}
|
|
|
|
// bytesEqual compares two byte slices
|
|
func bytesEqual(a, b []byte) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if a[i] != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|