From 5a971feed43cc3d8357d08293d0129d866af47d9 Mon Sep 17 00:00:00 2001 From: winlin Date: Sun, 17 May 2026 11:48:12 -0400 Subject: [PATCH] Claude: Add proxy bootstrap seams and unit tests. --- internal/bootstrap/proxy.go | 142 ++++++- internal/bootstrap/proxy_test.go | 643 +++++++++++++++++++++++++++++++ 2 files changed, 771 insertions(+), 14 deletions(-) create mode 100644 internal/bootstrap/proxy_test.go diff --git a/internal/bootstrap/proxy.go b/internal/bootstrap/proxy.go index bb2cf4d6f..5d7e09782 100644 --- a/internal/bootstrap/proxy.go +++ b/internal/bootstrap/proxy.go @@ -18,12 +18,126 @@ import ( ) // NewProxyBootstrap creates a new Bootstrap instance for the proxy server. -func NewProxyBootstrap() Bootstrap { - return &proxyBootstrap{} +func NewProxyBootstrap(opts ...func(*proxyBootstrap)) Bootstrap { + v := &proxyBootstrap{} + + // Default newEnvironment: read the real process env / .env file. + v.newEnvironment = func(ctx context.Context) (env.ProxyEnvironment, error) { + return env.NewProxyEnvironment(ctx) + } + // Default newSignalHandler: construct a real OS signal handler. + v.newSignalHandler = func() signalHandler { + return signal.NewHandler() + } + // Default newRedisLoadBalancer: construct a real Redis-backed load balancer. + v.newRedisLoadBalancer = func(environment env.ProxyEnvironment) lb.OriginLoadBalancer { + return lb.NewRedisLoadBalancer(environment) + } + // Default newMemoryLoadBalancer: construct a real in-memory load balancer. + v.newMemoryLoadBalancer = func(environment env.ProxyEnvironment) lb.OriginLoadBalancer { + return lb.NewMemoryLoadBalancer(environment) + } + // Default newRTMPProxyServer: construct a real RTMP proxy server. + v.newRTMPProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.RTMPProxyServer { + return proxy.NewRTMPProxyServer(environment, loadBalancer) + } + // Default newWebRTCProxyServer: construct a real WebRTC proxy server. + v.newWebRTCProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.WebRTCProxyServer { + return proxy.NewWebRTCProxyServer(environment, loadBalancer) + } + // Default newHTTPAPIProxyServer: construct a real HTTP API proxy server. + v.newHTTPAPIProxyServer = func(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc proxy.WebRTCProxyServer) proxy.HTTPAPIProxyServer { + return proxy.NewHTTPAPIProxyServer(environment, gracefulQuitTimeout, rtc) + } + // Default newSRSSRTProxyServer: construct a real SRT proxy server. + v.newSRSSRTProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxyServer { + return proxy.NewSRSSRTProxyServer(environment, loadBalancer) + } + // Default newSystemAPI: construct a real system API server. + v.newSystemAPI = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxyServer { + return proxy.NewSystemAPI(environment, loadBalancer, gracefulQuitTimeout) + } + // Default newHTTPStreamProxyServer: construct a real HTTP stream proxy server. + v.newHTTPStreamProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxy.HTTPStreamProxyServer { + return proxy.NewHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout) + } + + for _, opt := range opts { + opt(v) + } + return v } // proxyBootstrap implements the Bootstrap interface for the proxy server. -type proxyBootstrap struct{} +type proxyBootstrap struct { + // newEnvironment constructs the proxy environment. Defaults to + // env.NewProxyEnvironment; tests may override via a functional option to + // supply a fake environment without reading the real process env or .env file. + newEnvironment func(ctx context.Context) (env.ProxyEnvironment, error) + // newSignalHandler constructs the OS signal handler used to install + // signal listeners and the force-quit timer. Defaults to signal.NewHandler; + // tests may override via a functional option to supply a fake handler that + // does not install real OS signal handlers or a real force-quit timer. + newSignalHandler func() signalHandler + // newRedisLoadBalancer constructs the Redis-backed load balancer used when + // environment.LoadBalancerType() == "redis". Defaults to lb.NewRedisLoadBalancer; + // tests may override via a functional option to supply a fake load balancer + // that does not connect to a real Redis instance. + newRedisLoadBalancer func(environment env.ProxyEnvironment) lb.OriginLoadBalancer + // newMemoryLoadBalancer constructs the in-memory load balancer used when + // environment.LoadBalancerType() is anything other than "redis". Defaults to + // lb.NewMemoryLoadBalancer; tests may override via a functional option to + // supply a fake load balancer for assertions on the default branch. + newMemoryLoadBalancer func(environment env.ProxyEnvironment) lb.OriginLoadBalancer + // newRTMPProxyServer constructs the RTMP proxy server. Defaults to + // proxy.NewRTMPProxyServer; tests may override via a functional option to + // supply a fake (e.g. proxyfakes.FakeRTMPProxyServer) that does not bind a + // real TCP port. + newRTMPProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.RTMPProxyServer + // newWebRTCProxyServer constructs the WebRTC proxy server. Defaults to + // proxy.NewWebRTCProxyServer; tests may override via a functional option to + // supply a fake (e.g. proxyfakes.FakeWebRTCProxyServer) that does not bind + // a real UDP port. + newWebRTCProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.WebRTCProxyServer + // newHTTPAPIProxyServer constructs the HTTP API proxy server. Defaults to + // proxy.NewHTTPAPIProxyServer; tests may override via a functional option + // to supply a fake (e.g. proxyfakes.FakeHTTPAPIProxyServer) that does not + // bind a real HTTP port. + newHTTPAPIProxyServer func(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc proxy.WebRTCProxyServer) proxy.HTTPAPIProxyServer + // newSRSSRTProxyServer constructs the SRT proxy server. Defaults to + // proxy.NewSRSSRTProxyServer; tests may override via a functional option + // to supply a fake that does not bind a real UDP port. Returned as the + // local proxyServer interface because proxy.NewSRSSRTProxyServer currently + // returns an unexported concrete type. + newSRSSRTProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxyServer + // newSystemAPI constructs the system API server. Defaults to proxy.NewSystemAPI; + // tests may override via a functional option to supply a fake that does not + // bind a real HTTP port. Returned as the local proxyServer interface because + // proxy.NewSystemAPI currently returns an unexported concrete type. + newSystemAPI func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxyServer + // newHTTPStreamProxyServer constructs the HTTP stream proxy server. Defaults + // to proxy.NewHTTPStreamProxyServer; tests may override via a functional + // option to supply a fake (e.g. proxyfakes.FakeHTTPStreamProxyServer) that + // does not bind a real HTTP port. + newHTTPStreamProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxy.HTTPStreamProxyServer +} + +// signalHandler is the minimal contract of a signal handler that proxyBootstrap +// drives. *signal.Handler satisfies it. Tests may supply a fake that does not +// install real OS signal handlers or a real force-quit timer. +type signalHandler interface { + InstallSignals(ctx context.Context, cancel context.CancelFunc) + InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error +} + +// proxyServer is the minimal Run/Close contract used by proxyBootstrap for the +// SRT proxy and system API. proxy.NewSRSSRTProxyServer and proxy.NewSystemAPI +// currently return unexported concrete types which bootstrap cannot name; their +// values satisfy this interface structurally so tests can still inject fakes. +type proxyServer interface { + Run(ctx context.Context) error + Close() error +} // Start initializes the context with logger and signal handlers, then runs the bootstrap. // Returns any error encountered during startup. @@ -33,7 +147,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error { // Install signals. ctx, cancel := context.WithCancel(ctx) - signal.NewHandler().InstallSignals(ctx, cancel) + b.newSignalHandler().InstallSignals(ctx, cancel) // Run the main loop, ignore the user cancel error. err := b.run(ctx) @@ -50,7 +164,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error { // It blocks until the context is cancelled. func (b *proxyBootstrap) run(ctx context.Context) error { // Setup the environment variables. - environment, err := env.NewProxyEnvironment(ctx) + environment, err := b.newEnvironment(ctx) if err != nil { return errors.Wrapf(err, "create environment") } @@ -58,7 +172,7 @@ func (b *proxyBootstrap) run(ctx context.Context) error { // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur // because the main thread exits after the context is cancelled. However, sometimes the main thread // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. - if err := signal.NewHandler().InstallForceQuit(ctx, environment); err != nil { + if err := b.newSignalHandler().InstallForceQuit(ctx, environment); err != nil { return errors.Wrapf(err, "install force quit") } @@ -86,9 +200,9 @@ func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment var loadBalancer lb.OriginLoadBalancer switch environment.LoadBalancerType() { case "redis": - loadBalancer = lb.NewRedisLoadBalancer(environment) + loadBalancer = b.newRedisLoadBalancer(environment) default: - loadBalancer = lb.NewMemoryLoadBalancer(environment) + loadBalancer = b.newMemoryLoadBalancer(environment) } if err := loadBalancer.Initialize(ctx); err != nil { @@ -101,42 +215,42 @@ func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment // startServers initializes and starts all protocol servers. func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) error { // Start the RTMP server. - rtmpProxyServer := proxy.NewRTMPProxyServer(environment, loadBalancer) + rtmpProxyServer := b.newRTMPProxyServer(environment, loadBalancer) if err := rtmpProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtmp server") } defer rtmpProxyServer.Close() // Start the WebRTC server. - webRTCProxyServer := proxy.NewWebRTCProxyServer(environment, loadBalancer) + webRTCProxyServer := b.newWebRTCProxyServer(environment, loadBalancer) if err := webRTCProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtc server") } defer webRTCProxyServer.Close() // Start the HTTP API server. - httpAPIProxyServer := proxy.NewHTTPAPIProxyServer(environment, gracefulQuitTimeout, webRTCProxyServer) + httpAPIProxyServer := b.newHTTPAPIProxyServer(environment, gracefulQuitTimeout, webRTCProxyServer) if err := httpAPIProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "http api server") } defer httpAPIProxyServer.Close() // Start the SRT server. - srsSRTProxyServer := proxy.NewSRSSRTProxyServer(environment, loadBalancer) + srsSRTProxyServer := b.newSRSSRTProxyServer(environment, loadBalancer) if err := srsSRTProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "srt server") } defer srsSRTProxyServer.Close() // Start the System API server. - systemAPI := proxy.NewSystemAPI(environment, loadBalancer, gracefulQuitTimeout) + systemAPI := b.newSystemAPI(environment, loadBalancer, gracefulQuitTimeout) if err := systemAPI.Run(ctx); err != nil { return errors.Wrapf(err, "system api server") } defer systemAPI.Close() // Start the HTTP web server. - httpStreamProxyServer := proxy.NewHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout) + httpStreamProxyServer := b.newHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout) if err := httpStreamProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } diff --git a/internal/bootstrap/proxy_test.go b/internal/bootstrap/proxy_test.go new file mode 100644 index 000000000..d47550ade --- /dev/null +++ b/internal/bootstrap/proxy_test.go @@ -0,0 +1,643 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package bootstrap + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env" + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" + "srsx/internal/proxy" + "srsx/internal/proxy/proxyfakes" +) + +// ============================================================================= +// Local fakes +// ============================================================================= + +// fakeSignalHandler implements signalHandler without touching real OS signals. +// InstallSignalsCancels, when true, cancels the supplied cancel func immediately +// so callers can drive the run/Start "ctx already cancelled" branch. +type fakeSignalHandler struct { + installSignalsCalls atomic.Int32 + installForceQuitCalls atomic.Int32 + installForceQuitReturn error + installSignalsCancels bool + lastInstallSignalsCtx context.Context + lastInstallForceQuitCtx context.Context +} + +func (f *fakeSignalHandler) InstallSignals(ctx context.Context, cancel context.CancelFunc) { + f.installSignalsCalls.Add(1) + f.lastInstallSignalsCtx = ctx + if f.installSignalsCancels { + cancel() + } +} + +func (f *fakeSignalHandler) InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error { + f.installForceQuitCalls.Add(1) + f.lastInstallForceQuitCtx = ctx + return f.installForceQuitReturn +} + +// fakeProxyServer implements the local proxyServer interface for the SRT proxy +// and system API seams. +type fakeProxyServer struct { + runCalls atomic.Int32 + closeCalls atomic.Int32 + runReturn error + closeReturn error + lastRunCtx context.Context +} + +func (f *fakeProxyServer) Run(ctx context.Context) error { + f.runCalls.Add(1) + f.lastRunCtx = ctx + return f.runReturn +} + +func (f *fakeProxyServer) Close() error { + f.closeCalls.Add(1) + return f.closeReturn +} + +// ============================================================================= +// Helpers +// ============================================================================= + +// fakeEnvWithDefaults returns a FakeProxyEnvironment with reasonable defaults +// so run() can reach all stages without being short-circuited by a parse error. +func fakeEnvWithDefaults() *envfakes.FakeProxyEnvironment { + e := &envfakes.FakeProxyEnvironment{} + e.LoadBalancerTypeReturns("memory") + e.GraceQuitTimeoutReturns("1s") + e.ForceQuitTimeoutReturns("1s") + return e +} + +// bootstrapFakes bundles the fakes installed by withAllFakes for assertions. +type bootstrapFakes struct { + env *envfakes.FakeProxyEnvironment + signal *fakeSignalHandler + lbMemory *lbfakes.FakeOriginLoadBalancer + lbRedis *lbfakes.FakeOriginLoadBalancer + rtmp *proxyfakes.FakeRTMPProxyServer + webrtc *proxyfakes.FakeWebRTCProxyServer + httpAPI *proxyfakes.FakeHTTPAPIProxyServer + srt *fakeProxyServer + systemAPI *fakeProxyServer + httpStream *proxyfakes.FakeHTTPStreamProxyServer + memoryCalls atomic.Int32 + redisCalls atomic.Int32 + rtcInHTTPAPI atomic.Value // proxy.WebRTCProxyServer instance passed to newHTTPAPIProxyServer +} + +// withAllFakes returns a functional option that swaps every seam for a fake. +// The returned bootstrapFakes lets tests inspect calls and arguments. +func withAllFakes(e *envfakes.FakeProxyEnvironment) (func(*proxyBootstrap), *bootstrapFakes) { + f := &bootstrapFakes{ + env: e, + signal: &fakeSignalHandler{}, + lbMemory: &lbfakes.FakeOriginLoadBalancer{}, + lbRedis: &lbfakes.FakeOriginLoadBalancer{}, + rtmp: &proxyfakes.FakeRTMPProxyServer{}, + webrtc: &proxyfakes.FakeWebRTCProxyServer{}, + httpAPI: &proxyfakes.FakeHTTPAPIProxyServer{}, + srt: &fakeProxyServer{}, + systemAPI: &fakeProxyServer{}, + httpStream: &proxyfakes.FakeHTTPStreamProxyServer{}, + } + opt := func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return f.env, nil } + b.newSignalHandler = func() signalHandler { return f.signal } + b.newRedisLoadBalancer = func(env.ProxyEnvironment) lb.OriginLoadBalancer { + f.redisCalls.Add(1) + return f.lbRedis + } + b.newMemoryLoadBalancer = func(env.ProxyEnvironment) lb.OriginLoadBalancer { + f.memoryCalls.Add(1) + return f.lbMemory + } + b.newRTMPProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer) proxy.RTMPProxyServer { return f.rtmp } + b.newWebRTCProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer) proxy.WebRTCProxyServer { return f.webrtc } + b.newHTTPAPIProxyServer = func(_ env.ProxyEnvironment, _ time.Duration, rtc proxy.WebRTCProxyServer) proxy.HTTPAPIProxyServer { + f.rtcInHTTPAPI.Store(rtc) + return f.httpAPI + } + b.newSRSSRTProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer) proxyServer { return f.srt } + b.newSystemAPI = func(env.ProxyEnvironment, lb.OriginLoadBalancer, time.Duration) proxyServer { return f.systemAPI } + b.newHTTPStreamProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer, time.Duration) proxy.HTTPStreamProxyServer { + return f.httpStream + } + } + return opt, f +} + +// ============================================================================= +// NewProxyBootstrap +// ============================================================================= + +func TestNewProxyBootstrap_DefaultsAllSeams(t *testing.T) { + b := NewProxyBootstrap().(*proxyBootstrap) + + if b.newEnvironment == nil { + t.Error("newEnvironment seam should default to non-nil") + } + if b.newSignalHandler == nil { + t.Error("newSignalHandler seam should default to non-nil") + } + if b.newRedisLoadBalancer == nil { + t.Error("newRedisLoadBalancer seam should default to non-nil") + } + if b.newMemoryLoadBalancer == nil { + t.Error("newMemoryLoadBalancer seam should default to non-nil") + } + if b.newRTMPProxyServer == nil { + t.Error("newRTMPProxyServer seam should default to non-nil") + } + if b.newWebRTCProxyServer == nil { + t.Error("newWebRTCProxyServer seam should default to non-nil") + } + if b.newHTTPAPIProxyServer == nil { + t.Error("newHTTPAPIProxyServer seam should default to non-nil") + } + if b.newSRSSRTProxyServer == nil { + t.Error("newSRSSRTProxyServer seam should default to non-nil") + } + if b.newSystemAPI == nil { + t.Error("newSystemAPI seam should default to non-nil") + } + if b.newHTTPStreamProxyServer == nil { + t.Error("newHTTPStreamProxyServer seam should default to non-nil") + } +} + +func TestNewProxyBootstrap_AppliesOpts(t *testing.T) { + var called bool + NewProxyBootstrap(func(b *proxyBootstrap) { called = true }) + if !called { + t.Fatal("opt was not invoked") + } +} + +// TestNewProxyBootstrap_DefaultsConstructRealInstances exercises every default +// closure that is safe to call in a unit test (i.e. does not touch real +// network/filesystem state). newEnvironment is excluded because env.NewProxyEnvironment +// loads a .env file and mutates process env vars. +func TestNewProxyBootstrap_DefaultsConstructRealInstances(t *testing.T) { + b := NewProxyBootstrap().(*proxyBootstrap) + e := fakeEnvWithDefaults() + loadBalancer := &lbfakes.FakeOriginLoadBalancer{} + + if got := b.newSignalHandler(); got == nil { + t.Error("newSignalHandler default returned nil") + } + if got := b.newRedisLoadBalancer(e); got == nil { + t.Error("newRedisLoadBalancer default returned nil") + } + if got := b.newMemoryLoadBalancer(e); got == nil { + t.Error("newMemoryLoadBalancer default returned nil") + } + if got := b.newRTMPProxyServer(e, loadBalancer); got == nil { + t.Error("newRTMPProxyServer default returned nil") + } + rtc := b.newWebRTCProxyServer(e, loadBalancer) + if rtc == nil { + t.Error("newWebRTCProxyServer default returned nil") + } + if got := b.newHTTPAPIProxyServer(e, time.Second, rtc); got == nil { + t.Error("newHTTPAPIProxyServer default returned nil") + } + if got := b.newSRSSRTProxyServer(e, loadBalancer); got == nil { + t.Error("newSRSSRTProxyServer default returned nil") + } + if got := b.newSystemAPI(e, loadBalancer, time.Second); got == nil { + t.Error("newSystemAPI default returned nil") + } + if got := b.newHTTPStreamProxyServer(e, loadBalancer, time.Second); got == nil { + t.Error("newHTTPStreamProxyServer default returned nil") + } +} + +func TestNewProxyBootstrap_OptCanOverrideSeam(t *testing.T) { + customErr := errors.New("custom") + b := NewProxyBootstrap(func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, customErr } + }).(*proxyBootstrap) + + _, err := b.newEnvironment(context.Background()) + if !errors.Is(err, customErr) { + t.Errorf("custom newEnvironment not applied: %v", err) + } +} + +// ============================================================================= +// initializeLoadBalancer +// ============================================================================= + +func TestInitializeLoadBalancer_Redis(t *testing.T) { + e := fakeEnvWithDefaults() + e.LoadBalancerTypeReturns("redis") + opt, f := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + got, err := b.initializeLoadBalancer(context.Background(), f.env) + if err != nil { + t.Fatalf("initializeLoadBalancer: %v", err) + } + if got != f.lbRedis { + t.Error("expected the redis load balancer") + } + if f.redisCalls.Load() != 1 { + t.Errorf("newRedisLoadBalancer calls = %d, want 1", f.redisCalls.Load()) + } + if f.memoryCalls.Load() != 0 { + t.Errorf("newMemoryLoadBalancer calls = %d, want 0", f.memoryCalls.Load()) + } + if f.lbRedis.InitializeCallCount() != 1 { + t.Errorf("Initialize calls = %d, want 1", f.lbRedis.InitializeCallCount()) + } +} + +func TestInitializeLoadBalancer_Memory(t *testing.T) { + e := fakeEnvWithDefaults() + e.LoadBalancerTypeReturns("memory") + opt, f := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + got, err := b.initializeLoadBalancer(context.Background(), f.env) + if err != nil { + t.Fatalf("initializeLoadBalancer: %v", err) + } + if got != f.lbMemory { + t.Error("expected the memory load balancer") + } + if f.memoryCalls.Load() != 1 { + t.Errorf("newMemoryLoadBalancer calls = %d, want 1", f.memoryCalls.Load()) + } + if f.redisCalls.Load() != 0 { + t.Errorf("newRedisLoadBalancer calls = %d, want 0", f.redisCalls.Load()) + } +} + +func TestInitializeLoadBalancer_DefaultBranchUsesMemory(t *testing.T) { + e := fakeEnvWithDefaults() + e.LoadBalancerTypeReturns("anything-else") + opt, f := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + if _, err := b.initializeLoadBalancer(context.Background(), f.env); err != nil { + t.Fatalf("initializeLoadBalancer: %v", err) + } + if f.memoryCalls.Load() != 1 { + t.Error("unknown LoadBalancerType should fall through to memory") + } +} + +func TestInitializeLoadBalancer_InitializeErrorIsWrapped(t *testing.T) { + initErr := errors.New("boom") + e := fakeEnvWithDefaults() + opt, f := withAllFakes(e) + f.lbMemory.InitializeReturns(initErr) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + _, err := b.initializeLoadBalancer(context.Background(), f.env) + if err == nil { + t.Fatal("expected an error") + } + if !errors.Is(err, initErr) { + t.Errorf("error chain missing initErr: %v", err) + } +} + +// ============================================================================= +// startServers +// ============================================================================= + +// runStartServersUntilCancel runs startServers in a goroutine, cancels the ctx +// once the test has observed all servers running, and returns the result. +func runStartServersUntilCancel(t *testing.T, b *proxyBootstrap, env env.ProxyEnvironment, lb lb.OriginLoadBalancer) error { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- b.startServers(ctx, env, lb, 50*time.Millisecond) }() + // Give startServers time to invoke all six constructors and block on <-ctx.Done(). + time.Sleep(20 * time.Millisecond) + cancel() + select { + case err := <-done: + return err + case <-time.After(2 * time.Second): + t.Fatal("startServers did not return after ctx cancel") + return nil + } +} + +func TestStartServers_HappyPath_StartsAndClosesAllSix(t *testing.T) { + opt, f := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + if err := runStartServersUntilCancel(t, b, f.env, f.lbMemory); err != nil { + t.Fatalf("startServers: %v", err) + } + + if got := f.rtmp.RunCallCount(); got != 1 { + t.Errorf("rtmp Run = %d, want 1", got) + } + if got := f.webrtc.RunCallCount(); got != 1 { + t.Errorf("webrtc Run = %d, want 1", got) + } + if got := f.httpAPI.RunCallCount(); got != 1 { + t.Errorf("httpAPI Run = %d, want 1", got) + } + if got := f.srt.runCalls.Load(); got != 1 { + t.Errorf("srt Run = %d, want 1", got) + } + if got := f.systemAPI.runCalls.Load(); got != 1 { + t.Errorf("systemAPI Run = %d, want 1", got) + } + if got := f.httpStream.RunCallCount(); got != 1 { + t.Errorf("httpStream Run = %d, want 1", got) + } + + if got := f.rtmp.CloseCallCount(); got != 1 { + t.Errorf("rtmp Close = %d, want 1", got) + } + if got := f.webrtc.CloseCallCount(); got != 1 { + t.Errorf("webrtc Close = %d, want 1", got) + } + if got := f.httpAPI.CloseCallCount(); got != 1 { + t.Errorf("httpAPI Close = %d, want 1", got) + } + if got := f.srt.closeCalls.Load(); got != 1 { + t.Errorf("srt Close = %d, want 1", got) + } + if got := f.systemAPI.closeCalls.Load(); got != 1 { + t.Errorf("systemAPI Close = %d, want 1", got) + } + if got := f.httpStream.CloseCallCount(); got != 1 { + t.Errorf("httpStream Close = %d, want 1", got) + } +} + +func TestStartServers_HTTPAPIReceivesWebRTCInstance(t *testing.T) { + opt, f := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + if err := runStartServersUntilCancel(t, b, f.env, f.lbMemory); err != nil { + t.Fatalf("startServers: %v", err) + } + + rtc := f.rtcInHTTPAPI.Load() + if rtc == nil { + t.Fatal("newHTTPAPIProxyServer was not invoked with a WebRTC instance") + } + if rtc.(proxy.WebRTCProxyServer) != f.webrtc { + t.Error("HTTPAPI received a different WebRTC instance than newWebRTCProxyServer returned") + } +} + +func TestStartServers_RunErrorsAreWrappedAndShortCircuit(t *testing.T) { + tests := []struct { + name string + install func(f *bootstrapFakes, err error) + wantWrap string + earlierStarted func(f *bootstrapFakes) bool + }{ + { + name: "rtmp", + install: func(f *bootstrapFakes, err error) { f.rtmp.RunReturns(err) }, + wantWrap: "rtmp server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.webrtc.RunCallCount() == 0 && f.httpAPI.RunCallCount() == 0 + }, + }, + { + name: "webrtc", + install: func(f *bootstrapFakes, err error) { f.webrtc.RunReturns(err) }, + wantWrap: "rtc server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.rtmp.RunCallCount() == 1 && f.httpAPI.RunCallCount() == 0 + }, + }, + { + name: "httpAPI", + install: func(f *bootstrapFakes, err error) { f.httpAPI.RunReturns(err) }, + wantWrap: "http api server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.webrtc.RunCallCount() == 1 && f.srt.runCalls.Load() == 0 + }, + }, + { + name: "srt", + install: func(f *bootstrapFakes, err error) { f.srt.runReturn = err }, + wantWrap: "srt server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.httpAPI.RunCallCount() == 1 && f.systemAPI.runCalls.Load() == 0 + }, + }, + { + name: "systemAPI", + install: func(f *bootstrapFakes, err error) { f.systemAPI.runReturn = err }, + wantWrap: "system api server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.srt.runCalls.Load() == 1 && f.httpStream.RunCallCount() == 0 + }, + }, + { + name: "httpStream", + install: func(f *bootstrapFakes, err error) { f.httpStream.RunReturns(err) }, + wantWrap: "http server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.systemAPI.runCalls.Load() == 1 + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + runErr := errors.New("boom-" + tc.name) + opt, f := withAllFakes(fakeEnvWithDefaults()) + tc.install(f, runErr) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.startServers(context.Background(), f.env, f.lbMemory, 50*time.Millisecond) + if err == nil { + t.Fatalf("%s: expected error", tc.name) + } + if !errors.Is(err, runErr) { + t.Errorf("%s: error chain missing runErr: %v", tc.name, err) + } + if !contains(err.Error(), tc.wantWrap) { + t.Errorf("%s: error %q does not contain wrap %q", tc.name, err.Error(), tc.wantWrap) + } + if !tc.earlierStarted(f) { + t.Errorf("%s: short-circuit invariant violated", tc.name) + } + }) + } +} + +// contains is a tiny helper so the table-driven test doesn't pull in strings +// just for substring matching. +func contains(haystack, needle string) bool { + for i := 0; i+len(needle) <= len(haystack); i++ { + if haystack[i:i+len(needle)] == needle { + return true + } + } + return false +} + +// ============================================================================= +// run +// ============================================================================= + +func TestRun_NewEnvironmentErrorIsWrapped(t *testing.T) { + envErr := errors.New("env-boom") + opt, _ := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt, func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, envErr } + }).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, envErr) { + t.Errorf("error chain missing envErr: %v", err) + } + if !contains(err.Error(), "create environment") { + t.Errorf("expected wrap %q, got %q", "create environment", err.Error()) + } +} + +func TestRun_InstallForceQuitErrorIsWrapped(t *testing.T) { + fqErr := errors.New("force-quit-boom") + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.signal.installForceQuitReturn = fqErr + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, fqErr) { + t.Errorf("error chain missing fqErr: %v", err) + } + if !contains(err.Error(), "install force quit") { + t.Errorf("expected wrap %q, got %q", "install force quit", err.Error()) + } +} + +func TestRun_BadGraceQuitDurationIsWrapped(t *testing.T) { + e := fakeEnvWithDefaults() + e.GraceQuitTimeoutReturns("not-a-duration") + opt, _ := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !contains(err.Error(), "parse gracefully quit timeout") { + t.Errorf("expected wrap %q, got %q", "parse gracefully quit timeout", err.Error()) + } +} + +func TestRun_LoadBalancerInitializeErrorIsWrapped(t *testing.T) { + initErr := errors.New("init-boom") + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.lbMemory.InitializeReturns(initErr) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, initErr) { + t.Errorf("error chain missing initErr: %v", err) + } + if !contains(err.Error(), "initialize srs load balancer") { + t.Errorf("expected wrap %q, got %q", "initialize srs load balancer", err.Error()) + } +} + +func TestRun_HappyPath_BlocksUntilCancelThenReturnsNil(t *testing.T) { + opt, _ := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- b.run(ctx) }() + time.Sleep(20 * time.Millisecond) + cancel() + select { + case err := <-done: + if err != nil { + t.Errorf("run: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("run did not return after ctx cancel") + } +} + +// ============================================================================= +// Start +// ============================================================================= + +func TestStart_HappyPath_InstallsSignalsAndReturnsNil(t *testing.T) { + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.signal.installSignalsCancels = true // cancel the inner ctx immediately + b := NewProxyBootstrap(opt) + + err := b.Start(context.Background()) + if err != nil { + t.Fatalf("Start: %v", err) + } + if f.signal.installSignalsCalls.Load() != 1 { + t.Errorf("InstallSignals calls = %d, want 1", f.signal.installSignalsCalls.Load()) + } + if f.signal.installForceQuitCalls.Load() != 1 { + t.Errorf("InstallForceQuit calls = %d, want 1", f.signal.installForceQuitCalls.Load()) + } +} + +func TestStart_PropagatesNonCancelError(t *testing.T) { + envErr := errors.New("env-boom") + opt, _ := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt, func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, envErr } + }) + + err := b.Start(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, envErr) { + t.Errorf("error chain missing envErr: %v", err) + } +} + +func TestStart_AbsorbsErrorWhenContextCancelled(t *testing.T) { + // When InstallSignals cancels the inner ctx and run returns an error, Start + // should swallow the error (treating it as a graceful shutdown). + envErr := errors.New("post-cancel-boom") + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.signal.installSignalsCancels = true + b := NewProxyBootstrap(opt, func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, envErr } + }) + + err := b.Start(context.Background()) + if err != nil { + t.Errorf("Start should swallow error after ctx cancel, got: %v", err) + } +}