- Refactor the Go proxy for dependency injection: every proxy server, the bootstrap, the signal handler, the load balancers, and AMF0 now accept functional-option seams (factories/closures) so tests can inject fakes without binding real sockets, talking to real Redis, or racing on package globals. - Drop the package-global `lb.SrsLoadBalancer`. The bootstrap creates the LB locally and threads it through every proxy server constructor. Two old global indirections in `internal/signal` and `internal/rtmp/amf0` are likewise replaced by per-instance fields. - Rename `internal/server` → `internal/proxy` and rename the `lb` public surface for clarity: `SRSLoadBalancer` is split into `OriginService` / `HLSService` / `RTCService` and recomposed as `OriginLoadBalancer`; `SRSServer` → `OriginServer`; all proxy server types gain a `Proxy` qualifier (e.g. `RTMPServer` → `RTMPProxyServer`). - Extract the Redis client behind a new `internal/redisclient` package with a minimal `RedisClient` interface and a counterfeiter fake. - Add counterfeiter fakes (`proxyfakes`, `lbfakes`, `redisclientfakes`) and ~7.5k lines of unit tests covering bootstrap, memory + Redis LBs, all five proxy servers, the signal handler, and AMF0. - Add two new E2E flows — `proxy-e2e-srt-test.sh` (SRT publish through proxy, verify SRT/RTMP/HTTP-FLV/HLS playback) and `proxy-e2e-whip-test.sh` (WHIP publish, verify RTMP/HTTP-FLV/HLS via origin `rtc_to_rtmp`) — plus `setup-ffmpeg-with-whip.sh`, a macOS builder for an ffmpeg with openssl-DTLS WHIP and SRT support that the two scripts auto-invoke when needed. - Workspace reorg: move `memory/` and `skills/` to the repo root so all agent tools (Claude / Codex / Kiro / OpenClaw) share one source of truth via symlinks. Sync `docs/proxy/proxy-load-balancer.md` and `memory/srs-codebase-map.md` with the new names. No protocol, log, HTTP API, or wire-format changes. Refactor only — all externally observable proxy behavior is unchanged. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: chatgpt-codex-connector[bot] <199175422+chatgpt-codex-connector[bot]@users.noreply.github.com>
264 lines
7.5 KiB
Go
264 lines
7.5 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package lb
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"srsx/internal/env/envfakes"
|
|
)
|
|
|
|
// stubHLS is a minimal HLSPlayStream for testing.
|
|
type stubHLS struct {
|
|
spbhid string
|
|
}
|
|
|
|
func (s *stubHLS) GetSPBHID() string { return s.spbhid }
|
|
func (s *stubHLS) Initialize(ctx context.Context) HLSPlayStream { return s }
|
|
|
|
// stubRTC is a minimal RTCConnection for testing.
|
|
type stubRTC struct {
|
|
ufrag string
|
|
}
|
|
|
|
func (s *stubRTC) GetUfrag() string { return s.ufrag }
|
|
|
|
// newMem returns a fresh in-memory load balancer with a default fake env.
|
|
func newMem() *memoryLoadBalancer {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
return NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
|
}
|
|
|
|
func TestNewMemoryLoadBalancer(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
lb := NewMemoryLoadBalancer(env)
|
|
if lb == nil {
|
|
t.Fatal("NewMemoryLoadBalancer returned nil")
|
|
}
|
|
}
|
|
|
|
func TestMemLB_Initialize_DefaultBackendDisabled(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.DefaultBackendEnabledReturns("off")
|
|
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
|
if err := lb.Initialize(context.Background()); err != nil {
|
|
t.Fatalf("Initialize: %v", err)
|
|
}
|
|
// No server stored when disabled.
|
|
count := 0
|
|
lb.servers.Range(func(string, *OriginServer) bool { count++; return true })
|
|
if count != 0 {
|
|
t.Fatalf("expected 0 servers, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_Initialize_DefaultBackendError(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.DefaultBackendEnabledReturns("on")
|
|
env.DefaultBackendIPReturns("") // triggers "empty default backend ip"
|
|
lb := NewMemoryLoadBalancer(env)
|
|
err := lb.Initialize(context.Background())
|
|
if err == nil || !strings.Contains(err.Error(), "initialize default SRS") {
|
|
t.Fatalf("expected wrapped error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_Initialize_KeepaliveTick(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.DefaultBackendEnabledReturns("on")
|
|
env.DefaultBackendIPReturns("1.2.3.4")
|
|
env.DefaultBackendRTMPReturns(":1935")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
|
// Shorten the keep-alive interval on this instance only so concurrent
|
|
// tests don't race on shared state.
|
|
lb.keepaliveInterval = time.Millisecond
|
|
if err := lb.Initialize(ctx); err != nil {
|
|
t.Fatalf("Initialize: %v", err)
|
|
}
|
|
|
|
// Find the server and watch UpdatedAt advance after a keep-alive tick.
|
|
var s *OriginServer
|
|
lb.servers.Range(func(_ string, v *OriginServer) bool { s = v; return false })
|
|
if s == nil {
|
|
t.Fatal("expected server stored")
|
|
}
|
|
first := s.UpdatedAt
|
|
|
|
// Wait long enough for several ticks (interval is 1ms, server.UpdatedAt
|
|
// is set to time.Now() inside NewDefaultOriginServerForDebugging on each
|
|
// Update? — actually Update only stores the server pointer, so UpdatedAt
|
|
// won't change. The goroutine still hits the tick branch though, which
|
|
// is all we need for coverage).
|
|
time.Sleep(20 * time.Millisecond)
|
|
cancel()
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
_ = first
|
|
}
|
|
|
|
func TestMemLB_Initialize_DefaultBackendSuccess(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.DefaultBackendEnabledReturns("on")
|
|
env.DefaultBackendIPReturns("1.2.3.4")
|
|
env.DefaultBackendRTMPReturns(":1935")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
|
if err := lb.Initialize(ctx); err != nil {
|
|
t.Fatalf("Initialize: %v", err)
|
|
}
|
|
|
|
count := 0
|
|
lb.servers.Range(func(string, *OriginServer) bool { count++; return true })
|
|
if count != 1 {
|
|
t.Fatalf("expected 1 server stored, got %d", count)
|
|
}
|
|
|
|
// Cancel and give the keep-alive goroutine a moment to exit cleanly.
|
|
cancel()
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
|
|
func TestMemLB_Update(t *testing.T) {
|
|
lb := newMem()
|
|
s := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1"}
|
|
if err := lb.Update(context.Background(), s); err != nil {
|
|
t.Fatalf("Update: %v", err)
|
|
}
|
|
got, ok := lb.servers.Load(s.ID())
|
|
if !ok || got != s {
|
|
t.Fatalf("Update did not store the server: got=%v ok=%v", got, ok)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_Pick_NoServers(t *testing.T) {
|
|
lb := newMem()
|
|
_, err := lb.Pick(context.Background(), "url1")
|
|
if err == nil || !strings.Contains(err.Error(), "no server available") {
|
|
t.Fatalf("expected no-server error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_Pick_AliveServer_Sticky(t *testing.T) {
|
|
lb := newMem()
|
|
s := &OriginServer{ServerID: "a", PID: "1", UpdatedAt: time.Now()}
|
|
_ = lb.Update(context.Background(), s)
|
|
|
|
got, err := lb.Pick(context.Background(), "url1")
|
|
if err != nil {
|
|
t.Fatalf("Pick: %v", err)
|
|
}
|
|
if got != s {
|
|
t.Fatalf("Pick returned %v, want %v", got, s)
|
|
}
|
|
|
|
// Second pick for the same URL returns the same server (sticky branch).
|
|
got2, err := lb.Pick(context.Background(), "url1")
|
|
if err != nil {
|
|
t.Fatalf("Pick second: %v", err)
|
|
}
|
|
if got2 != got {
|
|
t.Fatalf("second Pick returned %v, want %v (sticky)", got2, got)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_Pick_OnlyDeadServers_Fallback(t *testing.T) {
|
|
lb := newMem()
|
|
// UpdatedAt long past => not alive. Tests the fallback "use all servers" branch.
|
|
s := &OriginServer{
|
|
ServerID: "a",
|
|
PID: "1",
|
|
UpdatedAt: time.Now().Add(-2 * ServerAliveDuration),
|
|
}
|
|
_ = lb.Update(context.Background(), s)
|
|
|
|
got, err := lb.Pick(context.Background(), "url1")
|
|
if err != nil {
|
|
t.Fatalf("Pick: %v", err)
|
|
}
|
|
if got != s {
|
|
t.Fatalf("expected dead-server fallback to return %v, got %v", s, got)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_LoadHLSBySPBHID_NotFound(t *testing.T) {
|
|
lb := newMem()
|
|
_, err := lb.LoadHLSBySPBHID(context.Background(), "missing")
|
|
if err == nil || !strings.Contains(err.Error(), "no HLS streaming") {
|
|
t.Fatalf("expected error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_LoadOrStoreHLS_New(t *testing.T) {
|
|
lb := newMem()
|
|
s := &stubHLS{spbhid: "abc"}
|
|
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s)
|
|
if err != nil {
|
|
t.Fatalf("LoadOrStoreHLS: %v", err)
|
|
}
|
|
if got != s {
|
|
t.Fatalf("LoadOrStoreHLS returned %v, want %v", got, s)
|
|
}
|
|
|
|
// Lookup via SPBHID works (dual-index write).
|
|
bySPBHID, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
|
|
if err != nil {
|
|
t.Fatalf("LoadHLSBySPBHID: %v", err)
|
|
}
|
|
if bySPBHID != s {
|
|
t.Fatalf("LoadHLSBySPBHID returned %v, want %v", bySPBHID, s)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_LoadOrStoreHLS_Existing(t *testing.T) {
|
|
lb := newMem()
|
|
s1 := &stubHLS{spbhid: "first"}
|
|
s2 := &stubHLS{spbhid: "second"}
|
|
_, _ = lb.LoadOrStoreHLS(context.Background(), "url1", s1)
|
|
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s2)
|
|
if err != nil {
|
|
t.Fatalf("LoadOrStoreHLS: %v", err)
|
|
}
|
|
if got != s1 {
|
|
t.Fatalf("expected existing s1, got %v", got)
|
|
}
|
|
// SPBHID 'second' (from the rejected s2) maps to the existing s1.
|
|
bySPBHID, _ := lb.LoadHLSBySPBHID(context.Background(), "second")
|
|
if bySPBHID != s1 {
|
|
t.Fatalf("expected SPBHID 'second' to map to s1, got %v", bySPBHID)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_StoreWebRTC_And_Load(t *testing.T) {
|
|
lb := newMem()
|
|
s := &stubRTC{ufrag: "ufrg1"}
|
|
if err := lb.StoreWebRTC(context.Background(), "url1", s); err != nil {
|
|
t.Fatalf("StoreWebRTC: %v", err)
|
|
}
|
|
got, err := lb.LoadWebRTCByUfrag(context.Background(), "ufrg1")
|
|
if err != nil {
|
|
t.Fatalf("LoadWebRTCByUfrag: %v", err)
|
|
}
|
|
if got != s {
|
|
t.Fatalf("got %v, want %v", got, s)
|
|
}
|
|
}
|
|
|
|
func TestMemLB_LoadWebRTCByUfrag_NotFound(t *testing.T) {
|
|
lb := newMem()
|
|
_, err := lb.LoadWebRTCByUfrag(context.Background(), "missing")
|
|
if err == nil || !strings.Contains(err.Error(), "no WebRTC streaming") {
|
|
t.Fatalf("expected error, got %v", err)
|
|
}
|
|
}
|