package gorm_tracing import ( "errors" "fmt" "strings" "time" "git.ma-al.com/maal-libraries/observer/pkg/attr" "git.ma-al.com/maal-libraries/observer/pkg/event" "git.ma-al.com/maal-libraries/observer/pkg/level" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.25.0" "go.opentelemetry.io/otel/trace" "gorm.io/gorm" ) type Option func(p *gormPlugin) func defaultSqlQueryFormatter(query string) string { return strings.Join(strings.Fields(strings.TrimSpace(query)), " ") } // WithEveryStatementAsEvent configures the plugin to log all statements as distinct events, not just errors and warnings. func WithEveryStatementAsEvent() Option { return func(p *gormPlugin) { p.everyStatementAsEvent = true } } // WithSlowQueryDuration configures the duration at which the query will be considered slow an logged with warning. func WithSlowQueryDuration(duration time.Duration) Option { return func(p *gormPlugin) { p.slowQueryDuration = duration } } // WithTracerProvider configures a tracer provider that is used to create a tracer. func WithTracerProvider(provider trace.TracerProvider) Option { return func(p *gormPlugin) { p.provider = provider } } // WithAttributes configures attributes that are used to create a span. func WithAttributes(attrs ...attribute.KeyValue) Option { return func(p *gormPlugin) { p.attrs = append(p.attrs, attrs...) } } // WithDBName configures a db.name attribute. func WithDBName(name string) Option { return func(p *gormPlugin) { p.attrs = append(p.attrs, semconv.DBNameKey.String(name)) } } // WithoutQueryVariables configures the db.statement attribute to exclude query variables func WithoutQueryVariables() Option { return func(p *gormPlugin) { p.excludeQueryVars = true } } // WithQueryFormatter configures a query formatter func WithQueryFormatter(queryFormatter func(query string) string) Option { return func(p *gormPlugin) { p.queryFormatter = queryFormatter } } // WithDefaultQueryFormatter adds a simple formatter that trims any excess whitespaces from the query func WithDefaultQueryFormatter() Option { return func(p *gormPlugin) { p.queryFormatter = defaultSqlQueryFormatter } } type gormPlugin struct { provider trace.TracerProvider tracer trace.Tracer attrs []attribute.KeyValue excludeQueryVars bool queryFormatter func(query string) string slowQueryDuration time.Duration everyStatementAsEvent bool } // Overrides and sets some options with recommended defaults func DefaultGormPlugin(opts ...Option) gorm.Plugin { p := &gormPlugin{} for _, opt := range opts { opt(p) } WithDefaultQueryFormatter()(p) WithoutQueryVariables()(p) WithEveryStatementAsEvent()(p) p.provider = otel.GetTracerProvider() p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing") return p } func NewGormPlugin(opts ...Option) gorm.Plugin { p := &gormPlugin{} for _, opt := range opts { opt(p) } if p.provider == nil { p.provider = otel.GetTracerProvider() } p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing") return p } func (p gormPlugin) Name() string { return "observerGorm" } type gormHookFunc func(tx *gorm.DB) type gormRegister interface { Register(name string, fn func(*gorm.DB)) error } func (p gormPlugin) Initialize(db *gorm.DB) (err error) { cb := db.Callback() hooks := []struct { callback gormRegister hook gormHookFunc name string }{ {cb.Create().Before("gorm:create"), p.before("gorm.Create"), "before:create"}, {cb.Create().After("gorm:create"), p.after(), "after:create"}, {cb.Query().Before("gorm:query"), p.before("gorm.Query"), "before:select"}, {cb.Query().After("gorm:query"), p.after(), "after:select"}, {cb.Delete().Before("gorm:delete"), p.before("gorm.Delete"), "before:delete"}, {cb.Delete().After("gorm:delete"), p.after(), "after:delete"}, {cb.Update().Before("gorm:update"), p.before("gorm.Update"), "before:update"}, {cb.Update().After("gorm:update"), p.after(), "after:update"}, {cb.Row().Before("gorm:row"), p.before("gorm.Row"), "before:row"}, {cb.Row().After("gorm:row"), p.after(), "after:row"}, {cb.Raw().Before("gorm:raw"), p.before("gorm.Raw"), "before:raw"}, {cb.Raw().After("gorm:raw"), p.after(), "after:raw"}, } var firstErr error for _, h := range hooks { if err := h.callback.Register("observer:"+h.name, h.hook); err != nil && firstErr == nil { firstErr = fmt.Errorf("callback register %s failed: %w", h.name, err) } } return firstErr } func (p *gormPlugin) before(spanName string) gormHookFunc { return func(tx *gorm.DB) { ctx := tx.Statement.Context tx.Statement.Set("observer:statement_start", time.Now()) tx.Statement.Context, _ = p.tracer.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindClient)) } } func (p *gormPlugin) after() gormHookFunc { return func(tx *gorm.DB) { span := trace.SpanFromContext(tx.Statement.Context) if !span.IsRecording() { return } defer span.End() attrs := make([]attribute.KeyValue, 0, len(p.attrs)+4) attrs = append(attrs, p.attrs...) if sys := dbSystem(tx); sys.Valid() { attrs = append(attrs, sys) } vars := tx.Statement.Vars if p.excludeQueryVars { // Replace query variables with '?' to mask them vars = make([]interface{}, len(tx.Statement.Vars)) for i := 0; i < len(vars); i++ { vars[i] = "?" } } query := tx.Dialector.Explain(tx.Statement.SQL.String(), vars...) attrs = append(attrs, semconv.DBStatementKey.String(p.formatQuery(query))) if tx.Statement.Table != "" { attrs = append(attrs, semconv.DBSQLTableKey.String(tx.Statement.Table)) } if tx.Statement.RowsAffected != -1 { attrs = append(attrs, attr.DBRowsAffected(tx.Statement.RowsAffected)) } slowQuery := false if statementStart, ok := tx.Statement.Get("observer:statement_start"); ok { start := statementStart.(time.Time) duration := time.Now().Sub(start) attrs = append(attrs, attr.DBExecutionTimeMs(duration)) if p.slowQueryDuration != time.Duration(0) { if duration >= p.slowQueryDuration { slowQuery = true event.NewErrInSpan(event.Error{ Level: level.WARN, Err: errors.New("slow query execution"), Attributes: attrs, }.SkipMoreInCallStack(3), span) attrs = append(attrs, attr.SeverityLevel(level.WARN)) } } } errQuery := false if tx.Statement.Error != nil { errQuery = true event.NewErrInSpan(event.Error{ Level: level.ERR, Err: tx.Statement.Error, Attributes: attrs, }.SkipMoreInCallStack(3), span) attrs = append(attrs, attr.SeverityLevel(level.ERR)) } if !slowQuery && !errQuery && p.everyStatementAsEvent { event.NewInSpan(event.Event{ Level: level.DEBUG, ShortMessage: "executed an sql query with gorm", Attributes: attrs, }.SkipMoreInCallStack(3), span) attrs = append(attrs, attr.SeverityLevel(level.DEBUG)) } span.SetAttributes(attrs...) } } func (p *gormPlugin) formatQuery(query string) string { if p.queryFormatter != nil { return p.queryFormatter(query) } return query } func dbSystem(tx *gorm.DB) attribute.KeyValue { switch tx.Dialector.Name() { case "mysql": return semconv.DBSystemMySQL case "mssql": return semconv.DBSystemMSSQL case "postgres", "postgresql": return semconv.DBSystemPostgreSQL case "sqlite": return semconv.DBSystemSqlite case "sqlserver": return semconv.DBSystemKey.String("sqlserver") case "clickhouse": return semconv.DBSystemKey.String("clickhouse") default: return attribute.KeyValue{} } }