package gorm_tracing import ( "fmt" "git.ma-al.com/maal-libraries/observer/pkg/fiber_tracing" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.25.0" "go.opentelemetry.io/otel/trace" "gorm.io/gorm" ) var dbRowsAffected = attribute.Key("db.rows_affected") type Option func(p *gormPlugin) // WithTracerProvider configures a tracer provider that is used to create a tracer. func WithTracerProvider(provider trace.TracerProvider) Option { return func(p *gormPlugin) { p.provider = provider } } // WithAttributes configures attributes that are used to create a span. func WithAttributes(attrs ...attribute.KeyValue) Option { return func(p *gormPlugin) { p.attrs = append(p.attrs, attrs...) } } // WithDBName configures a db.name attribute. func WithDBName(name string) Option { return func(p *gormPlugin) { p.attrs = append(p.attrs, semconv.DBNameKey.String(name)) } } // WithoutQueryVariables configures the db.statement attribute to exclude query variables func WithoutQueryVariables() Option { return func(p *gormPlugin) { p.excludeQueryVars = true } } // WithQueryFormatter configures a query formatter func WithQueryFormatter(queryFormatter func(query string) string) Option { return func(p *gormPlugin) { p.queryFormatter = queryFormatter } } // WithoutMetrics prevents DBStats metrics from being reported. func WithoutMetrics() Option { return func(p *gormPlugin) { p.excludeMetrics = true } } type gormPlugin struct { provider trace.TracerProvider tracer trace.Tracer attrs []attribute.KeyValue excludeQueryVars bool excludeMetrics bool queryFormatter func(query string) string } func NewGormPlugin(opts ...Option) gorm.Plugin { p := &gormPlugin{} for _, opt := range opts { opt(p) } if p.provider == nil { p.provider = otel.GetTracerProvider() } p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing") return p } func (p gormPlugin) Name() string { return "observerGorm" } type gormHookFunc func(tx *gorm.DB) type gormRegister interface { Register(name string, fn func(*gorm.DB)) error } func (p gormPlugin) Initialize(db *gorm.DB) (err error) { // if !p.excludeMetrics { // if db, ok := db.ConnPool.(*sql.DB); ok { // metrics.ReportDBStatsMetrics(db) // } // } cb := db.Callback() hooks := []struct { callback gormRegister hook gormHookFunc name string }{ {cb.Create().Before("gorm:create"), p.before("gorm.Create"), "before:create"}, {cb.Create().After("gorm:create"), p.after(), "after:create"}, {cb.Query().Before("gorm:query"), p.before("gorm.Select"), "before:select"}, {cb.Query().After("gorm:query"), p.after(), "after:select"}, {cb.Delete().Before("gorm:delete"), p.before("gorm.Delete"), "before:delete"}, {cb.Delete().After("gorm:delete"), p.after(), "after:delete"}, {cb.Update().Before("gorm:update"), p.before("gorm.Update"), "before:update"}, {cb.Update().After("gorm:update"), p.after(), "after:update"}, {cb.Row().Before("gorm:row"), p.before("gorm.Row"), "before:row"}, {cb.Row().After("gorm:row"), p.after(), "after:row"}, {cb.Raw().Before("gorm:raw"), p.before("gorm.Raw"), "before:raw"}, {cb.Raw().After("gorm:raw"), p.after(), "after:raw"}, } var firstErr error for _, h := range hooks { if err := h.callback.Register("observer:"+h.name, h.hook); err != nil && firstErr == nil { firstErr = fmt.Errorf("callback register %s failed: %w", h.name, err) } } return firstErr } func (p *gormPlugin) before(spanName string) gormHookFunc { return func(tx *gorm.DB) { ctx := tx.Statement.Context tx.Statement.Context, _ = fiber_tracing.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindClient)) } } func (p *gormPlugin) after() gormHookFunc { return func(tx *gorm.DB) { span := trace.SpanFromContext(tx.Statement.Context) if !span.IsRecording() { return } defer span.End() // fmt.Printf("%#v\n", span) attrs := make([]attribute.KeyValue, 0, len(p.attrs)+4) attrs = append(attrs, p.attrs...) if sys := dbSystem(tx); sys.Valid() { attrs = append(attrs, sys) } vars := tx.Statement.Vars if p.excludeQueryVars { // Replace query variables with '?' to mask them vars = make([]interface{}, len(tx.Statement.Vars)) for i := 0; i < len(vars); i++ { vars[i] = "?" } } query := tx.Dialector.Explain(tx.Statement.SQL.String(), vars...) attrs = append(attrs, semconv.DBStatementKey.String(p.formatQuery(query))) if tx.Statement.Table != "" { attrs = append(attrs, semconv.DBSQLTableKey.String(tx.Statement.Table)) } if tx.Statement.RowsAffected != -1 { attrs = append(attrs, dbRowsAffected.Int64(tx.Statement.RowsAffected)) } span.SetAttributes(attrs...) } } func (p *gormPlugin) formatQuery(query string) string { if p.queryFormatter != nil { return p.queryFormatter(query) } return query } func dbSystem(tx *gorm.DB) attribute.KeyValue { switch tx.Dialector.Name() { case "mysql": return semconv.DBSystemMySQL case "mssql": return semconv.DBSystemMSSQL case "postgres", "postgresql": return semconv.DBSystemPostgreSQL case "sqlite": return semconv.DBSystemSqlite case "sqlserver": return semconv.DBSystemKey.String("sqlserver") case "clickhouse": return semconv.DBSystemKey.String("clickhouse") default: return attribute.KeyValue{} } }