diff --git a/logger.go b/logger.go index 1445beb..bdac9a5 100644 --- a/logger.go +++ b/logger.go @@ -121,6 +121,27 @@ func (l logger) log(ctx context.Context, level slog.Level, format string, args . _ = l.sloggerHandler.Handle(ctx, r) } +// log adds context attributes and logs a message with the given slog level +func (l logger) logAttrs(ctx context.Context, level slog.Level, msg string, attrs ...any) { + if ctx == nil { + ctx = context.Background() + } + if !l.sloggerHandler.Enabled(ctx, level) { + return + } + + // Properly handle the PC for the caller + var pc uintptr + var pcs [1]uintptr + // skip [runtime.Callers, this function, this function's caller] + runtime.Callers(3, pcs[:]) + pc = pcs[0] + r := slog.NewRecord(time.Now(), level, msg, pc) + r.Add(attrs...) + + _ = l.sloggerHandler.Handle(ctx, r) +} + // Trace logs sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { if l.ignoreTrace { @@ -141,7 +162,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql strin slog.String(l.sourceField, utils.FileWithLineNum()), }) - l.log(ctx, l.logLevel[ErrorLogType], err.Error(), attributes...) + l.logAttrs(ctx, l.logLevel[ErrorLogType], err.Error(), attributes...) case l.slowThreshold != 0 && elapsed > l.slowThreshold: sql, rows := fc() @@ -154,7 +175,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql strin slog.Int64(RowsField, rows), slog.String(l.sourceField, utils.FileWithLineNum()), }) - l.log(ctx, l.logLevel[SlowQueryLogType], fmt.Sprintf("slow sql query [%s >= %v]", elapsed, l.slowThreshold), attributes...) + l.logAttrs(ctx, l.logLevel[SlowQueryLogType], fmt.Sprintf("slow sql query [%s >= %v]", elapsed, l.slowThreshold), attributes...) case l.traceAll || l.gormLevel == gormlogger.Info: sql, rows := fc() @@ -167,7 +188,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql strin slog.String(l.sourceField, utils.FileWithLineNum()), }) - l.log(ctx, l.logLevel[DefaultLogType], fmt.Sprintf("SQL query executed [%s]", elapsed), attributes...) + l.logAttrs(ctx, l.logLevel[DefaultLogType], fmt.Sprintf("SQL query executed [%s]", elapsed), attributes...) } } diff --git a/logger_test.go b/logger_test.go index cb4e2c8..473a79a 100644 --- a/logger_test.go +++ b/logger_test.go @@ -48,6 +48,9 @@ func Test_logger_Enabled(t *testing.T) { l := New(WithHandler(slog.NewTextHandler(buffer, &slog.HandlerOptions{Level: leveler}))) leveler.Set(slog.LevelWarn) + l.Info(nil, "an info message") + assert.Equal(t, 0, buffer.Len()) + l.Info(context.Background(), "an info message") assert.Equal(t, 0, buffer.Len()) @@ -55,6 +58,32 @@ func Test_logger_Enabled(t *testing.T) { assert.Greater(t, buffer.Len(), 0) } +func Test_logger_Trace_Enabled(t *testing.T) { + buffer := bytes.NewBuffer(nil) + leveler := &slog.LevelVar{} + l := New( + WithHandler(slog.NewTextHandler(buffer, &slog.HandlerOptions{Level: leveler})), + WithSlowThreshold(10*time.Second), + WithTraceAll(), + ) + + fc := func() (string, int64) { + return "SELECT * FROM user", 1 + } + + leveler.Set(slog.LevelWarn) + l.Trace(context.Background(), time.Now().Add(-1*time.Second), fc, nil) + assert.Equal(t, 0, buffer.Len()) + + leveler.Set(slog.LevelError) + l.Trace(context.Background(), time.Now().Add(-1*time.Minute), fc, nil) + assert.Equal(t, 0, buffer.Len()) + + leveler.Set(slog.Level(42)) + l.Trace(context.Background(), time.Now().Add(-1*time.Minute), fc, fmt.Errorf("awesome error")) + assert.Equal(t, 0, buffer.Len()) +} + func Test_logger_LogMode(t *testing.T) { l := logger{gormLevel: gormlogger.Info} actual := l.LogMode(gormlogger.Info)