replica ready
This commit is contained in:
168
pkg/replica/sqlbuilder.go
Normal file
168
pkg/replica/sqlbuilder.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package replica
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-mysql-org/go-mysql/replication"
|
||||
)
|
||||
|
||||
// SQLBuilder handles SQL statement building
|
||||
type SQLBuilder struct{}
|
||||
|
||||
// NewSQLBuilder creates a new SQL builder
|
||||
func NewSQLBuilder() *SQLBuilder {
|
||||
return &SQLBuilder{}
|
||||
}
|
||||
|
||||
// BuildInsert builds an INSERT statement
|
||||
func (sb *SQLBuilder) BuildInsert(schema, table string, tableMap *replication.TableMapEvent, row []interface{}) string {
|
||||
if len(row) == 0 {
|
||||
return fmt.Sprintf("INSERT INTO `%s`.`%s` VALUES ()", schema, table)
|
||||
}
|
||||
|
||||
var columns []string
|
||||
var values []string
|
||||
|
||||
for i, col := range row {
|
||||
colName := sb.getColumnName(tableMap, i)
|
||||
columns = append(columns, colName)
|
||||
values = append(values, formatValue(col))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES (%s)",
|
||||
schema, table, strings.Join(columns, ", "), strings.Join(values, ", "))
|
||||
}
|
||||
|
||||
// BuildUpdate builds an UPDATE statement
|
||||
func (sb *SQLBuilder) BuildUpdate(schema, table string, tableMap *replication.TableMapEvent, before, after []interface{}) string {
|
||||
if len(before) == 0 || len(after) == 0 {
|
||||
return fmt.Sprintf("UPDATE `%s`.`%s` SET id = id", schema, table)
|
||||
}
|
||||
|
||||
var setClauses []string
|
||||
var whereClauses []string
|
||||
|
||||
for i := range before {
|
||||
colName := sb.getColumnName(tableMap, i)
|
||||
if !valuesEqual(before[i], after[i]) {
|
||||
setClauses = append(setClauses, fmt.Sprintf("%s = %s", colName, formatValue(after[i])))
|
||||
}
|
||||
if i == 0 {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s = %s", colName, formatValue(before[i])))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("UPDATE `%s`.`%s` SET %s WHERE %s",
|
||||
schema, table, strings.Join(setClauses, ", "), strings.Join(whereClauses, " AND "))
|
||||
}
|
||||
|
||||
// BuildDelete builds a DELETE statement
|
||||
func (sb *SQLBuilder) BuildDelete(schema, table string, tableMap *replication.TableMapEvent, row []interface{}) string {
|
||||
if len(row) == 0 {
|
||||
return fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE 1=0", schema, table)
|
||||
}
|
||||
|
||||
colName := sb.getColumnName(tableMap, 0)
|
||||
whereClause := fmt.Sprintf("%s = %s", colName, formatValue(row[0]))
|
||||
return fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE %s", schema, table, whereClause)
|
||||
}
|
||||
|
||||
// getColumnName returns the column name at the given index
|
||||
func (sb *SQLBuilder) getColumnName(tableMap *replication.TableMapEvent, index int) string {
|
||||
if tableMap == nil || index >= len(tableMap.ColumnName) {
|
||||
return fmt.Sprintf("`col_%d`", index)
|
||||
}
|
||||
return fmt.Sprintf("`%s`", tableMap.ColumnName[index])
|
||||
}
|
||||
|
||||
// formatValue formats a value for SQL
|
||||
func formatValue(col interface{}) string {
|
||||
switch v := col.(type) {
|
||||
case []byte:
|
||||
if str := string(v); validUTF8(v) {
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(str, "'", "''"))
|
||||
}
|
||||
return fmt.Sprintf("X'%s'", hexEncode(v))
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''"))
|
||||
case nil:
|
||||
return "NULL"
|
||||
default:
|
||||
return fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
}
|
||||
|
||||
func validUTF8(b []byte) bool {
|
||||
for i := 0; i < len(b); {
|
||||
if b[i] < 0x80 {
|
||||
i++
|
||||
} else if (b[i] & 0xE0) == 0xC0 {
|
||||
if i+1 >= len(b) || (b[i+1]&0xC0) != 0x80 {
|
||||
return false
|
||||
}
|
||||
i += 2
|
||||
} else if (b[i] & 0xF0) == 0xE0 {
|
||||
if i+2 >= len(b) || (b[i+1]&0xC0) != 0x80 || (b[i+2]&0xC0) != 0x80 {
|
||||
return false
|
||||
}
|
||||
i += 3
|
||||
} else if (b[i] & 0xF8) == 0xF0 {
|
||||
if i+3 >= len(b) || (b[i+1]&0xC0) != 0x80 || (b[i+2]&0xC0) != 0x80 || (b[i+3]&0xC0) != 0x80 {
|
||||
return false
|
||||
}
|
||||
i += 4
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hexEncode(b []byte) string {
|
||||
hexChars := "0123456789ABCDEF"
|
||||
result := make([]byte, len(b)*2)
|
||||
for i, byteVal := range b {
|
||||
result[i*2] = hexChars[byteVal>>4]
|
||||
result[i*2+1] = hexChars[byteVal&0x0F]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// valuesEqual compares two values, handling slices properly
|
||||
func valuesEqual(a, b interface{}) bool {
|
||||
// Handle nil cases
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle byte slices specially
|
||||
if aBytes, ok := a.([]byte); ok {
|
||||
if bBytes, ok := b.([]byte); ok {
|
||||
return bytesEqual(aBytes, bBytes)
|
||||
}
|
||||
return false
|
||||
}
|
||||
if _, ok := b.([]byte); ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// For other types, use fmt.Sprintf comparison
|
||||
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
||||
}
|
||||
|
||||
// bytesEqual compares two byte slices
|
||||
func bytesEqual(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user