feat: add gorm tracing middleware
This commit is contained in:
@ -58,6 +58,7 @@ const (
|
||||
ServiceLayerKey = attribute.Key("service.layer")
|
||||
ServiceLayerNameKey = attribute.Key("service.layer_name")
|
||||
DBExecutionTimeMsKey = attribute.Key("db.execution_time_ms")
|
||||
DBRowsAffectedKey = attribute.Key("db.rows_affected")
|
||||
)
|
||||
|
||||
type ServiceArchitectureLayer string
|
||||
@ -266,3 +267,10 @@ func DBExecutionTimeMs(duration time.Duration) attribute.KeyValue {
|
||||
Value: attribute.Int64Value(duration.Milliseconds()),
|
||||
}
|
||||
}
|
||||
|
||||
func DBRowsAffected(rows int64) attribute.KeyValue {
|
||||
return attribute.KeyValue{
|
||||
Key: DBRowsAffectedKey,
|
||||
Value: attribute.Int64Value(rows),
|
||||
}
|
||||
}
|
||||
|
282
pkg/gorm_tracing/middleware.go
Normal file
282
pkg/gorm_tracing/middleware.go
Normal file
@ -0,0 +1,282 @@
|
||||
package gorm_tracing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.ma-al.com/maal-libraries/observer/pkg/attr"
|
||||
"git.ma-al.com/maal-libraries/observer/pkg/event"
|
||||
"git.ma-al.com/maal-libraries/observer/pkg/level"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.25.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Option func(p *gormPlugin)
|
||||
|
||||
func defaultSqlQueryFormatter(query string) string {
|
||||
return strings.Join(strings.Fields(strings.TrimSpace(query)), " ")
|
||||
}
|
||||
|
||||
// WithEveryStatementAsEvent configures the plugin to log all statements as distinct events, not just errors and warnings.
|
||||
func WithEveryStatementAsEvent() Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.everyStatementAsEvent = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithSlowQueryDuration configures the duration at which the query will be considered slow an logged with warning.
|
||||
func WithSlowQueryDuration(duration time.Duration) Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.slowQueryDuration = duration
|
||||
}
|
||||
}
|
||||
|
||||
// WithTracerProvider configures a tracer provider that is used to create a tracer.
|
||||
func WithTracerProvider(provider trace.TracerProvider) Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.provider = provider
|
||||
}
|
||||
}
|
||||
|
||||
// WithAttributes configures attributes that are used to create a span.
|
||||
func WithAttributes(attrs ...attribute.KeyValue) Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.attrs = append(p.attrs, attrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithDBName configures a db.name attribute.
|
||||
func WithDBName(name string) Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.attrs = append(p.attrs, semconv.DBNameKey.String(name))
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutQueryVariables configures the db.statement attribute to exclude query variables
|
||||
func WithoutQueryVariables() Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.excludeQueryVars = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithQueryFormatter configures a query formatter
|
||||
func WithQueryFormatter(queryFormatter func(query string) string) Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.queryFormatter = queryFormatter
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultQueryFormatter adds a simple formatter that trims any excess whitespaces from the query
|
||||
func WithDefaultQueryFormatter() Option {
|
||||
return func(p *gormPlugin) {
|
||||
p.queryFormatter = defaultSqlQueryFormatter
|
||||
}
|
||||
}
|
||||
|
||||
type gormPlugin struct {
|
||||
provider trace.TracerProvider
|
||||
tracer trace.Tracer
|
||||
attrs []attribute.KeyValue
|
||||
excludeQueryVars bool
|
||||
queryFormatter func(query string) string
|
||||
slowQueryDuration time.Duration
|
||||
everyStatementAsEvent bool
|
||||
}
|
||||
|
||||
// Overrides and sets some options with recommended defaults
|
||||
func DefaultGormPlugin(opts ...Option) gorm.Plugin {
|
||||
p := &gormPlugin{}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
WithDefaultQueryFormatter()(p)
|
||||
WithoutQueryVariables()(p)
|
||||
WithEveryStatementAsEvent()(p)
|
||||
p.provider = otel.GetTracerProvider()
|
||||
p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing")
|
||||
return p
|
||||
}
|
||||
|
||||
func NewGormPlugin(opts ...Option) gorm.Plugin {
|
||||
p := &gormPlugin{}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
|
||||
if p.provider == nil {
|
||||
p.provider = otel.GetTracerProvider()
|
||||
}
|
||||
p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing")
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p gormPlugin) Name() string {
|
||||
return "observerGorm"
|
||||
}
|
||||
|
||||
type gormHookFunc func(tx *gorm.DB)
|
||||
|
||||
type gormRegister interface {
|
||||
Register(name string, fn func(*gorm.DB)) error
|
||||
}
|
||||
|
||||
func (p gormPlugin) Initialize(db *gorm.DB) (err error) {
|
||||
cb := db.Callback()
|
||||
hooks := []struct {
|
||||
callback gormRegister
|
||||
hook gormHookFunc
|
||||
name string
|
||||
}{
|
||||
{cb.Create().Before("gorm:create"), p.before("gorm.Create"), "before:create"},
|
||||
{cb.Create().After("gorm:create"), p.after(), "after:create"},
|
||||
|
||||
{cb.Query().Before("gorm:query"), p.before("gorm.Query"), "before:select"},
|
||||
{cb.Query().After("gorm:query"), p.after(), "after:select"},
|
||||
|
||||
{cb.Delete().Before("gorm:delete"), p.before("gorm.Delete"), "before:delete"},
|
||||
{cb.Delete().After("gorm:delete"), p.after(), "after:delete"},
|
||||
|
||||
{cb.Update().Before("gorm:update"), p.before("gorm.Update"), "before:update"},
|
||||
{cb.Update().After("gorm:update"), p.after(), "after:update"},
|
||||
|
||||
{cb.Row().Before("gorm:row"), p.before("gorm.Row"), "before:row"},
|
||||
{cb.Row().After("gorm:row"), p.after(), "after:row"},
|
||||
|
||||
{cb.Raw().Before("gorm:raw"), p.before("gorm.Raw"), "before:raw"},
|
||||
{cb.Raw().After("gorm:raw"), p.after(), "after:raw"},
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
|
||||
for _, h := range hooks {
|
||||
if err := h.callback.Register("observer:"+h.name, h.hook); err != nil && firstErr == nil {
|
||||
firstErr = fmt.Errorf("callback register %s failed: %w", h.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *gormPlugin) before(spanName string) gormHookFunc {
|
||||
return func(tx *gorm.DB) {
|
||||
ctx := tx.Statement.Context
|
||||
|
||||
tx.Statement.Set("observer:statement_start", time.Now())
|
||||
tx.Statement.Context, _ = p.tracer.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindClient))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *gormPlugin) after() gormHookFunc {
|
||||
return func(tx *gorm.DB) {
|
||||
span := trace.SpanFromContext(tx.Statement.Context)
|
||||
if !span.IsRecording() {
|
||||
return
|
||||
}
|
||||
defer span.End()
|
||||
|
||||
attrs := make([]attribute.KeyValue, 0, len(p.attrs)+4)
|
||||
attrs = append(attrs, p.attrs...)
|
||||
|
||||
if sys := dbSystem(tx); sys.Valid() {
|
||||
attrs = append(attrs, sys)
|
||||
}
|
||||
|
||||
vars := tx.Statement.Vars
|
||||
if p.excludeQueryVars {
|
||||
// Replace query variables with '?' to mask them
|
||||
vars = make([]interface{}, len(tx.Statement.Vars))
|
||||
|
||||
for i := 0; i < len(vars); i++ {
|
||||
vars[i] = "?"
|
||||
}
|
||||
}
|
||||
|
||||
query := tx.Dialector.Explain(tx.Statement.SQL.String(), vars...)
|
||||
|
||||
attrs = append(attrs, semconv.DBStatementKey.String(p.formatQuery(query)))
|
||||
if tx.Statement.Table != "" {
|
||||
attrs = append(attrs, semconv.DBSQLTableKey.String(tx.Statement.Table))
|
||||
}
|
||||
if tx.Statement.RowsAffected != -1 {
|
||||
attrs = append(attrs, attr.DBRowsAffected(tx.Statement.RowsAffected))
|
||||
}
|
||||
|
||||
slowQuery := false
|
||||
if statementStart, ok := tx.Statement.Get("observer:statement_start"); ok {
|
||||
start := statementStart.(time.Time)
|
||||
duration := time.Now().Sub(start)
|
||||
|
||||
attrs = append(attrs, attr.DBExecutionTimeMs(duration))
|
||||
|
||||
if p.slowQueryDuration != time.Duration(0) {
|
||||
if duration >= p.slowQueryDuration {
|
||||
slowQuery = true
|
||||
|
||||
event.NewErrInSpan(event.Error{
|
||||
Level: level.WARN,
|
||||
Err: errors.New("slow query execution"),
|
||||
Attributes: attrs,
|
||||
}.SkipMoreInCallStack(3), span)
|
||||
|
||||
attrs = append(attrs, attr.SeverityLevel(level.WARN))
|
||||
}
|
||||
}
|
||||
}
|
||||
errQuery := false
|
||||
if tx.Statement.Error != nil {
|
||||
errQuery = true
|
||||
|
||||
event.NewErrInSpan(event.Error{
|
||||
Level: level.ERR,
|
||||
Err: tx.Statement.Error,
|
||||
Attributes: attrs,
|
||||
}.SkipMoreInCallStack(3), span)
|
||||
|
||||
attrs = append(attrs, attr.SeverityLevel(level.ERR))
|
||||
}
|
||||
|
||||
if !slowQuery && !errQuery && p.everyStatementAsEvent {
|
||||
event.NewInSpan(event.Event{
|
||||
Level: level.DEBUG,
|
||||
ShortMessage: "executed an sql query with gorm",
|
||||
Attributes: attrs,
|
||||
}.SkipMoreInCallStack(3), span)
|
||||
|
||||
attrs = append(attrs, attr.SeverityLevel(level.DEBUG))
|
||||
}
|
||||
|
||||
span.SetAttributes(attrs...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *gormPlugin) formatQuery(query string) string {
|
||||
if p.queryFormatter != nil {
|
||||
return p.queryFormatter(query)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func dbSystem(tx *gorm.DB) attribute.KeyValue {
|
||||
switch tx.Dialector.Name() {
|
||||
case "mysql":
|
||||
return semconv.DBSystemMySQL
|
||||
case "mssql":
|
||||
return semconv.DBSystemMSSQL
|
||||
case "postgres", "postgresql":
|
||||
return semconv.DBSystemPostgreSQL
|
||||
case "sqlite":
|
||||
return semconv.DBSystemSqlite
|
||||
case "sqlserver":
|
||||
return semconv.DBSystemKey.String("sqlserver")
|
||||
case "clickhouse":
|
||||
return semconv.DBSystemKey.String("clickhouse")
|
||||
default:
|
||||
return attribute.KeyValue{}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user