observer/pkg/gorm_tracing/middleware.go
2024-07-24 15:40:18 +02:00

208 lines
5.3 KiB
Go

package gorm_tracing
import (
"fmt"
"git.ma-al.com/maal-libraries/observer/pkg/fiber_tracing"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.25.0"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
)
var dbRowsAffected = attribute.Key("db.rows_affected")
type Option func(p *gormPlugin)
// WithTracerProvider configures a tracer provider that is used to create a tracer.
func WithTracerProvider(provider trace.TracerProvider) Option {
return func(p *gormPlugin) {
p.provider = provider
}
}
// WithAttributes configures attributes that are used to create a span.
func WithAttributes(attrs ...attribute.KeyValue) Option {
return func(p *gormPlugin) {
p.attrs = append(p.attrs, attrs...)
}
}
// WithDBName configures a db.name attribute.
func WithDBName(name string) Option {
return func(p *gormPlugin) {
p.attrs = append(p.attrs, semconv.DBNameKey.String(name))
}
}
// WithoutQueryVariables configures the db.statement attribute to exclude query variables
func WithoutQueryVariables() Option {
return func(p *gormPlugin) {
p.excludeQueryVars = true
}
}
// WithQueryFormatter configures a query formatter
func WithQueryFormatter(queryFormatter func(query string) string) Option {
return func(p *gormPlugin) {
p.queryFormatter = queryFormatter
}
}
// WithoutMetrics prevents DBStats metrics from being reported.
func WithoutMetrics() Option {
return func(p *gormPlugin) {
p.excludeMetrics = true
}
}
type gormPlugin struct {
provider trace.TracerProvider
tracer trace.Tracer
attrs []attribute.KeyValue
excludeQueryVars bool
excludeMetrics bool
queryFormatter func(query string) string
}
func NewGormPlugin(opts ...Option) gorm.Plugin {
p := &gormPlugin{}
for _, opt := range opts {
opt(p)
}
if p.provider == nil {
p.provider = otel.GetTracerProvider()
}
p.tracer = p.provider.Tracer("git.ma-al.com/maal-libraries/observer/pkg/gorm_tracing")
return p
}
func (p gormPlugin) Name() string {
return "observerGorm"
}
type gormHookFunc func(tx *gorm.DB)
type gormRegister interface {
Register(name string, fn func(*gorm.DB)) error
}
func (p gormPlugin) Initialize(db *gorm.DB) (err error) {
// if !p.excludeMetrics {
// if db, ok := db.ConnPool.(*sql.DB); ok {
// metrics.ReportDBStatsMetrics(db)
// }
// }
cb := db.Callback()
hooks := []struct {
callback gormRegister
hook gormHookFunc
name string
}{
{cb.Create().Before("gorm:create"), p.before("gorm.Create"), "before:create"},
{cb.Create().After("gorm:create"), p.after(), "after:create"},
{cb.Query().Before("gorm:query"), p.before("gorm.Select"), "before:select"},
{cb.Query().After("gorm:query"), p.after(), "after:select"},
{cb.Delete().Before("gorm:delete"), p.before("gorm.Delete"), "before:delete"},
{cb.Delete().After("gorm:delete"), p.after(), "after:delete"},
{cb.Update().Before("gorm:update"), p.before("gorm.Update"), "before:update"},
{cb.Update().After("gorm:update"), p.after(), "after:update"},
{cb.Row().Before("gorm:row"), p.before("gorm.Row"), "before:row"},
{cb.Row().After("gorm:row"), p.after(), "after:row"},
{cb.Raw().Before("gorm:raw"), p.before("gorm.Raw"), "before:raw"},
{cb.Raw().After("gorm:raw"), p.after(), "after:raw"},
}
var firstErr error
for _, h := range hooks {
if err := h.callback.Register("observer:"+h.name, h.hook); err != nil && firstErr == nil {
firstErr = fmt.Errorf("callback register %s failed: %w", h.name, err)
}
}
return firstErr
}
func (p *gormPlugin) before(spanName string) gormHookFunc {
return func(tx *gorm.DB) {
ctx := tx.Statement.Context
tx.Statement.Context, _ = fiber_tracing.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindClient))
}
}
func (p *gormPlugin) after() gormHookFunc {
return func(tx *gorm.DB) {
span := trace.SpanFromContext(tx.Statement.Context)
if !span.IsRecording() {
return
}
defer span.End()
// fmt.Printf("%#v\n", span)
attrs := make([]attribute.KeyValue, 0, len(p.attrs)+4)
attrs = append(attrs, p.attrs...)
if sys := dbSystem(tx); sys.Valid() {
attrs = append(attrs, sys)
}
vars := tx.Statement.Vars
if p.excludeQueryVars {
// Replace query variables with '?' to mask them
vars = make([]interface{}, len(tx.Statement.Vars))
for i := 0; i < len(vars); i++ {
vars[i] = "?"
}
}
query := tx.Dialector.Explain(tx.Statement.SQL.String(), vars...)
attrs = append(attrs, semconv.DBStatementKey.String(p.formatQuery(query)))
if tx.Statement.Table != "" {
attrs = append(attrs, semconv.DBSQLTableKey.String(tx.Statement.Table))
}
if tx.Statement.RowsAffected != -1 {
attrs = append(attrs, dbRowsAffected.Int64(tx.Statement.RowsAffected))
}
span.SetAttributes(attrs...)
}
}
func (p *gormPlugin) formatQuery(query string) string {
if p.queryFormatter != nil {
return p.queryFormatter(query)
}
return query
}
func dbSystem(tx *gorm.DB) attribute.KeyValue {
switch tx.Dialector.Name() {
case "mysql":
return semconv.DBSystemMySQL
case "mssql":
return semconv.DBSystemMSSQL
case "postgres", "postgresql":
return semconv.DBSystemPostgreSQL
case "sqlite":
return semconv.DBSystemSqlite
case "sqlserver":
return semconv.DBSystemKey.String("sqlserver")
case "clickhouse":
return semconv.DBSystemKey.String("clickhouse")
default:
return attribute.KeyValue{}
}
}