Proxy: Switch internal/logger to slog JSON output and add unit tests.

Replaces the stdlib log.Logger with log/slog JSON handlers (UTC timestamps,
semantic level labels via custom slog.Level values), hides withContextID
since it has no external callers, and adds unit tests reaching 100%
statement coverage for the package.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
winlin 2026-04-19 20:45:38 -04:00
parent 98d09138b8
commit 6406ad23b0
4 changed files with 324 additions and 33 deletions

View File

@ -14,7 +14,7 @@ type key string
var cidKey key = "cid.srsx.ossrs.org"
// generateContextID generates a random context id in string.
// GenerateContextID generates a random context id in string.
func GenerateContextID() string {
randomBytes := make([]byte, 32)
_, _ = rand.Read(randomBytes)
@ -26,11 +26,11 @@ func GenerateContextID() string {
// WithContext creates a new context with cid, which will be used for log.
func WithContext(ctx context.Context) context.Context {
return WithContextID(ctx, GenerateContextID())
return withContextID(ctx, GenerateContextID())
}
// WithContextID creates a new context with cid, which will be used for log.
func WithContextID(ctx context.Context, cid string) context.Context {
// withContextID creates a new context with cid, which will be used for log.
func withContextID(ctx context.Context, cid string) context.Context {
return context.WithValue(ctx, cidKey, cid)
}

View File

@ -0,0 +1,82 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"encoding/hex"
"testing"
)
func TestGenerateContextID_LengthAndHex(t *testing.T) {
cid := GenerateContextID()
if len(cid) != 7 {
t.Fatalf("len(cid) = %d, want 7", len(cid))
}
if _, err := hex.DecodeString(cid + "0"); err != nil {
t.Fatalf("cid %q is not hex: %v", cid, err)
}
}
func TestGenerateContextID_Unique(t *testing.T) {
seen := make(map[string]struct{}, 1000)
for i := range 1000 {
cid := GenerateContextID()
if _, dup := seen[cid]; dup {
t.Fatalf("duplicate cid %q at iteration %d", cid, i)
}
seen[cid] = struct{}{}
}
}
func TestWithContext_AttachesCID(t *testing.T) {
ctx := WithContext(context.Background())
cid := ContextID(ctx)
if len(cid) != 7 {
t.Fatalf("ContextID length = %d, want 7", len(cid))
}
}
func TestWithContext_IndependentCIDs(t *testing.T) {
c1 := WithContext(context.Background())
c2 := WithContext(context.Background())
if ContextID(c1) == ContextID(c2) {
t.Fatalf("expected distinct cids, got %q twice", ContextID(c1))
}
}
func TestContextID_Missing(t *testing.T) {
if got := ContextID(context.Background()); got != "" {
t.Fatalf("ContextID on empty ctx = %q, want \"\"", got)
}
}
func TestContextID_WrongTypeReturnsEmpty(t *testing.T) {
ctx := context.WithValue(context.Background(), cidKey, 42)
if got := ContextID(ctx); got != "" {
t.Fatalf("ContextID with int value = %q, want \"\"", got)
}
}
func TestWithContextID_RoundTrip(t *testing.T) {
ctx := withContextID(context.Background(), "abcdef1")
if got := ContextID(ctx); got != "abcdef1" {
t.Fatalf("ContextID = %q, want %q", got, "abcdef1")
}
}
func TestWithContextID_Overwrite(t *testing.T) {
ctx := withContextID(context.Background(), "first00")
ctx = withContextID(ctx, "second1")
if got := ContextID(ctx); got != "second1" {
t.Fatalf("ContextID after overwrite = %q, want %q", got, "second1")
}
}
func TestCIDKey_NotCollidingWithPlainString(t *testing.T) {
ctx := context.WithValue(context.Background(), string(cidKey), "plain")
if got := ContextID(ctx); got != "" {
t.Fatalf("ContextID leaked through string key = %q, want \"\"", got)
}
}

View File

@ -5,8 +5,9 @@ package logger
import (
"context"
"io/ioutil"
stdLog "log"
"fmt"
"io"
"log/slog"
"os"
)
@ -15,8 +16,8 @@ type logger interface {
}
type loggerPlus struct {
logger *stdLog.Logger
level string
logger *slog.Logger
level slog.Level
}
func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus {
@ -27,61 +28,95 @@ func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus {
return v
}
func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) {
format, args := f, a
func (v *loggerPlus) Printf(ctx context.Context, f string, a ...any) {
attrs := []slog.Attr{slog.Int("pid", os.Getpid())}
if cid := ContextID(ctx); cid != "" {
format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...)
attrs = append(attrs, slog.String("cid", cid))
}
v.logger.Printf(format, args...)
v.logger.LogAttrs(ctx, v.level, fmt.Sprintf(f, a...), attrs...)
}
var verboseLogger logger
func Vf(ctx context.Context, format string, a ...interface{}) {
func Vf(ctx context.Context, format string, a ...any) {
verboseLogger.Printf(ctx, format, a...)
}
var debugLogger logger
func Df(ctx context.Context, format string, a ...interface{}) {
func Df(ctx context.Context, format string, a ...any) {
debugLogger.Printf(ctx, format, a...)
}
var warnLogger logger
func Wf(ctx context.Context, format string, a ...interface{}) {
func Wf(ctx context.Context, format string, a ...any) {
warnLogger.Printf(ctx, format, a...)
}
var errorLogger logger
func Ef(ctx context.Context, format string, a ...interface{}) {
func Ef(ctx context.Context, format string, a ...any) {
errorLogger.Printf(ctx, format, a...)
}
const (
logVerboseLabel = "verb"
logDebugLabel = "debug"
logWarnLabel = "warn"
logErrorLabel = "error"
levelVerb slog.Level = slog.LevelDebug - 4
levelDebug slog.Level = slog.LevelDebug
levelWarn slog.Level = slog.LevelWarn
levelError slog.Level = slog.LevelError
)
// newJSONLogger builds a slog.Logger that writes JSON records to w, renders the
// time in UTC, and maps our custom levels to short lowercase labels.
func newJSONLogger(w io.Writer) *slog.Logger {
h := slog.NewJSONHandler(w, &slog.HandlerOptions{
Level: levelVerb,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if len(groups) != 0 {
return a
}
switch a.Key {
case slog.TimeKey:
return slog.Time(slog.TimeKey, a.Value.Time().UTC())
case slog.LevelKey:
return slog.String(slog.LevelKey, levelLabel(a.Value.Any().(slog.Level)))
}
return a
},
})
return slog.New(h)
}
func levelLabel(l slog.Level) string {
switch l {
case levelVerb:
return "verb"
case levelDebug:
return "debug"
case levelWarn:
return "warn"
case levelError:
return "error"
}
return l.String()
}
func init() {
verboseLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logVerboseLabel
verboseLogger = newLoggerPlus(func(l *loggerPlus) {
l.logger = newJSONLogger(io.Discard)
l.level = levelVerb
})
debugLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logDebugLabel
debugLogger = newLoggerPlus(func(l *loggerPlus) {
l.logger = newJSONLogger(os.Stdout)
l.level = levelDebug
})
warnLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logWarnLabel
warnLogger = newLoggerPlus(func(l *loggerPlus) {
l.logger = newJSONLogger(os.Stderr)
l.level = levelWarn
})
errorLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logErrorLabel
errorLogger = newLoggerPlus(func(l *loggerPlus) {
l.logger = newJSONLogger(os.Stderr)
l.level = levelError
})
}

174
internal/logger/log_test.go Normal file
View File

@ -0,0 +1,174 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"bytes"
"context"
"encoding/json"
"io"
"log/slog"
"os"
"strings"
"testing"
"time"
)
func decodeLine(t *testing.T, line []byte) map[string]any {
t.Helper()
var m map[string]any
if err := json.Unmarshal(bytes.TrimSpace(line), &m); err != nil {
t.Fatalf("decode %q: %v", line, err)
}
return m
}
func bufLoggerPlus(w io.Writer, level slog.Level) *loggerPlus {
return newLoggerPlus(func(l *loggerPlus) {
l.logger = newJSONLogger(w)
l.level = level
})
}
func TestLevelLabel_Known(t *testing.T) {
cases := map[slog.Level]string{
levelVerb: "verb",
levelDebug: "debug",
levelWarn: "warn",
levelError: "error",
}
for lvl, want := range cases {
if got := levelLabel(lvl); got != want {
t.Errorf("levelLabel(%v) = %q, want %q", lvl, got, want)
}
}
}
func TestLevelLabel_UnknownFallsBackToString(t *testing.T) {
got := levelLabel(slog.Level(99))
if got == "" {
t.Fatalf("levelLabel(99) returned empty")
}
if got == "verb" || got == "debug" || got == "warn" || got == "error" {
t.Fatalf("levelLabel(99) = %q, want slog.Level.String() form", got)
}
}
func TestPrintf_EmitsAllFields(t *testing.T) {
var buf bytes.Buffer
lp := bufLoggerPlus(&buf, levelDebug)
ctx := withContextID(context.Background(), "abc1234")
lp.Printf(ctx, "hello %s %d", "world", 42)
m := decodeLine(t, buf.Bytes())
if m["level"] != "debug" {
t.Errorf("level = %v, want debug", m["level"])
}
if m["msg"] != "hello world 42" {
t.Errorf("msg = %v, want %q", m["msg"], "hello world 42")
}
if m["cid"] != "abc1234" {
t.Errorf("cid = %v, want abc1234", m["cid"])
}
pid, ok := m["pid"].(float64)
if !ok || int(pid) != os.Getpid() {
t.Errorf("pid = %v, want %d", m["pid"], os.Getpid())
}
ts, ok := m["time"].(string)
if !ok || !strings.HasSuffix(ts, "Z") {
t.Errorf("time = %v, want UTC suffix Z", m["time"])
}
if _, err := time.Parse(time.RFC3339Nano, ts); err != nil {
t.Errorf("time %q not RFC3339Nano: %v", ts, err)
}
}
func TestPrintf_OmitsCIDWhenAbsent(t *testing.T) {
var buf bytes.Buffer
bufLoggerPlus(&buf, levelWarn).Printf(context.Background(), "no cid here")
m := decodeLine(t, buf.Bytes())
if v, present := m["cid"]; present {
t.Errorf("cid should be absent, got %v", v)
}
if m["level"] != "warn" {
t.Errorf("level = %v, want warn", m["level"])
}
}
func TestPrintf_AllLevelsMapToLabel(t *testing.T) {
cases := []struct {
level slog.Level
label string
}{
{levelVerb, "verb"},
{levelDebug, "debug"},
{levelWarn, "warn"},
{levelError, "error"},
}
for _, tc := range cases {
var buf bytes.Buffer
bufLoggerPlus(&buf, tc.level).Printf(context.Background(), "hi")
m := decodeLine(t, buf.Bytes())
if m["level"] != tc.label {
t.Errorf("level(%v) rendered as %v, want %q", tc.level, m["level"], tc.label)
}
}
}
func TestNewJSONLogger_GroupedAttrsPassThrough(t *testing.T) {
var buf bytes.Buffer
lg := newJSONLogger(&buf)
lg.LogAttrs(context.Background(), levelDebug, "grouped",
slog.Group("meta", slog.String("inner", "v")))
m := decodeLine(t, buf.Bytes())
meta, ok := m["meta"].(map[string]any)
if !ok {
t.Fatalf("meta not an object: %v", m["meta"])
}
if meta["inner"] != "v" {
t.Errorf("meta.inner = %v, want v", meta["inner"])
}
}
func TestPackageWrappers_RouteToRightLogger(t *testing.T) {
origV, origD, origW, origE := verboseLogger, debugLogger, warnLogger, errorLogger
t.Cleanup(func() {
verboseLogger, debugLogger, warnLogger, errorLogger = origV, origD, origW, origE
})
vBuf, dBuf, wBuf, eBuf := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}
verboseLogger = bufLoggerPlus(vBuf, levelVerb)
debugLogger = bufLoggerPlus(dBuf, levelDebug)
warnLogger = bufLoggerPlus(wBuf, levelWarn)
errorLogger = bufLoggerPlus(eBuf, levelError)
ctx := context.Background()
Vf(ctx, "v=%d", 1)
Df(ctx, "d=%d", 2)
Wf(ctx, "w=%d", 3)
Ef(ctx, "e=%d", 4)
checks := []struct {
name string
buf *bytes.Buffer
label string
msg string
}{
{"Vf", vBuf, "verb", "v=1"},
{"Df", dBuf, "debug", "d=2"},
{"Wf", wBuf, "warn", "w=3"},
{"Ef", eBuf, "error", "e=4"},
}
for _, c := range checks {
m := decodeLine(t, c.buf.Bytes())
if m["level"] != c.label {
t.Errorf("%s level = %v, want %v", c.name, m["level"], c.label)
}
if m["msg"] != c.msg {
t.Errorf("%s msg = %v, want %v", c.name, m["msg"], c.msg)
}
}
}