diff --git a/internal/logger/context.go b/internal/logger/context.go index a50bf2a1a..292981645 100644 --- a/internal/logger/context.go +++ b/internal/logger/context.go @@ -14,7 +14,7 @@ type key string var cidKey key = "cid.srsx.ossrs.org" -// generateContextID generates a random context id in string. +// GenerateContextID generates a random context id in string. func GenerateContextID() string { randomBytes := make([]byte, 32) _, _ = rand.Read(randomBytes) @@ -26,11 +26,11 @@ func GenerateContextID() string { // WithContext creates a new context with cid, which will be used for log. func WithContext(ctx context.Context) context.Context { - return WithContextID(ctx, GenerateContextID()) + return withContextID(ctx, GenerateContextID()) } -// WithContextID creates a new context with cid, which will be used for log. -func WithContextID(ctx context.Context, cid string) context.Context { +// withContextID creates a new context with cid, which will be used for log. +func withContextID(ctx context.Context, cid string) context.Context { return context.WithValue(ctx, cidKey, cid) } diff --git a/internal/logger/context_test.go b/internal/logger/context_test.go new file mode 100644 index 000000000..2adda9538 --- /dev/null +++ b/internal/logger/context_test.go @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "encoding/hex" + "testing" +) + +func TestGenerateContextID_LengthAndHex(t *testing.T) { + cid := GenerateContextID() + if len(cid) != 7 { + t.Fatalf("len(cid) = %d, want 7", len(cid)) + } + if _, err := hex.DecodeString(cid + "0"); err != nil { + t.Fatalf("cid %q is not hex: %v", cid, err) + } +} + +func TestGenerateContextID_Unique(t *testing.T) { + seen := make(map[string]struct{}, 1000) + for i := range 1000 { + cid := GenerateContextID() + if _, dup := seen[cid]; dup { + t.Fatalf("duplicate cid %q at iteration %d", cid, i) + } + seen[cid] = struct{}{} + } +} + +func TestWithContext_AttachesCID(t *testing.T) { + ctx := WithContext(context.Background()) + cid := ContextID(ctx) + if len(cid) != 7 { + t.Fatalf("ContextID length = %d, want 7", len(cid)) + } +} + +func TestWithContext_IndependentCIDs(t *testing.T) { + c1 := WithContext(context.Background()) + c2 := WithContext(context.Background()) + if ContextID(c1) == ContextID(c2) { + t.Fatalf("expected distinct cids, got %q twice", ContextID(c1)) + } +} + +func TestContextID_Missing(t *testing.T) { + if got := ContextID(context.Background()); got != "" { + t.Fatalf("ContextID on empty ctx = %q, want \"\"", got) + } +} + +func TestContextID_WrongTypeReturnsEmpty(t *testing.T) { + ctx := context.WithValue(context.Background(), cidKey, 42) + if got := ContextID(ctx); got != "" { + t.Fatalf("ContextID with int value = %q, want \"\"", got) + } +} + +func TestWithContextID_RoundTrip(t *testing.T) { + ctx := withContextID(context.Background(), "abcdef1") + if got := ContextID(ctx); got != "abcdef1" { + t.Fatalf("ContextID = %q, want %q", got, "abcdef1") + } +} + +func TestWithContextID_Overwrite(t *testing.T) { + ctx := withContextID(context.Background(), "first00") + ctx = withContextID(ctx, "second1") + if got := ContextID(ctx); got != "second1" { + t.Fatalf("ContextID after overwrite = %q, want %q", got, "second1") + } +} + +func TestCIDKey_NotCollidingWithPlainString(t *testing.T) { + ctx := context.WithValue(context.Background(), string(cidKey), "plain") + if got := ContextID(ctx); got != "" { + t.Fatalf("ContextID leaked through string key = %q, want \"\"", got) + } +} diff --git a/internal/logger/log.go b/internal/logger/log.go index 9653c0846..f710653e5 100644 --- a/internal/logger/log.go +++ b/internal/logger/log.go @@ -5,8 +5,9 @@ package logger import ( "context" - "io/ioutil" - stdLog "log" + "fmt" + "io" + "log/slog" "os" ) @@ -15,8 +16,8 @@ type logger interface { } type loggerPlus struct { - logger *stdLog.Logger - level string + logger *slog.Logger + level slog.Level } func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { @@ -27,61 +28,95 @@ func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { return v } -func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { - format, args := f, a +func (v *loggerPlus) Printf(ctx context.Context, f string, a ...any) { + attrs := []slog.Attr{slog.Int("pid", os.Getpid())} if cid := ContextID(ctx); cid != "" { - format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) + attrs = append(attrs, slog.String("cid", cid)) } - - v.logger.Printf(format, args...) + v.logger.LogAttrs(ctx, v.level, fmt.Sprintf(f, a...), attrs...) } var verboseLogger logger -func Vf(ctx context.Context, format string, a ...interface{}) { +func Vf(ctx context.Context, format string, a ...any) { verboseLogger.Printf(ctx, format, a...) } var debugLogger logger -func Df(ctx context.Context, format string, a ...interface{}) { +func Df(ctx context.Context, format string, a ...any) { debugLogger.Printf(ctx, format, a...) } var warnLogger logger -func Wf(ctx context.Context, format string, a ...interface{}) { +func Wf(ctx context.Context, format string, a ...any) { warnLogger.Printf(ctx, format, a...) } var errorLogger logger -func Ef(ctx context.Context, format string, a ...interface{}) { +func Ef(ctx context.Context, format string, a ...any) { errorLogger.Printf(ctx, format, a...) } const ( - logVerboseLabel = "verb" - logDebugLabel = "debug" - logWarnLabel = "warn" - logErrorLabel = "error" + levelVerb slog.Level = slog.LevelDebug - 4 + levelDebug slog.Level = slog.LevelDebug + levelWarn slog.Level = slog.LevelWarn + levelError slog.Level = slog.LevelError ) +// newJSONLogger builds a slog.Logger that writes JSON records to w, renders the +// time in UTC, and maps our custom levels to short lowercase labels. +func newJSONLogger(w io.Writer) *slog.Logger { + h := slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: levelVerb, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if len(groups) != 0 { + return a + } + switch a.Key { + case slog.TimeKey: + return slog.Time(slog.TimeKey, a.Value.Time().UTC()) + case slog.LevelKey: + return slog.String(slog.LevelKey, levelLabel(a.Value.Any().(slog.Level))) + } + return a + }, + }) + return slog.New(h) +} + +func levelLabel(l slog.Level) string { + switch l { + case levelVerb: + return "verb" + case levelDebug: + return "debug" + case levelWarn: + return "warn" + case levelError: + return "error" + } + return l.String() +} + func init() { - verboseLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logVerboseLabel + verboseLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(io.Discard) + l.level = levelVerb }) - debugLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logDebugLabel + debugLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(os.Stdout) + l.level = levelDebug }) - warnLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logWarnLabel + warnLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(os.Stderr) + l.level = levelWarn }) - errorLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logErrorLabel + errorLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(os.Stderr) + l.level = levelError }) } diff --git a/internal/logger/log_test.go b/internal/logger/log_test.go new file mode 100644 index 000000000..626cb73f2 --- /dev/null +++ b/internal/logger/log_test.go @@ -0,0 +1,174 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "os" + "strings" + "testing" + "time" +) + +func decodeLine(t *testing.T, line []byte) map[string]any { + t.Helper() + var m map[string]any + if err := json.Unmarshal(bytes.TrimSpace(line), &m); err != nil { + t.Fatalf("decode %q: %v", line, err) + } + return m +} + +func bufLoggerPlus(w io.Writer, level slog.Level) *loggerPlus { + return newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(w) + l.level = level + }) +} + +func TestLevelLabel_Known(t *testing.T) { + cases := map[slog.Level]string{ + levelVerb: "verb", + levelDebug: "debug", + levelWarn: "warn", + levelError: "error", + } + for lvl, want := range cases { + if got := levelLabel(lvl); got != want { + t.Errorf("levelLabel(%v) = %q, want %q", lvl, got, want) + } + } +} + +func TestLevelLabel_UnknownFallsBackToString(t *testing.T) { + got := levelLabel(slog.Level(99)) + if got == "" { + t.Fatalf("levelLabel(99) returned empty") + } + if got == "verb" || got == "debug" || got == "warn" || got == "error" { + t.Fatalf("levelLabel(99) = %q, want slog.Level.String() form", got) + } +} + +func TestPrintf_EmitsAllFields(t *testing.T) { + var buf bytes.Buffer + lp := bufLoggerPlus(&buf, levelDebug) + ctx := withContextID(context.Background(), "abc1234") + lp.Printf(ctx, "hello %s %d", "world", 42) + + m := decodeLine(t, buf.Bytes()) + if m["level"] != "debug" { + t.Errorf("level = %v, want debug", m["level"]) + } + if m["msg"] != "hello world 42" { + t.Errorf("msg = %v, want %q", m["msg"], "hello world 42") + } + if m["cid"] != "abc1234" { + t.Errorf("cid = %v, want abc1234", m["cid"]) + } + pid, ok := m["pid"].(float64) + if !ok || int(pid) != os.Getpid() { + t.Errorf("pid = %v, want %d", m["pid"], os.Getpid()) + } + ts, ok := m["time"].(string) + if !ok || !strings.HasSuffix(ts, "Z") { + t.Errorf("time = %v, want UTC suffix Z", m["time"]) + } + if _, err := time.Parse(time.RFC3339Nano, ts); err != nil { + t.Errorf("time %q not RFC3339Nano: %v", ts, err) + } +} + +func TestPrintf_OmitsCIDWhenAbsent(t *testing.T) { + var buf bytes.Buffer + bufLoggerPlus(&buf, levelWarn).Printf(context.Background(), "no cid here") + + m := decodeLine(t, buf.Bytes()) + if v, present := m["cid"]; present { + t.Errorf("cid should be absent, got %v", v) + } + if m["level"] != "warn" { + t.Errorf("level = %v, want warn", m["level"]) + } +} + +func TestPrintf_AllLevelsMapToLabel(t *testing.T) { + cases := []struct { + level slog.Level + label string + }{ + {levelVerb, "verb"}, + {levelDebug, "debug"}, + {levelWarn, "warn"}, + {levelError, "error"}, + } + for _, tc := range cases { + var buf bytes.Buffer + bufLoggerPlus(&buf, tc.level).Printf(context.Background(), "hi") + m := decodeLine(t, buf.Bytes()) + if m["level"] != tc.label { + t.Errorf("level(%v) rendered as %v, want %q", tc.level, m["level"], tc.label) + } + } +} + +func TestNewJSONLogger_GroupedAttrsPassThrough(t *testing.T) { + var buf bytes.Buffer + lg := newJSONLogger(&buf) + lg.LogAttrs(context.Background(), levelDebug, "grouped", + slog.Group("meta", slog.String("inner", "v"))) + + m := decodeLine(t, buf.Bytes()) + meta, ok := m["meta"].(map[string]any) + if !ok { + t.Fatalf("meta not an object: %v", m["meta"]) + } + if meta["inner"] != "v" { + t.Errorf("meta.inner = %v, want v", meta["inner"]) + } +} + +func TestPackageWrappers_RouteToRightLogger(t *testing.T) { + origV, origD, origW, origE := verboseLogger, debugLogger, warnLogger, errorLogger + t.Cleanup(func() { + verboseLogger, debugLogger, warnLogger, errorLogger = origV, origD, origW, origE + }) + + vBuf, dBuf, wBuf, eBuf := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{} + verboseLogger = bufLoggerPlus(vBuf, levelVerb) + debugLogger = bufLoggerPlus(dBuf, levelDebug) + warnLogger = bufLoggerPlus(wBuf, levelWarn) + errorLogger = bufLoggerPlus(eBuf, levelError) + + ctx := context.Background() + Vf(ctx, "v=%d", 1) + Df(ctx, "d=%d", 2) + Wf(ctx, "w=%d", 3) + Ef(ctx, "e=%d", 4) + + checks := []struct { + name string + buf *bytes.Buffer + label string + msg string + }{ + {"Vf", vBuf, "verb", "v=1"}, + {"Df", dBuf, "debug", "d=2"}, + {"Wf", wBuf, "warn", "w=3"}, + {"Ef", eBuf, "error", "e=4"}, + } + for _, c := range checks { + m := decodeLine(t, c.buf.Bytes()) + if m["level"] != c.label { + t.Errorf("%s level = %v, want %v", c.name, m["level"], c.label) + } + if m["msg"] != c.msg { + t.Errorf("%s msg = %v, want %v", c.name, m["msg"], c.msg) + } + } +}