From ab7f2c969ad8054aa4d5eba6a3b75cca1d747c05 Mon Sep 17 00:00:00 2001 From: Natalia Goc Date: Wed, 24 Jul 2024 15:40:18 +0200 Subject: [PATCH] feat: add gorm tracing middleware --- go.mod | 3 + go.sum | 6 + pkg/attr/attr.go | 8 + pkg/gorm_tracing/middleware.go | 281 +++++++++++++++++++++++++++++++++ 4 files changed, 298 insertions(+) create mode 100644 pkg/gorm_tracing/middleware.go diff --git a/go.mod b/go.mod index bac4b2b..6b5b4bc 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.8 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -37,4 +39,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/grpc v1.63.2 // indirect google.golang.org/protobuf v1.33.0 // indirect + gorm.io/gorm v1.25.11 // indirect ) diff --git a/go.sum b/go.sum index 2124972..27ef01b 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1/go.mod h1:5SN9VR2LTsRFsrEC6FHgRbTWrTHu6tqPeKxEQv15giM= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= @@ -104,3 +108,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg= +gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/pkg/attr/attr.go b/pkg/attr/attr.go index b17cffb..a4090c5 100644 --- a/pkg/attr/attr.go +++ b/pkg/attr/attr.go @@ -58,6 +58,7 @@ const ( ServiceLayerKey = attribute.Key("service.layer") ServiceLayerNameKey = attribute.Key("service.layer_name") DBExecutionTimeMsKey = attribute.Key("db.execution_time_ms") + DBRowsAffectedKey = attribute.Key("db.rows_affected") ) type ServiceArchitectureLayer string @@ -266,3 +267,10 @@ func DBExecutionTimeMs(duration time.Duration) attribute.KeyValue { Value: attribute.Int64Value(duration.Milliseconds()), } } + +func DBRowsAffected(rows int64) attribute.KeyValue { + return attribute.KeyValue{ + Key: DBRowsAffectedKey, + Value: attribute.Int64Value(rows), + } +} diff --git a/pkg/gorm_tracing/middleware.go b/pkg/gorm_tracing/middleware.go new file mode 100644 index 0000000..ce13b06 --- /dev/null +++ b/pkg/gorm_tracing/middleware.go @@ -0,0 +1,281 @@ +package gorm_tracing + +import ( + "errors" + "fmt" + "strings" + "time" + + "git.ma-al.com/maal-libraries/observer/pkg/attr" + "git.ma-al.com/maal-libraries/observer/pkg/event" + "git.ma-al.com/maal-libraries/observer/pkg/level" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.25.0" + "go.opentelemetry.io/otel/trace" + "gorm.io/gorm" +) + +type Option func(p *gormPlugin) + +func defaultSqlQueryFormatter(query string) string { + return strings.Join(strings.Fields(strings.TrimSpace(query)), " ") +} + +// WithEveryStatementAsEvent configures the plugin to log all statements as distinct events, not just errors and warnings. +func WithEveryStatementAsEvent() Option { + return func(p *gormPlugin) { + p.everyStatementAsEvent = true + } +} + +// WithSlowQueryDuration configures the duration at which the query will be considered slow an logged with warning. +func WithSlowQueryDuration(duration time.Duration) Option { + return func(p *gormPlugin) { + p.slowQueryDuration = duration + } +} + +// WithTracerProvider configures a tracer provider that is used to create a tracer. +func WithTracerProvider(provider trace.TracerProvider) Option { + return func(p *gormPlugin) { + p.provider = provider + } +} + +// WithAttributes configures attributes that are used to create a span. +func WithAttributes(attrs ...attribute.KeyValue) Option { + return func(p *gormPlugin) { + p.attrs = append(p.attrs, attrs...) + } +} + +// WithDBName configures a db.name attribute. +func WithDBName(name string) Option { + return func(p *gormPlugin) { + p.attrs = append(p.attrs, semconv.DBNameKey.String(name)) + } +} + +// WithoutQueryVariables configures the db.statement attribute to exclude query variables +func WithoutQueryVariables() Option { + return func(p *gormPlugin) { + p.excludeQueryVars = true + } +} + +// WithQueryFormatter configures a query formatter +func WithQueryFormatter(queryFormatter func(query string) string) Option { + return func(p *gormPlugin) { + p.queryFormatter = queryFormatter + } +} + +// WithDefaultQueryFormatter adds a simple formatter that trims any excess whitespaces from the query +func WithDefaultQueryFormatter() Option { + return func(p *gormPlugin) { + p.queryFormatter = defaultSqlQueryFormatter + } +} + +type gormPlugin struct { + provider trace.TracerProvider + tracer trace.Tracer + attrs []attribute.KeyValue + excludeQueryVars bool + queryFormatter func(query string) string + slowQueryDuration time.Duration + everyStatementAsEvent bool +} + +// Overrides and sets some options with recommended defaults +func DefaultGormPlugin(opts ...Option) gorm.Plugin { + p := &gormPlugin{} + for _, opt := range opts { + opt(p) + } + WithDefaultQueryFormatter()(p) + WithoutQueryVariables()(p) + WithEveryStatementAsEvent()(p) + p.provider = otel.GetTracerProvider() + p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing") + return p +} + +func NewGormPlugin(opts ...Option) gorm.Plugin { + p := &gormPlugin{} + for _, opt := range opts { + opt(p) + } + + if p.provider == nil { + p.provider = otel.GetTracerProvider() + } + p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing") + + return p +} + +func (p gormPlugin) Name() string { + return "observerGorm" +} + +type gormHookFunc func(tx *gorm.DB) + +type gormRegister interface { + Register(name string, fn func(*gorm.DB)) error +} + +func (p gormPlugin) Initialize(db *gorm.DB) (err error) { + cb := db.Callback() + hooks := []struct { + callback gormRegister + hook gormHookFunc + name string + }{ + {cb.Create().Before("gorm:create"), p.before("gorm.Create"), "before:create"}, + {cb.Create().After("gorm:create"), p.after(), "after:create"}, + + {cb.Query().Before("gorm:query"), p.before("gorm.Query"), "before:select"}, + {cb.Query().After("gorm:query"), p.after(), "after:select"}, + + {cb.Delete().Before("gorm:delete"), p.before("gorm.Delete"), "before:delete"}, + {cb.Delete().After("gorm:delete"), p.after(), "after:delete"}, + + {cb.Update().Before("gorm:update"), p.before("gorm.Update"), "before:update"}, + {cb.Update().After("gorm:update"), p.after(), "after:update"}, + + {cb.Row().Before("gorm:row"), p.before("gorm.Row"), "before:row"}, + {cb.Row().After("gorm:row"), p.after(), "after:row"}, + + {cb.Raw().Before("gorm:raw"), p.before("gorm.Raw"), "before:raw"}, + {cb.Raw().After("gorm:raw"), p.after(), "after:raw"}, + } + + var firstErr error + + for _, h := range hooks { + if err := h.callback.Register("observer:"+h.name, h.hook); err != nil && firstErr == nil { + firstErr = fmt.Errorf("callback register %s failed: %w", h.name, err) + } + } + + return firstErr +} + +func (p *gormPlugin) before(spanName string) gormHookFunc { + return func(tx *gorm.DB) { + ctx := tx.Statement.Context + + tx.Statement.Set("observer:statement_start", time.Now()) + tx.Statement.Context, _ = p.tracer.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindClient)) + } +} + +func (p *gormPlugin) after() gormHookFunc { + return func(tx *gorm.DB) { + span := trace.SpanFromContext(tx.Statement.Context) + if !span.IsRecording() { + return + } + defer span.End() + + attrs := make([]attribute.KeyValue, 0, len(p.attrs)+4) + attrs = append(attrs, p.attrs...) + + if sys := dbSystem(tx); sys.Valid() { + attrs = append(attrs, sys) + } + + vars := tx.Statement.Vars + if p.excludeQueryVars { + // Replace query variables with '?' to mask them + vars = make([]interface{}, len(tx.Statement.Vars)) + + for i := 0; i < len(vars); i++ { + vars[i] = "?" + } + } + + query := tx.Dialector.Explain(tx.Statement.SQL.String(), vars...) + + attrs = append(attrs, semconv.DBStatementKey.String(p.formatQuery(query))) + if tx.Statement.Table != "" { + attrs = append(attrs, semconv.DBSQLTableKey.String(tx.Statement.Table)) + } + if tx.Statement.RowsAffected != -1 { + attrs = append(attrs, attr.DBRowsAffected(tx.Statement.RowsAffected)) + } + + slowQuery := false + if statementStart, ok := tx.Statement.Get("observer:statement_start"); ok { + start := statementStart.(time.Time) + duration := time.Now().Sub(start) + + attrs = append(attrs, attr.DBExecutionTimeMs(duration)) + + if p.slowQueryDuration != time.Duration(0) { + if duration >= p.slowQueryDuration { + slowQuery = true + + event.NewErrInSpan(event.Error{ + Level: level.WARN, + Err: errors.New("slow query execution"), + Attributes: attrs, + }.SkipMoreInCallStack(3), span) + + attrs = append(attrs, attr.SeverityLevel(level.WARN)) + } + } + } + errQuery := false + if tx.Statement.Error != nil { + errQuery = true + + event.NewErrInSpan(event.Error{ + Level: level.ERR, + Err: tx.Statement.Error, + Attributes: attrs, + }.SkipMoreInCallStack(3), span) + + attrs = append(attrs, attr.SeverityLevel(level.ERR)) + } + + if !slowQuery && !errQuery && p.everyStatementAsEvent { + attrs = append(attrs, attr.SeverityLevel(level.DEBUG)) + event.NewInSpan(event.Event{ + Level: level.DEBUG, + ShortMessage: "executed an sql query with gorm", + Attributes: attrs, + }.SkipMoreInCallStack(3), span) + } + + span.SetAttributes(attrs...) + } +} + +func (p *gormPlugin) formatQuery(query string) string { + if p.queryFormatter != nil { + return p.queryFormatter(query) + } + return query +} + +func dbSystem(tx *gorm.DB) attribute.KeyValue { + switch tx.Dialector.Name() { + case "mysql": + return semconv.DBSystemMySQL + case "mssql": + return semconv.DBSystemMSSQL + case "postgres", "postgresql": + return semconv.DBSystemPostgreSQL + case "sqlite": + return semconv.DBSystemSqlite + case "sqlserver": + return semconv.DBSystemKey.String("sqlserver") + case "clickhouse": + return semconv.DBSystemKey.String("clickhouse") + default: + return attribute.KeyValue{} + } +}