diff --git a/.openclaw/TOOLS.md b/.openclaw/TOOLS.md index 75f99b24e..8952283ae 100644 --- a/.openclaw/TOOLS.md +++ b/.openclaw/TOOLS.md @@ -58,8 +58,8 @@ Skills are shared. Your setup is yours. Keeping them apart means you can update - **Never `git add`** — William stages files himself - **Never `git push`** — William pushes himself -- **Commit workflow:** `git diff --cached` → understand the changes → write title/description → `git commit -m "OpenClaw: ..."` -- Title prefix: `OpenClaw:` +- **Commit workflow:** `git diff --cached` → understand the changes → write title/description → `git commit -m "OpenClaw: ..."` or `"Claude: ..."` +- Title prefix: `OpenClaw:` or `Claude:` - **Co-author for ACP Claude Code:** If Claude Code (ACP) was used to make the changes, add: `Co-authored-by: Claude Opus 4.6 ` - **Co-author for ACP Codex:** If Codex (ACP) was used to make the changes, add: diff --git a/.openclaw/memory/srs-codebase-map.md b/.openclaw/memory/srs-codebase-map.md index 0e4ef6a50..e7c6988a9 100644 --- a/.openclaw/memory/srs-codebase-map.md +++ b/.openclaw/memory/srs-codebase-map.md @@ -221,15 +221,15 @@ The next-generation server (`cmd/` + `internal/`) is written in Go and maintaine `internal/lb` — Load balancer abstraction and two implementations. Defines `SRSLoadBalancer` interface (Initialize, Update, Pick, HLS/WebRTC state management) and `SRSServer` struct representing a backend origin (IP, listen endpoints for RTMP/HTTP/API/SRT/RTC, heartbeat tracking). **Memory LB** — in-memory using `sync.Map`, sticky random pick per stream URL, single-proxy deployment. **Redis LB** — Redis-backed shared state with TTL-based expiration, enables multi-proxy horizontal scaling behind a network load balancer. Also includes a debug helper that creates a fake backend from env vars when `PROXY_DEFAULT_BACKEND_ENABLED=on` for development without real SRS registration. -`internal/logger` — Structured logging with context IDs. Four log levels: Verbose (discarded), Debug (stdout), Warn (stderr), Error (stderr). Format: `[level][pid][cid] message`. Each connection/request gets a unique 7-char hex context ID for log correlation, stored in `context.Context`. +`internal/logger` — Structured logging with context IDs. Four log levels: Verbose (discarded), Debug (stdout), Warn (stderr), Error (stderr). Emits JSON via `log/slog` with `pid` and `cid` attributes. Each connection/request gets a unique 7-char hex context ID for log correlation, stored in `context.Context`. -`internal/env` — Environment-based configuration. All settings via env vars (or `.env` file via godotenv). Exposes an `Environment` interface with methods for each config value. Default ports: RTMP=11935, HTTP API=11985, HTTP Stream=18080, WebRTC=18000, SRT=20080, System API=12025. Timeouts: grace=20s, force=30s. Supports Redis config and default backend config for debugging. +`internal/env` — Environment-based configuration. All settings via env vars (or `.env` file parsed by an in-tree custom parser — no third-party dep; supports comments, `export` prefix, quoted values, escape sequences, and inline comments). Exposes an `Environment` interface (with a counterfeiter-generated fake in `envfakes/` for downstream tests) with methods for each config value. Default ports: RTMP=11935, HTTP API=11985, HTTP Stream=18080, WebRTC=18000, SRT=20080, System API=12025. Timeouts: grace=20s, force=30s. Supports Redis config and default backend config for debugging. -`internal/version` — Version constants. `SRSProxy` v1.5.0. Used in HTTP API responses and startup logging. +`internal/version` — Version constants. Signature `SRSX`, version tracks the SRS project version (currently 7.0.x). Used in HTTP API responses and startup logging. -`internal/errors` — Error handling with stack traces, forked from `github.com/pkg/errors`. Provides `New`, `Errorf`, `Wrap`, `Wrapf`, `WithMessage`, `WithStack`, `Cause`. Every error captures a stack trace at creation; `%+v` prints the full trace. `Cause()` walks the error chain to find the root error. +`internal/errors` — Error handling with stack traces, thin wrapper over stdlib `errors`. Provides `New`, `Errorf`, `Wrap`, `Wrapf`, `WithMessage`, `WithStack`, `Cause`, and re-exports `Is`/`As`/`Unwrap`/`Join`. Every error captures a stack trace at creation; `%+v` prints the full trace. `Cause()` walks the error chain to find the root error. -`internal/sync` — Generic sync primitives. `Map[K, V]`: type-safe generic wrapper around `sync.Map` with proper Go generics typing. Used throughout the codebase to avoid raw type assertions. +`internal/sync` — Generic sync primitives. `Map[K, V]`: type-safe generic interface over `sync.Map`, constructed via `NewMap[K, V]()`. Used throughout the codebase to avoid raw type assertions. `internal/signal` — OS signal handling. Listens for SIGINT/SIGTERM, cancels the root context. Installs a force-quit timer (default 30s) as a safety net if graceful shutdown hangs. diff --git a/.openclaw/skills/srs-develop/SKILL.md b/.openclaw/skills/srs-develop/SKILL.md index d6c840d39..9733c6740 100644 --- a/.openclaw/skills/srs-develop/SKILL.md +++ b/.openclaw/skills/srs-develop/SKILL.md @@ -22,7 +22,7 @@ Route the user's request to exactly ONE task type. Follow that task only. Do not | **Develop Code** | User wants to add, modify, refactor code, or update docs — any planned change | → [Develop Code](#task-develop-code) | ✅ Supported | | **Fix a Bug** | User reports something broken, unexpected behavior, or an error | → [Fix a Bug](#task-fix-a-bug) | ❌ Not yet supported | | **Learn Code** | User wants to understand how code works — no changes intended | → [Learn Code](#task-learn-code) | ❌ Not yet supported | -| **Review a PR** | User wants to review an existing pull request | → [Review a PR](#task-review-a-pr) | ❌ Not yet supported | +| **Review a PR** | User wants to review an existing pull request | → [Review a PR](#task-review-a-pr) | ✅ Supported | **If the routed task is not yet supported**, stop and tell the user: - What task type you routed to @@ -53,7 +53,37 @@ Do NOT attempt unsupported tasks. **Prerequisite:** You must arrive here via the [Task Router](#task-router). Do not execute this task directly — always complete the Task Router first to confirm this is the correct task type. -**Not yet supported.** Will be added in a future update. +**Scope:** Walk the pending changes on the current branch (relative to `develop`), summarize them, sync any stale navigation docs, then bump the version and add a changelog entry once the user supplies the PR number. + +**Guiding rules** +- **The user drives staging.** Never `git add` on your own. After each step, stop and wait for the user to review and stage the files they approve. Only run `git commit` when they say so. +- **Docs are navigation, not tutorials.** When a code change makes an entry stale, *correct* it — don't expand it. Only *add* a new entry when a new file or module was introduced; never to describe a refactor inside an existing module. + +**Step 1: Survey the changes** + +1. Run `git diff develop --stat` and `git log develop..HEAD --oneline` to get the shape of the branch. +2. Drill into non-test source diffs with `git diff develop -- ` to understand what actually changed. +3. Summarize back to the user: refactors, new files, and anything that could break downstream consumers (log format, public API, wire format, etc.). +4. Pause and let the user redirect or ask for more detail. + +**Step 2: Correct stale navigation docs** + +1. Check `.openclaw/memory/srs-codebase-map.md` for entries covering any module touched in this PR. +2. For each entry whose description is no longer accurate, make the **smallest** correction needed to match the new code. Keep the one-line summary style; do not expand into implementation detail. +3. Stop. Let the user review. When they `git add` the files they accept, commit with a short message in the existing style, e.g. `Claude: Sync srs-codebase-map with internal/.`. + +**Step 3: Bump the version and update the changelog** + +1. Ask the user for the PR number if they haven't given it. +2. Bump revision by one in **both** version files, keeping them in sync: + - `internal/version/version.go` — `VersionRevision()` + - `trunk/src/core/srs_core_version7.hpp` — `VERSION_REVISION` +3. Add a new top entry to `trunk/doc/CHANGELOG.md` under `## SRS 7.0 Changelog`, matching the existing format: + ``` + * v7.0, YYYY-MM-DD, Merge [#PR](URL): : . v7.0. (#PR) + ``` + Propose the summary to the user; don't invent one unilaterally. +4. Stop. Let the user review. When they `git add` the version files and changelog, commit with a short message like `Proxy: Bump to v7.0. for #.`. --- diff --git a/internal/errors/errors.go b/internal/errors/errors.go index d64470404..ce87e86f4 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -1,270 +1,153 @@ -// Package errors provides simple error handling primitives. +// Package errors provides error handling primitives with stack traces. // -// The traditional error handling idiom in Go is roughly akin to -// -// if err != nil { -// return err -// } -// -// which applied recursively up the call stack results in error reports -// without context or debugging information. The errors package allows -// programmers to add context to the failure path in their code in a way -// that does not destroy the original value of the error. +// It is a thin layer over the standard library's errors package, adding a +// stack trace at the point an error is created or wrapped. The wrapping +// chain is fully compatible with errors.Is, errors.As, and errors.Unwrap. // // # Adding context to an error // -// The errors.Wrap function returns a new error that adds context to the -// original error by recording a stack trace at the point Wrap is called, -// and the supplied message. For example -// -// _, err := ioutil.ReadAll(r) +// _, err := io.ReadAll(r) // if err != nil { // return errors.Wrap(err, "read failed") // } // -// If additional control is required the errors.WithStack and errors.WithMessage -// functions destructure errors.Wrap into its component operations of annotating -// an error with a stack trace and an a message, respectively. -// -// # Retrieving the cause of an error -// -// Using errors.Wrap constructs a stack of errors, adding context to the -// preceding error. Depending on the nature of the error it may be necessary -// to reverse the operation of errors.Wrap to retrieve the original error -// for inspection. Any error value which implements this interface -// -// type causer interface { -// Cause() error -// } -// -// can be inspected by errors.Cause. errors.Cause will recursively retrieve -// the topmost error which does not implement causer, which is assumed to be -// the original cause. For example: -// -// switch err := errors.Cause(err).(type) { -// case *MyError: -// // handle specifically -// default: -// // unknown error -// } -// -// causer interface is not exported by this package, but is considered a part -// of stable public API. -// // # Formatted printing of errors // -// All error values returned from this package implement fmt.Formatter and can -// be formatted by the fmt package. The following verbs are supported +// %s the error message (full wrap chain) +// %v same as %s +// %+v the error message followed by the captured stack trace +// %q the error message, quoted // -// %s print the error. If the error has a Cause it will be -// printed recursively -// %v see %s -// %+v extended format. Each Frame of the error's StackTrace will -// be printed in detail. +// # Retrieving the stack trace // -// # Retrieving the stack trace of an error or wrapper -// -// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are -// invoked. This information can be retrieved with the following interface. +// Errors returned by this package satisfy the following interface: // // type stackTracer interface { -// StackTrace() errors.StackTrace +// StackTrace() []uintptr // } -// -// Where errors.StackTrace is defined as -// -// type StackTrace []Frame -// -// The Frame type represents a call site in the stack trace. Frame supports -// the fmt.Formatter interface that can be used for printing information about -// the stack trace of this error. For example: -// -// if err, ok := err.(stackTracer); ok { -// for _, f := range err.StackTrace() { -// fmt.Printf("%+s:%d", f) -// } -// } -// -// stackTracer interface is not exported by this package, but is considered a part -// of stable public API. -// -// See the documentation for Frame.Format for more details. -// Fork from https://github.com/pkg/errors package errors import ( + "errors" "fmt" - "io" + "runtime" ) -// New returns an error with the supplied message. -// New also records the stack trace at the point it was called. -func New(message string) error { - return &fundamental{ - msg: message, - stack: callers(), - } +// Re-exported stdlib primitives so callers can use a single import. +var ( + Is = errors.Is + As = errors.As + Unwrap = errors.Unwrap + Join = errors.Join +) + +// withStack wraps an error with a captured stack trace. +type withStack struct { + err error + pcs []uintptr } -// Errorf formats according to a format specifier and returns the string -// as a value that satisfies error. -// Errorf also records the stack trace at the point it was called. -func Errorf(format string, args ...interface{}) error { - return &fundamental{ - msg: fmt.Sprintf(format, args...), - stack: callers(), - } +func (e *withStack) Error() string { + return e.err.Error() } -// fundamental is an error that has a message and a stack, but no caller. -type fundamental struct { - msg string - *stack +func (e *withStack) Unwrap() error { + return e.err } -func (f *fundamental) Error() string { return f.msg } +func (e *withStack) StackTrace() []uintptr { + return e.pcs +} -func (f *fundamental) Format(s fmt.State, verb rune) { +func (e *withStack) Format(s fmt.State, verb rune) { switch verb { case 'v': if s.Flag('+') { - io.WriteString(s, f.msg) - f.stack.Format(s, verb) + fmt.Fprint(s, e.err.Error()) + frames := runtime.CallersFrames(e.pcs) + for { + f, more := frames.Next() + fmt.Fprintf(s, "\n%s\n\t%s:%d", f.Function, f.File, f.Line) + if !more { + break + } + } return } fallthrough case 's': - io.WriteString(s, f.msg) + fmt.Fprint(s, e.err.Error()) case 'q': - fmt.Fprintf(s, "%q", f.msg) + fmt.Fprintf(s, "%q", e.err.Error()) } } +func callers() []uintptr { + var pcs [32]uintptr + n := runtime.Callers(3, pcs[:]) + return pcs[:n] +} + +func attach(err error) error { + return &withStack{err: err, pcs: callers()} +} + +// New returns an error with the supplied message and a captured stack trace. +func New(message string) error { + return attach(errors.New(message)) +} + +// Errorf formats according to a format specifier and returns a new error with +// a captured stack trace. It supports %w for wrapping an existing error. +func Errorf(format string, args ...any) error { + return attach(fmt.Errorf(format, args...)) +} + // WithStack annotates err with a stack trace at the point WithStack was called. // If err is nil, WithStack returns nil. func WithStack(err error) error { if err == nil { return nil } - return &withStack{ - err, - callers(), - } + return attach(err) } -type withStack struct { - error - *stack -} - -func (w *withStack) Cause() error { return w.error } - -func (w *withStack) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - fmt.Fprintf(s, "%+v", w.Cause()) - w.stack.Format(s, verb) - return - } - fallthrough - case 's': - io.WriteString(s, w.Error()) - case 'q': - fmt.Fprintf(s, "%q", w.Error()) - } -} - -// Wrap returns an error annotating err with a stack trace -// at the point Wrap is called, and the supplied message. -// If err is nil, Wrap returns nil. -func Wrap(err error, message string) error { - if err == nil { - return nil - } - err = &withMessage{ - cause: err, - msg: message, - } - return &withStack{ - err, - callers(), - } -} - -// Wrapf returns an error annotating err with a stack trace -// at the point Wrapf is call, and the format specifier. -// If err is nil, Wrapf returns nil. -func Wrapf(err error, format string, args ...interface{}) error { - if err == nil { - return nil - } - err = &withMessage{ - cause: err, - msg: fmt.Sprintf(format, args...), - } - return &withStack{ - err, - callers(), - } -} - -// WithMessage annotates err with a new message. +// WithMessage annotates err with a new message, without capturing a stack. // If err is nil, WithMessage returns nil. func WithMessage(err error, message string) error { if err == nil { return nil } - return &withMessage{ - cause: err, - msg: message, + return fmt.Errorf("%s: %w", message, err) +} + +// Wrap returns an error annotating err with a message and a captured stack. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil } + return attach(fmt.Errorf("%s: %w", message, err)) } -type withMessage struct { - cause error - msg string -} - -func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -func (w *withMessage) Cause() error { return w.cause } - -func (w *withMessage) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - fmt.Fprintf(s, "%+v\n", w.Cause()) - io.WriteString(s, w.msg) - return - } - fallthrough - case 's', 'q': - io.WriteString(s, w.Error()) +// Wrapf is the formatting variant of Wrap. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...any) error { + if err == nil { + return nil } + return attach(fmt.Errorf(fmt.Sprintf(format, args...)+": %w", err)) } -// Cause returns the underlying cause of the error, if possible. -// An error value has a cause if it implements the following -// interface: -// -// type causer interface { -// Cause() error -// } -// -// If the error does not implement Cause, the original error will -// be returned. If the error is nil, nil will be returned without further -// investigation. +// Cause walks the error's Unwrap chain and returns the root error. +// New code should prefer errors.Is or errors.As. func Cause(err error) error { - type causer interface { - Cause() error - } - for err != nil { - cause, ok := err.(causer) - if !ok { - break + u := errors.Unwrap(err) + if u == nil { + return err } - err = cause.Cause() + err = u } - return err + return nil } diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go new file mode 100644 index 000000000..a348c2705 --- /dev/null +++ b/internal/errors/errors_test.go @@ -0,0 +1,233 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package errors + +import ( + stderrors "errors" + "fmt" + "strings" + "testing" +) + +func TestNew_MessageAndStack(t *testing.T) { + err := New("boom") + if err == nil { + t.Fatal("New returned nil") + } + if err.Error() != "boom" { + t.Fatalf("Error() = %q, want %q", err.Error(), "boom") + } + ws, ok := err.(*withStack) + if !ok { + t.Fatalf("New did not return *withStack, got %T", err) + } + if len(ws.StackTrace()) == 0 { + t.Fatal("StackTrace is empty") + } +} + +func TestErrorf_FormatsMessage(t *testing.T) { + err := Errorf("code=%d reason=%s", 42, "oops") + if err.Error() != "code=42 reason=oops" { + t.Fatalf("Error() = %q", err.Error()) + } +} + +func TestErrorf_SupportsWrapVerb(t *testing.T) { + root := stderrors.New("root") + err := Errorf("ctx: %w", root) + if !stderrors.Is(err, root) { + t.Fatal("errors.Is did not find root through Errorf(%w)") + } +} + +func TestWithStack_NilReturnsNil(t *testing.T) { + if got := WithStack(nil); got != nil { + t.Fatalf("WithStack(nil) = %v, want nil", got) + } +} + +func TestWithStack_PreservesMessage(t *testing.T) { + inner := stderrors.New("plain") + err := WithStack(inner) + if err.Error() != "plain" { + t.Fatalf("Error() = %q, want %q", err.Error(), "plain") + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not find inner through WithStack") + } +} + +func TestWithMessage_NilReturnsNil(t *testing.T) { + if got := WithMessage(nil, "ignored"); got != nil { + t.Fatalf("WithMessage(nil) = %v, want nil", got) + } +} + +func TestWithMessage_PrependsAndWraps(t *testing.T) { + inner := stderrors.New("root") + err := WithMessage(inner, "ctx") + if err.Error() != "ctx: root" { + t.Fatalf("Error() = %q", err.Error()) + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not traverse WithMessage") + } + // WithMessage must not capture a stack — verify the result is not a *withStack. + if _, ok := err.(*withStack); ok { + t.Fatal("WithMessage should not attach a stack") + } +} + +func TestWrap_NilReturnsNil(t *testing.T) { + if got := Wrap(nil, "ignored"); got != nil { + t.Fatalf("Wrap(nil) = %v, want nil", got) + } +} + +func TestWrap_MessageAndStackAndChain(t *testing.T) { + inner := stderrors.New("root") + err := Wrap(inner, "ctx") + if err.Error() != "ctx: root" { + t.Fatalf("Error() = %q", err.Error()) + } + ws, ok := err.(*withStack) + if !ok { + t.Fatalf("Wrap did not return *withStack, got %T", err) + } + if len(ws.StackTrace()) == 0 { + t.Fatal("StackTrace is empty") + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not traverse Wrap") + } +} + +func TestWrapf_NilReturnsNil(t *testing.T) { + if got := Wrapf(nil, "ignored %d", 1); got != nil { + t.Fatalf("Wrapf(nil) = %v, want nil", got) + } +} + +func TestWrapf_FormatsAndChains(t *testing.T) { + inner := stderrors.New("root") + err := Wrapf(inner, "ctx=%d op=%s", 7, "read") + if err.Error() != "ctx=7 op=read: root" { + t.Fatalf("Error() = %q", err.Error()) + } + if !stderrors.Is(err, inner) { + t.Fatal("errors.Is did not traverse Wrapf") + } +} + +func TestCause_NilReturnsNil(t *testing.T) { + if got := Cause(nil); got != nil { + t.Fatalf("Cause(nil) = %v, want nil", got) + } +} + +func TestCause_NoUnwrapReturnsSelf(t *testing.T) { + root := stderrors.New("root") + if got := Cause(root); got != root { + t.Fatalf("Cause(root) = %v, want root", got) + } +} + +func TestCause_WalksToRoot(t *testing.T) { + root := stderrors.New("root") + err := Wrap(Wrap(WithMessage(root, "a"), "b"), "c") + if got := Cause(err); got != root { + t.Fatalf("Cause = %v, want root", got) + } +} + +func TestUnwrap_ReturnsInner(t *testing.T) { + inner := stderrors.New("inner") + err := WithStack(inner) + if got := stderrors.Unwrap(err); got != inner { + t.Fatalf("Unwrap = %v, want inner", got) + } +} + +func TestFormat_S(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%s", err) + if got != "msg" { + t.Fatalf("%%s = %q, want %q", got, "msg") + } +} + +func TestFormat_VFallsThroughToS(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%v", err) + if got != "msg" { + t.Fatalf("%%v = %q, want %q", got, "msg") + } +} + +func TestFormat_VPlusIncludesStack(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%+v", err) + if !strings.HasPrefix(got, "msg") { + t.Fatalf("%%+v output does not start with message: %q", got) + } + // Must include this test function in the captured stack. + if !strings.Contains(got, "TestFormat_VPlusIncludesStack") { + t.Fatalf("%%+v output missing caller frame:\n%s", got) + } + // Must include a file:line reference. + if !strings.Contains(got, "errors_test.go:") { + t.Fatalf("%%+v output missing file:line:\n%s", got) + } +} + +func TestFormat_Q(t *testing.T) { + err := New("msg") + got := fmt.Sprintf("%q", err) + if got != `"msg"` { + t.Fatalf("%%q = %q, want %q", got, `"msg"`) + } +} + +func TestIs_ThroughWrapChain(t *testing.T) { + sentinel := stderrors.New("sentinel") + err := Wrap(WithMessage(WithStack(sentinel), "mid"), "outer") + if !stderrors.Is(err, sentinel) { + t.Fatal("errors.Is failed to traverse Wrap/WithMessage/WithStack chain") + } +} + +type typedErr struct{ code int } + +func (t *typedErr) Error() string { return fmt.Sprintf("typed(%d)", t.code) } + +func TestAs_ThroughWrapChain(t *testing.T) { + target := &typedErr{code: 7} + err := Wrap(WithStack(target), "ctx") + var got *typedErr + if !stderrors.As(err, &got) { + t.Fatal("errors.As failed to find *typedErr in chain") + } + if got.code != 7 { + t.Fatalf("As returned code=%d, want 7", got.code) + } +} + +func TestReExports_AreStdlib(t *testing.T) { + // Sanity: the re-exports must actually be the stdlib functions. + a := stderrors.New("a") + b := stderrors.New("b") + joined := Join(a, b) + if !Is(joined, a) || !Is(joined, b) { + t.Fatal("Join/Is re-exports do not match stdlib behavior") + } + if Unwrap(WithStack(a)) != a { + t.Fatal("Unwrap re-export does not match stdlib behavior") + } + var target *typedErr + te := &typedErr{code: 1} + if !As(WithStack(te), &target) { + t.Fatal("As re-export does not match stdlib behavior") + } +} diff --git a/internal/errors/stack.go b/internal/errors/stack.go deleted file mode 100644 index 7e5aacc48..000000000 --- a/internal/errors/stack.go +++ /dev/null @@ -1,187 +0,0 @@ -// Fork from https://github.com/pkg/errors -package errors - -import ( - "fmt" - "io" - "path" - "runtime" - "strings" -) - -// Frame represents a program counter inside a stack frame. -type Frame uintptr - -// pc returns the program counter for this frame; -// multiple frames may have the same PC value. -func (f Frame) pc() uintptr { return uintptr(f) - 1 } - -// file returns the full path to the file that contains the -// function for this Frame's pc. -func (f Frame) file() string { - fn := runtime.FuncForPC(f.pc()) - if fn == nil { - return "unknown" - } - file, _ := fn.FileLine(f.pc()) - return file -} - -// line returns the line number of source code of the -// function for this Frame's pc. -func (f Frame) line() int { - fn := runtime.FuncForPC(f.pc()) - if fn == nil { - return 0 - } - _, line := fn.FileLine(f.pc()) - return line -} - -// Format formats the frame according to the fmt.Formatter interface. -// -// %s source file -// %d source line -// %n function name -// %v equivalent to %s:%d -// -// Format accepts flags that alter the printing of some verbs, as follows: -// -// %+s path of source file relative to the compile time GOPATH -// %+v equivalent to %+s:%d -func (f Frame) Format(s fmt.State, verb rune) { - switch verb { - case 's': - switch { - case s.Flag('+'): - pc := f.pc() - fn := runtime.FuncForPC(pc) - if fn == nil { - io.WriteString(s, "unknown") - } else { - file, _ := fn.FileLine(pc) - fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) - } - default: - io.WriteString(s, path.Base(f.file())) - } - case 'd': - fmt.Fprintf(s, "%d", f.line()) - case 'n': - name := runtime.FuncForPC(f.pc()).Name() - io.WriteString(s, funcname(name)) - case 'v': - f.Format(s, 's') - io.WriteString(s, ":") - f.Format(s, 'd') - } -} - -// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). -type StackTrace []Frame - -// Format formats the stack of Frames according to the fmt.Formatter interface. -// -// %s lists source files for each Frame in the stack -// %v lists the source file and line number for each Frame in the stack -// -// Format accepts flags that alter the printing of some verbs, as follows: -// -// %+v Prints filename, function, and line number for each Frame in the stack. -func (st StackTrace) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case s.Flag('+'): - for _, f := range st { - fmt.Fprintf(s, "\n%+v", f) - } - case s.Flag('#'): - fmt.Fprintf(s, "%#v", []Frame(st)) - default: - fmt.Fprintf(s, "%v", []Frame(st)) - } - case 's': - fmt.Fprintf(s, "%s", []Frame(st)) - } -} - -// stack represents a stack of program counters. -type stack []uintptr - -func (s *stack) Format(st fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case st.Flag('+'): - for _, pc := range *s { - f := Frame(pc) - fmt.Fprintf(st, "\n%+v", f) - } - } - } -} - -func (s *stack) StackTrace() StackTrace { - f := make([]Frame, len(*s)) - for i := 0; i < len(f); i++ { - f[i] = Frame((*s)[i]) - } - return f -} - -func callers() *stack { - const depth = 32 - var pcs [depth]uintptr - n := runtime.Callers(3, pcs[:]) - var st stack = pcs[0:n] - return &st -} - -// funcname removes the path prefix component of a function's name reported by func.Name(). -func funcname(name string) string { - i := strings.LastIndex(name, "/") - name = name[i+1:] - i = strings.Index(name, ".") - return name[i+1:] -} - -func trimGOPATH(name, file string) string { - // Here we want to get the source file path relative to the compile time - // GOPATH. As of Go 1.6.x there is no direct way to know the compiled - // GOPATH at runtime, but we can infer the number of path segments in the - // GOPATH. We note that fn.Name() returns the function name qualified by - // the import path, which does not include the GOPATH. Thus we can trim - // segments from the beginning of the file path until the number of path - // separators remaining is one more than the number of path separators in - // the function name. For example, given: - // - // GOPATH /home/user - // file /home/user/src/pkg/sub/file.go - // fn.Name() pkg/sub.Type.Method - // - // We want to produce: - // - // pkg/sub/file.go - // - // From this we can easily see that fn.Name() has one less path separator - // than our desired output. We count separators from the end of the file - // path until it finds two more than in the function name and then move - // one character forward to preserve the initial path segment without a - // leading separator. - const sep = "/" - goal := strings.Count(name, sep) + 2 - i := len(file) - for n := 0; n < goal; n++ { - i = strings.LastIndex(file[:i], sep) - if i == -1 { - // not enough separators found, set i so that the slice expression - // below leaves file unmodified - i = -len(sep) - break - } - } - // get back to 0 or trim the leading separator - file = file[i+len(sep):] - return file -} diff --git a/internal/lb/mem.go b/internal/lb/mem.go index 3901ed93b..8fe3602a5 100644 --- a/internal/lb/mem.go +++ b/internal/lb/mem.go @@ -36,7 +36,13 @@ type MemoryLoadBalancer struct { // NewMemoryLoadBalancer creates a new memory-based load balancer. func NewMemoryLoadBalancer(environment env.Environment) SRSLoadBalancer { return &MemoryLoadBalancer{ - environment: environment, + environment: environment, + servers: sync.NewMap[string, *SRSServer](), + picked: sync.NewMap[string, *SRSServer](), + hlsStreamURL: sync.NewMap[string, HLSPlayStream](), + hlsSPBHID: sync.NewMap[string, HLSPlayStream](), + rtcStreamURL: sync.NewMap[string, RTCConnection](), + rtcUfrag: sync.NewMap[string, RTCConnection](), } } diff --git a/internal/logger/context.go b/internal/logger/context.go index a50bf2a1a..292981645 100644 --- a/internal/logger/context.go +++ b/internal/logger/context.go @@ -14,7 +14,7 @@ type key string var cidKey key = "cid.srsx.ossrs.org" -// generateContextID generates a random context id in string. +// GenerateContextID generates a random context id in string. func GenerateContextID() string { randomBytes := make([]byte, 32) _, _ = rand.Read(randomBytes) @@ -26,11 +26,11 @@ func GenerateContextID() string { // WithContext creates a new context with cid, which will be used for log. func WithContext(ctx context.Context) context.Context { - return WithContextID(ctx, GenerateContextID()) + return withContextID(ctx, GenerateContextID()) } -// WithContextID creates a new context with cid, which will be used for log. -func WithContextID(ctx context.Context, cid string) context.Context { +// withContextID creates a new context with cid, which will be used for log. +func withContextID(ctx context.Context, cid string) context.Context { return context.WithValue(ctx, cidKey, cid) } diff --git a/internal/logger/context_test.go b/internal/logger/context_test.go new file mode 100644 index 000000000..2adda9538 --- /dev/null +++ b/internal/logger/context_test.go @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "encoding/hex" + "testing" +) + +func TestGenerateContextID_LengthAndHex(t *testing.T) { + cid := GenerateContextID() + if len(cid) != 7 { + t.Fatalf("len(cid) = %d, want 7", len(cid)) + } + if _, err := hex.DecodeString(cid + "0"); err != nil { + t.Fatalf("cid %q is not hex: %v", cid, err) + } +} + +func TestGenerateContextID_Unique(t *testing.T) { + seen := make(map[string]struct{}, 1000) + for i := range 1000 { + cid := GenerateContextID() + if _, dup := seen[cid]; dup { + t.Fatalf("duplicate cid %q at iteration %d", cid, i) + } + seen[cid] = struct{}{} + } +} + +func TestWithContext_AttachesCID(t *testing.T) { + ctx := WithContext(context.Background()) + cid := ContextID(ctx) + if len(cid) != 7 { + t.Fatalf("ContextID length = %d, want 7", len(cid)) + } +} + +func TestWithContext_IndependentCIDs(t *testing.T) { + c1 := WithContext(context.Background()) + c2 := WithContext(context.Background()) + if ContextID(c1) == ContextID(c2) { + t.Fatalf("expected distinct cids, got %q twice", ContextID(c1)) + } +} + +func TestContextID_Missing(t *testing.T) { + if got := ContextID(context.Background()); got != "" { + t.Fatalf("ContextID on empty ctx = %q, want \"\"", got) + } +} + +func TestContextID_WrongTypeReturnsEmpty(t *testing.T) { + ctx := context.WithValue(context.Background(), cidKey, 42) + if got := ContextID(ctx); got != "" { + t.Fatalf("ContextID with int value = %q, want \"\"", got) + } +} + +func TestWithContextID_RoundTrip(t *testing.T) { + ctx := withContextID(context.Background(), "abcdef1") + if got := ContextID(ctx); got != "abcdef1" { + t.Fatalf("ContextID = %q, want %q", got, "abcdef1") + } +} + +func TestWithContextID_Overwrite(t *testing.T) { + ctx := withContextID(context.Background(), "first00") + ctx = withContextID(ctx, "second1") + if got := ContextID(ctx); got != "second1" { + t.Fatalf("ContextID after overwrite = %q, want %q", got, "second1") + } +} + +func TestCIDKey_NotCollidingWithPlainString(t *testing.T) { + ctx := context.WithValue(context.Background(), string(cidKey), "plain") + if got := ContextID(ctx); got != "" { + t.Fatalf("ContextID leaked through string key = %q, want \"\"", got) + } +} diff --git a/internal/logger/log.go b/internal/logger/log.go index 9653c0846..f710653e5 100644 --- a/internal/logger/log.go +++ b/internal/logger/log.go @@ -5,8 +5,9 @@ package logger import ( "context" - "io/ioutil" - stdLog "log" + "fmt" + "io" + "log/slog" "os" ) @@ -15,8 +16,8 @@ type logger interface { } type loggerPlus struct { - logger *stdLog.Logger - level string + logger *slog.Logger + level slog.Level } func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { @@ -27,61 +28,95 @@ func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { return v } -func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { - format, args := f, a +func (v *loggerPlus) Printf(ctx context.Context, f string, a ...any) { + attrs := []slog.Attr{slog.Int("pid", os.Getpid())} if cid := ContextID(ctx); cid != "" { - format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) + attrs = append(attrs, slog.String("cid", cid)) } - - v.logger.Printf(format, args...) + v.logger.LogAttrs(ctx, v.level, fmt.Sprintf(f, a...), attrs...) } var verboseLogger logger -func Vf(ctx context.Context, format string, a ...interface{}) { +func Vf(ctx context.Context, format string, a ...any) { verboseLogger.Printf(ctx, format, a...) } var debugLogger logger -func Df(ctx context.Context, format string, a ...interface{}) { +func Df(ctx context.Context, format string, a ...any) { debugLogger.Printf(ctx, format, a...) } var warnLogger logger -func Wf(ctx context.Context, format string, a ...interface{}) { +func Wf(ctx context.Context, format string, a ...any) { warnLogger.Printf(ctx, format, a...) } var errorLogger logger -func Ef(ctx context.Context, format string, a ...interface{}) { +func Ef(ctx context.Context, format string, a ...any) { errorLogger.Printf(ctx, format, a...) } const ( - logVerboseLabel = "verb" - logDebugLabel = "debug" - logWarnLabel = "warn" - logErrorLabel = "error" + levelVerb slog.Level = slog.LevelDebug - 4 + levelDebug slog.Level = slog.LevelDebug + levelWarn slog.Level = slog.LevelWarn + levelError slog.Level = slog.LevelError ) +// newJSONLogger builds a slog.Logger that writes JSON records to w, renders the +// time in UTC, and maps our custom levels to short lowercase labels. +func newJSONLogger(w io.Writer) *slog.Logger { + h := slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: levelVerb, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if len(groups) != 0 { + return a + } + switch a.Key { + case slog.TimeKey: + return slog.Time(slog.TimeKey, a.Value.Time().UTC()) + case slog.LevelKey: + return slog.String(slog.LevelKey, levelLabel(a.Value.Any().(slog.Level))) + } + return a + }, + }) + return slog.New(h) +} + +func levelLabel(l slog.Level) string { + switch l { + case levelVerb: + return "verb" + case levelDebug: + return "debug" + case levelWarn: + return "warn" + case levelError: + return "error" + } + return l.String() +} + func init() { - verboseLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logVerboseLabel + verboseLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(io.Discard) + l.level = levelVerb }) - debugLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logDebugLabel + debugLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(os.Stdout) + l.level = levelDebug }) - warnLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logWarnLabel + warnLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(os.Stderr) + l.level = levelWarn }) - errorLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logErrorLabel + errorLogger = newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(os.Stderr) + l.level = levelError }) } diff --git a/internal/logger/log_test.go b/internal/logger/log_test.go new file mode 100644 index 000000000..626cb73f2 --- /dev/null +++ b/internal/logger/log_test.go @@ -0,0 +1,174 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "os" + "strings" + "testing" + "time" +) + +func decodeLine(t *testing.T, line []byte) map[string]any { + t.Helper() + var m map[string]any + if err := json.Unmarshal(bytes.TrimSpace(line), &m); err != nil { + t.Fatalf("decode %q: %v", line, err) + } + return m +} + +func bufLoggerPlus(w io.Writer, level slog.Level) *loggerPlus { + return newLoggerPlus(func(l *loggerPlus) { + l.logger = newJSONLogger(w) + l.level = level + }) +} + +func TestLevelLabel_Known(t *testing.T) { + cases := map[slog.Level]string{ + levelVerb: "verb", + levelDebug: "debug", + levelWarn: "warn", + levelError: "error", + } + for lvl, want := range cases { + if got := levelLabel(lvl); got != want { + t.Errorf("levelLabel(%v) = %q, want %q", lvl, got, want) + } + } +} + +func TestLevelLabel_UnknownFallsBackToString(t *testing.T) { + got := levelLabel(slog.Level(99)) + if got == "" { + t.Fatalf("levelLabel(99) returned empty") + } + if got == "verb" || got == "debug" || got == "warn" || got == "error" { + t.Fatalf("levelLabel(99) = %q, want slog.Level.String() form", got) + } +} + +func TestPrintf_EmitsAllFields(t *testing.T) { + var buf bytes.Buffer + lp := bufLoggerPlus(&buf, levelDebug) + ctx := withContextID(context.Background(), "abc1234") + lp.Printf(ctx, "hello %s %d", "world", 42) + + m := decodeLine(t, buf.Bytes()) + if m["level"] != "debug" { + t.Errorf("level = %v, want debug", m["level"]) + } + if m["msg"] != "hello world 42" { + t.Errorf("msg = %v, want %q", m["msg"], "hello world 42") + } + if m["cid"] != "abc1234" { + t.Errorf("cid = %v, want abc1234", m["cid"]) + } + pid, ok := m["pid"].(float64) + if !ok || int(pid) != os.Getpid() { + t.Errorf("pid = %v, want %d", m["pid"], os.Getpid()) + } + ts, ok := m["time"].(string) + if !ok || !strings.HasSuffix(ts, "Z") { + t.Errorf("time = %v, want UTC suffix Z", m["time"]) + } + if _, err := time.Parse(time.RFC3339Nano, ts); err != nil { + t.Errorf("time %q not RFC3339Nano: %v", ts, err) + } +} + +func TestPrintf_OmitsCIDWhenAbsent(t *testing.T) { + var buf bytes.Buffer + bufLoggerPlus(&buf, levelWarn).Printf(context.Background(), "no cid here") + + m := decodeLine(t, buf.Bytes()) + if v, present := m["cid"]; present { + t.Errorf("cid should be absent, got %v", v) + } + if m["level"] != "warn" { + t.Errorf("level = %v, want warn", m["level"]) + } +} + +func TestPrintf_AllLevelsMapToLabel(t *testing.T) { + cases := []struct { + level slog.Level + label string + }{ + {levelVerb, "verb"}, + {levelDebug, "debug"}, + {levelWarn, "warn"}, + {levelError, "error"}, + } + for _, tc := range cases { + var buf bytes.Buffer + bufLoggerPlus(&buf, tc.level).Printf(context.Background(), "hi") + m := decodeLine(t, buf.Bytes()) + if m["level"] != tc.label { + t.Errorf("level(%v) rendered as %v, want %q", tc.level, m["level"], tc.label) + } + } +} + +func TestNewJSONLogger_GroupedAttrsPassThrough(t *testing.T) { + var buf bytes.Buffer + lg := newJSONLogger(&buf) + lg.LogAttrs(context.Background(), levelDebug, "grouped", + slog.Group("meta", slog.String("inner", "v"))) + + m := decodeLine(t, buf.Bytes()) + meta, ok := m["meta"].(map[string]any) + if !ok { + t.Fatalf("meta not an object: %v", m["meta"]) + } + if meta["inner"] != "v" { + t.Errorf("meta.inner = %v, want v", meta["inner"]) + } +} + +func TestPackageWrappers_RouteToRightLogger(t *testing.T) { + origV, origD, origW, origE := verboseLogger, debugLogger, warnLogger, errorLogger + t.Cleanup(func() { + verboseLogger, debugLogger, warnLogger, errorLogger = origV, origD, origW, origE + }) + + vBuf, dBuf, wBuf, eBuf := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{} + verboseLogger = bufLoggerPlus(vBuf, levelVerb) + debugLogger = bufLoggerPlus(dBuf, levelDebug) + warnLogger = bufLoggerPlus(wBuf, levelWarn) + errorLogger = bufLoggerPlus(eBuf, levelError) + + ctx := context.Background() + Vf(ctx, "v=%d", 1) + Df(ctx, "d=%d", 2) + Wf(ctx, "w=%d", 3) + Ef(ctx, "e=%d", 4) + + checks := []struct { + name string + buf *bytes.Buffer + label string + msg string + }{ + {"Vf", vBuf, "verb", "v=1"}, + {"Df", dBuf, "debug", "d=2"}, + {"Wf", wBuf, "warn", "w=3"}, + {"Ef", eBuf, "error", "e=4"}, + } + for _, c := range checks { + m := decodeLine(t, c.buf.Bytes()) + if m["level"] != c.label { + t.Errorf("%s level = %v, want %v", c.name, m["level"], c.label) + } + if m["msg"] != c.msg { + t.Errorf("%s msg = %v, want %v", c.name, m["msg"], c.msg) + } + } +} diff --git a/internal/protocol/rtc.go b/internal/protocol/rtc.go index b1f43bce2..add8bdf00 100644 --- a/internal/protocol/rtc.go +++ b/internal/protocol/rtc.go @@ -45,7 +45,11 @@ type srsWebRTCServer struct { } func NewSRSWebRTCServer(environment env.Environment, opts ...func(*srsWebRTCServer)) *srsWebRTCServer { - v := &srsWebRTCServer{environment: environment} + v := &srsWebRTCServer{ + environment: environment, + usernames: sync.NewMap[string, *RTCConnection](), + addresses: sync.NewMap[string, *RTCConnection](), + } for _, opt := range opts { opt(v) } diff --git a/internal/protocol/srt.go b/internal/protocol/srt.go index f51724c2a..ced994ef6 100644 --- a/internal/protocol/srt.go +++ b/internal/protocol/srt.go @@ -43,6 +43,7 @@ func NewSRSSRTServer(environment env.Environment, opts ...func(*srsSRTServer)) * v := &srsSRTServer{ environment: environment, start: time.Now(), + sockets: sync.NewMap[uint32, *SRTConnection](), } for _, opt := range opts { diff --git a/internal/signal/signal.go b/internal/signal/signal.go index 2dae9d23c..c794ec8bb 100644 --- a/internal/signal/signal.go +++ b/internal/signal/signal.go @@ -15,9 +15,15 @@ import ( "srsx/internal/logger" ) +// Indirections so tests can substitute signal delivery and process exit. +var ( + signalNotify = signal.Notify + osExit = os.Exit +) + func InstallSignals(ctx context.Context, cancel context.CancelFunc) { sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + signalNotify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) go func() { for s := range sc { @@ -40,7 +46,7 @@ func InstallForceQuit(ctx context.Context, environment env.Environment) error { <-ctx.Done() time.Sleep(forceTimeout) logger.Wf(ctx, "Force to exit by timeout") - os.Exit(1) + osExit(1) }() return nil } diff --git a/internal/signal/signal_test.go b/internal/signal/signal_test.go new file mode 100644 index 000000000..207f78aee --- /dev/null +++ b/internal/signal/signal_test.go @@ -0,0 +1,170 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package signal + +import ( + "context" + "os" + "strings" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "srsx/internal/env/envfakes" +) + +// swapNotify replaces signalNotify with a capturing fake and returns a getter +// for the channel registered by the code under test plus a restore func. +func swapNotify(t *testing.T) (func() chan<- os.Signal, func()) { + t.Helper() + orig := signalNotify + var ( + mu sync.Mutex + ch chan<- os.Signal + ) + signalNotify = func(c chan<- os.Signal, _ ...os.Signal) { + mu.Lock() + defer mu.Unlock() + ch = c + } + return func() chan<- os.Signal { + mu.Lock() + defer mu.Unlock() + return ch + }, func() { + signalNotify = orig + } +} + +func swapExit(t *testing.T) (*int32, chan int, func()) { + t.Helper() + orig := osExit + var called int32 + done := make(chan int, 1) + osExit = func(code int) { + atomic.StoreInt32(&called, 1) + select { + case done <- code: + default: + } + // Block to mimic os.Exit never returning; the goroutine holding us + // here is abandoned when the test ends. + select {} + } + return &called, done, func() { osExit = orig } +} + +func TestInstallSignals_CancelsOnSignal(t *testing.T) { + getCh, restore := swapNotify(t) + defer restore() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + InstallSignals(ctx, cancel) + + ch := getCh() + if ch == nil { + t.Fatal("signalNotify was not called") + } + ch <- syscall.SIGINT + + select { + case <-ctx.Done(): + case <-time.After(time.Second): + t.Fatal("ctx was not canceled after signal") + } +} + +func TestInstallSignals_HandlesRepeatedSignals(t *testing.T) { + getCh, restore := swapNotify(t) + defer restore() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + InstallSignals(ctx, cancel) + ch := getCh() + + // Multiple signals must not panic; cancel() is idempotent. + ch <- syscall.SIGINT + ch <- syscall.SIGTERM + ch <- os.Interrupt + + select { + case <-ctx.Done(): + case <-time.After(time.Second): + t.Fatal("ctx was not canceled") + } +} + +func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) { + fakeEnv := &envfakes.FakeEnvironment{} + fakeEnv.ForceQuitTimeoutReturns("not-a-duration") + + err := InstallForceQuit(t.Context(), fakeEnv) + if err == nil { + t.Fatal("want error for bad duration") + } + if !strings.Contains(err.Error(), "parse force timeout") { + t.Fatalf("err = %v", err) + } + if !strings.Contains(err.Error(), "not-a-duration") { + t.Fatalf("err missing input: %v", err) + } +} + +func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) { + called, done, restore := swapExit(t) + defer restore() + + fakeEnv := &envfakes.FakeEnvironment{} + fakeEnv.ForceQuitTimeoutReturns("1ms") + + ctx, cancel := context.WithCancel(t.Context()) + if err := InstallForceQuit(ctx, fakeEnv); err != nil { + t.Fatalf("unexpected err: %v", err) + } + + // Before cancel, the goroutine is blocked and exit must not fire. + if atomic.LoadInt32(called) != 0 { + t.Fatal("osExit called before ctx cancel") + } + cancel() + + select { + case code := <-done: + if code != 1 { + t.Fatalf("exit code = %d, want 1", code) + } + case <-time.After(time.Second): + t.Fatal("osExit not called after cancel + timeout") + } +} + +func TestInstallForceQuit_WaitsForCancelBeforeSleeping(t *testing.T) { + called, done, restore := swapExit(t) + defer restore() + + fakeEnv := &envfakes.FakeEnvironment{} + fakeEnv.ForceQuitTimeoutReturns("10ms") + + // Intentionally use a never-canceled context and leak the goroutine: + // if we canceled at test end, the goroutine would wake and race with + // restore() writing osExit. + if err := InstallForceQuit(context.Background(), fakeEnv); err != nil { + t.Fatalf("unexpected err: %v", err) + } + + select { + case <-done: + t.Fatal("osExit fired without ctx cancel") + case <-time.After(30 * time.Millisecond): + } + if atomic.LoadInt32(called) != 0 { + t.Fatal("osExit called unexpectedly") + } +} diff --git a/internal/sync/map.go b/internal/sync/map.go index 05f628a44..16387ec03 100644 --- a/internal/sync/map.go +++ b/internal/sync/map.go @@ -5,15 +5,28 @@ package sync import "sync" -type Map[K comparable, V any] struct { +type Map[K comparable, V any] interface { + Delete(key K) + Load(key K) (value V, ok bool) + LoadAndDelete(key K) (value V, loaded bool) + LoadOrStore(key K, value V) (actual V, loaded bool) + Range(f func(key K, value V) bool) + Store(key K, value V) +} + +func NewMap[K comparable, V any]() Map[K, V] { + return &mapImpl[K, V]{} +} + +type mapImpl[K comparable, V any] struct { m sync.Map } -func (m *Map[K, V]) Delete(key K) { +func (m *mapImpl[K, V]) Delete(key K) { m.m.Delete(key) } -func (m *Map[K, V]) Load(key K) (value V, ok bool) { +func (m *mapImpl[K, V]) Load(key K) (value V, ok bool) { v, ok := m.m.Load(key) if !ok { return value, ok @@ -21,7 +34,7 @@ func (m *Map[K, V]) Load(key K) (value V, ok bool) { return v.(V), ok } -func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { +func (m *mapImpl[K, V]) LoadAndDelete(key K) (value V, loaded bool) { v, loaded := m.m.LoadAndDelete(key) if !loaded { return value, loaded @@ -29,17 +42,17 @@ func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { return v.(V), loaded } -func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { +func (m *mapImpl[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { a, loaded := m.m.LoadOrStore(key, value) return a.(V), loaded } -func (m *Map[K, V]) Range(f func(key K, value V) bool) { +func (m *mapImpl[K, V]) Range(f func(key K, value V) bool) { m.m.Range(func(key, value any) bool { return f(key.(K), value.(V)) }) } -func (m *Map[K, V]) Store(key K, value V) { +func (m *mapImpl[K, V]) Store(key K, value V) { m.m.Store(key, value) } diff --git a/internal/sync/map_test.go b/internal/sync/map_test.go new file mode 100644 index 000000000..e23d0d698 --- /dev/null +++ b/internal/sync/map_test.go @@ -0,0 +1,182 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package sync + +import ( + "sort" + "testing" +) + +func TestNewMap_ReturnsEmpty(t *testing.T) { + m := NewMap[string, int]() + if m == nil { + t.Fatal("NewMap returned nil") + } + if v, ok := m.Load("missing"); ok || v != 0 { + t.Fatalf("Load(missing) = (%v, %v), want (0, false)", v, ok) + } +} + +func TestStore_AndLoad(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + + v, ok := m.Load("a") + if !ok || v != 1 { + t.Fatalf("Load(a) = (%v, %v), want (1, true)", v, ok) + } +} + +func TestLoad_MissingReturnsZero(t *testing.T) { + m := NewMap[string, int]() + v, ok := m.Load("nope") + if ok { + t.Fatal("Load on missing key returned ok=true") + } + if v != 0 { + t.Fatalf("Load on missing key returned %v, want zero", v) + } +} + +func TestDelete_RemovesKey(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Delete("a") + + if _, ok := m.Load("a"); ok { + t.Fatal("Load(a) returned ok=true after Delete") + } +} + +func TestDelete_MissingIsNoop(t *testing.T) { + m := NewMap[string, int]() + m.Delete("never-stored") +} + +func TestLoadAndDelete_Present(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 42) + + v, loaded := m.LoadAndDelete("a") + if !loaded { + t.Fatal("LoadAndDelete returned loaded=false for present key") + } + if v != 42 { + t.Fatalf("LoadAndDelete returned %v, want 42", v) + } + if _, ok := m.Load("a"); ok { + t.Fatal("key still present after LoadAndDelete") + } +} + +func TestLoadAndDelete_Absent(t *testing.T) { + m := NewMap[string, int]() + v, loaded := m.LoadAndDelete("nope") + if loaded { + t.Fatal("LoadAndDelete returned loaded=true for absent key") + } + if v != 0 { + t.Fatalf("LoadAndDelete on absent key returned %v, want zero", v) + } +} + +func TestLoadOrStore_StoresWhenAbsent(t *testing.T) { + m := NewMap[string, int]() + actual, loaded := m.LoadOrStore("a", 7) + if loaded { + t.Fatal("LoadOrStore returned loaded=true for absent key") + } + if actual != 7 { + t.Fatalf("LoadOrStore returned %v, want 7", actual) + } + + v, ok := m.Load("a") + if !ok || v != 7 { + t.Fatalf("Load after LoadOrStore = (%v, %v), want (7, true)", v, ok) + } +} + +func TestLoadOrStore_LoadsWhenPresent(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + + actual, loaded := m.LoadOrStore("a", 999) + if !loaded { + t.Fatal("LoadOrStore returned loaded=false for present key") + } + if actual != 1 { + t.Fatalf("LoadOrStore returned %v, want existing value 1", actual) + } + + v, _ := m.Load("a") + if v != 1 { + t.Fatalf("LoadOrStore overwrote existing value: got %v, want 1", v) + } +} + +func TestRange_VisitsAllEntries(t *testing.T) { + m := NewMap[string, int]() + want := map[string]int{"a": 1, "b": 2, "c": 3} + for k, v := range want { + m.Store(k, v) + } + + got := map[string]int{} + m.Range(func(key string, value int) bool { + got[key] = value + return true + }) + + if len(got) != len(want) { + t.Fatalf("Range visited %d entries, want %d", len(got), len(want)) + } + for k, v := range want { + if got[k] != v { + t.Fatalf("Range got[%q] = %v, want %v", k, got[k], v) + } + } +} + +func TestRange_EarlyStop(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + visited := 0 + m.Range(func(key string, value int) bool { + visited++ + return false + }) + + if visited != 1 { + t.Fatalf("Range visited %d entries after returning false, want 1", visited) + } +} + +func TestMap_PointerValueType(t *testing.T) { + type entry struct{ n int } + m := NewMap[string, *entry]() + + e := &entry{n: 5} + m.Store("k", e) + + got, ok := m.Load("k") + if !ok { + t.Fatal("Load returned ok=false") + } + if got != e { + t.Fatalf("Load returned different pointer: %p vs %p", got, e) + } + + keys := []string{} + m.Range(func(key string, value *entry) bool { + keys = append(keys, key) + return true + }) + sort.Strings(keys) + if len(keys) != 1 || keys[0] != "k" { + t.Fatalf("Range keys = %v, want [k]", keys) + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 40baf28c0..4b86edda0 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -7,10 +7,8 @@ import ( "context" "encoding/binary" "encoding/json" - stdErr "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -71,7 +69,7 @@ func ApiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { // ParseBody read the body from r, and unmarshal JSON to v. func ParseBody(r io.ReadCloser, v interface{}) error { - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) if err != nil { return errors.Wrapf(err, "read body") } @@ -115,17 +113,17 @@ func BuildStreamURL(r string) (string, error) { func IsPeerClosedError(err error) bool { causeErr := errors.Cause(err) - if stdErr.Is(causeErr, io.EOF) { + if errors.Is(causeErr, io.EOF) { return true } - if stdErr.Is(causeErr, syscall.EPIPE) { + if errors.Is(causeErr, syscall.EPIPE) { return true } if netErr, ok := causeErr.(*net.OpError); ok { if sysErr, ok := netErr.Err.(*os.SyscallError); ok { - if stdErr.Is(sysErr.Err, syscall.ECONNRESET) { + if errors.Is(sysErr.Err, syscall.ECONNRESET) { return true } } diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 000000000..2977bee5c --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,414 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package utils + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "syscall" + "testing" + + "srsx/internal/errors" +) + +// errReadCloser always fails on Read. +type errReadCloser struct{ closed bool } + +func (e *errReadCloser) Read(p []byte) (int, error) { return 0, io.ErrUnexpectedEOF } +func (e *errReadCloser) Close() error { e.closed = true; return nil } + +func TestApiResponse_EncodesJSON(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ApiResponse(context.Background(), rec, req, map[string]int{"a": 1}) + + if rec.Code != http.StatusOK { + t.Fatalf("code = %d, want 200", rec.Code) + } + if got := rec.Header().Get("Content-Type"); got != "application/json" { + t.Fatalf("Content-Type = %q", got) + } + if rec.Header().Get("Server") == "" { + t.Fatal("Server header empty") + } + if got := strings.TrimSpace(rec.Body.String()); got != `{"a":1}` { + t.Fatalf("body = %q", got) + } +} + +func TestApiResponse_MarshalErrorFallsBackToApiError(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // Channels are not JSON-marshalable. + ApiResponse(context.Background(), rec, req, make(chan int)) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("code = %d, want 500", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") { + t.Fatalf("Content-Type = %q", ct) + } +} + +func TestApiError_WritesPlainText500(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ApiError(context.Background(), rec, req, errors.New("boom")) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("code = %d", rec.Code) + } + if got := strings.TrimSpace(rec.Body.String()); got != "boom" { + t.Fatalf("body = %q", got) + } +} + +func TestApiCORS_OptionsPreflightReturnsTrue(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodOptions, "/", nil) + if !ApiCORS(context.Background(), rec, req) { + t.Fatal("OPTIONS should return true") + } + if rec.Code != http.StatusOK { + t.Fatalf("code = %d", rec.Code) + } + if rec.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Fatal("missing Allow-Origin") + } +} + +func TestApiCORS_NonOptionsReturnsFalse(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + if ApiCORS(context.Background(), rec, req) { + t.Fatal("GET should return false") + } + if rec.Header().Get("Access-Control-Allow-Methods") != "*" { + t.Fatal("missing Allow-Methods") + } +} + +func TestParseBody_Success(t *testing.T) { + var v struct { + Name string `json:"name"` + } + body := io.NopCloser(strings.NewReader(`{"name":"alice"}`)) + if err := ParseBody(body, &v); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if v.Name != "alice" { + t.Fatalf("name = %q", v.Name) + } +} + +func TestParseBody_EmptyBodyIsNoOp(t *testing.T) { + var v struct{ Name string } + if err := ParseBody(io.NopCloser(strings.NewReader("")), &v); err != nil { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestParseBody_ReadError(t *testing.T) { + if err := ParseBody(&errReadCloser{}, &struct{}{}); err == nil { + t.Fatal("want error") + } +} + +func TestParseBody_UnmarshalError(t *testing.T) { + var v struct{ Name string } + err := ParseBody(io.NopCloser(strings.NewReader("not json")), &v) + if err == nil { + t.Fatal("want error") + } + if !strings.Contains(err.Error(), "json unmarshal") { + t.Fatalf("err = %v", err) + } +} + +func TestBuildStreamURL(t *testing.T) { + cases := []struct { + in, want string + }{ + {"rtmp://example.com/live/stream", "example.com/live/stream"}, + {"rtmp://example.com:1935/live/stream", "example.com/live/stream"}, + {"rtmp://127.0.0.1/live/stream", "__defaultVhost__/live/stream"}, + {"rtmp://localhost/live/stream", "__defaultVhost__/live/stream"}, + {"rtmp://localhost:1935/live/stream", "__defaultVhost__/live/stream"}, + } + for _, c := range cases { + got, err := BuildStreamURL(c.in) + if err != nil { + t.Fatalf("%s: err = %v", c.in, err) + } + if got != c.want { + t.Fatalf("%s: got %q want %q", c.in, got, c.want) + } + } +} + +func TestBuildStreamURL_ParseError(t *testing.T) { + if _, err := BuildStreamURL("http://%zz"); err == nil { + t.Fatal("want error") + } +} + +func TestIsPeerClosedError(t *testing.T) { + cases := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"EOF", io.EOF, true}, + {"wrapped-EOF", errors.Wrap(io.EOF, "ctx"), true}, + {"EPIPE", syscall.EPIPE, true}, + // errors.Cause fully unwraps OpError → SyscallError → Errno, so the + // OpError branch inside IsPeerClosedError is not reachable for the + // canonical wrapping shape. We still exercise these constructions to + // lock in the current behavior. + {"ECONNRESET-wrapped-in-OpError", &net.OpError{Err: &os.SyscallError{Err: syscall.ECONNRESET}}, false}, + {"OpError-with-other-syscall", &net.OpError{Err: &os.SyscallError{Err: syscall.EINVAL}}, false}, + {"OpError-not-SyscallError", &net.OpError{Err: errors.New("boom")}, false}, + {"unrelated", errors.New("other"), false}, + } + for _, c := range cases { + if got := IsPeerClosedError(c.err); got != c.want { + t.Fatalf("%s: got %v want %v", c.name, got, c.want) + } + } +} + +func TestIsClosedNetworkError(t *testing.T) { + cases := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"OpError-matching", &net.OpError{Err: errors.New("use of closed network connection")}, true}, + {"OpError-other", &net.OpError{Err: errors.New("other")}, false}, + {"plain-with-substring", errors.New("wrap: use of closed network connection"), true}, + {"plain-unrelated", errors.New("other thing"), false}, + } + for _, c := range cases { + if got := IsClosedNetworkError(c.err); got != c.want { + t.Fatalf("%s: got %v want %v", c.name, got, c.want) + } + } +} + +func TestConvertURLToStreamURL_PathForm(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com:8080/live/stream.flv", nil) + unified, full := ConvertURLToStreamURL(req) + if unified != "http://example.com/live/stream" { + t.Fatalf("unified = %q", unified) + } + if full != "http://example.com/live/stream.flv" { + t.Fatalf("full = %q", full) + } +} + +func TestConvertURLToStreamURL_HostWithoutPort(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/a/b.m3u8", nil) + req.Host = "example.com" + unified, full := ConvertURLToStreamURL(req) + if unified != "http://__defaultVhost__/a/b" { + t.Fatalf("unified = %q", unified) + } + if full != "http://__defaultVhost__/a/b.m3u8" { + t.Fatalf("full = %q", full) + } +} + +func TestConvertURLToStreamURL_BadHostWithColonFallsBack(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/a/b.ts", nil) + req.Host = "a:b:c" + unified, _ := ConvertURLToStreamURL(req) + if !strings.Contains(unified, "__defaultVhost__") { + t.Fatalf("unified = %q", unified) + } +} + +func TestConvertURLToStreamURL_QueryForm(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com:8080/?app=live&stream=foo", nil) + unified, full := ConvertURLToStreamURL(req) + if unified != "http://example.com/live/foo" { + t.Fatalf("unified = %q", unified) + } + if full != "http://example.com/live/foo" { + t.Fatalf("full = %q", full) + } +} + +func TestConvertURLToStreamURL_TLS(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com:443/a/b.flv", nil) + req.TLS = &tls.ConnectionState{} + unified, _ := ConvertURLToStreamURL(req) + if !strings.HasPrefix(unified, "https://") { + t.Fatalf("unified = %q", unified) + } +} + +func TestRtcIsSTUN(t *testing.T) { + cases := []struct { + data []byte + want bool + }{ + {nil, false}, + {[]byte{}, false}, + {[]byte{0x00, 0x01}, true}, + {[]byte{0x01}, true}, + {[]byte{0x02}, false}, + } + for i, c := range cases { + if got := RtcIsSTUN(c.data); got != c.want { + t.Fatalf("case %d: got %v want %v", i, got, c.want) + } + } +} + +func TestRtcIsRTPOrRTCP(t *testing.T) { + short := make([]byte, 11) + valid := make([]byte, 12) + valid[0] = 0x80 + badFirstByte := make([]byte, 12) + badFirstByte[0] = 0xC0 + + if RtcIsRTPOrRTCP(short) { + t.Fatal("short should be false") + } + if !RtcIsRTPOrRTCP(valid) { + t.Fatal("valid should be true") + } + if RtcIsRTPOrRTCP(badFirstByte) { + t.Fatal("0xC0 should be false") + } +} + +func TestSrtIsHandshake(t *testing.T) { + if SrtIsHandshake([]byte{0x80, 0x00, 0x00}) { + t.Fatal("short should be false") + } + if !SrtIsHandshake([]byte{0x80, 0x00, 0x00, 0x00}) { + t.Fatal("handshake magic should be true") + } + if SrtIsHandshake([]byte{0x00, 0x00, 0x00, 0x01}) { + t.Fatal("non-handshake should be false") + } +} + +func TestSrtParseSocketID(t *testing.T) { + if SrtParseSocketID(make([]byte, 15)) != 0 { + t.Fatal("short should be 0") + } + data := make([]byte, 16) + data[12], data[13], data[14], data[15] = 0x00, 0x00, 0x00, 0x42 + if got := SrtParseSocketID(data); got != 0x42 { + t.Fatalf("got %#x", got) + } +} + +func TestParseIceUfragPwd(t *testing.T) { + sdp := "v=0\r\na=ice-ufrag:abc\r\na=ice-pwd:secret\r\n" + ufrag, pwd, err := ParseIceUfragPwd(sdp) + if err != nil { + t.Fatalf("err = %v", err) + } + if ufrag != "abc" || pwd != "secret" { + t.Fatalf("ufrag=%q pwd=%q", ufrag, pwd) + } +} + +func TestParseIceUfragPwd_MissingUfrag(t *testing.T) { + if _, _, err := ParseIceUfragPwd("a=ice-pwd:secret"); err == nil { + t.Fatal("want error") + } +} + +func TestParseIceUfragPwd_MissingPwd(t *testing.T) { + if _, _, err := ParseIceUfragPwd("a=ice-ufrag:abc"); err == nil { + t.Fatal("want error") + } +} + +func TestParseSRTStreamID_WithHost(t *testing.T) { + host, resource, err := ParseSRTStreamID("h=example.com,r=live/stream") + if err != nil { + t.Fatalf("err = %v", err) + } + if host != "example.com" || resource != "live/stream" { + t.Fatalf("host=%q resource=%q", host, resource) + } +} + +func TestParseSRTStreamID_WithoutHost(t *testing.T) { + host, resource, err := ParseSRTStreamID("r=live/stream") + if err != nil { + t.Fatalf("err = %v", err) + } + if host != "" || resource != "live/stream" { + t.Fatalf("host=%q resource=%q", host, resource) + } +} + +func TestParseSRTStreamID_MissingResource(t *testing.T) { + if _, _, err := ParseSRTStreamID("h=example.com"); err == nil { + t.Fatal("want error") + } +} + +func TestParseListenEndpoint(t *testing.T) { + cases := []struct { + name string + in string + wantErr bool + protocol string + ip string // "" means nil + port uint16 + }{ + {"bare-port", "1935", false, "tcp", "", 1935}, + {"bare-port-bad", "abc", true, "", "", 0}, + {"url-host-port", "tcp://0.0.0.0:1935", false, "tcp", "0.0.0.0", 1935}, + {"url-empty-host", "tcp://:1935", false, "tcp", "", 1935}, + {"url-port-only", "udp://1935", false, "udp", "", 1935}, + {"url-port-only-bad", "udp://abc", true, "", "", 0}, + {"url-split-fail", "tcp://a:b:c:d", true, "", "", 0}, + {"url-bad-port", "tcp://host:bad", true, "", "", 0}, + {"legacy", "tcp:1.2.3.4:1935", false, "tcp", "1.2.3.4", 1935}, + {"legacy-bad-port", "tcp:1.2.3.4:bad", true, "", "", 0}, + {"legacy-wrong-parts", "a:b", true, "", "", 0}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + proto, ip, port, err := ParseListenEndpoint(c.in) + if (err != nil) != c.wantErr { + t.Fatalf("err = %v wantErr = %v", err, c.wantErr) + } + if c.wantErr { + return + } + if proto != c.protocol { + t.Fatalf("protocol = %q want %q", proto, c.protocol) + } + if port != c.port { + t.Fatalf("port = %d want %d", port, c.port) + } + if c.ip == "" { + if ip != nil { + t.Fatalf("ip = %v want nil", ip) + } + } else { + if ip == nil || ip.String() != c.ip { + t.Fatalf("ip = %v want %s", ip, c.ip) + } + } + }) + } +} diff --git a/internal/version/version.go b/internal/version/version.go index 99de863b3..4d444115a 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -15,7 +15,7 @@ func VersionMinor() int { } func VersionRevision() int { - return 144 + return 145 } func Version() string { diff --git a/internal/version/version_test.go b/internal/version/version_test.go new file mode 100644 index 000000000..786131059 --- /dev/null +++ b/internal/version/version_test.go @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package version + +import ( + "fmt" + "testing" +) + +func TestVersionComponents(t *testing.T) { + if got := VersionMajor(); got != 7 { + t.Fatalf("VersionMajor = %d, want 7", got) + } + if got := VersionMinor(); got != 0 { + t.Fatalf("VersionMinor = %d, want 0", got) + } + if got := VersionRevision(); got <= 0 { + t.Fatalf("VersionRevision = %d, want > 0", got) + } +} + +func TestVersion_FormatsMajorMinorRevision(t *testing.T) { + want := fmt.Sprintf("%d.%d.%d", VersionMajor(), VersionMinor(), VersionRevision()) + if got := Version(); got != want { + t.Fatalf("Version = %q, want %q", got, want) + } +} + +func TestSignature(t *testing.T) { + if got := Signature(); got != "SRSX" { + t.Fatalf("Signature = %q, want SRSX", got) + } +} diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index b3117fc16..1972001f9 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2026-04-23, Merge [#4667](https://github.com/ossrs/srs/pull/4667): Proxy: Refactor internal/errors and internal/sync, and add unit tests across internal/*. v7.0.145 (#4667) * v7.0, 2026-04-18, Merge [#4665](https://github.com/ossrs/srs/pull/4665): Proxy: Harden internal/env tests and add counterfeiter fake generation. v7.0.144 (#4665) * v7.0, 2026-04-12, Merge [#4661](https://github.com/ossrs/srs/pull/4661): Proxy: Move build output to bin/, replace godotenv with custom .env parser, and update docs. v7.0.143 (#4661) * v7.0, 2026-04-06, Merge [#4657](https://github.com/ossrs/srs/pull/4657): Proxy: Refactor bootstrap for multi-server support and rebrand to SRSX. v7.0.142 (#4657) diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index 12344c8e1..c31dab668 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 144 +#define VERSION_REVISION 145 #endif