diff --git a/.claude/memory b/.claude/memory index 1c128f07d..c45bdff2b 120000 --- a/.claude/memory +++ b/.claude/memory @@ -1 +1 @@ -../.openclaw/memory \ No newline at end of file +../memory \ No newline at end of file diff --git a/.claude/skills b/.claude/skills index a8b71e9a5..42c5394a1 120000 --- a/.claude/skills +++ b/.claude/skills @@ -1 +1 @@ -../.openclaw/skills \ No newline at end of file +../skills \ No newline at end of file diff --git a/.codex/memory b/.codex/memory index 1c128f07d..c45bdff2b 120000 --- a/.codex/memory +++ b/.codex/memory @@ -1 +1 @@ -../.openclaw/memory \ No newline at end of file +../memory \ No newline at end of file diff --git a/.codex/skills b/.codex/skills index a8b71e9a5..42c5394a1 120000 --- a/.codex/skills +++ b/.codex/skills @@ -1 +1 @@ -../.openclaw/skills \ No newline at end of file +../skills \ No newline at end of file diff --git a/.gitignore b/.gitignore index 33e3cd817..0a30c5b80 100644 --- a/.gitignore +++ b/.gitignore @@ -47,4 +47,9 @@ cmake-build-debug # For AI /*personal* +/support* +/*srs-consults* /*workspace* +/skills/llm-switcher +/skills/*workspace* +/memory/202*.md diff --git a/.kiro/memory b/.kiro/memory index 1c128f07d..c45bdff2b 120000 --- a/.kiro/memory +++ b/.kiro/memory @@ -1 +1 @@ -../.openclaw/memory \ No newline at end of file +../memory \ No newline at end of file diff --git a/.kiro/skills b/.kiro/skills index a8b71e9a5..42c5394a1 120000 --- a/.kiro/skills +++ b/.kiro/skills @@ -1 +1 @@ -../.openclaw/skills \ No newline at end of file +../skills \ No newline at end of file diff --git a/.openclaw/memory b/.openclaw/memory new file mode 120000 index 000000000..c45bdff2b --- /dev/null +++ b/.openclaw/memory @@ -0,0 +1 @@ +../memory \ No newline at end of file diff --git a/.openclaw/skills b/.openclaw/skills new file mode 120000 index 000000000..42c5394a1 --- /dev/null +++ b/.openclaw/skills @@ -0,0 +1 @@ +../skills \ No newline at end of file diff --git a/README.md b/README.md index f4d0b4401..b8106c942 100755 --- a/README.md +++ b/README.md @@ -202,6 +202,3 @@ Please read [MIRRORS](trunk/doc/Resources.md#mirrors). Please read [DOCKERS](trunk/doc/Dockers.md). -Beijing, 2013.10
-Winlin - diff --git a/docs/proxy/proxy-load-balancer.md b/docs/proxy/proxy-load-balancer.md index 4ec3c9d66..a0266af78 100644 --- a/docs/proxy/proxy-load-balancer.md +++ b/docs/proxy/proxy-load-balancer.md @@ -53,14 +53,14 @@ Both implementations maintain stream-to-server mappings to ensure stream consist The load balancer uses a clean interface-based architecture: -**Core Interface**: `SRSLoadBalancer` +**Core Interface**: `OriginLoadBalancer` - Initialization and lifecycle management - Server registration and updates - Stream routing (Pick operation) - Protocol-specific state management (HLS, WebRTC) **Data Models**: -- `SRSServer`: Backend origin server representation +- `OriginServer`: Backend origin server representation - `HLSPlayStream`: Interface for HLS streaming sessions - `RTCConnection`: Interface for WebRTC connections diff --git a/internal/bootstrap/proxy.go b/internal/bootstrap/proxy.go index f59522b6d..5d7e09782 100644 --- a/internal/bootstrap/proxy.go +++ b/internal/bootstrap/proxy.go @@ -12,18 +12,132 @@ import ( "srsx/internal/errors" "srsx/internal/lb" "srsx/internal/logger" - "srsx/internal/server" + "srsx/internal/proxy" "srsx/internal/signal" "srsx/internal/version" ) // NewProxyBootstrap creates a new Bootstrap instance for the proxy server. -func NewProxyBootstrap() Bootstrap { - return &proxyBootstrap{} +func NewProxyBootstrap(opts ...func(*proxyBootstrap)) Bootstrap { + v := &proxyBootstrap{} + + // Default newEnvironment: read the real process env / .env file. + v.newEnvironment = func(ctx context.Context) (env.ProxyEnvironment, error) { + return env.NewProxyEnvironment(ctx) + } + // Default newSignalHandler: construct a real OS signal handler. + v.newSignalHandler = func() signalHandler { + return signal.NewHandler() + } + // Default newRedisLoadBalancer: construct a real Redis-backed load balancer. + v.newRedisLoadBalancer = func(environment env.ProxyEnvironment) lb.OriginLoadBalancer { + return lb.NewRedisLoadBalancer(environment) + } + // Default newMemoryLoadBalancer: construct a real in-memory load balancer. + v.newMemoryLoadBalancer = func(environment env.ProxyEnvironment) lb.OriginLoadBalancer { + return lb.NewMemoryLoadBalancer(environment) + } + // Default newRTMPProxyServer: construct a real RTMP proxy server. + v.newRTMPProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.RTMPProxyServer { + return proxy.NewRTMPProxyServer(environment, loadBalancer) + } + // Default newWebRTCProxyServer: construct a real WebRTC proxy server. + v.newWebRTCProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.WebRTCProxyServer { + return proxy.NewWebRTCProxyServer(environment, loadBalancer) + } + // Default newHTTPAPIProxyServer: construct a real HTTP API proxy server. + v.newHTTPAPIProxyServer = func(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc proxy.WebRTCProxyServer) proxy.HTTPAPIProxyServer { + return proxy.NewHTTPAPIProxyServer(environment, gracefulQuitTimeout, rtc) + } + // Default newSRSSRTProxyServer: construct a real SRT proxy server. + v.newSRSSRTProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxyServer { + return proxy.NewSRSSRTProxyServer(environment, loadBalancer) + } + // Default newSystemAPI: construct a real system API server. + v.newSystemAPI = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxyServer { + return proxy.NewSystemAPI(environment, loadBalancer, gracefulQuitTimeout) + } + // Default newHTTPStreamProxyServer: construct a real HTTP stream proxy server. + v.newHTTPStreamProxyServer = func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxy.HTTPStreamProxyServer { + return proxy.NewHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout) + } + + for _, opt := range opts { + opt(v) + } + return v } // proxyBootstrap implements the Bootstrap interface for the proxy server. -type proxyBootstrap struct{} +type proxyBootstrap struct { + // newEnvironment constructs the proxy environment. Defaults to + // env.NewProxyEnvironment; tests may override via a functional option to + // supply a fake environment without reading the real process env or .env file. + newEnvironment func(ctx context.Context) (env.ProxyEnvironment, error) + // newSignalHandler constructs the OS signal handler used to install + // signal listeners and the force-quit timer. Defaults to signal.NewHandler; + // tests may override via a functional option to supply a fake handler that + // does not install real OS signal handlers or a real force-quit timer. + newSignalHandler func() signalHandler + // newRedisLoadBalancer constructs the Redis-backed load balancer used when + // environment.LoadBalancerType() == "redis". Defaults to lb.NewRedisLoadBalancer; + // tests may override via a functional option to supply a fake load balancer + // that does not connect to a real Redis instance. + newRedisLoadBalancer func(environment env.ProxyEnvironment) lb.OriginLoadBalancer + // newMemoryLoadBalancer constructs the in-memory load balancer used when + // environment.LoadBalancerType() is anything other than "redis". Defaults to + // lb.NewMemoryLoadBalancer; tests may override via a functional option to + // supply a fake load balancer for assertions on the default branch. + newMemoryLoadBalancer func(environment env.ProxyEnvironment) lb.OriginLoadBalancer + // newRTMPProxyServer constructs the RTMP proxy server. Defaults to + // proxy.NewRTMPProxyServer; tests may override via a functional option to + // supply a fake (e.g. proxyfakes.FakeRTMPProxyServer) that does not bind a + // real TCP port. + newRTMPProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.RTMPProxyServer + // newWebRTCProxyServer constructs the WebRTC proxy server. Defaults to + // proxy.NewWebRTCProxyServer; tests may override via a functional option to + // supply a fake (e.g. proxyfakes.FakeWebRTCProxyServer) that does not bind + // a real UDP port. + newWebRTCProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxy.WebRTCProxyServer + // newHTTPAPIProxyServer constructs the HTTP API proxy server. Defaults to + // proxy.NewHTTPAPIProxyServer; tests may override via a functional option + // to supply a fake (e.g. proxyfakes.FakeHTTPAPIProxyServer) that does not + // bind a real HTTP port. + newHTTPAPIProxyServer func(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc proxy.WebRTCProxyServer) proxy.HTTPAPIProxyServer + // newSRSSRTProxyServer constructs the SRT proxy server. Defaults to + // proxy.NewSRSSRTProxyServer; tests may override via a functional option + // to supply a fake that does not bind a real UDP port. Returned as the + // local proxyServer interface because proxy.NewSRSSRTProxyServer currently + // returns an unexported concrete type. + newSRSSRTProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer) proxyServer + // newSystemAPI constructs the system API server. Defaults to proxy.NewSystemAPI; + // tests may override via a functional option to supply a fake that does not + // bind a real HTTP port. Returned as the local proxyServer interface because + // proxy.NewSystemAPI currently returns an unexported concrete type. + newSystemAPI func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxyServer + // newHTTPStreamProxyServer constructs the HTTP stream proxy server. Defaults + // to proxy.NewHTTPStreamProxyServer; tests may override via a functional + // option to supply a fake (e.g. proxyfakes.FakeHTTPStreamProxyServer) that + // does not bind a real HTTP port. + newHTTPStreamProxyServer func(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) proxy.HTTPStreamProxyServer +} + +// signalHandler is the minimal contract of a signal handler that proxyBootstrap +// drives. *signal.Handler satisfies it. Tests may supply a fake that does not +// install real OS signal handlers or a real force-quit timer. +type signalHandler interface { + InstallSignals(ctx context.Context, cancel context.CancelFunc) + InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error +} + +// proxyServer is the minimal Run/Close contract used by proxyBootstrap for the +// SRT proxy and system API. proxy.NewSRSSRTProxyServer and proxy.NewSystemAPI +// currently return unexported concrete types which bootstrap cannot name; their +// values satisfy this interface structurally so tests can still inject fakes. +type proxyServer interface { + Run(ctx context.Context) error + Close() error +} // Start initializes the context with logger and signal handlers, then runs the bootstrap. // Returns any error encountered during startup. @@ -33,7 +147,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error { // Install signals. ctx, cancel := context.WithCancel(ctx) - signal.InstallSignals(ctx, cancel) + b.newSignalHandler().InstallSignals(ctx, cancel) // Run the main loop, ignore the user cancel error. err := b.run(ctx) @@ -50,7 +164,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error { // It blocks until the context is cancelled. func (b *proxyBootstrap) run(ctx context.Context) error { // Setup the environment variables. - environment, err := env.NewProxyEnvironment(ctx) + environment, err := b.newEnvironment(ctx) if err != nil { return errors.Wrapf(err, "create environment") } @@ -58,15 +172,16 @@ func (b *proxyBootstrap) run(ctx context.Context) error { // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur // because the main thread exits after the context is cancelled. However, sometimes the main thread // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. - if err := signal.InstallForceQuit(ctx, environment); err != nil { + if err := b.newSignalHandler().InstallForceQuit(ctx, environment); err != nil { return errors.Wrapf(err, "install force quit") } // Start the Go pprof if enabled. debug.HandleGoPprof(ctx, environment) - // Initialize the load balancer. - if err := b.initializeLoadBalancer(ctx, environment); err != nil { + // Create and initialize the load balancer. + loadBalancer, err := b.initializeLoadBalancer(ctx, environment) + if err != nil { return err } @@ -77,68 +192,69 @@ func (b *proxyBootstrap) run(ctx context.Context) error { } // Start all servers and block until context is cancelled. - return b.startServers(ctx, environment, gracefulQuitTimeout) + return b.startServers(ctx, environment, loadBalancer, gracefulQuitTimeout) } // initializeLoadBalancer sets up the load balancer based on configuration. -func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment env.ProxyEnvironment) error { +func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment env.ProxyEnvironment) (lb.OriginLoadBalancer, error) { + var loadBalancer lb.OriginLoadBalancer switch environment.LoadBalancerType() { case "redis": - lb.SrsLoadBalancer = lb.NewRedisLoadBalancer(environment) + loadBalancer = b.newRedisLoadBalancer(environment) default: - lb.SrsLoadBalancer = lb.NewMemoryLoadBalancer(environment) + loadBalancer = b.newMemoryLoadBalancer(environment) } - if err := lb.SrsLoadBalancer.Initialize(ctx); err != nil { - return errors.Wrapf(err, "initialize srs load balancer") + if err := loadBalancer.Initialize(ctx); err != nil { + return nil, errors.Wrapf(err, "initialize srs load balancer") } - return nil + return loadBalancer, nil } // startServers initializes and starts all protocol servers. -func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) error { +func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) error { // Start the RTMP server. - rtmpServer := server.NewRTMPServer(environment) - if err := rtmpServer.Run(ctx); err != nil { + rtmpProxyServer := b.newRTMPProxyServer(environment, loadBalancer) + if err := rtmpProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtmp server") } - defer rtmpServer.Close() + defer rtmpProxyServer.Close() // Start the WebRTC server. - webRTCServer := server.NewWebRTCServer(environment) - if err := webRTCServer.Run(ctx); err != nil { + webRTCProxyServer := b.newWebRTCProxyServer(environment, loadBalancer) + if err := webRTCProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtc server") } - defer webRTCServer.Close() + defer webRTCProxyServer.Close() // Start the HTTP API server. - httpAPIServer := server.NewHTTPAPIServer(environment, gracefulQuitTimeout, webRTCServer) - if err := httpAPIServer.Run(ctx); err != nil { + httpAPIProxyServer := b.newHTTPAPIProxyServer(environment, gracefulQuitTimeout, webRTCProxyServer) + if err := httpAPIProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "http api server") } - defer httpAPIServer.Close() + defer httpAPIProxyServer.Close() // Start the SRT server. - srsSRTServer := server.NewSRSSRTServer(environment) - if err := srsSRTServer.Run(ctx); err != nil { + srsSRTProxyServer := b.newSRSSRTProxyServer(environment, loadBalancer) + if err := srsSRTProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "srt server") } - defer srsSRTServer.Close() + defer srsSRTProxyServer.Close() // Start the System API server. - systemAPI := server.NewSystemAPI(environment, gracefulQuitTimeout) + systemAPI := b.newSystemAPI(environment, loadBalancer, gracefulQuitTimeout) if err := systemAPI.Run(ctx); err != nil { return errors.Wrapf(err, "system api server") } defer systemAPI.Close() // Start the HTTP web server. - httpStreamServer := server.NewHTTPStreamServer(environment, gracefulQuitTimeout) - if err := httpStreamServer.Run(ctx); err != nil { + httpStreamProxyServer := b.newHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout) + if err := httpStreamProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } - defer httpStreamServer.Close() + defer httpStreamProxyServer.Close() // Wait for the main loop to quit. <-ctx.Done() diff --git a/internal/bootstrap/proxy_test.go b/internal/bootstrap/proxy_test.go new file mode 100644 index 000000000..d47550ade --- /dev/null +++ b/internal/bootstrap/proxy_test.go @@ -0,0 +1,643 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package bootstrap + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env" + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" + "srsx/internal/proxy" + "srsx/internal/proxy/proxyfakes" +) + +// ============================================================================= +// Local fakes +// ============================================================================= + +// fakeSignalHandler implements signalHandler without touching real OS signals. +// InstallSignalsCancels, when true, cancels the supplied cancel func immediately +// so callers can drive the run/Start "ctx already cancelled" branch. +type fakeSignalHandler struct { + installSignalsCalls atomic.Int32 + installForceQuitCalls atomic.Int32 + installForceQuitReturn error + installSignalsCancels bool + lastInstallSignalsCtx context.Context + lastInstallForceQuitCtx context.Context +} + +func (f *fakeSignalHandler) InstallSignals(ctx context.Context, cancel context.CancelFunc) { + f.installSignalsCalls.Add(1) + f.lastInstallSignalsCtx = ctx + if f.installSignalsCancels { + cancel() + } +} + +func (f *fakeSignalHandler) InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error { + f.installForceQuitCalls.Add(1) + f.lastInstallForceQuitCtx = ctx + return f.installForceQuitReturn +} + +// fakeProxyServer implements the local proxyServer interface for the SRT proxy +// and system API seams. +type fakeProxyServer struct { + runCalls atomic.Int32 + closeCalls atomic.Int32 + runReturn error + closeReturn error + lastRunCtx context.Context +} + +func (f *fakeProxyServer) Run(ctx context.Context) error { + f.runCalls.Add(1) + f.lastRunCtx = ctx + return f.runReturn +} + +func (f *fakeProxyServer) Close() error { + f.closeCalls.Add(1) + return f.closeReturn +} + +// ============================================================================= +// Helpers +// ============================================================================= + +// fakeEnvWithDefaults returns a FakeProxyEnvironment with reasonable defaults +// so run() can reach all stages without being short-circuited by a parse error. +func fakeEnvWithDefaults() *envfakes.FakeProxyEnvironment { + e := &envfakes.FakeProxyEnvironment{} + e.LoadBalancerTypeReturns("memory") + e.GraceQuitTimeoutReturns("1s") + e.ForceQuitTimeoutReturns("1s") + return e +} + +// bootstrapFakes bundles the fakes installed by withAllFakes for assertions. +type bootstrapFakes struct { + env *envfakes.FakeProxyEnvironment + signal *fakeSignalHandler + lbMemory *lbfakes.FakeOriginLoadBalancer + lbRedis *lbfakes.FakeOriginLoadBalancer + rtmp *proxyfakes.FakeRTMPProxyServer + webrtc *proxyfakes.FakeWebRTCProxyServer + httpAPI *proxyfakes.FakeHTTPAPIProxyServer + srt *fakeProxyServer + systemAPI *fakeProxyServer + httpStream *proxyfakes.FakeHTTPStreamProxyServer + memoryCalls atomic.Int32 + redisCalls atomic.Int32 + rtcInHTTPAPI atomic.Value // proxy.WebRTCProxyServer instance passed to newHTTPAPIProxyServer +} + +// withAllFakes returns a functional option that swaps every seam for a fake. +// The returned bootstrapFakes lets tests inspect calls and arguments. +func withAllFakes(e *envfakes.FakeProxyEnvironment) (func(*proxyBootstrap), *bootstrapFakes) { + f := &bootstrapFakes{ + env: e, + signal: &fakeSignalHandler{}, + lbMemory: &lbfakes.FakeOriginLoadBalancer{}, + lbRedis: &lbfakes.FakeOriginLoadBalancer{}, + rtmp: &proxyfakes.FakeRTMPProxyServer{}, + webrtc: &proxyfakes.FakeWebRTCProxyServer{}, + httpAPI: &proxyfakes.FakeHTTPAPIProxyServer{}, + srt: &fakeProxyServer{}, + systemAPI: &fakeProxyServer{}, + httpStream: &proxyfakes.FakeHTTPStreamProxyServer{}, + } + opt := func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return f.env, nil } + b.newSignalHandler = func() signalHandler { return f.signal } + b.newRedisLoadBalancer = func(env.ProxyEnvironment) lb.OriginLoadBalancer { + f.redisCalls.Add(1) + return f.lbRedis + } + b.newMemoryLoadBalancer = func(env.ProxyEnvironment) lb.OriginLoadBalancer { + f.memoryCalls.Add(1) + return f.lbMemory + } + b.newRTMPProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer) proxy.RTMPProxyServer { return f.rtmp } + b.newWebRTCProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer) proxy.WebRTCProxyServer { return f.webrtc } + b.newHTTPAPIProxyServer = func(_ env.ProxyEnvironment, _ time.Duration, rtc proxy.WebRTCProxyServer) proxy.HTTPAPIProxyServer { + f.rtcInHTTPAPI.Store(rtc) + return f.httpAPI + } + b.newSRSSRTProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer) proxyServer { return f.srt } + b.newSystemAPI = func(env.ProxyEnvironment, lb.OriginLoadBalancer, time.Duration) proxyServer { return f.systemAPI } + b.newHTTPStreamProxyServer = func(env.ProxyEnvironment, lb.OriginLoadBalancer, time.Duration) proxy.HTTPStreamProxyServer { + return f.httpStream + } + } + return opt, f +} + +// ============================================================================= +// NewProxyBootstrap +// ============================================================================= + +func TestNewProxyBootstrap_DefaultsAllSeams(t *testing.T) { + b := NewProxyBootstrap().(*proxyBootstrap) + + if b.newEnvironment == nil { + t.Error("newEnvironment seam should default to non-nil") + } + if b.newSignalHandler == nil { + t.Error("newSignalHandler seam should default to non-nil") + } + if b.newRedisLoadBalancer == nil { + t.Error("newRedisLoadBalancer seam should default to non-nil") + } + if b.newMemoryLoadBalancer == nil { + t.Error("newMemoryLoadBalancer seam should default to non-nil") + } + if b.newRTMPProxyServer == nil { + t.Error("newRTMPProxyServer seam should default to non-nil") + } + if b.newWebRTCProxyServer == nil { + t.Error("newWebRTCProxyServer seam should default to non-nil") + } + if b.newHTTPAPIProxyServer == nil { + t.Error("newHTTPAPIProxyServer seam should default to non-nil") + } + if b.newSRSSRTProxyServer == nil { + t.Error("newSRSSRTProxyServer seam should default to non-nil") + } + if b.newSystemAPI == nil { + t.Error("newSystemAPI seam should default to non-nil") + } + if b.newHTTPStreamProxyServer == nil { + t.Error("newHTTPStreamProxyServer seam should default to non-nil") + } +} + +func TestNewProxyBootstrap_AppliesOpts(t *testing.T) { + var called bool + NewProxyBootstrap(func(b *proxyBootstrap) { called = true }) + if !called { + t.Fatal("opt was not invoked") + } +} + +// TestNewProxyBootstrap_DefaultsConstructRealInstances exercises every default +// closure that is safe to call in a unit test (i.e. does not touch real +// network/filesystem state). newEnvironment is excluded because env.NewProxyEnvironment +// loads a .env file and mutates process env vars. +func TestNewProxyBootstrap_DefaultsConstructRealInstances(t *testing.T) { + b := NewProxyBootstrap().(*proxyBootstrap) + e := fakeEnvWithDefaults() + loadBalancer := &lbfakes.FakeOriginLoadBalancer{} + + if got := b.newSignalHandler(); got == nil { + t.Error("newSignalHandler default returned nil") + } + if got := b.newRedisLoadBalancer(e); got == nil { + t.Error("newRedisLoadBalancer default returned nil") + } + if got := b.newMemoryLoadBalancer(e); got == nil { + t.Error("newMemoryLoadBalancer default returned nil") + } + if got := b.newRTMPProxyServer(e, loadBalancer); got == nil { + t.Error("newRTMPProxyServer default returned nil") + } + rtc := b.newWebRTCProxyServer(e, loadBalancer) + if rtc == nil { + t.Error("newWebRTCProxyServer default returned nil") + } + if got := b.newHTTPAPIProxyServer(e, time.Second, rtc); got == nil { + t.Error("newHTTPAPIProxyServer default returned nil") + } + if got := b.newSRSSRTProxyServer(e, loadBalancer); got == nil { + t.Error("newSRSSRTProxyServer default returned nil") + } + if got := b.newSystemAPI(e, loadBalancer, time.Second); got == nil { + t.Error("newSystemAPI default returned nil") + } + if got := b.newHTTPStreamProxyServer(e, loadBalancer, time.Second); got == nil { + t.Error("newHTTPStreamProxyServer default returned nil") + } +} + +func TestNewProxyBootstrap_OptCanOverrideSeam(t *testing.T) { + customErr := errors.New("custom") + b := NewProxyBootstrap(func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, customErr } + }).(*proxyBootstrap) + + _, err := b.newEnvironment(context.Background()) + if !errors.Is(err, customErr) { + t.Errorf("custom newEnvironment not applied: %v", err) + } +} + +// ============================================================================= +// initializeLoadBalancer +// ============================================================================= + +func TestInitializeLoadBalancer_Redis(t *testing.T) { + e := fakeEnvWithDefaults() + e.LoadBalancerTypeReturns("redis") + opt, f := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + got, err := b.initializeLoadBalancer(context.Background(), f.env) + if err != nil { + t.Fatalf("initializeLoadBalancer: %v", err) + } + if got != f.lbRedis { + t.Error("expected the redis load balancer") + } + if f.redisCalls.Load() != 1 { + t.Errorf("newRedisLoadBalancer calls = %d, want 1", f.redisCalls.Load()) + } + if f.memoryCalls.Load() != 0 { + t.Errorf("newMemoryLoadBalancer calls = %d, want 0", f.memoryCalls.Load()) + } + if f.lbRedis.InitializeCallCount() != 1 { + t.Errorf("Initialize calls = %d, want 1", f.lbRedis.InitializeCallCount()) + } +} + +func TestInitializeLoadBalancer_Memory(t *testing.T) { + e := fakeEnvWithDefaults() + e.LoadBalancerTypeReturns("memory") + opt, f := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + got, err := b.initializeLoadBalancer(context.Background(), f.env) + if err != nil { + t.Fatalf("initializeLoadBalancer: %v", err) + } + if got != f.lbMemory { + t.Error("expected the memory load balancer") + } + if f.memoryCalls.Load() != 1 { + t.Errorf("newMemoryLoadBalancer calls = %d, want 1", f.memoryCalls.Load()) + } + if f.redisCalls.Load() != 0 { + t.Errorf("newRedisLoadBalancer calls = %d, want 0", f.redisCalls.Load()) + } +} + +func TestInitializeLoadBalancer_DefaultBranchUsesMemory(t *testing.T) { + e := fakeEnvWithDefaults() + e.LoadBalancerTypeReturns("anything-else") + opt, f := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + if _, err := b.initializeLoadBalancer(context.Background(), f.env); err != nil { + t.Fatalf("initializeLoadBalancer: %v", err) + } + if f.memoryCalls.Load() != 1 { + t.Error("unknown LoadBalancerType should fall through to memory") + } +} + +func TestInitializeLoadBalancer_InitializeErrorIsWrapped(t *testing.T) { + initErr := errors.New("boom") + e := fakeEnvWithDefaults() + opt, f := withAllFakes(e) + f.lbMemory.InitializeReturns(initErr) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + _, err := b.initializeLoadBalancer(context.Background(), f.env) + if err == nil { + t.Fatal("expected an error") + } + if !errors.Is(err, initErr) { + t.Errorf("error chain missing initErr: %v", err) + } +} + +// ============================================================================= +// startServers +// ============================================================================= + +// runStartServersUntilCancel runs startServers in a goroutine, cancels the ctx +// once the test has observed all servers running, and returns the result. +func runStartServersUntilCancel(t *testing.T, b *proxyBootstrap, env env.ProxyEnvironment, lb lb.OriginLoadBalancer) error { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- b.startServers(ctx, env, lb, 50*time.Millisecond) }() + // Give startServers time to invoke all six constructors and block on <-ctx.Done(). + time.Sleep(20 * time.Millisecond) + cancel() + select { + case err := <-done: + return err + case <-time.After(2 * time.Second): + t.Fatal("startServers did not return after ctx cancel") + return nil + } +} + +func TestStartServers_HappyPath_StartsAndClosesAllSix(t *testing.T) { + opt, f := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + if err := runStartServersUntilCancel(t, b, f.env, f.lbMemory); err != nil { + t.Fatalf("startServers: %v", err) + } + + if got := f.rtmp.RunCallCount(); got != 1 { + t.Errorf("rtmp Run = %d, want 1", got) + } + if got := f.webrtc.RunCallCount(); got != 1 { + t.Errorf("webrtc Run = %d, want 1", got) + } + if got := f.httpAPI.RunCallCount(); got != 1 { + t.Errorf("httpAPI Run = %d, want 1", got) + } + if got := f.srt.runCalls.Load(); got != 1 { + t.Errorf("srt Run = %d, want 1", got) + } + if got := f.systemAPI.runCalls.Load(); got != 1 { + t.Errorf("systemAPI Run = %d, want 1", got) + } + if got := f.httpStream.RunCallCount(); got != 1 { + t.Errorf("httpStream Run = %d, want 1", got) + } + + if got := f.rtmp.CloseCallCount(); got != 1 { + t.Errorf("rtmp Close = %d, want 1", got) + } + if got := f.webrtc.CloseCallCount(); got != 1 { + t.Errorf("webrtc Close = %d, want 1", got) + } + if got := f.httpAPI.CloseCallCount(); got != 1 { + t.Errorf("httpAPI Close = %d, want 1", got) + } + if got := f.srt.closeCalls.Load(); got != 1 { + t.Errorf("srt Close = %d, want 1", got) + } + if got := f.systemAPI.closeCalls.Load(); got != 1 { + t.Errorf("systemAPI Close = %d, want 1", got) + } + if got := f.httpStream.CloseCallCount(); got != 1 { + t.Errorf("httpStream Close = %d, want 1", got) + } +} + +func TestStartServers_HTTPAPIReceivesWebRTCInstance(t *testing.T) { + opt, f := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + if err := runStartServersUntilCancel(t, b, f.env, f.lbMemory); err != nil { + t.Fatalf("startServers: %v", err) + } + + rtc := f.rtcInHTTPAPI.Load() + if rtc == nil { + t.Fatal("newHTTPAPIProxyServer was not invoked with a WebRTC instance") + } + if rtc.(proxy.WebRTCProxyServer) != f.webrtc { + t.Error("HTTPAPI received a different WebRTC instance than newWebRTCProxyServer returned") + } +} + +func TestStartServers_RunErrorsAreWrappedAndShortCircuit(t *testing.T) { + tests := []struct { + name string + install func(f *bootstrapFakes, err error) + wantWrap string + earlierStarted func(f *bootstrapFakes) bool + }{ + { + name: "rtmp", + install: func(f *bootstrapFakes, err error) { f.rtmp.RunReturns(err) }, + wantWrap: "rtmp server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.webrtc.RunCallCount() == 0 && f.httpAPI.RunCallCount() == 0 + }, + }, + { + name: "webrtc", + install: func(f *bootstrapFakes, err error) { f.webrtc.RunReturns(err) }, + wantWrap: "rtc server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.rtmp.RunCallCount() == 1 && f.httpAPI.RunCallCount() == 0 + }, + }, + { + name: "httpAPI", + install: func(f *bootstrapFakes, err error) { f.httpAPI.RunReturns(err) }, + wantWrap: "http api server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.webrtc.RunCallCount() == 1 && f.srt.runCalls.Load() == 0 + }, + }, + { + name: "srt", + install: func(f *bootstrapFakes, err error) { f.srt.runReturn = err }, + wantWrap: "srt server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.httpAPI.RunCallCount() == 1 && f.systemAPI.runCalls.Load() == 0 + }, + }, + { + name: "systemAPI", + install: func(f *bootstrapFakes, err error) { f.systemAPI.runReturn = err }, + wantWrap: "system api server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.srt.runCalls.Load() == 1 && f.httpStream.RunCallCount() == 0 + }, + }, + { + name: "httpStream", + install: func(f *bootstrapFakes, err error) { f.httpStream.RunReturns(err) }, + wantWrap: "http server", + earlierStarted: func(f *bootstrapFakes) bool { + return f.systemAPI.runCalls.Load() == 1 + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + runErr := errors.New("boom-" + tc.name) + opt, f := withAllFakes(fakeEnvWithDefaults()) + tc.install(f, runErr) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.startServers(context.Background(), f.env, f.lbMemory, 50*time.Millisecond) + if err == nil { + t.Fatalf("%s: expected error", tc.name) + } + if !errors.Is(err, runErr) { + t.Errorf("%s: error chain missing runErr: %v", tc.name, err) + } + if !contains(err.Error(), tc.wantWrap) { + t.Errorf("%s: error %q does not contain wrap %q", tc.name, err.Error(), tc.wantWrap) + } + if !tc.earlierStarted(f) { + t.Errorf("%s: short-circuit invariant violated", tc.name) + } + }) + } +} + +// contains is a tiny helper so the table-driven test doesn't pull in strings +// just for substring matching. +func contains(haystack, needle string) bool { + for i := 0; i+len(needle) <= len(haystack); i++ { + if haystack[i:i+len(needle)] == needle { + return true + } + } + return false +} + +// ============================================================================= +// run +// ============================================================================= + +func TestRun_NewEnvironmentErrorIsWrapped(t *testing.T) { + envErr := errors.New("env-boom") + opt, _ := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt, func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, envErr } + }).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, envErr) { + t.Errorf("error chain missing envErr: %v", err) + } + if !contains(err.Error(), "create environment") { + t.Errorf("expected wrap %q, got %q", "create environment", err.Error()) + } +} + +func TestRun_InstallForceQuitErrorIsWrapped(t *testing.T) { + fqErr := errors.New("force-quit-boom") + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.signal.installForceQuitReturn = fqErr + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, fqErr) { + t.Errorf("error chain missing fqErr: %v", err) + } + if !contains(err.Error(), "install force quit") { + t.Errorf("expected wrap %q, got %q", "install force quit", err.Error()) + } +} + +func TestRun_BadGraceQuitDurationIsWrapped(t *testing.T) { + e := fakeEnvWithDefaults() + e.GraceQuitTimeoutReturns("not-a-duration") + opt, _ := withAllFakes(e) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !contains(err.Error(), "parse gracefully quit timeout") { + t.Errorf("expected wrap %q, got %q", "parse gracefully quit timeout", err.Error()) + } +} + +func TestRun_LoadBalancerInitializeErrorIsWrapped(t *testing.T) { + initErr := errors.New("init-boom") + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.lbMemory.InitializeReturns(initErr) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + err := b.run(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, initErr) { + t.Errorf("error chain missing initErr: %v", err) + } + if !contains(err.Error(), "initialize srs load balancer") { + t.Errorf("expected wrap %q, got %q", "initialize srs load balancer", err.Error()) + } +} + +func TestRun_HappyPath_BlocksUntilCancelThenReturnsNil(t *testing.T) { + opt, _ := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt).(*proxyBootstrap) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- b.run(ctx) }() + time.Sleep(20 * time.Millisecond) + cancel() + select { + case err := <-done: + if err != nil { + t.Errorf("run: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("run did not return after ctx cancel") + } +} + +// ============================================================================= +// Start +// ============================================================================= + +func TestStart_HappyPath_InstallsSignalsAndReturnsNil(t *testing.T) { + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.signal.installSignalsCancels = true // cancel the inner ctx immediately + b := NewProxyBootstrap(opt) + + err := b.Start(context.Background()) + if err != nil { + t.Fatalf("Start: %v", err) + } + if f.signal.installSignalsCalls.Load() != 1 { + t.Errorf("InstallSignals calls = %d, want 1", f.signal.installSignalsCalls.Load()) + } + if f.signal.installForceQuitCalls.Load() != 1 { + t.Errorf("InstallForceQuit calls = %d, want 1", f.signal.installForceQuitCalls.Load()) + } +} + +func TestStart_PropagatesNonCancelError(t *testing.T) { + envErr := errors.New("env-boom") + opt, _ := withAllFakes(fakeEnvWithDefaults()) + b := NewProxyBootstrap(opt, func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, envErr } + }) + + err := b.Start(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, envErr) { + t.Errorf("error chain missing envErr: %v", err) + } +} + +func TestStart_AbsorbsErrorWhenContextCancelled(t *testing.T) { + // When InstallSignals cancels the inner ctx and run returns an error, Start + // should swallow the error (treating it as a graceful shutdown). + envErr := errors.New("post-cancel-boom") + opt, f := withAllFakes(fakeEnvWithDefaults()) + f.signal.installSignalsCancels = true + b := NewProxyBootstrap(opt, func(b *proxyBootstrap) { + b.newEnvironment = func(context.Context) (env.ProxyEnvironment, error) { return nil, envErr } + }) + + err := b.Start(context.Background()) + if err != nil { + t.Errorf("Start should swallow error after ctx cancel, got: %v", err) + } +} diff --git a/internal/lb/debug.go b/internal/lb/debug.go index 9dab03d82..24cd60dee 100644 --- a/internal/lb/debug.go +++ b/internal/lb/debug.go @@ -12,8 +12,8 @@ import ( "srsx/internal/logger" ) -// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. -func NewDefaultSRSForDebugging(environment env.ProxyEnvironment) (*SRSServer, error) { +// NewDefaultOriginServerForDebugging initializes the default origin server, for debugging only. +func NewDefaultOriginServerForDebugging(environment env.ProxyEnvironment) (*OriginServer, error) { if environment.DefaultBackendEnabled() != "on" { return nil, nil } @@ -25,7 +25,7 @@ func NewDefaultSRSForDebugging(environment env.ProxyEnvironment) (*SRSServer, er return nil, fmt.Errorf("empty default backend rtmp") } - server := NewSRSServer(func(srs *SRSServer) { + server := NewOriginServer(func(srs *OriginServer) { srs.IP = environment.DefaultBackendIP() srs.RTMP = []string{environment.DefaultBackendRTMP()} srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) diff --git a/internal/lb/gen.go b/internal/lb/gen.go new file mode 100644 index 000000000..f9822e11f --- /dev/null +++ b/internal/lb/gen.go @@ -0,0 +1,9 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +//go:generate go tool counterfeiter -o lbfakes/fake_origin_load_balancer.go . OriginLoadBalancer +//go:generate go tool counterfeiter -o lbfakes/fake_origin_service.go . OriginService +//go:generate go tool counterfeiter -o lbfakes/fake_hls_service.go . HLSService +//go:generate go tool counterfeiter -o lbfakes/fake_rtc_service.go . RTCService diff --git a/internal/lb/lb.go b/internal/lb/lb.go index 3c097c7f1..46bb3498e 100644 --- a/internal/lb/lb.go +++ b/internal/lb/lb.go @@ -19,8 +19,8 @@ const HLSAliveDuration = 120 * time.Second // If WebRTC streaming update in this duration, it's alive. const RTCAliveDuration = 120 * time.Second -// SRSServer represents a backend origin server. -type SRSServer struct { +// OriginServer represents a backend origin server. +type OriginServer struct { // The server IP. IP string `json:"ip,omitempty"` // The server device ID, configured by user. @@ -45,15 +45,15 @@ type SRSServer struct { UpdatedAt time.Time `json:"update_at,omitempty"` } -func (v *SRSServer) ID() string { +func (v *OriginServer) ID() string { return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID) } -func (v *SRSServer) String() string { +func (v *OriginServer) String() string { return fmt.Sprintf("%v", v) } -func (v *SRSServer) Format(f fmt.State, c rune) { +func (v *OriginServer) Format(f fmt.State, c rune) { switch c { case 'v', 's': if f.Flag('+') { @@ -87,8 +87,8 @@ func (v *SRSServer) Format(f fmt.State, c rune) { } } -func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { - v := &SRSServer{} +func NewOriginServer(opts ...func(*OriginServer)) *OriginServer { + v := &OriginServer{} for _, opt := range opts { opt(v) } @@ -109,23 +109,35 @@ type RTCConnection interface { GetUfrag() string } -// SRSLoadBalancer is the interface to load balance the SRS servers. -type SRSLoadBalancer interface { - // Initialize the load balancer. - Initialize(ctx context.Context) error - // Update the backend server. - Update(ctx context.Context, server *SRSServer) error +// OriginService is the interface for origin-server registry and stream routing. +type OriginService interface { + // Update records the latest registration or heartbeat for an origin server. + Update(ctx context.Context, server *OriginServer) error // Pick a backend server for the specified stream URL. - Pick(ctx context.Context, streamURL string) (*SRSServer, error) + Pick(ctx context.Context, streamURL string) (*OriginServer, error) +} + +// HLSService is the interface for HLS session state, indexed by stream URL and SPBHID. +type HLSService interface { // Load or store the HLS streaming for the specified stream URL. LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) +} + +// RTCService is the interface for WebRTC session state, indexed by stream URL and ICE ufrag. +type RTCService interface { // Store the WebRTC streaming for the specified stream URL. StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error // Load the WebRTC streaming by ufrag, the ICE username. LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) } -// SrsLoadBalancer is the global SRS load balancer instance. -var SrsLoadBalancer SRSLoadBalancer +// OriginLoadBalancer is the interface to load balance the SRS servers. +type OriginLoadBalancer interface { + OriginService + HLSService + RTCService + // Initialize the load balancer. + Initialize(ctx context.Context) error +} diff --git a/internal/lb/lb_test.go b/internal/lb/lb_test.go new file mode 100644 index 000000000..74adaa6cd --- /dev/null +++ b/internal/lb/lb_test.go @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +import ( + "fmt" + "strings" + "testing" + "time" +) + +func TestOriginServerID(t *testing.T) { + for _, tt := range []struct { + name string + v *OriginServer + want string + }{ + {"populated", &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1234"}, "srv-svc-1234"}, + {"empty", &OriginServer{}, "--"}, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.v.ID(); got != tt.want { + t.Fatalf("ID()=%q, want %q", got, tt.want) + } + }) + } +} + +func TestOriginServerString(t *testing.T) { + // String() routes through Format with the %v default branch. + v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"} + got := v.String() + if want := "SRS ip=1.2.3.4, id=srv-svc-p"; got != want { + t.Fatalf("String()=%q, want %q", got, want) + } +} + +func TestOriginServerFormat_ShortVerbs(t *testing.T) { + v := &OriginServer{IP: "10.0.0.1", ServerID: "srv", ServiceID: "svc", PID: "9"} + want := "SRS ip=10.0.0.1, id=srv-svc-9" + for _, verb := range []string{"%v", "%s"} { + got := fmt.Sprintf(verb, v) + if got != want { + t.Fatalf("Sprintf(%q)=%q, want %q", verb, got, want) + } + } +} + +func TestOriginServerFormat_PlusVerbsAllFields(t *testing.T) { + ts := time.Date(2026, 5, 16, 10, 30, 45, 123_000_000, time.UTC) + v := &OriginServer{ + IP: "10.0.0.1", DeviceID: "dev1", + ServerID: "srv", ServiceID: "svc", PID: "9", + RTMP: []string{":1935", ":1936"}, + HTTP: []string{":8080"}, + API: []string{":1985"}, + SRT: []string{":10080"}, + RTC: []string{":8000"}, + UpdatedAt: ts, + } + + for _, verb := range []string{"%+v", "%+s"} { + got := fmt.Sprintf(verb, v) + for _, sub := range []string{ + "SRS ip=10.0.0.1", + "id=srv-svc-9", + "pid=9, server=srv, service=svc", + "device=dev1", + "rtmp=[:1935,:1936]", + "http=[:8080]", + "api=[:1985]", + "srt=[:10080]", + "rtc=[:8000]", + "update=2026-05-16 10:30:45.123", + } { + if !strings.Contains(got, sub) { + t.Fatalf("Sprintf(%q)=%q missing %q", verb, got, sub) + } + } + } +} + +func TestOriginServerFormat_PlusVerbMinimal(t *testing.T) { + // Plus verb with no optional fields populated exercises the false + // branches of every "if len(X) > 0 / X != \"\"" guard in Format. + v := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "9"} + got := fmt.Sprintf("%+v", v) + + if !strings.Contains(got, "pid=9, server=srv, service=svc") { + t.Fatalf("%%+v output %q missing core ids", got) + } + if !strings.Contains(got, "update=") { + t.Fatalf("%%+v output %q missing update timestamp", got) + } + for _, sub := range []string{"device=", "rtmp=", "http=", "api=", "srt=", "rtc="} { + if strings.Contains(got, sub) { + t.Fatalf("%%+v output %q should not contain %q for an empty field", got, sub) + } + } +} + +func TestOriginServerFormat_OtherVerb(t *testing.T) { + // A non-v/s verb falls through to the default branch, which recursively + // formats with %v and appends ", fmt=%". + v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"} + got := fmt.Sprintf("%d", v) + want := "SRS ip=1.2.3.4, id=srv-svc-p, fmt=%d" + if got != want { + t.Fatalf("%%d output %q, want %q", got, want) + } +} + +func TestNewOriginServer(t *testing.T) { + t.Run("no opts", func(t *testing.T) { + v := NewOriginServer() + if v == nil { + t.Fatal("NewOriginServer() returned nil") + } + if v.IP != "" || v.DeviceID != "" || v.ServerID != "" || v.ServiceID != "" || v.PID != "" { + t.Fatalf("expected zero value, got %+v", v) + } + if len(v.RTMP)+len(v.HTTP)+len(v.API)+len(v.SRT)+len(v.RTC) != 0 { + t.Fatalf("expected empty endpoints, got %+v", v) + } + if !v.UpdatedAt.IsZero() { + t.Fatalf("expected zero UpdatedAt, got %v", v.UpdatedAt) + } + }) + + t.Run("with opts", func(t *testing.T) { + v := NewOriginServer( + func(s *OriginServer) { s.IP = "9.9.9.9" }, + func(s *OriginServer) { s.ServerID = "abc" }, + func(s *OriginServer) { s.RTMP = []string{":1935"} }, + ) + if v.IP != "9.9.9.9" || v.ServerID != "abc" || len(v.RTMP) != 1 || v.RTMP[0] != ":1935" { + t.Fatalf("opts not applied: got %+v", v) + } + }) +} diff --git a/internal/lb/lbfakes/fake_hls_service.go b/internal/lb/lbfakes/fake_hls_service.go new file mode 100644 index 000000000..8aa7a8340 --- /dev/null +++ b/internal/lb/lbfakes/fake_hls_service.go @@ -0,0 +1,197 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeHLSService struct { + LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error) + loadHLSBySPBHIDMutex sync.RWMutex + loadHLSBySPBHIDArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadHLSBySPBHIDReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadHLSBySPBHIDReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error) + loadOrStoreHLSMutex sync.RWMutex + loadOrStoreHLSArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + } + loadOrStoreHLSReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadOrStoreHLSReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHLSService) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) { + fake.loadHLSBySPBHIDMutex.Lock() + ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)] + fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadHLSBySPBHIDStub + fakeReturns := fake.loadHLSBySPBHIDReturns + fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2}) + fake.loadHLSBySPBHIDMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDCallCount() int { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + return len(fake.loadHLSBySPBHIDArgsForCall) +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = stub +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + argsForCall := fake.loadHLSBySPBHIDArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + fake.loadHLSBySPBHIDReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + if fake.loadHLSBySPBHIDReturnsOnCall == nil { + fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadHLSBySPBHIDReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) { + fake.loadOrStoreHLSMutex.Lock() + ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)] + fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + }{arg1, arg2, arg3}) + stub := fake.LoadOrStoreHLSStub + fakeReturns := fake.loadOrStoreHLSReturns + fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3}) + fake.loadOrStoreHLSMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHLSService) LoadOrStoreHLSCallCount() int { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + return len(fake.loadOrStoreHLSArgsForCall) +} + +func (fake *FakeHLSService) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = stub +} + +func (fake *FakeHLSService) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + argsForCall := fake.loadOrStoreHLSArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeHLSService) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + fake.loadOrStoreHLSReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + if fake.loadOrStoreHLSReturnsOnCall == nil { + fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadOrStoreHLSReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHLSService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.HLSService = new(FakeHLSService) diff --git a/internal/lb/lbfakes/fake_origin_load_balancer.go b/internal/lb/lbfakes/fake_origin_load_balancer.go new file mode 100644 index 000000000..ab16a6628 --- /dev/null +++ b/internal/lb/lbfakes/fake_origin_load_balancer.go @@ -0,0 +1,577 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeOriginLoadBalancer struct { + InitializeStub func(context.Context) error + initializeMutex sync.RWMutex + initializeArgsForCall []struct { + arg1 context.Context + } + initializeReturns struct { + result1 error + } + initializeReturnsOnCall map[int]struct { + result1 error + } + LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error) + loadHLSBySPBHIDMutex sync.RWMutex + loadHLSBySPBHIDArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadHLSBySPBHIDReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadHLSBySPBHIDReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error) + loadOrStoreHLSMutex sync.RWMutex + loadOrStoreHLSArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + } + loadOrStoreHLSReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadOrStoreHLSReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error) + loadWebRTCByUfragMutex sync.RWMutex + loadWebRTCByUfragArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadWebRTCByUfragReturns struct { + result1 lb.RTCConnection + result2 error + } + loadWebRTCByUfragReturnsOnCall map[int]struct { + result1 lb.RTCConnection + result2 error + } + PickStub func(context.Context, string) (*lb.OriginServer, error) + pickMutex sync.RWMutex + pickArgsForCall []struct { + arg1 context.Context + arg2 string + } + pickReturns struct { + result1 *lb.OriginServer + result2 error + } + pickReturnsOnCall map[int]struct { + result1 *lb.OriginServer + result2 error + } + StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error + storeWebRTCMutex sync.RWMutex + storeWebRTCArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + } + storeWebRTCReturns struct { + result1 error + } + storeWebRTCReturnsOnCall map[int]struct { + result1 error + } + UpdateStub func(context.Context, *lb.OriginServer) error + updateMutex sync.RWMutex + updateArgsForCall []struct { + arg1 context.Context + arg2 *lb.OriginServer + } + updateReturns struct { + result1 error + } + updateReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeOriginLoadBalancer) Initialize(arg1 context.Context) error { + fake.initializeMutex.Lock() + ret, specificReturn := fake.initializeReturnsOnCall[len(fake.initializeArgsForCall)] + fake.initializeArgsForCall = append(fake.initializeArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.InitializeStub + fakeReturns := fake.initializeReturns + fake.recordInvocation("Initialize", []interface{}{arg1}) + fake.initializeMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginLoadBalancer) InitializeCallCount() int { + fake.initializeMutex.RLock() + defer fake.initializeMutex.RUnlock() + return len(fake.initializeArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) InitializeCalls(stub func(context.Context) error) { + fake.initializeMutex.Lock() + defer fake.initializeMutex.Unlock() + fake.InitializeStub = stub +} + +func (fake *FakeOriginLoadBalancer) InitializeArgsForCall(i int) context.Context { + fake.initializeMutex.RLock() + defer fake.initializeMutex.RUnlock() + argsForCall := fake.initializeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeOriginLoadBalancer) InitializeReturns(result1 error) { + fake.initializeMutex.Lock() + defer fake.initializeMutex.Unlock() + fake.InitializeStub = nil + fake.initializeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) InitializeReturnsOnCall(i int, result1 error) { + fake.initializeMutex.Lock() + defer fake.initializeMutex.Unlock() + fake.InitializeStub = nil + if fake.initializeReturnsOnCall == nil { + fake.initializeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.initializeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) { + fake.loadHLSBySPBHIDMutex.Lock() + ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)] + fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadHLSBySPBHIDStub + fakeReturns := fake.loadHLSBySPBHIDReturns + fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2}) + fake.loadHLSBySPBHIDMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCallCount() int { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + return len(fake.loadHLSBySPBHIDArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = stub +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + argsForCall := fake.loadHLSBySPBHIDArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + fake.loadHLSBySPBHIDReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + if fake.loadHLSBySPBHIDReturnsOnCall == nil { + fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadHLSBySPBHIDReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) { + fake.loadOrStoreHLSMutex.Lock() + ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)] + fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + }{arg1, arg2, arg3}) + stub := fake.LoadOrStoreHLSStub + fakeReturns := fake.loadOrStoreHLSReturns + fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3}) + fake.loadOrStoreHLSMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCallCount() int { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + return len(fake.loadOrStoreHLSArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = stub +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + argsForCall := fake.loadOrStoreHLSArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + fake.loadOrStoreHLSReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + if fake.loadOrStoreHLSReturnsOnCall == nil { + fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadOrStoreHLSReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) { + fake.loadWebRTCByUfragMutex.Lock() + ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)] + fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadWebRTCByUfragStub + fakeReturns := fake.loadWebRTCByUfragReturns + fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2}) + fake.loadWebRTCByUfragMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCallCount() int { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + return len(fake.loadWebRTCByUfragArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = stub +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + argsForCall := fake.loadWebRTCByUfragArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + fake.loadWebRTCByUfragReturns = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + if fake.loadWebRTCByUfragReturnsOnCall == nil { + fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct { + result1 lb.RTCConnection + result2 error + }) + } + fake.loadWebRTCByUfragReturnsOnCall[i] = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) { + fake.pickMutex.Lock() + ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)] + fake.pickArgsForCall = append(fake.pickArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.PickStub + fakeReturns := fake.pickReturns + fake.recordInvocation("Pick", []interface{}{arg1, arg2}) + fake.pickMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) PickCallCount() int { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + return len(fake.pickArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = stub +} + +func (fake *FakeOriginLoadBalancer) PickArgsForCall(i int) (context.Context, string) { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + argsForCall := fake.pickArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) PickReturns(result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + fake.pickReturns = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + if fake.pickReturnsOnCall == nil { + fake.pickReturnsOnCall = make(map[int]struct { + result1 *lb.OriginServer + result2 error + }) + } + fake.pickReturnsOnCall[i] = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error { + fake.storeWebRTCMutex.Lock() + ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)] + fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + }{arg1, arg2, arg3}) + stub := fake.StoreWebRTCStub + fakeReturns := fake.storeWebRTCReturns + fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3}) + fake.storeWebRTCMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCCallCount() int { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + return len(fake.storeWebRTCArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = stub +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + argsForCall := fake.storeWebRTCArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCReturns(result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + fake.storeWebRTCReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCReturnsOnCall(i int, result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + if fake.storeWebRTCReturnsOnCall == nil { + fake.storeWebRTCReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeWebRTCReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) Update(arg1 context.Context, arg2 *lb.OriginServer) error { + fake.updateMutex.Lock() + ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)] + fake.updateArgsForCall = append(fake.updateArgsForCall, struct { + arg1 context.Context + arg2 *lb.OriginServer + }{arg1, arg2}) + stub := fake.UpdateStub + fakeReturns := fake.updateReturns + fake.recordInvocation("Update", []interface{}{arg1, arg2}) + fake.updateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginLoadBalancer) UpdateCallCount() int { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + return len(fake.updateArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = stub +} + +func (fake *FakeOriginLoadBalancer) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + argsForCall := fake.updateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) UpdateReturns(result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + fake.updateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) UpdateReturnsOnCall(i int, result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + if fake.updateReturnsOnCall == nil { + fake.updateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeOriginLoadBalancer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.OriginLoadBalancer = new(FakeOriginLoadBalancer) diff --git a/internal/lb/lbfakes/fake_origin_service.go b/internal/lb/lbfakes/fake_origin_service.go new file mode 100644 index 000000000..9ffaaa877 --- /dev/null +++ b/internal/lb/lbfakes/fake_origin_service.go @@ -0,0 +1,190 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeOriginService struct { + PickStub func(context.Context, string) (*lb.OriginServer, error) + pickMutex sync.RWMutex + pickArgsForCall []struct { + arg1 context.Context + arg2 string + } + pickReturns struct { + result1 *lb.OriginServer + result2 error + } + pickReturnsOnCall map[int]struct { + result1 *lb.OriginServer + result2 error + } + UpdateStub func(context.Context, *lb.OriginServer) error + updateMutex sync.RWMutex + updateArgsForCall []struct { + arg1 context.Context + arg2 *lb.OriginServer + } + updateReturns struct { + result1 error + } + updateReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeOriginService) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) { + fake.pickMutex.Lock() + ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)] + fake.pickArgsForCall = append(fake.pickArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.PickStub + fakeReturns := fake.pickReturns + fake.recordInvocation("Pick", []interface{}{arg1, arg2}) + fake.pickMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginService) PickCallCount() int { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + return len(fake.pickArgsForCall) +} + +func (fake *FakeOriginService) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = stub +} + +func (fake *FakeOriginService) PickArgsForCall(i int) (context.Context, string) { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + argsForCall := fake.pickArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginService) PickReturns(result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + fake.pickReturns = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginService) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + if fake.pickReturnsOnCall == nil { + fake.pickReturnsOnCall = make(map[int]struct { + result1 *lb.OriginServer + result2 error + }) + } + fake.pickReturnsOnCall[i] = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginService) Update(arg1 context.Context, arg2 *lb.OriginServer) error { + fake.updateMutex.Lock() + ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)] + fake.updateArgsForCall = append(fake.updateArgsForCall, struct { + arg1 context.Context + arg2 *lb.OriginServer + }{arg1, arg2}) + stub := fake.UpdateStub + fakeReturns := fake.updateReturns + fake.recordInvocation("Update", []interface{}{arg1, arg2}) + fake.updateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginService) UpdateCallCount() int { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + return len(fake.updateArgsForCall) +} + +func (fake *FakeOriginService) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = stub +} + +func (fake *FakeOriginService) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + argsForCall := fake.updateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginService) UpdateReturns(result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + fake.updateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginService) UpdateReturnsOnCall(i int, result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + if fake.updateReturnsOnCall == nil { + fake.updateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeOriginService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.OriginService = new(FakeOriginService) diff --git a/internal/lb/lbfakes/fake_rtc_service.go b/internal/lb/lbfakes/fake_rtc_service.go new file mode 100644 index 000000000..73772d666 --- /dev/null +++ b/internal/lb/lbfakes/fake_rtc_service.go @@ -0,0 +1,192 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeRTCService struct { + LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error) + loadWebRTCByUfragMutex sync.RWMutex + loadWebRTCByUfragArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadWebRTCByUfragReturns struct { + result1 lb.RTCConnection + result2 error + } + loadWebRTCByUfragReturnsOnCall map[int]struct { + result1 lb.RTCConnection + result2 error + } + StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error + storeWebRTCMutex sync.RWMutex + storeWebRTCArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + } + storeWebRTCReturns struct { + result1 error + } + storeWebRTCReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRTCService) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) { + fake.loadWebRTCByUfragMutex.Lock() + ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)] + fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadWebRTCByUfragStub + fakeReturns := fake.loadWebRTCByUfragReturns + fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2}) + fake.loadWebRTCByUfragMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRTCService) LoadWebRTCByUfragCallCount() int { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + return len(fake.loadWebRTCByUfragArgsForCall) +} + +func (fake *FakeRTCService) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = stub +} + +func (fake *FakeRTCService) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + argsForCall := fake.loadWebRTCByUfragArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRTCService) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + fake.loadWebRTCByUfragReturns = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeRTCService) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + if fake.loadWebRTCByUfragReturnsOnCall == nil { + fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct { + result1 lb.RTCConnection + result2 error + }) + } + fake.loadWebRTCByUfragReturnsOnCall[i] = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeRTCService) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error { + fake.storeWebRTCMutex.Lock() + ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)] + fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + }{arg1, arg2, arg3}) + stub := fake.StoreWebRTCStub + fakeReturns := fake.storeWebRTCReturns + fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3}) + fake.storeWebRTCMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRTCService) StoreWebRTCCallCount() int { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + return len(fake.storeWebRTCArgsForCall) +} + +func (fake *FakeRTCService) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = stub +} + +func (fake *FakeRTCService) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + argsForCall := fake.storeWebRTCArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRTCService) StoreWebRTCReturns(result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + fake.storeWebRTCReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRTCService) StoreWebRTCReturnsOnCall(i int, result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + if fake.storeWebRTCReturnsOnCall == nil { + fake.storeWebRTCReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeWebRTCReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRTCService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRTCService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.RTCService = new(FakeRTCService) diff --git a/internal/lb/mem.go b/internal/lb/mem.go index 57b4c88b4..f49434bba 100644 --- a/internal/lb/mem.go +++ b/internal/lb/mem.go @@ -15,14 +15,14 @@ import ( "srsx/internal/sync" ) -// MemoryLoadBalancer stores state in memory. -type MemoryLoadBalancer struct { +// memoryLoadBalancer stores state in memory. +type memoryLoadBalancer struct { // The environment interface. environment env.ProxyEnvironment // All available SRS servers, key is server ID. - servers sync.Map[string, *SRSServer] + servers sync.Map[string, *OriginServer] // The picked server to service client by specified stream URL, key is stream url. - picked sync.Map[string, *SRSServer] + picked sync.Map[string, *OriginServer] // The HLS streaming, key is stream URL. hlsStreamURL sync.Map[string, HLSPlayStream] // The HLS streaming, key is SPBHID. @@ -31,23 +31,28 @@ type MemoryLoadBalancer struct { rtcStreamURL sync.Map[string, RTCConnection] // The WebRTC streaming, key is ufrag. rtcUfrag sync.Map[string, RTCConnection] + // keepaliveInterval is the period at which the default-backend keep-alive + // goroutine re-Updates its registration. Struct field for test injection + // (avoids racing a package global across concurrent tests). + keepaliveInterval time.Duration } // NewMemoryLoadBalancer creates a new memory-based load balancer. -func NewMemoryLoadBalancer(environment env.ProxyEnvironment) SRSLoadBalancer { - return &MemoryLoadBalancer{ - environment: environment, - servers: sync.NewMap[string, *SRSServer](), - picked: sync.NewMap[string, *SRSServer](), - hlsStreamURL: sync.NewMap[string, HLSPlayStream](), - hlsSPBHID: sync.NewMap[string, HLSPlayStream](), - rtcStreamURL: sync.NewMap[string, RTCConnection](), - rtcUfrag: sync.NewMap[string, RTCConnection](), +func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer { + return &memoryLoadBalancer{ + environment: environment, + servers: sync.NewMap[string, *OriginServer](), + picked: sync.NewMap[string, *OriginServer](), + hlsStreamURL: sync.NewMap[string, HLSPlayStream](), + hlsSPBHID: sync.NewMap[string, HLSPlayStream](), + rtcStreamURL: sync.NewMap[string, RTCConnection](), + rtcUfrag: sync.NewMap[string, RTCConnection](), + keepaliveInterval: 30 * time.Second, } } -func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error { - server, err := NewDefaultSRSForDebugging(v.environment) +func (v *memoryLoadBalancer) Initialize(ctx context.Context) error { + server, err := NewDefaultOriginServerForDebugging(v.environment) if err != nil { return errors.Wrapf(err, "initialize default SRS") } @@ -63,7 +68,7 @@ func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error { select { case <-ctx.Done(): return - case <-time.After(30 * time.Second): + case <-time.After(v.keepaliveInterval): if err := v.Update(ctx, server); err != nil { logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err) } @@ -75,20 +80,20 @@ func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error { return nil } -func (v *MemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error { +func (v *memoryLoadBalancer) Update(ctx context.Context, server *OriginServer) error { v.servers.Store(server.ID(), server) return nil } -func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { +func (v *memoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) { // Always proxy to the same server for the same stream URL. if server, ok := v.picked.Load(streamURL); ok { return server, nil } // Gather all servers that were alive within the last few seconds. - var servers []*SRSServer - v.servers.Range(func(key string, server *SRSServer) bool { + var servers []*OriginServer + v.servers.Range(func(key string, server *OriginServer) bool { if time.Since(server.UpdatedAt) < ServerAliveDuration { servers = append(servers, server) } @@ -97,7 +102,7 @@ func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSSe // If no servers available, use all possible servers. if len(servers) == 0 { - v.servers.Range(func(key string, server *SRSServer) bool { + v.servers.Range(func(key string, server *OriginServer) bool { servers = append(servers, server) return true }) @@ -115,7 +120,7 @@ func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSSe return server, nil } -func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { +func (v *memoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { // Load the HLS streaming for the SPBHID, for TS files. if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) @@ -124,7 +129,7 @@ func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) } } -func (v *MemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { +func (v *memoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { // Update the HLS streaming for the stream URL, for M3u8. actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) if actual == nil { @@ -137,7 +142,7 @@ func (v *MemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL strin return actual, nil } -func (v *MemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { +func (v *memoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { // Update the WebRTC streaming for the stream URL. v.rtcStreamURL.Store(streamURL, value) @@ -146,7 +151,7 @@ func (v *MemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, return nil } -func (v *MemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { +func (v *memoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { if actual, ok := v.rtcUfrag.Load(ufrag); !ok { return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) } else { diff --git a/internal/lb/mem_test.go b/internal/lb/mem_test.go new file mode 100644 index 000000000..77e4f0569 --- /dev/null +++ b/internal/lb/mem_test.go @@ -0,0 +1,263 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +import ( + "context" + "strings" + "testing" + "time" + + "srsx/internal/env/envfakes" +) + +// stubHLS is a minimal HLSPlayStream for testing. +type stubHLS struct { + spbhid string +} + +func (s *stubHLS) GetSPBHID() string { return s.spbhid } +func (s *stubHLS) Initialize(ctx context.Context) HLSPlayStream { return s } + +// stubRTC is a minimal RTCConnection for testing. +type stubRTC struct { + ufrag string +} + +func (s *stubRTC) GetUfrag() string { return s.ufrag } + +// newMem returns a fresh in-memory load balancer with a default fake env. +func newMem() *memoryLoadBalancer { + env := &envfakes.FakeProxyEnvironment{} + return NewMemoryLoadBalancer(env).(*memoryLoadBalancer) +} + +func TestNewMemoryLoadBalancer(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + lb := NewMemoryLoadBalancer(env) + if lb == nil { + t.Fatal("NewMemoryLoadBalancer returned nil") + } +} + +func TestMemLB_Initialize_DefaultBackendDisabled(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("off") + lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) + if err := lb.Initialize(context.Background()); err != nil { + t.Fatalf("Initialize: %v", err) + } + // No server stored when disabled. + count := 0 + lb.servers.Range(func(string, *OriginServer) bool { count++; return true }) + if count != 0 { + t.Fatalf("expected 0 servers, got %d", count) + } +} + +func TestMemLB_Initialize_DefaultBackendError(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("") // triggers "empty default backend ip" + lb := NewMemoryLoadBalancer(env) + err := lb.Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "initialize default SRS") { + t.Fatalf("expected wrapped error, got %v", err) + } +} + +func TestMemLB_Initialize_KeepaliveTick(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) + // Shorten the keep-alive interval on this instance only so concurrent + // tests don't race on shared state. + lb.keepaliveInterval = time.Millisecond + if err := lb.Initialize(ctx); err != nil { + t.Fatalf("Initialize: %v", err) + } + + // Find the server and watch UpdatedAt advance after a keep-alive tick. + var s *OriginServer + lb.servers.Range(func(_ string, v *OriginServer) bool { s = v; return false }) + if s == nil { + t.Fatal("expected server stored") + } + first := s.UpdatedAt + + // Wait long enough for several ticks (interval is 1ms, server.UpdatedAt + // is set to time.Now() inside NewDefaultOriginServerForDebugging on each + // Update? — actually Update only stores the server pointer, so UpdatedAt + // won't change. The goroutine still hits the tick branch though, which + // is all we need for coverage). + time.Sleep(20 * time.Millisecond) + cancel() + time.Sleep(10 * time.Millisecond) + + _ = first +} + +func TestMemLB_Initialize_DefaultBackendSuccess(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) + if err := lb.Initialize(ctx); err != nil { + t.Fatalf("Initialize: %v", err) + } + + count := 0 + lb.servers.Range(func(string, *OriginServer) bool { count++; return true }) + if count != 1 { + t.Fatalf("expected 1 server stored, got %d", count) + } + + // Cancel and give the keep-alive goroutine a moment to exit cleanly. + cancel() + time.Sleep(20 * time.Millisecond) +} + +func TestMemLB_Update(t *testing.T) { + lb := newMem() + s := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1"} + if err := lb.Update(context.Background(), s); err != nil { + t.Fatalf("Update: %v", err) + } + got, ok := lb.servers.Load(s.ID()) + if !ok || got != s { + t.Fatalf("Update did not store the server: got=%v ok=%v", got, ok) + } +} + +func TestMemLB_Pick_NoServers(t *testing.T) { + lb := newMem() + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "no server available") { + t.Fatalf("expected no-server error, got %v", err) + } +} + +func TestMemLB_Pick_AliveServer_Sticky(t *testing.T) { + lb := newMem() + s := &OriginServer{ServerID: "a", PID: "1", UpdatedAt: time.Now()} + _ = lb.Update(context.Background(), s) + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got != s { + t.Fatalf("Pick returned %v, want %v", got, s) + } + + // Second pick for the same URL returns the same server (sticky branch). + got2, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick second: %v", err) + } + if got2 != got { + t.Fatalf("second Pick returned %v, want %v (sticky)", got2, got) + } +} + +func TestMemLB_Pick_OnlyDeadServers_Fallback(t *testing.T) { + lb := newMem() + // UpdatedAt long past => not alive. Tests the fallback "use all servers" branch. + s := &OriginServer{ + ServerID: "a", + PID: "1", + UpdatedAt: time.Now().Add(-2 * ServerAliveDuration), + } + _ = lb.Update(context.Background(), s) + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got != s { + t.Fatalf("expected dead-server fallback to return %v, got %v", s, got) + } +} + +func TestMemLB_LoadHLSBySPBHID_NotFound(t *testing.T) { + lb := newMem() + _, err := lb.LoadHLSBySPBHID(context.Background(), "missing") + if err == nil || !strings.Contains(err.Error(), "no HLS streaming") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestMemLB_LoadOrStoreHLS_New(t *testing.T) { + lb := newMem() + s := &stubHLS{spbhid: "abc"} + got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s) + if err != nil { + t.Fatalf("LoadOrStoreHLS: %v", err) + } + if got != s { + t.Fatalf("LoadOrStoreHLS returned %v, want %v", got, s) + } + + // Lookup via SPBHID works (dual-index write). + bySPBHID, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err != nil { + t.Fatalf("LoadHLSBySPBHID: %v", err) + } + if bySPBHID != s { + t.Fatalf("LoadHLSBySPBHID returned %v, want %v", bySPBHID, s) + } +} + +func TestMemLB_LoadOrStoreHLS_Existing(t *testing.T) { + lb := newMem() + s1 := &stubHLS{spbhid: "first"} + s2 := &stubHLS{spbhid: "second"} + _, _ = lb.LoadOrStoreHLS(context.Background(), "url1", s1) + got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s2) + if err != nil { + t.Fatalf("LoadOrStoreHLS: %v", err) + } + if got != s1 { + t.Fatalf("expected existing s1, got %v", got) + } + // SPBHID 'second' (from the rejected s2) maps to the existing s1. + bySPBHID, _ := lb.LoadHLSBySPBHID(context.Background(), "second") + if bySPBHID != s1 { + t.Fatalf("expected SPBHID 'second' to map to s1, got %v", bySPBHID) + } +} + +func TestMemLB_StoreWebRTC_And_Load(t *testing.T) { + lb := newMem() + s := &stubRTC{ufrag: "ufrg1"} + if err := lb.StoreWebRTC(context.Background(), "url1", s); err != nil { + t.Fatalf("StoreWebRTC: %v", err) + } + got, err := lb.LoadWebRTCByUfrag(context.Background(), "ufrg1") + if err != nil { + t.Fatalf("LoadWebRTCByUfrag: %v", err) + } + if got != s { + t.Fatalf("got %v, want %v", got, s) + } +} + +func TestMemLB_LoadWebRTCByUfrag_NotFound(t *testing.T) { + lb := newMem() + _, err := lb.LoadWebRTCByUfrag(context.Background(), "missing") + if err == nil || !strings.Contains(err.Error(), "no WebRTC streaming") { + t.Fatalf("expected error, got %v", err) + } +} diff --git a/internal/lb/redis.go b/internal/lb/redis.go index d47bf8982..fc2a7101e 100644 --- a/internal/lb/redis.go +++ b/internal/lb/redis.go @@ -11,40 +11,47 @@ import ( "strconv" "time" - // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ - "github.com/go-redis/redis/v8" - "srsx/internal/env" "srsx/internal/errors" "srsx/internal/logger" + "srsx/internal/redisclient" ) -// RedisLoadBalancer stores state in Redis. -type RedisLoadBalancer struct { +// redisLoadBalancer stores state in Redis. +type redisLoadBalancer struct { // The environment interface. environment env.ProxyEnvironment - // The redis client sdk. - rdb *redis.Client + // The redis client. + rdb redisclient.RedisClient + // newClient is the factory used by Initialize to build the Redis client. + // A struct field (rather than a package global) so concurrent tests can + // each supply their own without racing on shared state. + newClient func(addr, password string, db int) redisclient.RedisClient + // keepaliveInterval is the period at which the default-backend keep-alive + // goroutine re-Updates its registration. Struct field for test injection. + keepaliveInterval time.Duration } // NewRedisLoadBalancer creates a new Redis-based load balancer. -func NewRedisLoadBalancer(environment env.ProxyEnvironment) SRSLoadBalancer { - return &RedisLoadBalancer{ - environment: environment, +func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer { + return &redisLoadBalancer{ + environment: environment, + newClient: redisclient.New, + keepaliveInterval: 30 * time.Second, } } -func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { +func (v *redisLoadBalancer) Initialize(ctx context.Context) error { redisDatabase, err := strconv.Atoi(v.environment.RedisDB()) if err != nil { return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB()) } - rdb := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()), - Password: v.environment.RedisPassword(), - DB: redisDatabase, - }) + rdb := v.newClient( + fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()), + v.environment.RedisPassword(), + redisDatabase, + ) v.rdb = rdb if err := rdb.Ping(ctx).Err(); err != nil { @@ -52,7 +59,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { } logger.Debug(ctx, "RedisLB: connected to redis %v ok", rdb.String()) - server, err := NewDefaultSRSForDebugging(v.environment) + server, err := NewDefaultOriginServerForDebugging(v.environment) if err != nil { return errors.Wrapf(err, "initialize default SRS") } @@ -68,7 +75,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { select { case <-ctx.Done(): return - case <-time.After(30 * time.Second): + case <-time.After(v.keepaliveInterval): if err := v.Update(ctx, server); err != nil { logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err) } @@ -80,7 +87,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { return nil } -func (v *RedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error { +func (v *redisLoadBalancer) Update(ctx context.Context, server *OriginServer) error { b, err := json.Marshal(server) if err != nil { return errors.Wrapf(err, "marshal server %+v", server) @@ -130,14 +137,14 @@ func (v *RedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error return nil } -func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { +func (v *redisLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) { key := fmt.Sprintf("srs-proxy-url:%v", streamURL) // Always proxy to the same server for the same stream URL. if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil { // If server not exists, ignore and pick another server for the stream URL. if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 { - var server SRSServer + var server OriginServer if err := json.Unmarshal(b, &server); err != nil { return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b)) } @@ -163,7 +170,7 @@ func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSSer // All server should be alive, if not, should have been removed by redis. So we only // random pick one that is always available. Use global rand which is thread-safe since Go 1.20. var serverKey string - var server SRSServer + var server OriginServer for i := 0; i < 3; i++ { tryServerKey := serverKeys[rand.Intn(len(serverKeys))] b, err := v.rdb.Get(ctx, tryServerKey).Bytes() @@ -188,7 +195,7 @@ func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSSer return &server, nil } -func (v *RedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { +func (v *redisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { key := v.redisKeySPBHID(spbhid) b, err := v.rdb.Get(ctx, key).Bytes() @@ -208,7 +215,7 @@ func (v *RedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) return nil, errors.Errorf("Redis load balancer cannot deserialize interface types") } -func (v *RedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { +func (v *redisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { b, err := json.Marshal(value) if err != nil { return nil, errors.Wrapf(err, "marshal HLS %v", value) @@ -229,7 +236,7 @@ func (v *RedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string return value, nil } -func (v *RedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { +func (v *redisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { b, err := json.Marshal(value) if err != nil { return errors.Wrapf(err, "marshal WebRTC %v", value) @@ -249,7 +256,7 @@ func (v *RedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, v return nil } -func (v *RedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { +func (v *redisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { key := v.redisKeyUfrag(ufrag) b, err := v.rdb.Get(ctx, key).Bytes() @@ -267,26 +274,26 @@ func (v *RedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) return nil, errors.Errorf("Redis load balancer cannot deserialize interface types") } -func (v *RedisLoadBalancer) redisKeyUfrag(ufrag string) string { +func (v *redisLoadBalancer) redisKeyUfrag(ufrag string) string { return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) } -func (v *RedisLoadBalancer) redisKeyRTC(streamURL string) string { +func (v *redisLoadBalancer) redisKeyRTC(streamURL string) string { return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) } -func (v *RedisLoadBalancer) redisKeySPBHID(spbhid string) string { +func (v *redisLoadBalancer) redisKeySPBHID(spbhid string) string { return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) } -func (v *RedisLoadBalancer) redisKeyHLS(streamURL string) string { +func (v *redisLoadBalancer) redisKeyHLS(streamURL string) string { return fmt.Sprintf("srs-proxy-hls:%v", streamURL) } -func (v *RedisLoadBalancer) redisKeyServer(serverID string) string { +func (v *redisLoadBalancer) redisKeyServer(serverID string) string { return fmt.Sprintf("srs-proxy-server:%v", serverID) } -func (v *RedisLoadBalancer) redisKeyServers() string { +func (v *redisLoadBalancer) redisKeyServers() string { return fmt.Sprintf("srs-proxy-all-servers") } diff --git a/internal/lb/redis_test.go b/internal/lb/redis_test.go new file mode 100644 index 000000000..6e3c17796 --- /dev/null +++ b/internal/lb/redis_test.go @@ -0,0 +1,659 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/go-redis/redis/v8" + + "srsx/internal/env/envfakes" + "srsx/internal/redisclient" + "srsx/internal/redisclient/redisclientfakes" +) + +// ---------------------------------------------------------------------------- +// Helpers. +// ---------------------------------------------------------------------------- + +// statusCmd returns a *redis.StatusCmd that resolves to the given error. +func statusCmd(err error) *redis.StatusCmd { + c := redis.NewStatusCmd(context.Background()) + if err != nil { + c.SetErr(err) + } + return c +} + +// stringOK returns a *redis.StringCmd that resolves to the given bytes. +func stringOK(b []byte) *redis.StringCmd { + c := redis.NewStringCmd(context.Background()) + c.SetVal(string(b)) + return c +} + +// stringErr returns a *redis.StringCmd that resolves to the given error. +func stringErr(err error) *redis.StringCmd { + c := redis.NewStringCmd(context.Background()) + c.SetErr(err) + return c +} + +// withFakeClient returns a fresh *redisLoadBalancer whose newClient factory is +// wired to return the supplied fake. Each test gets its own instance, so +// concurrent tests cannot race on shared state. +func withFakeClient(env *envfakes.FakeProxyEnvironment, client redisclient.RedisClient) *redisLoadBalancer { + lb := NewRedisLoadBalancer(env).(*redisLoadBalancer) + lb.newClient = func(string, string, int) redisclient.RedisClient { return client } + return lb +} + +// newRedisLB constructs a redisLoadBalancer with a fake rdb already wired in. +// Used by tests that exercise methods other than Initialize. +func newRedisLB(rdb redisclient.RedisClient) *redisLoadBalancer { + env := &envfakes.FakeProxyEnvironment{} + lb := NewRedisLoadBalancer(env).(*redisLoadBalancer) + lb.rdb = rdb + return lb +} + +// ---------------------------------------------------------------------------- +// Constructor & Initialize. +// ---------------------------------------------------------------------------- + +func TestNewRedisLoadBalancer(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + if lb := NewRedisLoadBalancer(env); lb == nil { + t.Fatal("NewRedisLoadBalancer returned nil") + } +} + +func TestRedisLB_Initialize_BadRedisDB(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("not-a-number") + err := NewRedisLoadBalancer(env).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "invalid PROXY_REDIS_DB") { + t.Fatalf("expected Atoi error, got %v", err) + } +} + +func TestRedisLB_Initialize_PingFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(fmt.Errorf("connection refused"))) + fake.StringReturns("Redis") + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + err := withFakeClient(env, fake).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "unable to connect to redis") { + t.Fatalf("expected ping error, got %v", err) + } +} + +func TestRedisLB_Initialize_DefaultBackendDisabled(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + // DefaultBackendEnabled defaults to "" (not "on") => no server registered. + if err := withFakeClient(env, fake).Initialize(context.Background()); err != nil { + t.Fatalf("Initialize: %v", err) + } +} + +func TestRedisLB_Initialize_DefaultBackendError(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("") // triggers NewDefaultOriginServerForDebugging error + err := withFakeClient(env, fake).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "initialize default SRS") { + t.Fatalf("expected default-SRS error, got %v", err) + } +} + +func TestRedisLB_Initialize_UpdateFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + fake.SetReturns(statusCmd(fmt.Errorf("set failed"))) // every Set fails + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + err := withFakeClient(env, fake).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "update default SRS") { + t.Fatalf("expected update error, got %v", err) + } +} + +func TestRedisLB_Initialize_Success(t *testing.T) { + var setCalls atomic.Int32 + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + fake.SetStub = func(ctx context.Context, key string, value interface{}, ttl time.Duration) *redis.StatusCmd { + setCalls.Add(1) + return statusCmd(nil) + } + // Every Get returns redis.Nil-style error so the server list is treated as empty. + fake.GetReturns(stringErr(fmt.Errorf("redis: nil"))) + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + + lb := withFakeClient(env, fake) + lb.keepaliveInterval = time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := lb.Initialize(ctx); err != nil { + t.Fatalf("Initialize: %v", err) + } + + // Initial Update made 2 Set calls (server + server list). Wait long enough + // for the keep-alive tick to issue more. + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) && setCalls.Load() < 4 { + time.Sleep(5 * time.Millisecond) + } + cancel() + time.Sleep(10 * time.Millisecond) + if setCalls.Load() < 4 { + t.Fatalf("keep-alive did not tick: setCalls=%d", setCalls.Load()) + } +} + +// ---------------------------------------------------------------------------- +// Update. +// ---------------------------------------------------------------------------- + +func TestRedisLB_Update_SetServerFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("boom"))) + lb := newRedisLB(fake) + err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}) + if err == nil || !strings.Contains(err.Error(), "set key=") { + t.Fatalf("expected set-server error, got %v", err) + } +} + +func TestRedisLB_Update_FreshList(t *testing.T) { + // No existing server list => Get for server-list key returns error. + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + + lb := newRedisLB(fake) + server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"} + if err := lb.Update(context.Background(), server); err != nil { + t.Fatalf("Update: %v", err) + } + + // Two Set calls: server + servers-list. + if got := fake.SetCallCount(); got != 2 { + t.Fatalf("Set call count=%d, want 2", got) + } + // The second Set value should be a JSON array containing the server key. + _, _, value, _ := fake.SetArgsForCall(1) + var keys []string + if err := json.Unmarshal(value.([]byte), &keys); err != nil { + t.Fatalf("server-list value not JSON: %v", err) + } + want := lb.redisKeyServer(server.ID()) + if len(keys) != 1 || keys[0] != want { + t.Fatalf("server-list keys=%v, want [%q]", keys, want) + } +} + +func TestRedisLB_Update_PrunesDeadAndAppends(t *testing.T) { + server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"} + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + + // First Get: server-list, returns ["dead", "alive"]. + // Subsequent Gets: probe each key — "dead" missing, "alive" present. + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{"dead", "alive"}) + return stringOK(b) + } + if key == "alive" { + return stringOK([]byte("ok")) + } + return stringErr(fmt.Errorf("nil")) + } + + lb := newRedisLB(fake) + if err := lb.Update(context.Background(), server); err != nil { + t.Fatalf("Update: %v", err) + } + + // Inspect the server-list Set call: should contain "alive" (kept) and the + // new server key (appended); "dead" should be pruned. + _, _, value, _ := fake.SetArgsForCall(1) + var keys []string + if err := json.Unmarshal(value.([]byte), &keys); err != nil { + t.Fatalf("not JSON: %v", err) + } + wantNew := lb.redisKeyServer(server.ID()) + if len(keys) != 2 || keys[0] != "alive" || keys[1] != wantNew { + t.Fatalf("server-list keys=%v, want [alive, %q]", keys, wantNew) + } +} + +func TestRedisLB_Update_AlreadyInList(t *testing.T) { + server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"} + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + wantKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{wantKey}) + return stringOK(b) + } + return stringOK([]byte("ok")) + } + + if err := lb.Update(context.Background(), server); err != nil { + t.Fatalf("Update: %v", err) + } + _, _, value, _ := fake.SetArgsForCall(1) + var keys []string + _ = json.Unmarshal(value.([]byte), &keys) + if len(keys) != 1 || keys[0] != wantKey { + t.Fatalf("expected no duplication, got %v", keys) + } +} + +func TestRedisLB_Update_BadServerListJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Update_SetServerListFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + // First Set ok (server), second Set fails (server list). + fake.SetReturnsOnCall(0, statusCmd(nil)) + fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("set list failed"))) + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + + lb := newRedisLB(fake) + err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}) + if err == nil || !strings.Contains(err.Error(), "set list failed") { + t.Fatalf("expected server-list set error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// Pick. +// ---------------------------------------------------------------------------- + +func TestRedisLB_Pick_StickyHit(t *testing.T) { + server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"} + serverJSON, _ := json.Marshal(server) + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + streamKey := "srs-proxy-url:url1" + serverKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + switch key { + case streamKey: + return stringOK([]byte(serverKey)) + case serverKey: + return stringOK(serverJSON) + } + return stringErr(fmt.Errorf("nil")) + } + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got.ID() != server.ID() { + t.Fatalf("Pick returned %v, want %v", got, server) + } +} + +func TestRedisLB_Pick_StickyBadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + lb := newRedisLB(fake) + streamKey := "srs-proxy-url:url1" + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + switch key { + case streamKey: + return stringOK([]byte("srv-key")) + case "srv-key": + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Pick_NoServersAvailable(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + // Sticky miss + server list missing. + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "no server available") { + t.Fatalf("expected no-server error, got %v", err) + } +} + +func TestRedisLB_Pick_BadServerListJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Pick_AllProbesFail(t *testing.T) { + // Server list contains one key, but probing it returns nil bytes (the + // `len(b) > 0` guard rejects it). After 3 attempts, Pick errors out. + fake := &redisclientfakes.FakeRedisClient{} + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{"srv-key"}) + return stringOK(b) + } + // "srv-key" probe returns empty bytes — falls through the available check. + if key == "srv-key" { + return stringOK(nil) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "no server available in") { + t.Fatalf("expected exhausted-probes error, got %v", err) + } +} + +func TestRedisLB_Pick_ScanSuccess(t *testing.T) { + server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"} + serverJSON, _ := json.Marshal(server) + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + serverKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{serverKey}) + return stringOK(b) + } + if key == serverKey { + return stringOK(serverJSON) + } + // Sticky lookup for the URL key misses. + return stringErr(fmt.Errorf("nil")) + } + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got.ID() != server.ID() { + t.Fatalf("Pick returned %v", got) + } + // Pick should also store the picked-mapping. + if fake.SetCallCount() != 1 { + t.Fatalf("expected 1 Set call to store picked mapping, got %d", fake.SetCallCount()) + } +} + +func TestRedisLB_Pick_ScanBadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{"srv-key"}) + return stringOK(b) + } + if key == "srv-key" { + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Pick_StoreMappingFails(t *testing.T) { + server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"} + serverJSON, _ := json.Marshal(server) + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("set failed"))) + lb := newRedisLB(fake) + serverKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{serverKey}) + return stringOK(b) + } + if key == serverKey { + return stringOK(serverJSON) + } + return stringErr(fmt.Errorf("nil")) + } + + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "set failed") { + t.Fatalf("expected set-mapping error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// LoadHLSBySPBHID and LoadWebRTCByUfrag — symmetric behavior. +// ---------------------------------------------------------------------------- + +func TestRedisLB_LoadHLSBySPBHID_GetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + lb := newRedisLB(fake) + _, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err == nil || !strings.Contains(err.Error(), "get key=") { + t.Fatalf("expected get error, got %v", err) + } +} + +func TestRedisLB_LoadHLSBySPBHID_BadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte("not-json"))) + lb := newRedisLB(fake) + _, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_LoadHLSBySPBHID_InterfaceLimitation(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`))) + lb := newRedisLB(fake) + _, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err == nil || !strings.Contains(err.Error(), "cannot deserialize") { + t.Fatalf("expected interface limitation error, got %v", err) + } +} + +func TestRedisLB_LoadWebRTCByUfrag_GetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + lb := newRedisLB(fake) + _, err := lb.LoadWebRTCByUfrag(context.Background(), "u") + if err == nil || !strings.Contains(err.Error(), "get key=") { + t.Fatalf("expected get error, got %v", err) + } +} + +func TestRedisLB_LoadWebRTCByUfrag_BadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte("not-json"))) + lb := newRedisLB(fake) + _, err := lb.LoadWebRTCByUfrag(context.Background(), "u") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_LoadWebRTCByUfrag_InterfaceLimitation(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`))) + lb := newRedisLB(fake) + _, err := lb.LoadWebRTCByUfrag(context.Background(), "u") + if err == nil || !strings.Contains(err.Error(), "cannot deserialize") { + t.Fatalf("expected interface limitation error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// LoadOrStoreHLS and StoreWebRTC. +// ---------------------------------------------------------------------------- + +func TestRedisLB_LoadOrStoreHLS_Success(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + + hls := &stubHLS{spbhid: "abc"} + got, err := lb.LoadOrStoreHLS(context.Background(), "url1", hls) + if err != nil { + t.Fatalf("LoadOrStoreHLS: %v", err) + } + if got != hls { + t.Fatalf("got %v, want input back", got) + } + if fake.SetCallCount() != 2 { + t.Fatalf("expected 2 Set calls (URL + SPBHID), got %d", fake.SetCallCount()) + } +} + +func TestRedisLB_LoadOrStoreHLS_FirstSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("boom"))) + lb := newRedisLB(fake) + _, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"}) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestRedisLB_LoadOrStoreHLS_SecondSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturnsOnCall(0, statusCmd(nil)) + fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom"))) + lb := newRedisLB(fake) + _, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"}) + if err == nil || !strings.Contains(err.Error(), "second boom") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestRedisLB_StoreWebRTC_Success(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + if err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}); err != nil { + t.Fatalf("StoreWebRTC: %v", err) + } + if fake.SetCallCount() != 2 { + t.Fatalf("expected 2 Set calls (URL + Ufrag), got %d", fake.SetCallCount()) + } +} + +func TestRedisLB_StoreWebRTC_FirstSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("boom"))) + lb := newRedisLB(fake) + err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestRedisLB_StoreWebRTC_SecondSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturnsOnCall(0, statusCmd(nil)) + fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom"))) + lb := newRedisLB(fake) + err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}) + if err == nil || !strings.Contains(err.Error(), "second boom") { + t.Fatalf("expected error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// Key helpers. +// ---------------------------------------------------------------------------- + +func TestRedisLB_KeyHelpers(t *testing.T) { + lb := &redisLoadBalancer{} + for _, tt := range []struct { + got, want string + }{ + {lb.redisKeyUfrag("u"), "srs-proxy-ufrag:u"}, + {lb.redisKeyRTC("url"), "srs-proxy-rtc:url"}, + {lb.redisKeySPBHID("s"), "srs-proxy-spbhid:s"}, + {lb.redisKeyHLS("url"), "srs-proxy-hls:url"}, + {lb.redisKeyServer("id"), "srs-proxy-server:id"}, + {lb.redisKeyServers(), "srs-proxy-all-servers"}, + } { + if tt.got != tt.want { + t.Errorf("got %q, want %q", tt.got, tt.want) + } + } +} diff --git a/internal/server/api.go b/internal/proxy/api.go similarity index 70% rename from internal/server/api.go rename to internal/proxy/api.go index a69353ee5..381173c61 100644 --- a/internal/server/api.go +++ b/internal/proxy/api.go @@ -1,7 +1,7 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package server +package proxy import ( "context" @@ -20,45 +20,70 @@ import ( "srsx/internal/version" ) -// HTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, +// HTTPAPIProxyServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, // to proxy other HTTP API of SRS like the streams and clients, etc. -type HTTPAPIServer interface { +type HTTPAPIProxyServer interface { Run(ctx context.Context) error Close() error } -type httpAPIServer struct { +type httpAPIProxyServer struct { // The environment interface. environment env.ProxyEnvironment // The underlayer HTTP server. - server *http.Server + server httpServer // The WebRTC server. - rtc WebRTCServer + rtc WebRTCProxyServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. wg sync.WaitGroup + // shutdown gracefully shuts down the underlying HTTP server. Defaults to + // v.server.Shutdown; tests may override via a functional option to verify + // the shutdown contract without binding a real socket. + shutdown func(ctx context.Context) error + // newServer constructs the underlying HTTP server bound to addr and the + // ServeMux that handlers are registered on. Defaults to a real http.Server + // and ServeMux; tests may override via a functional option to supply a fake + // server that does not bind a real port. + newServer func(addr string) (httpServer, *http.ServeMux) } -func NewHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc WebRTCServer) HTTPAPIServer { - v := &httpAPIServer{ +func NewHTTPAPIProxyServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc WebRTCProxyServer, opts ...func(*httpAPIProxyServer)) HTTPAPIProxyServer { + v := &httpAPIProxyServer{ environment: environment, gracefulQuitTimeout: gracefulQuitTimeout, rtc: rtc, } + + // Default shutdown: delegate to the underlying http.Server. The closure + // captures v rather than v.server so the dereference happens at call time, + // after Run() has assigned v.server. + v.shutdown = func(ctx context.Context) error { + return v.server.Shutdown(ctx) + } + // Default newServer: a real http.Server and ServeMux pair. + v.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + return &http.Server{Addr: addr, Handler: mux}, mux + } + + for _, opt := range opts { + opt(v) + } return v } -func (v *httpAPIServer) Close() error { +func (v *httpAPIProxyServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) v.wg.Wait() return nil } -func (v *httpAPIServer) Run(ctx context.Context) error { +func (v *httpAPIProxyServer) Run(ctx context.Context) error { // Parse address to listen. addr := v.environment.HttpAPI() if !strings.Contains(addr, ":") { @@ -66,8 +91,8 @@ func (v *httpAPIServer) Run(ctx context.Context) error { } // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} + server, mux := v.newServer(addr) + v.server = server logger.Debug(ctx, "HTTP API server listen at %v", addr) // Shutdown the server gracefully when quiting. @@ -78,7 +103,7 @@ func (v *httpAPIServer) Run(ctx context.Context) error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) }() // The basic version handler, also can be used as health check API. @@ -147,26 +172,54 @@ func (v *httpAPIServer) Run(ctx context.Context) error { type systemAPI struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The underlayer HTTP server. - server *http.Server + server httpServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. wg sync.WaitGroup + // shutdown gracefully shuts down the underlying HTTP server. Defaults to + // v.server.Shutdown; tests may override via a functional option to verify + // the shutdown contract without binding a real socket. + shutdown func(ctx context.Context) error + // newServer constructs the underlying HTTP server bound to addr and the + // ServeMux that handlers are registered on. Defaults to a real http.Server + // and ServeMux; tests may override via a functional option to supply a fake + // server that does not bind a real port. + newServer func(addr string) (httpServer, *http.ServeMux) } -func NewSystemAPI(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) *systemAPI { +func NewSystemAPI(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration, opts ...func(*systemAPI)) *systemAPI { v := &systemAPI{ environment: environment, + loadBalancer: loadBalancer, gracefulQuitTimeout: gracefulQuitTimeout, } + + // Default shutdown: delegate to the underlying http.Server. The closure + // captures v rather than v.server so the dereference happens at call time, + // after Run() has assigned v.server. + v.shutdown = func(ctx context.Context) error { + return v.server.Shutdown(ctx) + } + // Default newServer: a real http.Server and ServeMux pair. + v.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + return &http.Server{Addr: addr, Handler: mux}, mux + } + + for _, opt := range opts { + opt(v) + } return v } func (v *systemAPI) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) v.wg.Wait() return nil @@ -180,8 +233,8 @@ func (v *systemAPI) Run(ctx context.Context) error { } // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} + server, mux := v.newServer(addr) + v.server = server logger.Debug(ctx, "System API server listen at %v", addr) // Shutdown the server gracefully when quiting. @@ -192,7 +245,7 @@ func (v *systemAPI) Run(ctx context.Context) error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) }() // The basic version handler, also can be used as health check API. @@ -255,14 +308,14 @@ func (v *systemAPI) Run(ctx context.Context) error { return errors.Errorf("empty rtmp") } - server := lb.NewSRSServer(func(srs *lb.SRSServer) { + server := lb.NewOriginServer(func(srs *lb.OriginServer) { srs.IP, srs.DeviceID = ip, deviceID srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api srs.SRT, srs.RTC = srt, rtc srs.UpdatedAt = time.Now() }) - if err := lb.SrsLoadBalancer.Update(ctx, server); err != nil { + if err := v.loadBalancer.Update(ctx, server); err != nil { return errors.Wrapf(err, "update SRS server %+v", server) } diff --git a/internal/proxy/api_test.go b/internal/proxy/api_test.go new file mode 100644 index 000000000..c05941fb6 --- /dev/null +++ b/internal/proxy/api_test.go @@ -0,0 +1,892 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env/envfakes" + "srsx/internal/lb/lbfakes" +) + +// fakeWebRTCProxyServer is a minimal in-package WebRTCProxyServer used by +// httpAPIProxyServer tests. Only the WHIP/WHEP handler methods are exercised. +// Run/Close are inert stubs so the type satisfies the interface. +type fakeWebRTCProxyServer struct { + whipCalls atomic.Int32 + whepCalls atomic.Int32 + whipReturn error + whepReturn error + whipResponseBody string + whepResponseBody string +} + +func (f *fakeWebRTCProxyServer) Run(ctx context.Context) error { return nil } +func (f *fakeWebRTCProxyServer) Close() error { return nil } +func (f *fakeWebRTCProxyServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + f.whipCalls.Add(1) + if f.whipResponseBody != "" { + w.WriteHeader(http.StatusOK) + io.WriteString(w, f.whipResponseBody) + } + return f.whipReturn +} +func (f *fakeWebRTCProxyServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + f.whepCalls.Add(1) + if f.whepResponseBody != "" { + w.WriteHeader(http.StatusOK) + io.WriteString(w, f.whepResponseBody) + } + return f.whepReturn +} + +// captureMuxFromHTTPAPIRun drives NewHTTPAPIProxyServer.Run with a fake server +// that captures the registered mux. Caller is responsible for cancelling ctx +// to trigger shutdown. +func captureMuxFromHTTPAPIRun(t *testing.T, env *envfakes.FakeProxyEnvironment, + rtc WebRTCProxyServer, ctx context.Context, + opts ...func(*httpAPIProxyServer)) (*http.ServeMux, *fakeHTTPProxyServer, *httpAPIProxyServer) { + t.Helper() + + fakeSrv := newFakeHTTPProxyServer() + var capturedMux *http.ServeMux + + baseOpts := []func(*httpAPIProxyServer){ + func(s *httpAPIProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + capturedMux = mux + return fakeSrv, mux + } + }, + } + srvIface := NewHTTPAPIProxyServer(env, 50*time.Millisecond, rtc, append(baseOpts, opts...)...) + srv := srvIface.(*httpAPIProxyServer) + + if err := srv.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedMux == nil { + t.Fatal("newServer was not called by Run") + } + return capturedMux, fakeSrv, srv +} + +// captureMuxFromSystemAPIRun drives NewSystemAPI.Run with a fake server that +// captures the registered mux. Caller cancels ctx to trigger shutdown. +func captureMuxFromSystemAPIRun(t *testing.T, env *envfakes.FakeProxyEnvironment, + lbFake *lbfakes.FakeOriginLoadBalancer, ctx context.Context, + opts ...func(*systemAPI)) (*http.ServeMux, *fakeHTTPProxyServer, *systemAPI) { + t.Helper() + + fakeSrv := newFakeHTTPProxyServer() + var capturedMux *http.ServeMux + + baseOpts := []func(*systemAPI){ + func(s *systemAPI) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + capturedMux = mux + return fakeSrv, mux + } + }, + } + srv := NewSystemAPI(env, lbFake, 50*time.Millisecond, append(baseOpts, opts...)...) + + if err := srv.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedMux == nil { + t.Fatal("newServer was not called by Run") + } + return capturedMux, fakeSrv, srv +} + +// ============================================================================= +// NewHTTPAPIProxyServer +// ============================================================================= + +func TestHTTPAPIProxyServer_New_StoresFieldsAndDefaultsSeams(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + rtc := &fakeWebRTCProxyServer{} + timeout := 2 * time.Second + srv := NewHTTPAPIProxyServer(env, timeout, rtc).(*httpAPIProxyServer) + + if srv.environment != env { + t.Error("environment not stored") + } + if srv.rtc != rtc { + t.Error("rtc not stored") + } + if srv.gracefulQuitTimeout != timeout { + t.Errorf("gracefulQuitTimeout = %v, want %v", srv.gracefulQuitTimeout, timeout) + } + if srv.shutdown == nil { + t.Error("shutdown seam should default to non-nil") + } + if srv.newServer == nil { + t.Error("newServer seam should default to non-nil") + } +} + +func TestHTTPAPIProxyServer_New_AppliesOpts(t *testing.T) { + var called bool + srv := NewHTTPAPIProxyServer(&envfakes.FakeProxyEnvironment{}, time.Second, + &fakeWebRTCProxyServer{}, + func(s *httpAPIProxyServer) { called = true }).(*httpAPIProxyServer) + if !called { + t.Fatal("opt was not invoked") + } + if srv.shutdown == nil { + t.Error("default seams should still be set when opt doesn't override them") + } +} + +func TestHTTPAPIProxyServer_New_OptCanOverrideAllSeams(t *testing.T) { + customShutdown := func(context.Context) error { return errors.New("custom") } + customNewServer := func(string) (httpServer, *http.ServeMux) { return nil, nil } + + srv := NewHTTPAPIProxyServer(&envfakes.FakeProxyEnvironment{}, time.Second, + &fakeWebRTCProxyServer{}, + func(s *httpAPIProxyServer) { + s.shutdown = customShutdown + s.newServer = customNewServer + }).(*httpAPIProxyServer) + + if err := srv.shutdown(context.Background()); err == nil || err.Error() != "custom" { + t.Errorf("custom shutdown not applied: %v", err) + } + // Pointer comparison on func values isn't supported by ==; call the value + // and observe the override via behavior. + if got, _ := srv.newServer(""); got != nil { + t.Error("custom newServer not applied") + } +} + +// ============================================================================= +// httpAPIProxyServer — default factory behavior +// ============================================================================= + +func TestHTTPAPIProxyServer_DefaultNewServer_BuildsRealServerAndMux(t *testing.T) { + srv := NewHTTPAPIProxyServer(&envfakes.FakeProxyEnvironment{}, time.Second, + &fakeWebRTCProxyServer{}).(*httpAPIProxyServer) + + got, mux := srv.newServer(":12321") + if mux == nil { + t.Fatal("mux is nil") + } + real, ok := got.(*http.Server) + if !ok { + t.Fatalf("expected *http.Server, got %T", got) + } + if real.Addr != ":12321" { + t.Errorf("Addr = %q, want :12321", real.Addr) + } + if real.Handler != mux { + t.Error("Handler should be the returned mux") + } +} + +func TestHTTPAPIProxyServer_DefaultShutdown_DelegatesToServer(t *testing.T) { + fakeSrv := newFakeHTTPProxyServer() + srv := NewHTTPAPIProxyServer(&envfakes.FakeProxyEnvironment{}, time.Second, + &fakeWebRTCProxyServer{}).(*httpAPIProxyServer) + srv.server = fakeSrv // simulate what Run() would assign + + if err := srv.shutdown(context.Background()); err != nil { + t.Fatalf("shutdown: %v", err) + } + if fakeSrv.shutdownCalls.Load() != 1 { + t.Fatalf("shutdown was not delegated to server, calls=%d", fakeSrv.shutdownCalls.Load()) + } +} + +// ============================================================================= +// httpAPIProxyServer — Close +// ============================================================================= + +func TestHTTPAPIProxyServer_Close_InvokesShutdownWithDeadline(t *testing.T) { + var gotCtx context.Context + var calls int + srv := NewHTTPAPIProxyServer(&envfakes.FakeProxyEnvironment{}, 50*time.Millisecond, + &fakeWebRTCProxyServer{}, + func(s *httpAPIProxyServer) { + s.shutdown = func(ctx context.Context) error { + gotCtx = ctx + calls++ + return nil + } + }).(*httpAPIProxyServer) + + if err := srv.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if calls != 1 { + t.Fatalf("shutdown calls = %d, want 1", calls) + } + if _, ok := gotCtx.Deadline(); !ok { + t.Error("Close should pass a deadline-bearing ctx to shutdown") + } +} + +// ============================================================================= +// httpAPIProxyServer — Run lifecycle +// ============================================================================= + +func TestHTTPAPIProxyServer_Run_AddrWithoutColonPrependsIt(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns("11985") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPAPIProxyServer(env, 50*time.Millisecond, &fakeWebRTCProxyServer{}, + func(s *httpAPIProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srvIface.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srvIface.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != ":11985" { + t.Fatalf("newServer addr = %q, want :11985", capturedAddr) + } +} + +func TestHTTPAPIProxyServer_Run_AddrWithColonUnchanged(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns("127.0.0.1:9999") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPAPIProxyServer(env, 50*time.Millisecond, &fakeWebRTCProxyServer{}, + func(s *httpAPIProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srvIface.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srvIface.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != "127.0.0.1:9999" { + t.Fatalf("newServer addr = %q", capturedAddr) + } +} + +func TestHTTPAPIProxyServer_Run_CtxCancelTriggersShutdown(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, fakeSrv, _ := captureMuxFromHTTPAPIRun(t, env, &fakeWebRTCProxyServer{}, ctx) + + deadline := time.Now().Add(time.Second) + for fakeSrv.listenCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.listenCalls.Load() == 0 { + t.Fatal("ListenAndServe goroutine did not start") + } + + cancel() + + deadline = time.Now().Add(time.Second) + for fakeSrv.shutdownCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.shutdownCalls.Load() == 0 { + t.Fatal("Shutdown was not invoked after ctx cancel") + } +} + +// ============================================================================= +// httpAPIProxyServer — handler dispatch +// ============================================================================= + +func TestHTTPAPIProxyServer_Run_HandlerVersionsReturnsJSON(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, &fakeWebRTCProxyServer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/versions", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var body map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("json: %v\nbody=%s", err, rec.Body.String()) + } + if body["signature"] == "" { + t.Error("signature should be populated") + } + if body["version"] == "" { + t.Error("version should be populated") + } +} + +func TestHTTPAPIProxyServer_Run_HandlerWHIPDelegatesToRTC(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + rtc := &fakeWebRTCProxyServer{whipResponseBody: "ok-whip"} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, rtc, ctx) + + req := httptest.NewRequest(http.MethodPost, "/rtc/v1/whip/?app=live&stream=s", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rtc.whipCalls.Load() != 1 { + t.Fatalf("HandleApiForWHIP calls = %d, want 1", rtc.whipCalls.Load()) + } + if rtc.whepCalls.Load() != 0 { + t.Errorf("HandleApiForWHEP should not be invoked") + } + if !bytes.Equal(rec.Body.Bytes(), []byte("ok-whip")) { + t.Errorf("body = %q, want ok-whip", rec.Body.String()) + } +} + +func TestHTTPAPIProxyServer_Run_HandlerLegacyPublishRoutesToWHIP(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + rtc := &fakeWebRTCProxyServer{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, rtc, ctx) + + req := httptest.NewRequest(http.MethodPost, "/rtc/v1/publish/", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rtc.whipCalls.Load() != 1 { + t.Fatalf("HandleApiForWHIP via /rtc/v1/publish/ calls = %d, want 1", rtc.whipCalls.Load()) + } + if rtc.whepCalls.Load() != 0 { + t.Errorf("HandleApiForWHEP should not be invoked via /rtc/v1/publish/") + } +} + +func TestHTTPAPIProxyServer_Run_HandlerWHIPErrorInvokesApiError(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + rtc := &fakeWebRTCProxyServer{whipReturn: errors.New("boom-whip")} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, rtc, ctx) + + req := httptest.NewRequest(http.MethodPost, "/rtc/v1/whip/", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", rec.Code) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("boom-whip")) { + t.Errorf("body = %q, expected to contain error message", rec.Body.String()) + } +} + +func TestHTTPAPIProxyServer_Run_HandlerWHEPDelegatesToRTC(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + rtc := &fakeWebRTCProxyServer{whepResponseBody: "ok-whep"} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, rtc, ctx) + + req := httptest.NewRequest(http.MethodPost, "/rtc/v1/whep/", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rtc.whepCalls.Load() != 1 { + t.Fatalf("HandleApiForWHEP calls = %d, want 1", rtc.whepCalls.Load()) + } + if rtc.whipCalls.Load() != 0 { + t.Errorf("HandleApiForWHIP should not be invoked") + } + if !bytes.Equal(rec.Body.Bytes(), []byte("ok-whep")) { + t.Errorf("body = %q, want ok-whep", rec.Body.String()) + } +} + +func TestHTTPAPIProxyServer_Run_HandlerLegacyPlayRoutesToWHEP(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + rtc := &fakeWebRTCProxyServer{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, rtc, ctx) + + req := httptest.NewRequest(http.MethodPost, "/rtc/v1/play/", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rtc.whepCalls.Load() != 1 { + t.Fatalf("HandleApiForWHEP via /rtc/v1/play/ calls = %d, want 1", rtc.whepCalls.Load()) + } + if rtc.whipCalls.Load() != 0 { + t.Errorf("HandleApiForWHIP should not be invoked via /rtc/v1/play/") + } +} + +func TestHTTPAPIProxyServer_Run_HandlerWHEPErrorInvokesApiError(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpAPIReturns(":0") + rtc := &fakeWebRTCProxyServer{whepReturn: errors.New("boom-whep")} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromHTTPAPIRun(t, env, rtc, ctx) + + req := httptest.NewRequest(http.MethodPost, "/rtc/v1/whep/", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", rec.Code) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("boom-whep")) { + t.Errorf("body = %q", rec.Body.String()) + } +} + +// ============================================================================= +// NewSystemAPI +// ============================================================================= + +func TestSystemAPI_New_StoresFieldsAndDefaultsSeams(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + lbFake := &lbfakes.FakeOriginLoadBalancer{} + timeout := 2 * time.Second + srv := NewSystemAPI(env, lbFake, timeout) + + if srv.environment != env { + t.Error("environment not stored") + } + if srv.loadBalancer != lbFake { + t.Error("loadBalancer not stored") + } + if srv.gracefulQuitTimeout != timeout { + t.Errorf("gracefulQuitTimeout = %v, want %v", srv.gracefulQuitTimeout, timeout) + } + if srv.shutdown == nil { + t.Error("shutdown seam should default to non-nil") + } + if srv.newServer == nil { + t.Error("newServer seam should default to non-nil") + } +} + +func TestSystemAPI_New_AppliesOpts(t *testing.T) { + var called bool + srv := NewSystemAPI(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}, + time.Second, func(s *systemAPI) { called = true }) + if !called { + t.Fatal("opt was not invoked") + } + if srv.shutdown == nil { + t.Error("default seams should still be set when opt doesn't override them") + } +} + +func TestSystemAPI_New_OptCanOverrideAllSeams(t *testing.T) { + customShutdown := func(context.Context) error { return errors.New("custom") } + customNewServer := func(string) (httpServer, *http.ServeMux) { return nil, nil } + + srv := NewSystemAPI(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}, + time.Second, func(s *systemAPI) { + s.shutdown = customShutdown + s.newServer = customNewServer + }) + + if err := srv.shutdown(context.Background()); err == nil || err.Error() != "custom" { + t.Errorf("custom shutdown not applied: %v", err) + } + if got, _ := srv.newServer(""); got != nil { + t.Error("custom newServer not applied") + } +} + +// ============================================================================= +// systemAPI — default factory behavior +// ============================================================================= + +func TestSystemAPI_DefaultNewServer_BuildsRealServerAndMux(t *testing.T) { + srv := NewSystemAPI(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}, time.Second) + + got, mux := srv.newServer(":12321") + if mux == nil { + t.Fatal("mux is nil") + } + real, ok := got.(*http.Server) + if !ok { + t.Fatalf("expected *http.Server, got %T", got) + } + if real.Addr != ":12321" { + t.Errorf("Addr = %q, want :12321", real.Addr) + } + if real.Handler != mux { + t.Error("Handler should be the returned mux") + } +} + +func TestSystemAPI_DefaultShutdown_DelegatesToServer(t *testing.T) { + fakeSrv := newFakeHTTPProxyServer() + srv := NewSystemAPI(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}, time.Second) + srv.server = fakeSrv + + if err := srv.shutdown(context.Background()); err != nil { + t.Fatalf("shutdown: %v", err) + } + if fakeSrv.shutdownCalls.Load() != 1 { + t.Fatalf("shutdown was not delegated, calls=%d", fakeSrv.shutdownCalls.Load()) + } +} + +// ============================================================================= +// systemAPI — Close +// ============================================================================= + +func TestSystemAPI_Close_InvokesShutdownWithDeadline(t *testing.T) { + var gotCtx context.Context + var calls int + srv := NewSystemAPI(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *systemAPI) { + s.shutdown = func(ctx context.Context) error { + gotCtx = ctx + calls++ + return nil + } + }) + + if err := srv.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if calls != 1 { + t.Fatalf("shutdown calls = %d, want 1", calls) + } + if _, ok := gotCtx.Deadline(); !ok { + t.Error("Close should pass a deadline-bearing ctx to shutdown") + } +} + +// ============================================================================= +// systemAPI — Run lifecycle +// ============================================================================= + +func TestSystemAPI_Run_AddrWithoutColonPrependsIt(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns("12025") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srv := NewSystemAPI(env, &lbfakes.FakeOriginLoadBalancer{}, 50*time.Millisecond, + func(s *systemAPI) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srv.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != ":12025" { + t.Fatalf("newServer addr = %q, want :12025", capturedAddr) + } +} + +func TestSystemAPI_Run_AddrWithColonUnchanged(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns("127.0.0.1:9999") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srv := NewSystemAPI(env, &lbfakes.FakeOriginLoadBalancer{}, 50*time.Millisecond, + func(s *systemAPI) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srv.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != "127.0.0.1:9999" { + t.Fatalf("newServer addr = %q", capturedAddr) + } +} + +func TestSystemAPI_Run_CtxCancelTriggersShutdown(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, fakeSrv, _ := captureMuxFromSystemAPIRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + deadline := time.Now().Add(time.Second) + for fakeSrv.listenCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.listenCalls.Load() == 0 { + t.Fatal("ListenAndServe goroutine did not start") + } + + cancel() + + deadline = time.Now().Add(time.Second) + for fakeSrv.shutdownCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.shutdownCalls.Load() == 0 { + t.Fatal("Shutdown was not invoked after ctx cancel") + } +} + +// ============================================================================= +// systemAPI — handler dispatch +// ============================================================================= + +func TestSystemAPI_Run_HandlerVersionsReturnsJSON(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromSystemAPIRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/versions", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var body map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("json: %v\nbody=%s", err, rec.Body.String()) + } + if body["signature"] == "" { + t.Error("signature should be populated") + } + if body["version"] == "" { + t.Error("version should be populated") + } +} + +// validRegisterBody returns the JSON body for a happy-path /api/v1/srs/register call. +func validRegisterBody(t *testing.T) io.Reader { + t.Helper() + b, err := json.Marshal(map[string]any{ + "ip": "1.2.3.4", + "server": "srv-abc", + "service": "svc-1", + "pid": "12345", + "rtmp": []string{"1935"}, + "http": []string{"8080"}, + "api": []string{"1985"}, + "srt": []string{"10080"}, + "rtc": []string{"8000"}, + "device_id": "dev-x", + }) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return bytes.NewReader(b) +} + +func TestSystemAPI_Run_HandlerRegisterHappyPathCallsUpdate(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromSystemAPIRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/srs/register", validRegisterBody(t)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if lbFake.UpdateCallCount() != 1 { + t.Fatalf("Update calls = %d, want 1", lbFake.UpdateCallCount()) + } + _, server := lbFake.UpdateArgsForCall(0) + if server.IP != "1.2.3.4" { + t.Errorf("IP = %q", server.IP) + } + if server.ServerID != "srv-abc" { + t.Errorf("ServerID = %q", server.ServerID) + } + if server.ServiceID != "svc-1" { + t.Errorf("ServiceID = %q", server.ServiceID) + } + if server.PID != "12345" { + t.Errorf("PID = %q", server.PID) + } + if got := server.RTMP; len(got) != 1 || got[0] != "1935" { + t.Errorf("RTMP = %v", got) + } + if got := server.HTTP; len(got) != 1 || got[0] != "8080" { + t.Errorf("HTTP = %v", got) + } + if got := server.API; len(got) != 1 || got[0] != "1985" { + t.Errorf("API = %v", got) + } + if got := server.SRT; len(got) != 1 || got[0] != "10080" { + t.Errorf("SRT = %v", got) + } + if got := server.RTC; len(got) != 1 || got[0] != "8000" { + t.Errorf("RTC = %v", got) + } + if server.DeviceID != "dev-x" { + t.Errorf("DeviceID = %q", server.DeviceID) + } + if server.UpdatedAt.IsZero() { + t.Error("UpdatedAt should be set") + } +} + +func TestSystemAPI_Run_HandlerRegisterParseBodyError(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromSystemAPIRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/srs/register", + bytes.NewReader([]byte("not json"))) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if lbFake.UpdateCallCount() != 0 { + t.Fatalf("Update should not be called on parse body err, calls = %d", lbFake.UpdateCallCount()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("parse body")) { + t.Errorf("body = %q, expected parse body error", rec.Body.String()) + } +} + +// registerWithField returns a body with one field replaced. Other mandatory +// fields default to valid values so only the field under test triggers an +// error. +func registerWithField(t *testing.T, field string, value any) io.Reader { + t.Helper() + m := map[string]any{ + "ip": "1.2.3.4", + "server": "srv-abc", + "service": "svc-1", + "pid": "12345", + "rtmp": []string{"1935"}, + } + m[field] = value + b, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return bytes.NewReader(b) +} + +func TestSystemAPI_Run_HandlerRegisterValidationErrors(t *testing.T) { + cases := []struct { + name string + field string + value any + wantErrText string + }{ + {"empty-ip", "ip", "", "empty ip"}, + {"empty-server", "server", "", "empty server"}, + {"empty-service", "service", "", "empty service"}, + {"empty-pid", "pid", "", "empty pid"}, + {"empty-rtmp", "rtmp", []string{}, "empty rtmp"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromSystemAPIRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/srs/register", + registerWithField(t, tc.field, tc.value)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if lbFake.UpdateCallCount() != 0 { + t.Errorf("Update should not be called when %s is invalid, calls = %d", + tc.field, lbFake.UpdateCallCount()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(tc.wantErrText)) { + t.Errorf("body = %q, expected to contain %q", rec.Body.String(), tc.wantErrText) + } + }) + } +} + +func TestSystemAPI_Run_HandlerRegisterLoadBalancerUpdateError(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.SystemAPIReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.UpdateReturns(errors.New("lb-update-fail")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromSystemAPIRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/srs/register", validRegisterBody(t)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if lbFake.UpdateCallCount() != 1 { + t.Fatalf("Update calls = %d, want 1", lbFake.UpdateCallCount()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("lb-update-fail")) { + t.Errorf("body = %q, expected lb error", rec.Body.String()) + } +} diff --git a/internal/proxy/gen.go b/internal/proxy/gen.go new file mode 100644 index 000000000..dc2013c1c --- /dev/null +++ b/internal/proxy/gen.go @@ -0,0 +1,9 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +//go:generate go tool counterfeiter -o proxyfakes/fake_rtmp_proxy_server.go . RTMPProxyServer +//go:generate go tool counterfeiter -o proxyfakes/fake_http_stream_proxy_server.go . HTTPStreamProxyServer +//go:generate go tool counterfeiter -o proxyfakes/fake_http_api_proxy_server.go . HTTPAPIProxyServer +//go:generate go tool counterfeiter -o proxyfakes/fake_web_rtc_proxy_server.go . WebRTCProxyServer diff --git a/internal/server/http.go b/internal/proxy/http.go similarity index 68% rename from internal/server/http.go rename to internal/proxy/http.go index 21db47741..2bf052460 100644 --- a/internal/server/http.go +++ b/internal/proxy/http.go @@ -1,13 +1,13 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package server +package proxy import ( "context" "fmt" "io" - "io/ioutil" + "net/http" "os" "strconv" @@ -23,43 +23,109 @@ import ( "srsx/internal/version" ) -// HTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, +// HTTPStreamProxyServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, // HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy // the request to the origin server. -type HTTPStreamServer interface { +type HTTPStreamProxyServer interface { Run(ctx context.Context) error Close() error } -type httpStreamServer struct { +// httpServer is the minimal contract of an HTTP server that httpStreamProxyServer drives. +// *http.Server satisfies it. Tests may supply a fake that does not bind a real port. +type httpServer interface { + ListenAndServe() error + Shutdown(ctx context.Context) error +} + +// buildBackendHTTPURL composes the backend HTTP URL for a request path, targeting +// the given backend IP and port. Callers append query strings separately when needed. +func buildBackendHTTPURL(ip string, port int, path string) string { + return fmt.Sprintf("http://%v:%v%s", ip, port, path) +} + +type httpStreamProxyServer struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The underlayer HTTP server. - server *http.Server + server httpServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. wg stdSync.WaitGroup + // shutdown gracefully shuts down the underlying HTTP server. Defaults to + // v.server.Shutdown; tests may override via a functional option to verify + // the shutdown contract without binding a real socket. + shutdown func(ctx context.Context) error + // newServer constructs the underlying HTTP server bound to addr and the + // ServeMux that handlers are registered on. Defaults to a real http.Server + // and ServeMux; tests may override via a functional option to supply a fake + // server that does not bind a real port. + newServer func(addr string) (httpServer, *http.ServeMux) + // newHLSStream constructs a per-stream HLS playback object for the given + // stream URL pair. Defaults to newHLSPlayStream pre-wired with this server's + // load balancer and a fresh SPBHID; tests may override via a functional option. + newHLSStream func(streamURL, fullURL string) *hlsPlayStream + // newFlvTsConn constructs a per-request HTTP-FLV/TS connection bound to ctx. + // Defaults to newHTTPFlvTsConnection pre-wired with this server's load + // balancer; tests may override via a functional option. + newFlvTsConn func(ctx context.Context) *httpFlvTsConnection } -func NewHTTPStreamServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) HTTPStreamServer { - v := &httpStreamServer{ +func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration, opts ...func(*httpStreamProxyServer)) HTTPStreamProxyServer { + v := &httpStreamProxyServer{ environment: environment, + loadBalancer: loadBalancer, gracefulQuitTimeout: gracefulQuitTimeout, } + + // Default shutdown: delegate to the underlying http.Server. The closure + // captures v rather than v.server so the dereference happens at call time, + // after Run() has assigned v.server. + v.shutdown = func(ctx context.Context) error { + return v.server.Shutdown(ctx) + } + // Default newServer: a real http.Server and ServeMux pair. + v.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + return &http.Server{Addr: addr, Handler: mux}, mux + } + // Default newHLSStream: a real hlsPlayStream wired with the server's load + // balancer and a fresh SPBHID for this stream. + v.newHLSStream = func(streamURL, fullURL string) *hlsPlayStream { + return newHLSPlayStream(func(s *hlsPlayStream) { + s.loadBalancer = v.loadBalancer + s.SRSProxyBackendHLSID = logger.GenerateContextID() + s.StreamURL, s.FullURL = streamURL, fullURL + }) + } + // Default newFlvTsConn: a real httpFlvTsConnection wired with the server's + // load balancer and the given ctx. + v.newFlvTsConn = func(ctx context.Context) *httpFlvTsConnection { + return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = ctx + c.loadBalancer = v.loadBalancer + }) + } + + for _, opt := range opts { + opt(v) + } return v } -func (v *httpStreamServer) Close() error { +func (v *httpStreamProxyServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) v.wg.Wait() return nil } -func (v *httpStreamServer) Run(ctx context.Context) error { +func (v *httpStreamProxyServer) Run(ctx context.Context) error { // Parse address to listen. addr := v.environment.HttpServer() if !strings.Contains(addr, ":") { @@ -67,8 +133,8 @@ func (v *httpStreamServer) Run(ctx context.Context) error { } // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} + server, mux := v.newServer(addr) + v.server = server logger.Debug(ctx, "HTTP Stream server listen at %v", addr) // Shutdown the server gracefully when quiting. @@ -79,7 +145,7 @@ func (v *httpStreamServer) Run(ctx context.Context) error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) }() // The basic version handler, also can be used as health check API. @@ -128,10 +194,11 @@ func (v *httpStreamServer) Run(ctx context.Context) error { return } - stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) { - s.SRSProxyBackendHLSID = logger.GenerateContextID() - s.StreamURL, s.FullURL = streamURL, fullURL - })) + stream, err := v.loadBalancer.LoadOrStoreHLS(ctx, streamURL, v.newHLSStream(streamURL, fullURL)) + if err != nil { + http.Error(w, fmt.Sprintf("load or store hls %v", streamURL), http.StatusBadRequest) + return + } stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r) return @@ -142,7 +209,7 @@ func (v *httpStreamServer) Run(ctx context.Context) error { strings.HasSuffix(r.URL.Path, ".ts") { // If SPBHID is specified, it must be a HLS stream client. if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { - if stream, err := lb.SrsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { + if stream, err := v.loadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) } else { stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r) @@ -151,9 +218,7 @@ func (v *httpStreamServer) Run(ctx context.Context) error { } // Use HTTP pseudo streaming to proxy the request. - newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { - c.ctx = ctx - }).ServeHTTP(w, r) + v.newFlvTsConn(ctx).ServeHTTP(w, r) return } @@ -196,10 +261,17 @@ func (v *httpStreamServer) Run(ctx context.Context) error { type httpFlvTsConnection struct { // The context for HTTP streaming. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer + // buildBackendURL composes the backend HTTP URL for a request path. Defaults + // to buildBackendHTTPURL; tests may override via a functional option. + buildBackendURL func(ip string, port int, path string) string } func newHTTPFlvTsConnection(opts ...func(*httpFlvTsConnection)) *httpFlvTsConnection { - v := &httpFlvTsConnection{} + v := &httpFlvTsConnection{ + buildBackendURL: buildBackendHTTPURL, + } for _, opt := range opts { opt(v) } @@ -233,7 +305,7 @@ func (v *httpFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -245,7 +317,7 @@ func (v *httpFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, return nil } -func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error { +func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.OriginServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { return errors.Errorf("no http stream server") @@ -259,7 +331,7 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons } // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + backendURL := v.buildBackendURL(backend.IP, httpPort, r.URL.Path) req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) if err != nil { return errors.Wrapf(err, "create request to %v", backendURL) @@ -303,6 +375,8 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons type hlsPlayStream struct { // The context for HLS streaming. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The spbhid, used to identify the backend server. SRSProxyBackendHLSID string `json:"spbhid"` @@ -310,10 +384,15 @@ type hlsPlayStream struct { StreamURL string `json:"stream_url"` // The full request URL for HLS streaming FullURL string `json:"full_url"` + // buildBackendURL composes the backend HTTP URL for a request path. Defaults + // to buildBackendHTTPURL; tests may override via a functional option. + buildBackendURL func(ip string, port int, path string) string `json:"-"` } func newHLSPlayStream(opts ...func(*hlsPlayStream)) *hlsPlayStream { - v := &hlsPlayStream{} + v := &hlsPlayStream{ + buildBackendURL: buildBackendHTTPURL, + } for _, opt := range opts { opt(v) } @@ -351,7 +430,7 @@ func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -363,10 +442,10 @@ func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt return nil } -func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error { +func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.OriginServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { - return errors.Errorf("no rtmp server") + return errors.Errorf("no http server") } var httpPort int @@ -377,7 +456,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite } // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + backendURL := v.buildBackendURL(backend.IP, httpPort, r.URL.Path) if r.URL.RawQuery != "" { backendURL += "?" + r.URL.RawQuery } @@ -416,7 +495,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { return errors.Wrapf(err, "read stream from %v", backendURL) } diff --git a/internal/proxy/http_test.go b/internal/proxy/http_test.go new file mode 100644 index 000000000..fa64225c4 --- /dev/null +++ b/internal/proxy/http_test.go @@ -0,0 +1,1289 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + stdSync "sync" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" +) + +// httptestHostPort splits an httptest.Server URL into host and port strings. +func httptestHostPort(t *testing.T, ts *httptest.Server) (string, string) { + t.Helper() + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("parse httptest URL %q: %v", ts.URL, err) + } + return u.Hostname(), u.Port() +} + +// reservedClosedPort binds and immediately closes a TCP port, returning an +// address that is reliably refused for the lifetime of the test. +func reservedClosedPort(t *testing.T) (string, string) { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("reserve port: %v", err) + } + addr := l.Addr().(*net.TCPAddr) + if err := l.Close(); err != nil { + t.Fatalf("close listener: %v", err) + } + return addr.IP.String(), strconv.Itoa(addr.Port) +} + +// ============================================================================= +// newHLSPlayStream +// ============================================================================= + +func TestHLSPlayStream_New_DefaultsBuildBackendURL(t *testing.T) { + v := newHLSPlayStream() + if v.buildBackendURL == nil { + t.Fatal("buildBackendURL should default to non-nil") + } + if got := v.buildBackendURL("1.2.3.4", 8080, "/live.ts"); got != "http://1.2.3.4:8080/live.ts" { + t.Fatalf("default buildBackendURL produced %q", got) + } +} + +func TestHLSPlayStream_New_AppliesOpts(t *testing.T) { + ctx := context.Background() + lbStub := &lbfakes.FakeOriginLoadBalancer{} + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = ctx + s.loadBalancer = lbStub + s.SRSProxyBackendHLSID = "spb-id" + s.StreamURL = "vhost/app/stream" + s.FullURL = "http://example.com/live.m3u8" + }) + if v.ctx != ctx { + t.Error("ctx not applied") + } + if v.loadBalancer != lbStub { + t.Error("loadBalancer not applied") + } + if v.SRSProxyBackendHLSID != "spb-id" { + t.Errorf("SRSProxyBackendHLSID = %q", v.SRSProxyBackendHLSID) + } + if v.StreamURL != "vhost/app/stream" { + t.Errorf("StreamURL = %q", v.StreamURL) + } + if v.FullURL != "http://example.com/live.m3u8" { + t.Errorf("FullURL = %q", v.FullURL) + } +} + +func TestHLSPlayStream_New_OptCanOverrideBuildBackendURL(t *testing.T) { + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.buildBackendURL = func(string, int, string) string { return "custom" } + }) + if got := v.buildBackendURL("", 0, ""); got != "custom" { + t.Fatalf("override not applied: got %q", got) + } +} + +// ============================================================================= +// Initialize +// ============================================================================= + +func TestHLSPlayStream_Initialize_SetsCtxWhenNil(t *testing.T) { + v := newHLSPlayStream() + ret := v.Initialize(context.Background()) + if v.ctx == nil { + t.Fatal("Initialize should set v.ctx when nil") + } + if ret != lb.HLSPlayStream(v) { + t.Fatal("Initialize should return v") + } +} + +func TestHLSPlayStream_Initialize_PreservesExistingCtx(t *testing.T) { + type ctxKey struct{} + existing := context.WithValue(context.Background(), ctxKey{}, "sentinel") + v := newHLSPlayStream(func(s *hlsPlayStream) { s.ctx = existing }) + v.Initialize(context.Background()) + if got, _ := v.ctx.Value(ctxKey{}).(string); got != "sentinel" { + t.Fatalf("Initialize should not replace existing ctx, value=%q", got) + } +} + +// ============================================================================= +// GetSPBHID +// ============================================================================= + +func TestHLSPlayStream_GetSPBHID(t *testing.T) { + v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "spb-xyz" }) + if v.GetSPBHID() != "spb-xyz" { + t.Fatalf("GetSPBHID = %q", v.GetSPBHID()) + } +} + +// ============================================================================= +// ServeHTTP / serve / CORS +// ============================================================================= + +func TestHLSPlayStream_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + if lbFake.PickCallCount() != 0 { + t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount()) + } +} + +func TestHLSPlayStream_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + // ApiError writes a JSON error response. Verify the body is non-empty + // and the status is not the default 200 (or that some response was made). + if rec.Body.Len() == 0 { + t.Fatal("ServeHTTP error branch should produce a response body") + } +} + +func TestHLSPlayStream_Serve_PickError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serve(v.ctx, rec, req) + if err == nil || !strings.Contains(err.Error(), "pick backend for vhost/app/stream") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_Serve_WrapsServeByBackendError(t *testing.T) { + // Backend with empty HTTP slice triggers serveByBackend's "no http server" + // error, which serve() then wraps with "serve %v with %v by backend %+v". + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil) + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + s.FullURL = "http://example.com/live.m3u8" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serve(v.ctx, rec, req) + if err == nil || !strings.Contains(err.Error(), "serve http://example.com/live.m3u8 with vhost/app/stream by backend") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_Serve_HappyPathRewritesM3U8(t *testing.T) { + m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\n" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, m3u8) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil) + + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + s.FullURL = "http://example.com/live.m3u8" + s.SRSProxyBackendHLSID = "spb-1" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serve(v.ctx, rec, req); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !strings.Contains(rec.Body.String(), "live-0.ts?spbhid=spb-1") { + t.Fatalf("body missing spbhid rewrite: %q", rec.Body.String()) + } +} + +// ============================================================================= +// serveByBackend — error paths (no HTTP round-trip needed) +// ============================================================================= + +func TestHLSPlayStream_ServeByBackend_NoHTTPEndpoint(t *testing.T) { + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"}) + if err == nil || !strings.Contains(err.Error(), "no http server") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_BadPort(t *testing.T) { + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}}) + if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_RequestBuildError(t *testing.T) { + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.buildBackendURL = func(string, int, string) string { return "://invalid-url" } + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}}) + if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_DialError(t *testing.T) { + host, port := reservedClosedPort(t) + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.HasSuffix(err.Error(), "EOF") { + t.Fatalf("expected error suffixed with 'EOF', got: %v", err) + } +} + +// ============================================================================= +// serveByBackend — HTTP round-trip via httptest.Server +// ============================================================================= + +func TestHLSPlayStream_ServeByBackend_NonOKStatus(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.Contains(err.Error(), "status=404") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_TSPassthrough(t *testing.T) { + payload := []byte{0x47, 0x00, 0x01, 0x02, 0x03, 0x04} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp2t") + _, _ = w.Write(payload) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if got := rec.Body.Bytes(); !bytes.Equal(got, payload) { + t.Fatalf("body mismatch: got=%v want=%v", got, payload) + } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } +} + +func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithoutQuery(t *testing.T) { + m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\nlive-1.ts\n" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, m3u8) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + body := rec.Body.String() + for _, want := range []string{"live-0.ts?spbhid=ABC", "live-1.ts?spbhid=ABC"} { + if !strings.Contains(body, want) { + t.Fatalf("missing %q in body: %q", want, body) + } + } +} + +func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) { + m3u8 := "#EXTM3U\nlive-0.ts?token=foo\n" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, m3u8) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if want := "live-0.ts?spbhid=ABC&&token=foo"; !strings.Contains(rec.Body.String(), want) { + t.Fatalf("missing %q in body: %q", want, rec.Body.String()) + } +} + +func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) { + var seenURL string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenURL = r.URL.String() + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts?token=foo", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !strings.Contains(seenURL, "token=foo") { + t.Fatalf("backend should see raw query, got %q", seenURL) + } +} + +// ============================================================================= +// httpFlvTsConnection +// ============================================================================= + +// ============================================================================= +// newHTTPFlvTsConnection +// ============================================================================= + +func TestHTTPFlvTsConn_New_DefaultsBuildBackendURL(t *testing.T) { + v := newHTTPFlvTsConnection() + if v.buildBackendURL == nil { + t.Fatal("buildBackendURL should default to non-nil") + } + if got := v.buildBackendURL("1.2.3.4", 8080, "/live.flv"); got != "http://1.2.3.4:8080/live.flv" { + t.Fatalf("default buildBackendURL produced %q", got) + } +} + +func TestHTTPFlvTsConn_New_AppliesOpts(t *testing.T) { + ctx := context.Background() + lbStub := &lbfakes.FakeOriginLoadBalancer{} + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = ctx + c.loadBalancer = lbStub + }) + if v.ctx != ctx { + t.Error("ctx not applied") + } + if v.loadBalancer != lbStub { + t.Error("loadBalancer not applied") + } +} + +func TestHTTPFlvTsConn_New_OptCanOverrideBuildBackendURL(t *testing.T) { + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.buildBackendURL = func(string, int, string) string { return "custom" } + }) + if got := v.buildBackendURL("", 0, ""); got != "custom" { + t.Fatalf("override not applied: got %q", got) + } +} + +// ============================================================================= +// ServeHTTP / serve / CORS +// ============================================================================= + +func TestHTTPFlvTsConn_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + if lbFake.PickCallCount() != 0 { + t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount()) + } +} + +func TestHTTPFlvTsConn_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + if rec.Body.Len() == 0 { + t.Fatal("ServeHTTP error branch should produce a response body") + } +} + +func TestHTTPFlvTsConn_Serve_PickError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil) + rec := httptest.NewRecorder() + err := v.serve(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "pick backend for") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_Serve_WrapsServeByBackendError(t *testing.T) { + // Empty HTTP slice on backend triggers serveByBackend's "no http stream + // server" error, which serve() wraps with "serve with ". + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil) + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil) + rec := httptest.NewRecorder() + err := v.serve(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "serve ") || !strings.Contains(err.Error(), " by backend ") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_Serve_HappyPath(t *testing.T) { + payload := []byte{0x46, 0x4c, 0x56, 0x01} // "FLV\x01" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(payload) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil) + + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil) + rec := httptest.NewRecorder() + if err := v.serve(context.Background(), rec, req); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !bytes.Equal(rec.Body.Bytes(), payload) { + t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload) + } +} + +// ============================================================================= +// serveByBackend — error paths +// ============================================================================= + +func TestHTTPFlvTsConn_ServeByBackend_NoHTTPEndpoint(t *testing.T) { + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"}) + if err == nil || !strings.Contains(err.Error(), "no http stream server") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_BadPort(t *testing.T) { + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}}) + if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_RequestBuildError(t *testing.T) { + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.buildBackendURL = func(string, int, string) string { return "://invalid-url" } + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}}) + if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_DialError(t *testing.T) { + host, port := reservedClosedPort(t) + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.Contains(err.Error(), "do request to") { + t.Fatalf("unexpected err: %v", err) + } +} + +// ============================================================================= +// serveByBackend — HTTP round-trip +// ============================================================================= + +func TestHTTPFlvTsConn_ServeByBackend_NonOKStatus(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.Contains(err.Error(), "status=404") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_BodyPassthrough(t *testing.T) { + payload := []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(payload) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !bytes.Equal(rec.Body.Bytes(), payload) { + t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload) + } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_DropsRawQuery(t *testing.T) { + // Unlike hlsPlayStream.serveByBackend, the FLV/TS path forwards only + // r.URL.Path — it does NOT append RawQuery to the backend request. + var seenRawQuery string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenRawQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv?token=foo", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if seenRawQuery != "" { + t.Fatalf("backend should NOT see raw query, got %q", seenRawQuery) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) { + var seenMethod string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenMethod = r.Method + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodHead, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if seenMethod != http.MethodHead { + t.Fatalf("backend method = %q, want HEAD", seenMethod) + } +} + +// ============================================================================= +// httpStreamProxyServer +// ============================================================================= + +// fakeHTTPProxyServer is an httpServer that blocks in ListenAndServe until +// Shutdown is called. Used to drive Run()'s lifecycle without binding a port. +type fakeHTTPProxyServer struct { + listenCalls atomic.Int32 + shutdownCalls atomic.Int32 + listenReturn error + shutdownReturn error + block chan struct{} + once stdSync.Once +} + +func newFakeHTTPProxyServer() *fakeHTTPProxyServer { + return &fakeHTTPProxyServer{ + listenReturn: http.ErrServerClosed, + block: make(chan struct{}), + } +} + +func (f *fakeHTTPProxyServer) ListenAndServe() error { + f.listenCalls.Add(1) + <-f.block + return f.listenReturn +} + +func (f *fakeHTTPProxyServer) Shutdown(ctx context.Context) error { + f.shutdownCalls.Add(1) + f.once.Do(func() { close(f.block) }) + return f.shutdownReturn +} + +// captureMuxFromRun calls Run with a fake server that captures the registered +// mux. Returns the mux and the fake server for further assertions. Caller is +// responsible for cancelling ctx to trigger shutdown. +func captureMuxFromRun(t *testing.T, env *envfakes.FakeProxyEnvironment, + lbFake *lbfakes.FakeOriginLoadBalancer, ctx context.Context, + opts ...func(*httpStreamProxyServer)) (*http.ServeMux, *fakeHTTPProxyServer, *httpStreamProxyServer) { + t.Helper() + + fakeSrv := newFakeHTTPProxyServer() + var capturedMux *http.ServeMux + + baseOpts := []func(*httpStreamProxyServer){ + func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + capturedMux = mux + return fakeSrv, mux + } + }, + } + srvIface := NewHTTPStreamProxyServer(env, lbFake, 50*time.Millisecond, append(baseOpts, opts...)...) + srv := srvIface.(*httpStreamProxyServer) + + if err := srv.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedMux == nil { + t.Fatal("newServer was not called by Run") + } + return capturedMux, fakeSrv, srv +} + +// ============================================================================= +// NewHTTPStreamProxyServer +// ============================================================================= + +func TestHTTPStreamProxyServer_New_StoresFieldsAndDefaultsSeams(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + lbFake := &lbfakes.FakeOriginLoadBalancer{} + timeout := 2 * time.Second + srv := NewHTTPStreamProxyServer(env, lbFake, timeout).(*httpStreamProxyServer) + + if srv.environment != env { + t.Error("environment not stored") + } + if srv.loadBalancer != lbFake { + t.Error("loadBalancer not stored") + } + if srv.gracefulQuitTimeout != timeout { + t.Errorf("gracefulQuitTimeout = %v, want %v", srv.gracefulQuitTimeout, timeout) + } + if srv.shutdown == nil { + t.Error("shutdown seam should default to non-nil") + } + if srv.newServer == nil { + t.Error("newServer seam should default to non-nil") + } + if srv.newHLSStream == nil { + t.Error("newHLSStream seam should default to non-nil") + } + if srv.newFlvTsConn == nil { + t.Error("newFlvTsConn seam should default to non-nil") + } +} + +func TestHTTPStreamProxyServer_New_AppliesOpts(t *testing.T) { + var optCalled bool + srv := NewHTTPStreamProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + time.Second, + func(s *httpStreamProxyServer) { optCalled = true }, + ).(*httpStreamProxyServer) + if !optCalled { + t.Fatal("opt was not invoked") + } + if srv.shutdown == nil { + t.Error("default seams should still be set when opts don't override them") + } +} + +func TestHTTPStreamProxyServer_New_OptCanOverrideAllSeams(t *testing.T) { + customShutdown := func(context.Context) error { return errors.New("custom") } + customNewServer := func(string) (httpServer, *http.ServeMux) { return nil, nil } + customNewHLS := func(string, string) *hlsPlayStream { return nil } + customNewFlv := func(context.Context) *httpFlvTsConnection { return nil } + + srv := NewHTTPStreamProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + time.Second, + func(s *httpStreamProxyServer) { + s.shutdown = customShutdown + s.newServer = customNewServer + s.newHLSStream = customNewHLS + s.newFlvTsConn = customNewFlv + }, + ).(*httpStreamProxyServer) + + if err := srv.shutdown(context.Background()); err == nil || err.Error() != "custom" { + t.Errorf("custom shutdown not applied: %v", err) + } + // Pointer comparison on func values isn't supported by ==; call them and + // check the override took effect via observable behavior. + if got, _ := srv.newServer(""); got != nil { + t.Error("custom newServer not applied") + } + if srv.newHLSStream("", "") != nil { + t.Error("custom newHLSStream not applied") + } + if srv.newFlvTsConn(context.Background()) != nil { + t.Error("custom newFlvTsConn not applied") + } +} + +// ============================================================================= +// Default factory behavior +// ============================================================================= + +func TestHTTPStreamProxyServer_DefaultNewServer_BuildsRealServerAndMux(t *testing.T) { + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer) + + got, mux := srv.newServer(":12345") + if mux == nil { + t.Fatal("mux is nil") + } + real, ok := got.(*http.Server) + if !ok { + t.Fatalf("expected *http.Server, got %T", got) + } + if real.Addr != ":12345" { + t.Errorf("Addr = %q, want :12345", real.Addr) + } + if real.Handler != mux { + t.Error("Handler should be the returned mux") + } +} + +func TestHTTPStreamProxyServer_DefaultNewHLSStream_WiresFields(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake, + time.Second).(*httpStreamProxyServer) + + got := srv.newHLSStream("vhost/app/stream", "http://example.com/live.m3u8") + if got.loadBalancer != lbFake { + t.Error("loadBalancer not wired") + } + if got.StreamURL != "vhost/app/stream" { + t.Errorf("StreamURL = %q", got.StreamURL) + } + if got.FullURL != "http://example.com/live.m3u8" { + t.Errorf("FullURL = %q", got.FullURL) + } + if got.SRSProxyBackendHLSID == "" { + t.Error("SRSProxyBackendHLSID should be auto-generated") + } + if got.buildBackendURL == nil { + t.Error("buildBackendURL default should be propagated") + } +} + +func TestHTTPStreamProxyServer_DefaultNewFlvTsConn_WiresFields(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake, + time.Second).(*httpStreamProxyServer) + + type ctxKey struct{} + ctx := context.WithValue(context.Background(), ctxKey{}, "sentinel") + got := srv.newFlvTsConn(ctx) + if got.ctx != ctx { + t.Error("ctx not wired") + } + if got.loadBalancer != lbFake { + t.Error("loadBalancer not wired") + } +} + +func TestHTTPStreamProxyServer_DefaultShutdown_DelegatesToServer(t *testing.T) { + fakeSrv := newFakeHTTPProxyServer() + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer) + srv.server = fakeSrv // simulate what Run() would assign + if err := srv.shutdown(context.Background()); err != nil { + t.Fatalf("shutdown: %v", err) + } + if fakeSrv.shutdownCalls.Load() != 1 { + t.Fatalf("shutdown was not delegated to server, calls=%d", fakeSrv.shutdownCalls.Load()) + } +} + +// ============================================================================= +// Close +// ============================================================================= + +func TestHTTPStreamProxyServer_Close_InvokesShutdownWithDeadline(t *testing.T) { + var gotCtx context.Context + var calls int + srv := NewHTTPStreamProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, + func(s *httpStreamProxyServer) { + s.shutdown = func(ctx context.Context) error { + gotCtx = ctx + calls++ + return nil + } + }, + ).(*httpStreamProxyServer) + + if err := srv.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if calls != 1 { + t.Fatalf("shutdown calls = %d, want 1", calls) + } + if _, ok := gotCtx.Deadline(); !ok { + t.Error("Close should pass a deadline-bearing ctx to shutdown") + } +} + +// ============================================================================= +// Run — lifecycle +// ============================================================================= + +func TestHTTPStreamProxyServer_Run_AddrWithoutColonPrependsIt(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns("8080") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srvIface.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srvIface.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != ":8080" { + t.Fatalf("newServer addr = %q, want :8080", capturedAddr) + } +} + +func TestHTTPStreamProxyServer_Run_AddrWithColonUnchanged(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns("127.0.0.1:9999") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srvIface.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srvIface.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != "127.0.0.1:9999" { + t.Fatalf("newServer addr = %q", capturedAddr) + } +} + +func TestHTTPStreamProxyServer_Run_StaticFilesInvalidPath(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + env.StaticFilesReturns("/no/such/path/exists/__srsbot_test__") + + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + return fakeSrv, http.NewServeMux() + } + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := srvIface.Run(ctx) + if err == nil || !strings.Contains(err.Error(), "invalid static files") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPStreamProxyServer_Run_CtxCancelTriggersShutdown(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, fakeSrv, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + // Wait briefly for ListenAndServe goroutine to be running. + deadline := time.Now().Add(time.Second) + for fakeSrv.listenCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.listenCalls.Load() == 0 { + t.Fatal("ListenAndServe goroutine did not start") + } + + cancel() + + // Wait for Shutdown to be invoked by the watcher goroutine. + deadline = time.Now().Add(time.Second) + for fakeSrv.shutdownCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.shutdownCalls.Load() == 0 { + t.Fatal("Shutdown was not invoked after ctx cancel") + } +} + +// ============================================================================= +// Run — handler dispatch +// ============================================================================= + +func TestHTTPStreamProxyServer_Run_HandlerVersionsReturnsJSON(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/versions", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var body struct { + Code int `json:"code"` + PID string `json:"pid"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("json: %v\nbody=%s", err, rec.Body.String()) + } + if body.Code != 0 { + t.Errorf("Code = %d, want 0", body.Code) + } + if body.PID == "" { + t.Error("PID should be populated") + } + if body.Data.Version == "" { + t.Error("Version should be populated") + } +} + +func TestHTTPStreamProxyServer_Run_HandlerM3U8InvokesNewHLSStream(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + // Make LoadOrStoreHLS return whatever was passed in (the stream from newHLSStream). + lbFake.LoadOrStoreHLSStub = func(_ context.Context, _ string, s lb.HLSPlayStream) (lb.HLSPlayStream, error) { + return s, nil + } + + var capturedStreamURL, capturedFullURL string + var newHLSCalls int + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) { + // Wrap default newHLSStream to capture args, but return a real + // hlsPlayStream so the .(*hlsPlayStream) cast inside Run's handler works. + // The returned stream has a fake loadBalancer; ServeHTTP will short-circuit + // on the OPTIONS preflight we send below. + s.newHLSStream = func(streamURL, fullURL string) *hlsPlayStream { + newHLSCalls++ + capturedStreamURL, capturedFullURL = streamURL, fullURL + return newHLSPlayStream(func(h *hlsPlayStream) { + h.loadBalancer = lbFake + h.StreamURL, h.FullURL = streamURL, fullURL + }) + } + }) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if newHLSCalls != 1 { + t.Fatalf("newHLSStream calls = %d, want 1", newHLSCalls) + } + if !strings.HasSuffix(capturedStreamURL, "/live") { + t.Errorf("captured streamURL %q should end with /live", capturedStreamURL) + } + if !strings.Contains(capturedFullURL, "live.m3u8") { + t.Errorf("captured fullURL %q should contain live.m3u8", capturedFullURL) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerM3U8LoadOrStoreErrorReturns400(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.LoadOrStoreHLSReturns(nil, errors.New("redis down")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "load or store hls") { + t.Errorf("body should mention 'load or store hls', got %q", rec.Body.String()) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerFlvInvokesNewFlvTsConn(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + var newFlvCalls int + var capturedCtx context.Context + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) { + s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection { + newFlvCalls++ + capturedCtx = reqCtx + return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = reqCtx + c.loadBalancer = lbFake + }) + } + }) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if newFlvCalls != 1 { + t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls) + } + if capturedCtx == nil { + t.Error("captured ctx should be non-nil") + } +} + +func TestHTTPStreamProxyServer_Run_HandlerTsInvokesNewFlvTsConn(t *testing.T) { + // Same dispatch as .flv but for .ts (without spbhid). + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + var newFlvCalls int + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) { + s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection { + newFlvCalls++ + return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = reqCtx + c.loadBalancer = lbFake + }) + } + }) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.ts", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if newFlvCalls != 1 { + t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDLoadsByID(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + stub := newHLSPlayStream(func(h *hlsPlayStream) { + h.loadBalancer = lbFake + h.SRSProxyBackendHLSID = "ABC" + }) + lbFake.LoadHLSBySPBHIDReturns(stub, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live-0.ts?spbhid=ABC", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if lbFake.LoadHLSBySPBHIDCallCount() != 1 { + t.Fatalf("LoadHLSBySPBHID calls = %d, want 1", lbFake.LoadHLSBySPBHIDCallCount()) + } + _, gotID := lbFake.LoadHLSBySPBHIDArgsForCall(0) + if gotID != "ABC" { + t.Errorf("LoadHLSBySPBHID id = %q, want ABC", gotID) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDErrorReturns400(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.LoadHLSBySPBHIDReturns(nil, errors.New("not found")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/live-0.ts?spbhid=missing", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerUnmatchedReturns404(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + // StaticFiles unset, no .m3u8/.flv/.ts suffix → 404. + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/random/path", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", rec.Code) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerServesStaticFiles(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hi"), 0644); err != nil { + t.Fatalf("write: %v", err) + } + + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + env.StaticFilesReturns(dir) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/hello.txt", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if rec.Body.String() != "hi" { + t.Errorf("body = %q, want hi", rec.Body.String()) + } +} diff --git a/internal/proxy/proxyfakes/fake_http_api_proxy_server.go b/internal/proxy/proxyfakes/fake_http_api_proxy_server.go new file mode 100644 index 000000000..a16710ff9 --- /dev/null +++ b/internal/proxy/proxyfakes/fake_http_api_proxy_server.go @@ -0,0 +1,172 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package proxyfakes + +import ( + "context" + "srsx/internal/proxy" + "sync" +) + +type FakeHTTPAPIProxyServer struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + RunStub func(context.Context) error + runMutex sync.RWMutex + runArgsForCall []struct { + arg1 context.Context + } + runReturns struct { + result1 error + } + runReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHTTPAPIProxyServer) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHTTPAPIProxyServer) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeHTTPAPIProxyServer) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeHTTPAPIProxyServer) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPAPIProxyServer) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPAPIProxyServer) Run(arg1 context.Context) error { + fake.runMutex.Lock() + ret, specificReturn := fake.runReturnsOnCall[len(fake.runArgsForCall)] + fake.runArgsForCall = append(fake.runArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.RunStub + fakeReturns := fake.runReturns + fake.recordInvocation("Run", []interface{}{arg1}) + fake.runMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHTTPAPIProxyServer) RunCallCount() int { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + return len(fake.runArgsForCall) +} + +func (fake *FakeHTTPAPIProxyServer) RunCalls(stub func(context.Context) error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = stub +} + +func (fake *FakeHTTPAPIProxyServer) RunArgsForCall(i int) context.Context { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + argsForCall := fake.runArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHTTPAPIProxyServer) RunReturns(result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + fake.runReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPAPIProxyServer) RunReturnsOnCall(i int, result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + if fake.runReturnsOnCall == nil { + fake.runReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.runReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPAPIProxyServer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHTTPAPIProxyServer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ proxy.HTTPAPIProxyServer = new(FakeHTTPAPIProxyServer) diff --git a/internal/proxy/proxyfakes/fake_http_stream_proxy_server.go b/internal/proxy/proxyfakes/fake_http_stream_proxy_server.go new file mode 100644 index 000000000..7b6f71010 --- /dev/null +++ b/internal/proxy/proxyfakes/fake_http_stream_proxy_server.go @@ -0,0 +1,172 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package proxyfakes + +import ( + "context" + "srsx/internal/proxy" + "sync" +) + +type FakeHTTPStreamProxyServer struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + RunStub func(context.Context) error + runMutex sync.RWMutex + runArgsForCall []struct { + arg1 context.Context + } + runReturns struct { + result1 error + } + runReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHTTPStreamProxyServer) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHTTPStreamProxyServer) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeHTTPStreamProxyServer) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeHTTPStreamProxyServer) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPStreamProxyServer) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPStreamProxyServer) Run(arg1 context.Context) error { + fake.runMutex.Lock() + ret, specificReturn := fake.runReturnsOnCall[len(fake.runArgsForCall)] + fake.runArgsForCall = append(fake.runArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.RunStub + fakeReturns := fake.runReturns + fake.recordInvocation("Run", []interface{}{arg1}) + fake.runMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHTTPStreamProxyServer) RunCallCount() int { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + return len(fake.runArgsForCall) +} + +func (fake *FakeHTTPStreamProxyServer) RunCalls(stub func(context.Context) error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = stub +} + +func (fake *FakeHTTPStreamProxyServer) RunArgsForCall(i int) context.Context { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + argsForCall := fake.runArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHTTPStreamProxyServer) RunReturns(result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + fake.runReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPStreamProxyServer) RunReturnsOnCall(i int, result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + if fake.runReturnsOnCall == nil { + fake.runReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.runReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHTTPStreamProxyServer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHTTPStreamProxyServer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ proxy.HTTPStreamProxyServer = new(FakeHTTPStreamProxyServer) diff --git a/internal/proxy/proxyfakes/fake_rtmp_proxy_server.go b/internal/proxy/proxyfakes/fake_rtmp_proxy_server.go new file mode 100644 index 000000000..44d20294d --- /dev/null +++ b/internal/proxy/proxyfakes/fake_rtmp_proxy_server.go @@ -0,0 +1,172 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package proxyfakes + +import ( + "context" + "srsx/internal/proxy" + "sync" +) + +type FakeRTMPProxyServer struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + RunStub func(context.Context) error + runMutex sync.RWMutex + runArgsForCall []struct { + arg1 context.Context + } + runReturns struct { + result1 error + } + runReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRTMPProxyServer) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRTMPProxyServer) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeRTMPProxyServer) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeRTMPProxyServer) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRTMPProxyServer) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRTMPProxyServer) Run(arg1 context.Context) error { + fake.runMutex.Lock() + ret, specificReturn := fake.runReturnsOnCall[len(fake.runArgsForCall)] + fake.runArgsForCall = append(fake.runArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.RunStub + fakeReturns := fake.runReturns + fake.recordInvocation("Run", []interface{}{arg1}) + fake.runMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRTMPProxyServer) RunCallCount() int { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + return len(fake.runArgsForCall) +} + +func (fake *FakeRTMPProxyServer) RunCalls(stub func(context.Context) error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = stub +} + +func (fake *FakeRTMPProxyServer) RunArgsForCall(i int) context.Context { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + argsForCall := fake.runArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRTMPProxyServer) RunReturns(result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + fake.runReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRTMPProxyServer) RunReturnsOnCall(i int, result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + if fake.runReturnsOnCall == nil { + fake.runReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.runReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRTMPProxyServer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRTMPProxyServer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ proxy.RTMPProxyServer = new(FakeRTMPProxyServer) diff --git a/internal/proxy/proxyfakes/fake_web_rtc_proxy_server.go b/internal/proxy/proxyfakes/fake_web_rtc_proxy_server.go new file mode 100644 index 000000000..5401d5fed --- /dev/null +++ b/internal/proxy/proxyfakes/fake_web_rtc_proxy_server.go @@ -0,0 +1,325 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package proxyfakes + +import ( + "context" + "net/http" + "srsx/internal/proxy" + "sync" +) + +type FakeWebRTCProxyServer struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + HandleApiForWHEPStub func(context.Context, http.ResponseWriter, *http.Request) error + handleApiForWHEPMutex sync.RWMutex + handleApiForWHEPArgsForCall []struct { + arg1 context.Context + arg2 http.ResponseWriter + arg3 *http.Request + } + handleApiForWHEPReturns struct { + result1 error + } + handleApiForWHEPReturnsOnCall map[int]struct { + result1 error + } + HandleApiForWHIPStub func(context.Context, http.ResponseWriter, *http.Request) error + handleApiForWHIPMutex sync.RWMutex + handleApiForWHIPArgsForCall []struct { + arg1 context.Context + arg2 http.ResponseWriter + arg3 *http.Request + } + handleApiForWHIPReturns struct { + result1 error + } + handleApiForWHIPReturnsOnCall map[int]struct { + result1 error + } + RunStub func(context.Context) error + runMutex sync.RWMutex + runArgsForCall []struct { + arg1 context.Context + } + runReturns struct { + result1 error + } + runReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeWebRTCProxyServer) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebRTCProxyServer) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeWebRTCProxyServer) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeWebRTCProxyServer) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHEP(arg1 context.Context, arg2 http.ResponseWriter, arg3 *http.Request) error { + fake.handleApiForWHEPMutex.Lock() + ret, specificReturn := fake.handleApiForWHEPReturnsOnCall[len(fake.handleApiForWHEPArgsForCall)] + fake.handleApiForWHEPArgsForCall = append(fake.handleApiForWHEPArgsForCall, struct { + arg1 context.Context + arg2 http.ResponseWriter + arg3 *http.Request + }{arg1, arg2, arg3}) + stub := fake.HandleApiForWHEPStub + fakeReturns := fake.handleApiForWHEPReturns + fake.recordInvocation("HandleApiForWHEP", []interface{}{arg1, arg2, arg3}) + fake.handleApiForWHEPMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHEPCallCount() int { + fake.handleApiForWHEPMutex.RLock() + defer fake.handleApiForWHEPMutex.RUnlock() + return len(fake.handleApiForWHEPArgsForCall) +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHEPCalls(stub func(context.Context, http.ResponseWriter, *http.Request) error) { + fake.handleApiForWHEPMutex.Lock() + defer fake.handleApiForWHEPMutex.Unlock() + fake.HandleApiForWHEPStub = stub +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHEPArgsForCall(i int) (context.Context, http.ResponseWriter, *http.Request) { + fake.handleApiForWHEPMutex.RLock() + defer fake.handleApiForWHEPMutex.RUnlock() + argsForCall := fake.handleApiForWHEPArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHEPReturns(result1 error) { + fake.handleApiForWHEPMutex.Lock() + defer fake.handleApiForWHEPMutex.Unlock() + fake.HandleApiForWHEPStub = nil + fake.handleApiForWHEPReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHEPReturnsOnCall(i int, result1 error) { + fake.handleApiForWHEPMutex.Lock() + defer fake.handleApiForWHEPMutex.Unlock() + fake.HandleApiForWHEPStub = nil + if fake.handleApiForWHEPReturnsOnCall == nil { + fake.handleApiForWHEPReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleApiForWHEPReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHIP(arg1 context.Context, arg2 http.ResponseWriter, arg3 *http.Request) error { + fake.handleApiForWHIPMutex.Lock() + ret, specificReturn := fake.handleApiForWHIPReturnsOnCall[len(fake.handleApiForWHIPArgsForCall)] + fake.handleApiForWHIPArgsForCall = append(fake.handleApiForWHIPArgsForCall, struct { + arg1 context.Context + arg2 http.ResponseWriter + arg3 *http.Request + }{arg1, arg2, arg3}) + stub := fake.HandleApiForWHIPStub + fakeReturns := fake.handleApiForWHIPReturns + fake.recordInvocation("HandleApiForWHIP", []interface{}{arg1, arg2, arg3}) + fake.handleApiForWHIPMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHIPCallCount() int { + fake.handleApiForWHIPMutex.RLock() + defer fake.handleApiForWHIPMutex.RUnlock() + return len(fake.handleApiForWHIPArgsForCall) +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHIPCalls(stub func(context.Context, http.ResponseWriter, *http.Request) error) { + fake.handleApiForWHIPMutex.Lock() + defer fake.handleApiForWHIPMutex.Unlock() + fake.HandleApiForWHIPStub = stub +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHIPArgsForCall(i int) (context.Context, http.ResponseWriter, *http.Request) { + fake.handleApiForWHIPMutex.RLock() + defer fake.handleApiForWHIPMutex.RUnlock() + argsForCall := fake.handleApiForWHIPArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHIPReturns(result1 error) { + fake.handleApiForWHIPMutex.Lock() + defer fake.handleApiForWHIPMutex.Unlock() + fake.HandleApiForWHIPStub = nil + fake.handleApiForWHIPReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) HandleApiForWHIPReturnsOnCall(i int, result1 error) { + fake.handleApiForWHIPMutex.Lock() + defer fake.handleApiForWHIPMutex.Unlock() + fake.HandleApiForWHIPStub = nil + if fake.handleApiForWHIPReturnsOnCall == nil { + fake.handleApiForWHIPReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleApiForWHIPReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) Run(arg1 context.Context) error { + fake.runMutex.Lock() + ret, specificReturn := fake.runReturnsOnCall[len(fake.runArgsForCall)] + fake.runArgsForCall = append(fake.runArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.RunStub + fakeReturns := fake.runReturns + fake.recordInvocation("Run", []interface{}{arg1}) + fake.runMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebRTCProxyServer) RunCallCount() int { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + return len(fake.runArgsForCall) +} + +func (fake *FakeWebRTCProxyServer) RunCalls(stub func(context.Context) error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = stub +} + +func (fake *FakeWebRTCProxyServer) RunArgsForCall(i int) context.Context { + fake.runMutex.RLock() + defer fake.runMutex.RUnlock() + argsForCall := fake.runArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeWebRTCProxyServer) RunReturns(result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + fake.runReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) RunReturnsOnCall(i int, result1 error) { + fake.runMutex.Lock() + defer fake.runMutex.Unlock() + fake.RunStub = nil + if fake.runReturnsOnCall == nil { + fake.runReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.runReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebRTCProxyServer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeWebRTCProxyServer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ proxy.WebRTCProxyServer = new(FakeWebRTCProxyServer) diff --git a/internal/server/rtc.go b/internal/proxy/rtc.go similarity index 71% rename from internal/server/rtc.go rename to internal/proxy/rtc.go index 7a85e0bbb..48a3e1e8f 100644 --- a/internal/server/rtc.go +++ b/internal/proxy/rtc.go @@ -1,12 +1,13 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package server +package proxy import ( "context" "encoding/binary" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -23,21 +24,24 @@ import ( "srsx/internal/utils" ) -// WebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out +// WebRTCProxyServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out // which backend server to proxy to. It will also replace the UDP port to the proxy server's in the // SDP answer. -type WebRTCServer interface { +type WebRTCProxyServer interface { Run(ctx context.Context) error Close() error HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error } -type webRTCServer struct { +type webRTCProxyServer struct { // The environment interface. environment env.ProxyEnvironment - // The UDP listener for WebRTC server. - listener *net.UDPConn + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer + // The UDP listener for WebRTC server. Stored as net.PacketConn so tests + // can inject a fake listener via listenUDP. + listener net.PacketConn // Fast cache for the username to identify the connection. // The key is username, the value is the UDP address. @@ -49,21 +53,59 @@ type webRTCServer struct { // The wait group for server. wg stdSync.WaitGroup + + // backendURL builds the URL to forward a WHIP/WHEP SDP exchange to a backend + // SRS server. Defaults to "http://:?"; tests may + // override to redirect requests to an httptest.Server. + backendURL func(backend *lb.OriginServer, r *http.Request) (string, error) + + // listenUDP opens the UDP listener for the WebRTC server. Defaults to a real + // net.ListenUDP on the resolved endpoint; tests may override via a functional + // option to supply a fake listener. + listenUDP func(ctx context.Context, endpoint string) (net.PacketConn, error) } -func NewWebRTCServer(environment env.ProxyEnvironment, opts ...func(*webRTCServer)) WebRTCServer { - v := &webRTCServer{ - environment: environment, - usernames: sync.NewMap[string, *rtcConnection](), - addresses: sync.NewMap[string, *rtcConnection](), +func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*webRTCProxyServer)) WebRTCProxyServer { + v := &webRTCProxyServer{ + environment: environment, + loadBalancer: loadBalancer, + usernames: sync.NewMap[string, *rtcConnection](), + addresses: sync.NewMap[string, *rtcConnection](), } + + // Default listenUDP: resolve the endpoint and open a real UDP socket. + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return nil, errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + return net.ListenUDP("udp", saddr) + } + + // Default backendURL: validate API endpoint, parse port, format URL preserving + // the inbound request's path and raw query. + v.backendURL = func(backend *lb.OriginServer, r *http.Request) (string, error) { + if len(backend.API) == 0 { + return "", errors.Errorf("no http api server") + } + apiPort, err := strconv.ParseInt(backend.API[0], 10, 64) + if err != nil { + return "", errors.Wrapf(err, "parse http port %v", backend.API[0]) + } + u := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) + if r.URL.RawQuery != "" { + u += "?" + r.URL.RawQuery + } + return u, nil + } + for _, opt := range opts { opt(v) } return v } -func (v *webRTCServer) Close() error { +func (v *webRTCProxyServer) Close() error { if v.listener != nil { _ = v.listener.Close() } @@ -72,7 +114,7 @@ func (v *webRTCServer) Close() error { return nil } -func (v *webRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *webRTCProxyServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -82,7 +124,7 @@ func (v *webRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWrit } // Read remote SDP offer from body. - remoteSDPOffer, err := ioutil.ReadAll(r.Body) + remoteSDPOffer, err := io.ReadAll(r.Body) if err != nil { return errors.Wrapf(err, "read remote sdp offer") } @@ -97,7 +139,7 @@ func (v *webRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWrit } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -109,7 +151,7 @@ func (v *webRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWrit return nil } -func (v *webRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *webRTCProxyServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -119,7 +161,7 @@ func (v *webRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWrit } // Read remote SDP offer from body. - remoteSDPOffer, err := ioutil.ReadAll(r.Body) + remoteSDPOffer, err := io.ReadAll(r.Body) if err != nil { return errors.Wrapf(err, "read remote sdp offer") } @@ -134,7 +176,7 @@ func (v *webRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWrit } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -146,26 +188,15 @@ func (v *webRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWrit return nil } -func (v *webRTCServer) proxyApiToBackend( - ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer, +func (v *webRTCProxyServer) proxyApiToBackend( + ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.OriginServer, remoteSDPOffer string, streamURL string, ) error { - // Parse HTTP port from backend. - if len(backend.API) == 0 { - return errors.Errorf("no http api server") - } - - var apiPort int - if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse http port %v", backend.API[0]) - } else { - apiPort = int(iv) - } - - // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) - if r.URL.RawQuery != "" { - backendURL += "?" + r.URL.RawQuery + // Resolve the backend URL via the configurable seam (so tests can redirect to + // an httptest.Server). + backendURL, err := v.backendURL(backend, r) + if err != nil { + return errors.Wrapf(err, "build backend url") } req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer)) @@ -226,7 +257,8 @@ func (v *webRTCServer) proxyApiToBackend( RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, } - if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) { + if err := v.loadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) { + c.loadBalancer = v.loadBalancer c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() c.Initialize(ctx, v.listener) @@ -246,24 +278,19 @@ func (v *webRTCServer) proxyApiToBackend( return nil } -func (v *webRTCServer) Run(ctx context.Context) error { +func (v *webRTCProxyServer) Run(ctx context.Context) error { // Parse address to listen. endpoint := v.environment.WebRTCServer() if !strings.Contains(endpoint, ":") { endpoint = fmt.Sprintf(":%v", endpoint) } - saddr, err := net.ResolveUDPAddr("udp", endpoint) + listener, err := v.listenUDP(ctx, endpoint) if err != nil { - return errors.Wrapf(err, "resolve udp addr %v", endpoint) - } - - listener, err := net.ListenUDP("udp", saddr) - if err != nil { - return errors.Wrapf(err, "listen udp %v", saddr) + return errors.Wrapf(err, "listen udp %v", endpoint) } v.listener = listener - logger.Debug(ctx, "WebRTC server listen at %v", saddr) + logger.Debug(ctx, "WebRTC server listen at %v", listener.LocalAddr()) // Consume all messages from UDP media transport. v.wg.Add(1) @@ -272,7 +299,7 @@ func (v *webRTCServer) Run(ctx context.Context) error { for ctx.Err() == nil { buf := make([]byte, 4096) - n, caddr, err := listener.ReadFromUDP(buf) + n, addr, err := listener.ReadFrom(buf) if err != nil { // If context is canceled or connection is closed, exit gracefully without logging error. if ctx.Err() != nil || utils.IsClosedNetworkError(err) { @@ -285,8 +312,8 @@ func (v *webRTCServer) Run(ctx context.Context) error { continue } - if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { - logger.Warn(ctx, "WebRTC handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil { + logger.Warn(ctx, "WebRTC handle udp %vB failed, addr=%v, err=%+v", n, addr, err) } } }() @@ -294,7 +321,7 @@ func (v *webRTCServer) Run(ctx context.Context) error { return nil } -func (v *webRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { +func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr net.Addr, data []byte) error { var connection *rtcConnection // If STUN binding request, parse the ufrag and identify the connection. @@ -315,10 +342,11 @@ func (v *webRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, d } // Load connection by username. - if s, err := lb.SrsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + if s, err := v.loadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) } else { connection = s.(*rtcConnection).Initialize(ctx, v.listener) + connection.loadBalancer = v.loadBalancer logger.Debug(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) } @@ -366,29 +394,45 @@ func (v *webRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, d type rtcConnection struct { // The stream context for WebRTC streaming. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The stream URL in vhost/app/stream schema. StreamURL string `json:"stream_url"` // The ufrag for this WebRTC connection. Ufrag string `json:"ufrag"` - // The UDP connection proxy to backend. - backendUDP *net.UDPConn + // The UDP connection proxy to backend. Stored as io.ReadWriteCloser so tests + // can inject a fake connection by overriding dialBackendUDP. + backendUDP io.ReadWriteCloser // The client UDP address. Note that it may change. - clientUDP *net.UDPAddr - // The listener UDP connection, used to send messages to client. - listenerUDP *net.UDPConn + clientUDP net.Addr + // The listener UDP connection, used to send messages to client. Stored as + // net.PacketConn so tests can inject a fake listener. + listenerUDP net.PacketConn + + // dialBackendUDP opens a UDP connection to a backend SRS server. Defaults to a real + // UDP dial; tests may override via a functional option to supply a fake connection. + dialBackendUDP func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) } func newRTCConnection(opts ...func(*rtcConnection)) *rtcConnection { v := &rtcConnection{} + + // Default dial: a real UDP connection to the backend. Uses Dialer.DialContext + // so ctx cancellation/deadline aborts DNS resolution (UDP itself has no handshake). + v.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + var d net.Dialer + return d.DialContext(ctx, "udp", net.JoinHostPort(ip, strconv.Itoa(port))) + } + for _, opt := range opts { opt(v) } return v } -func (v *rtcConnection) Initialize(ctx context.Context, listener *net.UDPConn) *rtcConnection { +func (v *rtcConnection) Initialize(ctx context.Context, listener net.PacketConn) *rtcConnection { if v.ctx == nil { v.ctx = logger.WithContext(ctx) } @@ -402,7 +446,7 @@ func (v *rtcConnection) GetUfrag() string { return v.Ufrag } -func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { +func (v *rtcConnection) HandlePacket(addr net.Addr, data []byte) error { ctx := v.ctx // Update the current UDP address. @@ -422,14 +466,14 @@ func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { go func() { for ctx.Err() == nil { buf := make([]byte, 4096) - n, _, err := v.backendUDP.ReadFromUDP(buf) + n, err := v.backendUDP.Read(buf) if err != nil { // TODO: If backend server closed unexpectedly, we should notice the stream to quit. logger.Warn(ctx, "read from backend failed, err=%v", err) break } - if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + if _, err = v.listenerUDP.WriteTo(buf[:n], v.clientUDP); err != nil { // TODO: If backend server closed unexpectedly, we should notice the stream to quit. logger.Warn(ctx, "write to client failed, err=%v", err) break @@ -450,7 +494,7 @@ func (v *rtcConnection) connectBackend(ctx context.Context) error { } // Pick a backend SRS server to proxy the RTC stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, v.StreamURL) + backend, err := v.loadBalancer.Pick(ctx, v.StreamURL) if err != nil { return errors.Wrapf(err, "pick backend") } @@ -467,12 +511,11 @@ func (v *rtcConnection) connectBackend(ctx context.Context) error { // Connect to backend SRS server via UDP client. // TODO: FIXME: Support close the connection when timeout or DTLS alert. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} - if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { - return errors.Wrapf(err, "dial udp to %v", backendAddr) - } else { - v.backendUDP = backendUDP + backendUDP, err := v.dialBackendUDP(ctx, backend.IP, int(udpPort)) + if err != nil { + return errors.Wrapf(err, "dial udp to %v:%v", backend.IP, udpPort) } + v.backendUDP = backendUDP return nil } diff --git a/internal/proxy/rtc_test.go b/internal/proxy/rtc_test.go new file mode 100644 index 000000000..64f3ac8cc --- /dev/null +++ b/internal/proxy/rtc_test.go @@ -0,0 +1,1111 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" +) + +func TestRtcICEPair_Ufrag(t *testing.T) { + cases := []struct { + name string + pair rtcICEPair + want string + }{ + { + name: "typical", + pair: rtcICEPair{ + RemoteICEUfrag: "remote-ufrag", + RemoteICEPwd: "remote-pwd", + LocalICEUfrag: "local-ufrag", + LocalICEPwd: "local-pwd", + }, + want: "local-ufrag:remote-ufrag", + }, + { + name: "both empty", + pair: rtcICEPair{}, + want: ":", + }, + { + name: "only local", + pair: rtcICEPair{LocalICEUfrag: "L"}, + want: "L:", + }, + { + name: "only remote", + pair: rtcICEPair{RemoteICEUfrag: "R"}, + want: ":R", + }, + { + name: "pwd fields do not affect ufrag", + pair: rtcICEPair{ + RemoteICEUfrag: "r", + RemoteICEPwd: "should-be-ignored", + LocalICEUfrag: "l", + LocalICEPwd: "should-be-ignored", + }, + want: "l:r", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := c.pair.Ufrag(); got != c.want { + t.Fatalf("Ufrag()=%q, want %q", got, c.want) + } + }) + } +} + +// fakeBackendUDP is an in-memory io.ReadWriteCloser standing in for the dialed +// UDP socket. Writes are captured on a channel; reads block until reads is fed +// or closed (in which case Read returns io.EOF). +type fakeBackendUDP struct { + writes chan []byte + reads chan []byte + closed atomic.Bool + writeErr error + readErr error + readOnce atomic.Bool // when set, second Read returns io.EOF to terminate the goroutine + bytesRead atomic.Int64 +} + +func newFakeBackendUDP() *fakeBackendUDP { + return &fakeBackendUDP{ + writes: make(chan []byte, 16), + reads: make(chan []byte, 16), + } +} + +func (f *fakeBackendUDP) Read(p []byte) (int, error) { + if f.readErr != nil { + return 0, f.readErr + } + data, ok := <-f.reads + if !ok { + return 0, io.EOF + } + n := copy(p, data) + f.bytesRead.Add(int64(n)) + return n, nil +} + +func (f *fakeBackendUDP) Write(p []byte) (int, error) { + if f.writeErr != nil { + return 0, f.writeErr + } + cp := make([]byte, len(p)) + copy(cp, p) + f.writes <- cp + return len(p), nil +} + +func (f *fakeBackendUDP) Close() error { + if f.closed.CompareAndSwap(false, true) { + close(f.reads) + } + return nil +} + +// fakePacketConn is an in-memory net.PacketConn standing in for the proxy's +// UDP listener. Only WriteTo is exercised; the other methods are no-ops. +type fakePacketConn struct { + writes chan packetWrite + writeErr error +} + +type packetWrite struct { + data []byte + addr net.Addr +} + +func newFakePacketConn() *fakePacketConn { + return &fakePacketConn{writes: make(chan packetWrite, 16)} +} + +func (f *fakePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + if f.writeErr != nil { + return 0, f.writeErr + } + cp := make([]byte, len(p)) + copy(cp, p) + f.writes <- packetWrite{data: cp, addr: addr} + return len(p), nil +} + +func (f *fakePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { return 0, nil, io.EOF } +func (f *fakePacketConn) Close() error { return nil } +func (f *fakePacketConn) LocalAddr() net.Addr { return nil } +func (f *fakePacketConn) SetDeadline(time.Time) error { return nil } +func (f *fakePacketConn) SetReadDeadline(time.Time) error { return nil } +func (f *fakePacketConn) SetWriteDeadline(time.Time) error { return nil } + +func TestNewRTCConnection(t *testing.T) { + t.Run("defaults dialBackendUDP", func(t *testing.T) { + c := newRTCConnection() + if c.dialBackendUDP == nil { + t.Fatal("expected dialBackendUDP to be defaulted") + } + }) + + t.Run("applies functional options", func(t *testing.T) { + c := newRTCConnection(func(c *rtcConnection) { + c.StreamURL = "vhost/app/stream" + c.Ufrag = "L:R" + }) + if c.StreamURL != "vhost/app/stream" { + t.Fatalf("StreamURL=%q", c.StreamURL) + } + if c.Ufrag != "L:R" { + t.Fatalf("Ufrag=%q", c.Ufrag) + } + }) + + t.Run("options override default dialBackendUDP", func(t *testing.T) { + called := false + dial := func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + called = true + return nil, nil + } + c := newRTCConnection(func(c *rtcConnection) { c.dialBackendUDP = dial }) + _, _ = c.dialBackendUDP(context.Background(), "", 0) + if !called { + t.Fatal("expected overridden dialBackendUDP to be invoked") + } + }) +} + +func TestRtcConnection_Initialize(t *testing.T) { + t.Run("sets ctx when nil", func(t *testing.T) { + c := newRTCConnection() + listener := newFakePacketConn() + ret := c.Initialize(context.Background(), listener) + if c.ctx == nil { + t.Fatal("expected ctx to be set") + } + if c.listenerUDP != listener { + t.Fatal("expected listenerUDP to be set") + } + if ret != c { + t.Fatal("expected Initialize to return receiver") + } + }) + + t.Run("does not overwrite existing ctx", func(t *testing.T) { + type ctxKey struct{} + original := context.WithValue(context.Background(), ctxKey{}, "marker") + c := newRTCConnection(func(c *rtcConnection) { c.ctx = original }) + c.Initialize(context.Background(), nil) + if got := c.ctx.Value(ctxKey{}); got != "marker" { + t.Fatalf("ctx was overwritten; got value=%v", got) + } + }) + + t.Run("nil listener does not overwrite existing", func(t *testing.T) { + existing := newFakePacketConn() + c := newRTCConnection(func(c *rtcConnection) { c.listenerUDP = existing }) + c.Initialize(context.Background(), nil) + if c.listenerUDP != existing { + t.Fatal("nil listener overwrote existing listenerUDP") + } + }) +} + +func TestRtcConnection_GetUfrag(t *testing.T) { + c := newRTCConnection(func(c *rtcConnection) { c.Ufrag = "abc:def" }) + if got := c.GetUfrag(); got != "abc:def" { + t.Fatalf("GetUfrag()=%q", got) + } +} + +// rtcConnFixture wires an rtcConnection with fakes for the load balancer, +// listener, and backend dial seam. +type rtcConnFixture struct { + conn *rtcConnection + lb *lbfakes.FakeOriginLoadBalancer + listener *fakePacketConn + backend *fakeBackendUDP + dialErr error + dialIP string + dialPort int +} + +func newRtcConnFixture() *rtcConnFixture { + f := &rtcConnFixture{ + lb: &lbfakes.FakeOriginLoadBalancer{}, + listener: newFakePacketConn(), + backend: newFakeBackendUDP(), + } + f.conn = newRTCConnection(func(c *rtcConnection) { + c.loadBalancer = f.lb + c.StreamURL = "vhost/app/stream" + c.Ufrag = "L:R" + c.listenerUDP = f.listener + c.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + f.dialIP, f.dialPort = ip, port + if f.dialErr != nil { + return nil, f.dialErr + } + return f.backend, nil + } + }) + return f +} + +func TestRtcConnection_ConnectBackend(t *testing.T) { + t.Run("noop when already connected", func(t *testing.T) { + f := newRtcConnFixture() + f.conn.backendUDP = f.backend + if err := f.conn.connectBackend(context.Background()); err != nil { + t.Fatalf("unexpected err=%v", err) + } + if f.lb.PickCallCount() != 0 { + t.Fatal("expected Pick not to be called when already connected") + } + }) + + t.Run("propagates Pick error", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(nil, errors.New("boom")) + err := f.conn.connectBackend(context.Background()) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected pick err, got %v", err) + } + }) + + t.Run("errors when backend has no RTC endpoints", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil) + err := f.conn.connectBackend(context.Background()) + if err == nil || !strings.Contains(err.Error(), "no udp server") { + t.Fatalf("expected no-udp-server err, got %v", err) + } + }) + + t.Run("propagates dial error", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil) + f.dialErr = errors.New("dial-failed") + err := f.conn.connectBackend(context.Background()) + if err == nil || !strings.Contains(err.Error(), "dial-failed") { + t.Fatalf("expected dial err, got %v", err) + } + if f.conn.backendUDP != nil { + t.Fatal("backendUDP should remain nil on dial failure") + } + }) + + t.Run("success sets backendUDP and forwards ip/port", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.5", RTC: []string{"18000"}}, nil) + if err := f.conn.connectBackend(context.Background()); err != nil { + t.Fatalf("unexpected err=%v", err) + } + if f.conn.backendUDP != f.backend { + t.Fatal("backendUDP not set") + } + if f.dialIP != "10.0.0.5" || f.dialPort != 18000 { + t.Fatalf("dial got ip=%q port=%d", f.dialIP, f.dialPort) + } + }) +} + +func TestRtcConnection_HandlePacket(t *testing.T) { + t.Run("writes data to backend and stores client addr", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + f.conn.Initialize(ctx, f.listener) + + clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.5"), Port: 5000} + payload := []byte("hello-backend") + if err := f.conn.HandlePacket(clientAddr, payload); err != nil { + t.Fatalf("HandlePacket err=%v", err) + } + + if f.conn.clientUDP != clientAddr { + t.Fatal("clientUDP not updated") + } + + select { + case got := <-f.backend.writes: + if string(got) != string(payload) { + t.Fatalf("backend got %q, want %q", got, payload) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for backend write") + } + }) + + t.Run("propagates connectBackend error", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(nil, errors.New("pick-fail")) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + f.conn.Initialize(ctx, f.listener) + + err := f.conn.HandlePacket(&net.UDPAddr{}, []byte("x")) + if err == nil || !strings.Contains(err.Error(), "pick-fail") { + t.Fatalf("expected propagated pick err, got %v", err) + } + }) + + t.Run("propagates backend write error", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil) + f.backend.writeErr = errors.New("write-fail") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + f.conn.Initialize(ctx, f.listener) + + err := f.conn.HandlePacket(&net.UDPAddr{}, []byte("x")) + if err == nil || !strings.Contains(err.Error(), "write-fail") { + t.Fatalf("expected propagated write err, got %v", err) + } + }) + + t.Run("backend reads are forwarded to listener", func(t *testing.T) { + f := newRtcConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + f.conn.Initialize(ctx, f.listener) + + clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.5"), Port: 5000} + if err := f.conn.HandlePacket(clientAddr, []byte("trigger")); err != nil { + t.Fatalf("HandlePacket err=%v", err) + } + // drain the trigger packet sent to backend + <-f.backend.writes + + // Feed a packet from the backend; expect it forwarded to the listener. + f.backend.reads <- []byte("from-backend") + + select { + case got := <-f.listener.writes: + if string(got.data) != "from-backend" { + t.Fatalf("listener got %q, want %q", got.data, "from-backend") + } + if got.addr != clientAddr { + t.Fatalf("listener addr=%v, want %v", got.addr, clientAddr) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for listener write") + } + + // Cleanly terminate the read loop. + _ = f.backend.Close() + }) +} + +// --------------------------------------------------------------------------- +// webRTCProxyServer: fakes, helpers, and fixtures +// --------------------------------------------------------------------------- + +// blockingUDPListener stands in for the WebRTC UDP listener used by Run(). +// ReadFrom blocks until packets are pushed via push(); Close unblocks the +// reader with a "use of closed network connection" error so the accept loop +// hits utils.IsClosedNetworkError and exits gracefully. +type blockingUDPListener struct { + packets chan udpPacket + writes chan packetWrite + closed atomic.Bool +} + +type udpPacket struct { + data []byte + addr net.Addr +} + +func newBlockingUDPListener() *blockingUDPListener { + return &blockingUDPListener{ + packets: make(chan udpPacket, 8), + writes: make(chan packetWrite, 16), + } +} + +func (l *blockingUDPListener) push(p udpPacket) { l.packets <- p } + +func (l *blockingUDPListener) ReadFrom(buf []byte) (int, net.Addr, error) { + p, ok := <-l.packets + if !ok { + return 0, nil, errors.New("use of closed network connection") + } + n := copy(buf, p.data) + return n, p.addr, nil +} + +func (l *blockingUDPListener) WriteTo(p []byte, addr net.Addr) (int, error) { + cp := make([]byte, len(p)) + copy(cp, p) + l.writes <- packetWrite{data: cp, addr: addr} + return len(p), nil +} + +func (l *blockingUDPListener) Close() error { + if l.closed.CompareAndSwap(false, true) { + close(l.packets) + } + return nil +} + +func (l *blockingUDPListener) LocalAddr() net.Addr { return fakeAddr{} } +func (l *blockingUDPListener) SetDeadline(time.Time) error { return nil } +func (l *blockingUDPListener) SetReadDeadline(time.Time) error { return nil } +func (l *blockingUDPListener) SetWriteDeadline(time.Time) error { return nil } + +// newStunBindingRequest builds a minimal STUN binding request packet whose +// USERNAME attribute (type 0x0006) carries the given ufrag. The first byte is +// 0x00 so utils.RtcIsSTUN returns true; the header's message-length field +// matches the attribute body so rtcStunPacket.UnmarshalBinary succeeds. +func newStunBindingRequest(ufrag string) []byte { + body := make([]byte, 0, 4+len(ufrag)+3) + body = append(body, 0x00, 0x06) + body = append(body, byte(len(ufrag)>>8), byte(len(ufrag))) + body = append(body, []byte(ufrag)...) + for len(body)%4 != 0 { + body = append(body, 0) + } + + hdr := make([]byte, 20) + binary.BigEndian.PutUint16(hdr[0:2], 0x0001) + binary.BigEndian.PutUint16(hdr[2:4], uint16(len(body))) + return append(hdr, body...) +} + +// fakeNonStunPacket builds a UDP payload whose first byte is neither 0/1 (so +// utils.RtcIsSTUN returns false) nor a valid RTP marker, so handleClientUDP +// treats it as "unknown" and skips parsing. +func fakeNonStunPacket() []byte { return []byte{0x42, 0x00, 0x00, 0x00} } + +// fakeRTPPacket builds a minimal payload that satisfies utils.RtcIsRTPOrRTCP +// (len >= 12, first byte 0x80) so handleClientUDP's STUN parser is skipped. +func fakeRTPPacket() []byte { + p := make([]byte, 12) + p[0] = 0x80 + return p +} + +// webRTCFixture bundles fakes plus a webRTCProxyServer wired against them. +// The default listenUDP returns the fixture's blocking listener; tests can +// either drive Run() through it or call handler methods directly without +// starting Run() at all. +type webRTCFixture struct { + env *envfakes.FakeProxyEnvironment + lb *lbfakes.FakeOriginLoadBalancer + listener *blockingUDPListener + server *webRTCProxyServer +} + +func newWebRTCFixture() *webRTCFixture { + f := &webRTCFixture{ + env: &envfakes.FakeProxyEnvironment{}, + lb: &lbfakes.FakeOriginLoadBalancer{}, + listener: newBlockingUDPListener(), + } + f.env.WebRTCServerReturns("18000") + + srv := NewWebRTCProxyServer(f.env, f.lb, func(v *webRTCProxyServer) { + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + return f.listener, nil + } + }) + f.server = srv.(*webRTCProxyServer) + return f +} + +// sampleSDPOffer is a minimal valid SDP offer with the ICE attributes +// ParseIceUfragPwd looks for. Used as the WHIP/WHEP request body. +const sampleSDPOffer = "v=0\r\n" + + "a=ice-ufrag:remote-ufrag\r\n" + + "a=ice-pwd:remote-pwd-very-long-value-32xx\r\n" + +// sampleSDPAnswer returns an SDP answer where the backend's RTC port appears +// in a candidate line so the proxy's port-rewrite path can be exercised. +func sampleSDPAnswer(port string) string { + return "v=0\r\n" + + "a=ice-ufrag:local-ufrag\r\n" + + "a=ice-pwd:local-pwd-very-long-value-32xxxx\r\n" + + "a=candidate:1 1 udp 1 1.2.3.4 " + port + " typ host\r\n" +} + +// --------------------------------------------------------------------------- +// NewWebRTCProxyServer: constructor & defaults +// --------------------------------------------------------------------------- + +func TestNewWebRTCProxyServer_SetsDefaults(t *testing.T) { + srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + v := srv.(*webRTCProxyServer) + if v.listenUDP == nil { + t.Fatal("listenUDP should default to a non-nil factory") + } + if v.backendURL == nil { + t.Fatal("backendURL should default to a non-nil factory") + } +} + +func TestNewWebRTCProxyServer_DefaultBackendURL_NoAPI(t *testing.T) { + srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + v := srv.(*webRTCProxyServer) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/", strings.NewReader("")) + _, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1"}, req) + if err == nil || !strings.Contains(err.Error(), "no http api server") { + t.Fatalf("expected no-api error, got %v", err) + } +} + +func TestNewWebRTCProxyServer_DefaultBackendURL_BadPort(t *testing.T) { + srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + v := srv.(*webRTCProxyServer) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/x", strings.NewReader("")) + _, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1", API: []string{"not-a-port"}}, req) + if err == nil || !strings.Contains(err.Error(), "parse http port") { + t.Fatalf("expected parse-port error, got %v", err) + } +} + +func TestNewWebRTCProxyServer_DefaultBackendURL_Success(t *testing.T) { + srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + v := srv.(*webRTCProxyServer) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader("")) + got, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}}, req) + if err != nil { + t.Fatalf("backendURL: %v", err) + } + want := "http://10.0.0.1:1985/rtc/v1/whip/?app=live&stream=demo" + if got != want { + t.Fatalf("backendURL=%q, want %q", got, want) + } +} + +func TestNewWebRTCProxyServer_DefaultBackendURL_NoQuery(t *testing.T) { + // When the inbound request has no raw query, the URL must not get a + // dangling "?" appended. + srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + v := srv.(*webRTCProxyServer) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whep/", strings.NewReader("")) + got, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}}, req) + if err != nil { + t.Fatalf("backendURL: %v", err) + } + want := "http://10.0.0.1:1985/rtc/v1/whep/" + if got != want { + t.Fatalf("backendURL=%q, want %q", got, want) + } +} + +func TestNewWebRTCProxyServer_AppliesOptions(t *testing.T) { + var listenCalls, backendCalls atomic.Int32 + srv := NewWebRTCProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + func(v *webRTCProxyServer) { + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + listenCalls.Add(1) + return nil, errors.New("unused") + } + v.backendURL = func(backend *lb.OriginServer, r *http.Request) (string, error) { + backendCalls.Add(1) + return "http://example.test", nil + } + }, + ) + v := srv.(*webRTCProxyServer) + _, _ = v.listenUDP(context.Background(), ":0") + _, _ = v.backendURL(&lb.OriginServer{}, httptest.NewRequest(http.MethodGet, "/", nil)) + if got := listenCalls.Load(); got != 1 { + t.Fatalf("custom listenUDP called %d times, want 1", got) + } + if got := backendCalls.Load(); got != 1 { + t.Fatalf("custom backendURL called %d times, want 1", got) + } +} + +// --------------------------------------------------------------------------- +// webRTCProxyServer.Close +// --------------------------------------------------------------------------- + +func TestWebRTCProxyServer_Close_NilListener(t *testing.T) { + // Close before Run must not panic, must not hang, and must not error. + srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + done := make(chan error, 1) + go func() { done <- srv.Close() }() + select { + case err := <-done: + if err != nil { + t.Fatalf("Close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Close hung with no listener") + } +} + +// --------------------------------------------------------------------------- +// webRTCProxyServer.Run +// --------------------------------------------------------------------------- + +func TestWebRTCProxyServer_Run_ListenError(t *testing.T) { + envFake := &envfakes.FakeProxyEnvironment{} + envFake.WebRTCServerReturns("18000") + srv := NewWebRTCProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *webRTCProxyServer) { + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + return nil, errors.New("permission denied") + } + }) + + err := srv.Run(context.Background()) + if err == nil { + t.Fatal("expected error from Run when listenUDP fails") + } + if !strings.Contains(err.Error(), "listen udp") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestWebRTCProxyServer_Run_EndpointWithoutColon(t *testing.T) { + // A bare port like "18000" must be normalized to ":18000". + envFake := &envfakes.FakeProxyEnvironment{} + envFake.WebRTCServerReturns("18000") + listener := newBlockingUDPListener() + var captured atomic.Value + srv := NewWebRTCProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *webRTCProxyServer) { + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + captured.Store(endpoint) + return listener, nil + } + }) + + if err := srv.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + defer srv.Close() + + if got := captured.Load(); got != ":18000" { + t.Fatalf("listenUDP endpoint=%v, want :18000", got) + } +} + +func TestWebRTCProxyServer_Run_EndpointWithColon(t *testing.T) { + // An endpoint that already contains ":" must be passed through unchanged. + envFake := &envfakes.FakeProxyEnvironment{} + envFake.WebRTCServerReturns("127.0.0.1:18000") + listener := newBlockingUDPListener() + var captured atomic.Value + srv := NewWebRTCProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *webRTCProxyServer) { + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + captured.Store(endpoint) + return listener, nil + } + }) + + if err := srv.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + defer srv.Close() + + if got := captured.Load(); got != "127.0.0.1:18000" { + t.Fatalf("listenUDP endpoint=%v, want 127.0.0.1:18000", got) + } +} + +func TestWebRTCProxyServer_Run_CloseStopsReadLoop(t *testing.T) { + // Start Run with an idle listener (no packets queued). The read goroutine + // blocks in ReadFrom. Close must unblock it via the "closed network + // connection" error and allow the wait group to drain. + f := newWebRTCFixture() + if err := f.server.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + + done := make(chan error, 1) + go func() { done <- f.server.Close() }() + select { + case err := <-done: + if err != nil { + t.Fatalf("Close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Close hung — read loop did not exit on listener close") + } +} + +// --------------------------------------------------------------------------- +// webRTCProxyServer.HandleApiForWHIP / HandleApiForWHEP +// --------------------------------------------------------------------------- + +func TestWebRTCProxyServer_HandleApiForWHIP_CORSPreflight(t *testing.T) { + // OPTIONS short-circuits before reading the body, so the LB is untouched. + f := newWebRTCFixture() + req := httptest.NewRequest(http.MethodOptions, "http://example.com/rtc/v1/whip/", nil) + rec := httptest.NewRecorder() + + if err := f.server.HandleApiForWHIP(context.Background(), rec, req); err != nil { + t.Fatalf("WHIP: %v", err) + } + if rec.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", rec.Code) + } + if f.lb.PickCallCount() != 0 { + t.Fatal("LB.Pick should not be called for CORS preflight") + } +} + +func TestWebRTCProxyServer_HandleApiForWHEP_CORSPreflight(t *testing.T) { + f := newWebRTCFixture() + req := httptest.NewRequest(http.MethodOptions, "http://example.com/rtc/v1/whep/", nil) + rec := httptest.NewRecorder() + + if err := f.server.HandleApiForWHEP(context.Background(), rec, req); err != nil { + t.Fatalf("WHEP: %v", err) + } + if rec.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", rec.Code) + } + if f.lb.PickCallCount() != 0 { + t.Fatal("LB.Pick should not be called for CORS preflight") + } +} + +func TestWebRTCProxyServer_HandleApiForWHIP_PickError(t *testing.T) { + f := newWebRTCFixture() + f.lb.PickReturns(nil, errors.New("no backend")) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + err := f.server.HandleApiForWHIP(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "pick backend") { + t.Fatalf("expected pick-backend error, got %v", err) + } +} + +func TestWebRTCProxyServer_HandleApiForWHEP_PickError(t *testing.T) { + f := newWebRTCFixture() + f.lb.PickReturns(nil, errors.New("no backend")) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whep/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + err := f.server.HandleApiForWHEP(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "pick backend") { + t.Fatalf("expected pick-backend error, got %v", err) + } +} + +func TestWebRTCProxyServer_HandleApiForWHIP_HappyPath(t *testing.T) { + // Drive a full WHIP exchange: the proxy forwards the offer to an httptest + // backend, rewrites the UDP port in the answer, and calls StoreWebRTC. + f := newWebRTCFixture() + f.env.WebRTCServerReturns("19000") + + const backendRTCPort = "18000" + var backendSawOffer atomic.Bool + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if string(body) == sampleSDPOffer { + backendSawOffer.Store(true) + } + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(sampleSDPAnswer(backendRTCPort))) + })) + defer backend.Close() + + // Override backendURL so the proxy talks to the httptest server. + f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) { + return backend.URL + r.URL.Path, nil + } + + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{backendRTCPort}}, nil) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + if err := f.server.HandleApiForWHIP(context.Background(), rec, req); err != nil { + t.Fatalf("WHIP: %v", err) + } + if !backendSawOffer.Load() { + t.Fatal("backend did not receive the SDP offer body") + } + if rec.Code != http.StatusCreated { + t.Fatalf("client status=%d, want 201", rec.Code) + } + body := rec.Body.String() + if !strings.Contains(body, " 19000 typ host") { + t.Fatalf("answer did not rewrite backend port; got %q", body) + } + if strings.Contains(body, " "+backendRTCPort+" typ host") { + t.Fatalf("answer still contains original backend port; got %q", body) + } + if f.lb.StoreWebRTCCallCount() != 1 { + t.Fatalf("StoreWebRTC called %d times, want 1", f.lb.StoreWebRTCCallCount()) + } + _, streamURL, stored := f.lb.StoreWebRTCArgsForCall(0) + if !strings.HasSuffix(streamURL, "/live/demo") { + t.Fatalf("StoreWebRTC streamURL=%q, want suffix /live/demo", streamURL) + } + if got := stored.GetUfrag(); got != "local-ufrag:remote-ufrag" { + t.Fatalf("stored ufrag=%q, want local-ufrag:remote-ufrag", got) + } +} + +func TestWebRTCProxyServer_HandleApiForWHEP_HappyPath(t *testing.T) { + f := newWebRTCFixture() + f.env.WebRTCServerReturns("19000") + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(sampleSDPAnswer("18000"))) + })) + defer backend.Close() + + f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) { + return backend.URL + r.URL.Path, nil + } + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whep/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + if err := f.server.HandleApiForWHEP(context.Background(), rec, req); err != nil { + t.Fatalf("WHEP: %v", err) + } + if f.lb.StoreWebRTCCallCount() != 1 { + t.Fatalf("StoreWebRTC called %d times, want 1", f.lb.StoreWebRTCCallCount()) + } +} + +// --------------------------------------------------------------------------- +// webRTCProxyServer.proxyApiToBackend: error paths +// --------------------------------------------------------------------------- + +func TestWebRTCProxyServer_ProxyApiToBackend_BackendURLError(t *testing.T) { + f := newWebRTCFixture() + f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) { + return "", errors.New("build err") + } + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + err := f.server.HandleApiForWHIP(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "build err") { + t.Fatalf("expected build err, got %v", err) + } +} + +func TestWebRTCProxyServer_ProxyApiToBackend_BackendNon200(t *testing.T) { + f := newWebRTCFixture() + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + defer backend.Close() + + f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) { + return backend.URL + r.URL.Path, nil + } + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + err := f.server.HandleApiForWHIP(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "proxy api to") { + t.Fatalf("expected proxy-api error, got %v", err) + } +} + +func TestWebRTCProxyServer_ProxyApiToBackend_BadAnswerNoIceUfrag(t *testing.T) { + // Backend returns an answer missing the ice-ufrag/pwd attributes; the + // proxy must surface the ParseIceUfragPwd error rather than calling + // StoreWebRTC. + f := newWebRTCFixture() + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("v=0\r\n")) + })) + defer backend.Close() + + f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) { + return backend.URL + r.URL.Path, nil + } + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer)) + rec := httptest.NewRecorder() + + err := f.server.HandleApiForWHIP(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "parse local sdp answer") { + t.Fatalf("expected parse-answer error, got %v", err) + } + if f.lb.StoreWebRTCCallCount() != 0 { + t.Fatal("StoreWebRTC should not be called when answer is malformed") + } +} + +// --------------------------------------------------------------------------- +// webRTCProxyServer.handleClientUDP +// --------------------------------------------------------------------------- + +func TestWebRTCProxyServer_HandleClientUDP_NonStunIgnored(t *testing.T) { + // A non-STUN, non-RTP/RTCP packet with no cached connection must return + // without touching the LB. + f := newWebRTCFixture() + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7000} + + if err := f.server.handleClientUDP(context.Background(), addr, fakeNonStunPacket()); err != nil { + t.Fatalf("handleClientUDP: %v", err) + } + if f.lb.LoadWebRTCByUfragCallCount() != 0 { + t.Fatal("LB.LoadWebRTCByUfrag should not be called for non-STUN packet") + } +} + +func TestWebRTCProxyServer_HandleClientUDP_RTPLikeIgnored(t *testing.T) { + // An RTP-like packet (first byte 0x80) skips STUN parsing entirely; the + // LB must not be consulted because no connection lookup happens. + f := newWebRTCFixture() + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7001} + + if err := f.server.handleClientUDP(context.Background(), addr, fakeRTPPacket()); err != nil { + t.Fatalf("handleClientUDP: %v", err) + } + if f.lb.LoadWebRTCByUfragCallCount() != 0 { + t.Fatal("LB.LoadWebRTCByUfrag should not be called for RTP-like packet") + } +} + +func TestWebRTCProxyServer_HandleClientUDP_StunBadPacket(t *testing.T) { + // A short payload that satisfies utils.RtcIsSTUN (first byte 0x00) but + // is shorter than the 20-byte STUN header should surface the + // unmarshaler's "too short" error. + f := newWebRTCFixture() + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7002} + + err := f.server.handleClientUDP(context.Background(), addr, []byte{0x00, 0x00, 0x00}) + if err == nil || !strings.Contains(err.Error(), "stun packet too short") { + t.Fatalf("expected too-short err, got %v", err) + } +} + +func TestWebRTCProxyServer_HandleClientUDP_StunCachedUsername(t *testing.T) { + // A STUN packet whose USERNAME matches a connection already in the + // username cache must route directly to that connection. We pre-wire + // the connection so its load balancer fails Pick, so HandlePacket exits + // quickly with a recognizable error and we can assert routing. + f := newWebRTCFixture() + + cachedLB := &lbfakes.FakeOriginLoadBalancer{} + cachedLB.PickReturns(nil, errors.New("test terminate")) + cached := newRTCConnection(func(c *rtcConnection) { + c.loadBalancer = cachedLB + c.StreamURL = "vhost/app/stream" + c.Ufrag = "L:R" + }) + cached.Initialize(context.Background(), f.listener) + f.server.usernames.Store("L:R", cached) + + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7003} + err := f.server.handleClientUDP(context.Background(), addr, newStunBindingRequest("L:R")) + if err == nil || !strings.Contains(err.Error(), "test terminate") { + t.Fatalf("expected terminate err, got %v", err) + } + // The address cache must have learned this addr. + if _, ok := f.server.addresses.Load(addr.String()); !ok { + t.Fatal("expected addr to be cached after routing via username") + } + if f.lb.LoadWebRTCByUfragCallCount() != 0 { + t.Fatal("LB.LoadWebRTCByUfrag should not be called when cached") + } +} + +func TestWebRTCProxyServer_HandleClientUDP_StunLoadsFromLB(t *testing.T) { + // STUN packet whose USERNAME is not in the cache: the proxy must consult + // the load balancer, cache the returned connection by username, and then + // dispatch to it. handleClientUDP rewires the loaded connection's + // loadBalancer to the server's LB, so we make f.lb.Pick fail to keep the + // HandlePacket call deterministic. + f := newWebRTCFixture() + f.lb.PickReturns(nil, errors.New("test terminate")) + + loaded := newRTCConnection(func(c *rtcConnection) { + c.StreamURL = "vhost/app/stream" + c.Ufrag = "L:R" + }) + f.lb.LoadWebRTCByUfragReturns(loaded, nil) + + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7004} + err := f.server.handleClientUDP(context.Background(), addr, newStunBindingRequest("L:R")) + if err == nil || !strings.Contains(err.Error(), "test terminate") { + t.Fatalf("expected terminate err, got %v", err) + } + if got := f.lb.LoadWebRTCByUfragCallCount(); got != 1 { + t.Fatalf("LoadWebRTCByUfrag called %d times, want 1", got) + } + if _, ok := f.server.usernames.Load("L:R"); !ok { + t.Fatal("expected username to be cached after LB load") + } + // The loaded connection should have been rewired to use the server's LB. + if loaded.loadBalancer != f.lb { + t.Fatal("loaded connection should adopt the server's load balancer") + } +} + +func TestWebRTCProxyServer_HandleClientUDP_StunLBError(t *testing.T) { + // LB.LoadWebRTCByUfrag failure must surface as a wrapped error. + f := newWebRTCFixture() + f.lb.LoadWebRTCByUfragReturns(nil, errors.New("lookup failed")) + + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7005} + err := f.server.handleClientUDP(context.Background(), addr, newStunBindingRequest("missing")) + if err == nil || !strings.Contains(err.Error(), "load webrtc by ufrag") { + t.Fatalf("expected load-webrtc err, got %v", err) + } +} + +func TestWebRTCProxyServer_HandleClientUDP_UsesCachedAddress(t *testing.T) { + // A non-STUN packet from an address already in the address cache must be + // dispatched to the cached connection without consulting the LB. + f := newWebRTCFixture() + cachedLB := &lbfakes.FakeOriginLoadBalancer{} + cachedLB.PickReturns(nil, errors.New("test terminate")) + cached := newRTCConnection(func(c *rtcConnection) { + c.loadBalancer = cachedLB + c.StreamURL = "vhost/app/stream" + c.Ufrag = "L:R" + }) + cached.Initialize(context.Background(), f.listener) + + addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7006} + f.server.addresses.Store(addr.String(), cached) + + err := f.server.handleClientUDP(context.Background(), addr, fakeRTPPacket()) + if err == nil || !strings.Contains(err.Error(), "test terminate") { + t.Fatalf("expected terminate err, got %v", err) + } + if f.lb.LoadWebRTCByUfragCallCount() != 0 { + t.Fatal("LB.LoadWebRTCByUfrag should not be called when address is cached") + } +} diff --git a/internal/server/rtmp.go b/internal/proxy/rtmp.go similarity index 76% rename from internal/server/rtmp.go rename to internal/proxy/rtmp.go index b787e99c8..96cadddf7 100644 --- a/internal/server/rtmp.go +++ b/internal/proxy/rtmp.go @@ -1,11 +1,12 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package server +package proxy import ( "context" "fmt" + "io" "net" "strconv" "strings" @@ -20,32 +21,58 @@ import ( "srsx/internal/version" ) -// RTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS +// RTMPProxyServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS // server. It will figure out the backend server to proxy to. Unlike the edge server, it will // not cache the stream, but just proxy the stream to backend. -type RTMPServer interface { +type RTMPProxyServer interface { Run(ctx context.Context) error Close() error } -type rtmpServer struct { +type rtmpProxyServer struct { // The environment interface. environment env.ProxyEnvironment - // The TCP listener for RTMP server. - listener *net.TCPListener + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer + // The listener for RTMP server. Stored as net.Listener so tests can inject + // a fake listener by overriding listen. + listener net.Listener // The wait group for all goroutines. wg sync.WaitGroup + // listen opens a listener on the given address. Defaults to a real TCP listener; + // tests may override via a functional option to supply a fake listener. + listen func(ctx context.Context, addr string) (net.Listener, error) + // newConnection creates a fresh rtmpConnection wired up with this server's + // load balancer. Defaults to a real rtmpConnection; tests may override via + // a functional option to supply a fake. + newConnection func() *rtmpConnection } -func NewRTMPServer(environment env.ProxyEnvironment, opts ...func(*rtmpServer)) RTMPServer { - v := &rtmpServer{environment: environment} +func NewRTMPProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*rtmpProxyServer)) RTMPProxyServer { + v := &rtmpProxyServer{environment: environment, loadBalancer: loadBalancer} + + // Default listen: a real TCP listener. Uses ListenConfig.Listen so ctx is + // consulted during setup (mainly address resolution); the listener itself + // is still torn down via Close(), not ctx cancellation. + v.listen = func(ctx context.Context, addr string) (net.Listener, error) { + var lc net.ListenConfig + return lc.Listen(ctx, "tcp", addr) + } + // Default connection factory: a real rtmpConnection wired up with the + // server's load balancer. + v.newConnection = func() *rtmpConnection { + return newRTMPConnection(func(c *rtmpConnection) { + c.loadBalancer = v.loadBalancer + }) + } + for _, opt := range opts { opt(v) } return v } -func (v *rtmpServer) Close() error { +func (v *rtmpProxyServer) Close() error { if v.listener != nil { v.listener.Close() } @@ -54,30 +81,25 @@ func (v *rtmpServer) Close() error { return nil } -func (v *rtmpServer) Run(ctx context.Context) error { +func (v *rtmpProxyServer) Run(ctx context.Context) error { endpoint := v.environment.RtmpServer() if !strings.Contains(endpoint, ":") { endpoint = ":" + endpoint } - addr, err := net.ResolveTCPAddr("tcp", endpoint) + listener, err := v.listen(ctx, endpoint) if err != nil { - return errors.Wrapf(err, "resolve rtmp addr %v", endpoint) - } - - listener, err := net.ListenTCP("tcp", addr) - if err != nil { - return errors.Wrapf(err, "listen rtmp addr %v", addr) + return errors.Wrapf(err, "listen rtmp addr %v", endpoint) } v.listener = listener - logger.Debug(ctx, "RTMP server listen at %v", addr) + logger.Debug(ctx, "RTMP server listen at %v", listener.Addr()) v.wg.Add(1) go func() { defer v.wg.Done() for { - conn, err := v.listener.AcceptTCP() + conn, err := v.listener.Accept() if err != nil { // If context is canceled or connection is closed, exit gracefully without logging error. if ctx.Err() != nil || utils.IsClosedNetworkError(err) { @@ -90,7 +112,7 @@ func (v *rtmpServer) Run(ctx context.Context) error { } v.wg.Add(1) - go func(ctx context.Context, conn *net.TCPConn) { + go func(ctx context.Context, conn net.Conn) { defer v.wg.Done() defer conn.Close() @@ -102,7 +124,7 @@ func (v *rtmpServer) Run(ctx context.Context) error { } } - rc := newRTMPConnection() + rc := v.newConnection() if err := rc.serve(ctx, conn); err != nil { handleErr(err) } else { @@ -122,17 +144,43 @@ func (v *rtmpServer) Run(ctx context.Context) error { // then proxy to the corresponding backend server. All state is in the RTMP request, so this // connection is stateless. type rtmpConnection struct { + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer + // newHandshake creates a fresh RTMP handshake instance. Defaults to a real handshake; + // tests may override via a functional option to supply a fake. + newHandshake func() rtmp.Handshake + // newProtocol creates a fresh RTMP protocol instance over the given stream. Defaults to + // a real protocol; tests may override via a functional option to supply a fake. + newProtocol func(rw io.ReadWriter) rtmp.Protocol + // newBackend creates a fresh backend client wired up with the given clientType and the + // connection's load balancer. Defaults to a real rtmpClientToBackend; tests may override + // via a functional option to supply a fake. + newBackend func(clientType RTMPClientType) *rtmpClientToBackend } func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection { v := &rtmpConnection{} + + // Default handshake factory: a real RTMP handshake. + v.newHandshake = rtmp.NewHandshake + // Default protocol factory: a real RTMP protocol. + v.newProtocol = rtmp.NewProtocol + // Default backend factory: a real rtmpClientToBackend wired up with the connection's + // load balancer and the given clientType. + v.newBackend = func(clientType RTMPClientType) *rtmpClientToBackend { + return newRTMPClientToBackend(func(client *rtmpClientToBackend) { + client.typ = clientType + client.loadBalancer = v.loadBalancer + }) + } + for _, opt := range opts { opt(v) } return v } -func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { +func (v *rtmpConnection) serve(ctx context.Context, conn net.Conn) error { logger.Debug(ctx, "Got RTMP client from %v", conn.RemoteAddr()) // If any goroutine quit, cancel another one. @@ -152,7 +200,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { } // Simple handshake with client. - hs := rtmp.NewHandshake() + hs := v.newHandshake() if _, err := hs.ReadC0S0(conn); err != nil { return errors.Wrapf(err, "read c0") } @@ -172,7 +220,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { return errors.Wrapf(err, "read c2") } - client := rtmp.NewProtocol(conn) + client := v.newProtocol(conn) logger.Debug(ctx, "RTMP simple handshake done") // Expect RTMP connect command with tcUrl. @@ -229,15 +277,16 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { var response rtmp.Packet switch pkt := identifyReq.(type) { case *rtmp.CallPacket: - if pkt.CommandName == "createStream" { + switch pkt.CommandName { + case "createStream": identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) response = identifyRes nextStreamID = 1 identifyRes.SetStreamID(nextStreamID) - } else if pkt.CommandName == "getStreamLength" { + case "getStreamLength": // Ignore and do not reply these packets. - } else { + default: // For releaseStream, FCPublish, etc. identifyRes := rtmp.NewCallPacket() response = identifyRes @@ -294,9 +343,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { tcUrl, streamName, currentStreamID, clientType) // Find a backend SRS server to proxy the RTMP stream. - backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) { - client.typ = clientType - }) + backend = v.newBackend(clientType) defer backend.Close() if err := backend.Connect(ctx, tcUrl, streamName); err != nil { @@ -304,7 +351,8 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { } // Start the streaming. - if clientType == RTMPClientTypePublisher { + switch clientType { + case RTMPClientTypePublisher: identifyRes := rtmp.NewCallPacket() identifyRes.CommandName = "onStatus" @@ -320,7 +368,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { return errors.Wrapf(err, "start publish") } - } else if clientType == RTMPClientTypeViewer { + case RTMPClientTypeViewer: identifyRes := rtmp.NewCallPacket() identifyRes.CommandName = "onStatus" @@ -423,16 +471,40 @@ const ( // rtmpClientToBackend is an RTMP client to proxy the RTMP stream to backend. type rtmpClientToBackend struct { - // The underlayer tcp client. - tcpConn *net.TCPConn + // The underlayer connection to backend. Stored as io.ReadWriteCloser so tests + // can inject a fake connection by overriding dial. + tcpConn io.ReadWriteCloser // The RTMP protocol client. client rtmp.Protocol // The stream type. typ RTMPClientType + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer + // dial opens a connection to a backend SRS server. Defaults to a real TCP dial; + // tests may override via a functional option to supply a fake connection. + dial func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) + // newHandshake creates a fresh RTMP handshake instance. Defaults to a real handshake; + // tests may override via a functional option to supply a fake. + newHandshake func() rtmp.Handshake + // newProtocol creates a fresh RTMP protocol instance over the given stream. Defaults to + // a real protocol; tests may override via a functional option to supply a fake. + newProtocol func(rw io.ReadWriter) rtmp.Protocol } func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend { v := &rtmpClientToBackend{} + + // Default dial: a real TCP connection to the backend. Uses Dialer.DialContext + // so ctx cancellation/deadline aborts the connect (net.DialTCP ignores ctx). + v.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", net.JoinHostPort(ip, strconv.Itoa(port))) + } + // Default handshake factory: a real RTMP handshake. + v.newHandshake = rtmp.NewHandshake + // Default protocol factory: a real RTMP protocol. + v.newProtocol = rtmp.NewProtocol + for _, opt := range opts { opt(v) } @@ -454,7 +526,7 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -471,16 +543,15 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str rtmpPort = int(iv) } - // Connect to backend SRS server via TCP client. - addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort} - c, err := net.DialTCP("tcp", nil, addr) + // Connect to backend SRS server. + c, err := v.dial(ctx, backend.IP, rtmpPort) if err != nil { - return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend) + return errors.Wrapf(err, "dial backend ip=%v, port=%v, srs=%v", backend.IP, rtmpPort, backend) } v.tcpConn = c - hs := rtmp.NewHandshake() - client := rtmp.NewProtocol(c) + hs := v.newHandshake() + client := v.newProtocol(c) v.client = client // Simple RTMP handshake with server. @@ -500,7 +571,7 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str if _, err = hs.ReadC2S2(c); err != nil { return errors.Wrapf(err, "read c2") } - logger.Debug(ctx, "backend simple handshake done, server=%v", addr) + logger.Debug(ctx, "backend simple handshake done, server=%v:%v", backend.IP, rtmpPort) if err := hs.WriteC2S2(c, hs.C1S1()); err != nil { return errors.Wrapf(err, "write c2") diff --git a/internal/proxy/rtmp_test.go b/internal/proxy/rtmp_test.go new file mode 100644 index 000000000..77aeabfc6 --- /dev/null +++ b/internal/proxy/rtmp_test.go @@ -0,0 +1,1287 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +import ( + "context" + "errors" + "io" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" + "srsx/internal/rtmp" + "srsx/internal/rtmp/rtmpfakes" +) + +// fakeConn is an in-memory io.ReadWriteCloser used to replace the TCP +// connection returned by dial. Read/Write are no-ops because every protocol +// call on the connection is intercepted by FakeHandshake/FakeProtocol. +type fakeConn struct { + closed atomic.Bool +} + +func (c *fakeConn) Read(p []byte) (int, error) { return 0, io.EOF } +func (c *fakeConn) Write(p []byte) (int, error) { return len(p), nil } +func (c *fakeConn) Close() error { + c.closed.Store(true) + return nil +} + +// backendFixture bundles the fakes plus an rtmpClientToBackend wired against +// them. Tests configure the fakes, then exercise the methods. +type backendFixture struct { + conn *fakeConn + lb *lbfakes.FakeOriginLoadBalancer + handshake *rtmpfakes.FakeHandshake + protocol *rtmpfakes.FakeProtocol + client *rtmpClientToBackend +} + +func newBackendFixture(typ RTMPClientType) *backendFixture { + f := &backendFixture{ + conn: &fakeConn{}, + lb: &lbfakes.FakeOriginLoadBalancer{}, + handshake: &rtmpfakes.FakeHandshake{}, + protocol: &rtmpfakes.FakeProtocol{}, + } + f.client = newRTMPClientToBackend(func(c *rtmpClientToBackend) { + c.typ = typ + c.loadBalancer = f.lb + c.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + return f.conn, nil + } + c.newHandshake = func() rtmp.Handshake { return f.handshake } + c.newProtocol = func(rw io.ReadWriter) rtmp.Protocol { return f.protocol } + }) + return f +} + +// queueDecode programs FakeProtocol.DecodeMessage to return the given packets +// in order, one per call. After the queue is drained, it returns an EOF-ish +// error to fail the test fast instead of looping forever. +func queueDecode(p *rtmpfakes.FakeProtocol, packets ...rtmp.Packet) { + var i atomic.Int32 + p.DecodeMessageStub = func(m rtmp.Message) (rtmp.Packet, error) { + idx := int(i.Add(1)) - 1 + if idx >= len(packets) { + return nil, errors.New("decode queue drained") + } + return packets[idx], nil + } +} + +// readMessageOK programs ReadMessage to always return a fresh empty Message. +// The payload is irrelevant because DecodeMessage is stubbed. +func readMessageOK(p *rtmpfakes.FakeProtocol) { + p.ReadMessageStub = func(ctx context.Context) (rtmp.Message, error) { + return rtmp.NewMessage(), nil + } +} + +// onStatusPacket builds a *rtmp.CallPacket whose Args is an Amf0Object +// carrying the given code. Used to drive both publish() (which inspects +// Args via Amf0Converter) and play() (which uses ArgsCode()). +func onStatusPacket(code string) *rtmp.CallPacket { + pkt := rtmp.NewCallPacket() + pkt.CommandName = "onStatus" + pkt.CommandObject = rtmp.NewAmf0Null() + data := rtmp.NewAmf0Object() + data.Set("code", rtmp.NewAmf0String(code)) + pkt.Args = data + return pkt +} + +func resultCallPacket() *rtmp.CallPacket { + pkt := rtmp.NewCallPacket() + pkt.CommandName = "_result" + return pkt +} + +func createStreamRes(id int) *rtmp.CreateStreamResPacket { + pkt := rtmp.NewCreateStreamResPacket(0) + pkt.SetStreamID(id) + return pkt +} + +// pickOK programs the load balancer to return a backend with one RTMP +// endpoint, mimicking a typical registered SRS origin. +func pickOK(f *backendFixture) { + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", RTMP: []string{"1935"}}, nil) +} + +// --------------------------------------------------------------------------- +// Close() +// --------------------------------------------------------------------------- + +func TestRtmpClientToBackend_Close_NilConn(t *testing.T) { + c := newRTMPClientToBackend() + if err := c.Close(); err != nil { + t.Fatalf("Close with nil tcpConn: %v", err) + } +} + +func TestRtmpClientToBackend_Close_FakeConn(t *testing.T) { + conn := &fakeConn{} + c := newRTMPClientToBackend(func(c *rtmpClientToBackend) { + c.tcpConn = conn + }) + if err := c.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if !conn.closed.Load() { + t.Fatal("fakeConn was not closed") + } +} + +// --------------------------------------------------------------------------- +// Connect() error paths +// --------------------------------------------------------------------------- + +func TestRtmpClientToBackend_Connect_BuildStreamURLError(t *testing.T) { + // url.Parse rejects URLs that start with a colon (no scheme/host parseable), + // so this drives BuildStreamURL's error branch before LB.Pick is reached. + f := newBackendFixture(RTMPClientTypePublisher) + err := f.client.Connect(context.Background(), ":bad-url", "stream") + if err == nil { + t.Fatal("expected error from BuildStreamURL") + } + if !strings.Contains(err.Error(), "build stream url") { + t.Fatalf("unexpected error %v", err) + } + if f.lb.PickCallCount() != 0 { + t.Fatalf("LB.Pick should not be called when URL is bad; got %d calls", f.lb.PickCallCount()) + } +} + +func TestRtmpClientToBackend_Connect_PickError(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + f.lb.PickReturns(nil, errors.New("no backend")) + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "pick backend") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_NoRTMPEndpoint(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1"}, nil) // empty RTMP slice + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "no rtmp server") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_BadRTMPPort(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", RTMP: []string{"not-a-port"}}, nil) + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "parse backend") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_DialError(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + f.client.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + return nil, errors.New("dial refused") + } + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "dial backend") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_DialHonorsCtxCancel(t *testing.T) { + // The default dial uses net.Dialer.DialContext, so a canceled ctx must + // surface as a dial error rather than hanging on the kernel connect. + // We assert this contract by having the test dial honor ctx itself. + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + f.client.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + return nil, ctx.Err() + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already-canceled ctx + err := f.client.Connect(ctx, "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "dial backend") { + t.Fatalf("unexpected error %v", err) + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected ctx.Canceled in chain, got %v", err) + } +} + +func TestRtmpClientToBackend_Connect_HandshakeWriteC0Error(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + f.handshake.WriteC0S0Returns(errors.New("write c0")) + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "write c0") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_HandshakeReadS0Error(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + f.handshake.ReadC0S0Returns(nil, errors.New("read s0")) + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "read s0") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_WriteConnectAppError(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + f.protocol.WritePacketReturns(errors.New("write packet")) + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "write connect app") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Connect_ExpectConnectAppResError(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + // WritePacket succeeds, but ReadMessage inside ExpectPacket fails. + f.protocol.ReadMessageReturns(nil, errors.New("read message")) + + err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "expect connect app res") { + t.Fatalf("unexpected error %v", err) + } +} + +// --------------------------------------------------------------------------- +// Connect() happy paths +// --------------------------------------------------------------------------- + +func TestRtmpClientToBackend_Connect_PublisherHappyPath(t *testing.T) { + f := newBackendFixture(RTMPClientTypePublisher) + pickOK(f) + readMessageOK(f.protocol) + queueDecode(f.protocol, + rtmp.NewConnectAppResPacket(0), // connect app res + resultCallPacket(), // releaseStream res + resultCallPacket(), // FCPublish res + createStreamRes(1), // createStream res + onStatusPacket("NetStream.Publish.Start"), // publish onStatus + ) + + if err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream"); err != nil { + t.Fatalf("Connect: %+v", err) + } + + if f.client.tcpConn != f.conn { + t.Fatal("tcpConn should be the fake conn from dial") + } + if f.client.client != f.protocol { + t.Fatal("client field should be the fake protocol") + } + // One WritePacket each for: connectApp, releaseStream, FCPublish, createStream, publish. + if got := f.protocol.WritePacketCallCount(); got != 5 { + t.Fatalf("WritePacket called %d times, want 5", got) + } +} + +func TestRtmpClientToBackend_Connect_ViewerHappyPath(t *testing.T) { + f := newBackendFixture(RTMPClientTypeViewer) + pickOK(f) + readMessageOK(f.protocol) + queueDecode(f.protocol, + rtmp.NewConnectAppResPacket(0), // connect app res + createStreamRes(1), // createStream res + onStatusPacket("NetStream.Play.Start"), // play onStatus + ) + + if err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream"); err != nil { + t.Fatalf("Connect: %+v", err) + } + + // One WritePacket each for: connectApp, createStream, play. + if got := f.protocol.WritePacketCallCount(); got != 3 { + t.Fatalf("WritePacket called %d times, want 3", got) + } +} + +// --------------------------------------------------------------------------- +// publish() in isolation +// --------------------------------------------------------------------------- + +func newIsolatedBackend(t *testing.T, typ RTMPClientType) (*rtmpClientToBackend, *rtmpfakes.FakeProtocol) { + t.Helper() + p := &rtmpfakes.FakeProtocol{} + readMessageOK(p) + c := newRTMPClientToBackend(func(c *rtmpClientToBackend) { c.typ = typ }) + return c, p +} + +func TestRtmpClientToBackend_Publish_HappyPath(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + queueDecode(p, + resultCallPacket(), // releaseStream _result + resultCallPacket(), // FCPublish _result + createStreamRes(7), // createStream res + onStatusPacket("NetStream.Publish.Start"), // final publish onStatus + ) + + if err := c.publish(context.Background(), p, "stream"); err != nil { + t.Fatalf("publish: %+v", err) + } + if got := p.WritePacketCallCount(); got != 4 { + t.Fatalf("WritePacket called %d times, want 4", got) + } +} + +func TestRtmpClientToBackend_Publish_ReleaseStreamWriteError(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + p.WritePacketReturns(errors.New("boom")) + + err := c.publish(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "releaseStream") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Publish_FCPublishExpectError(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + // First ExpectPacket (releaseStream res) succeeds; the second (FCPublish res) + // must fail. We fail ReadMessage on its second call. + var reads atomic.Int32 + p.ReadMessageStub = func(ctx context.Context) (rtmp.Message, error) { + if reads.Add(1) >= 2 { + return nil, errors.New("read fail") + } + return rtmp.NewMessage(), nil + } + queueDecode(p, resultCallPacket()) // only the first decode is consumed + + err := c.publish(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "FCPublish") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Publish_CreateStreamSkipsZeroID(t *testing.T) { + // The createStream loop continues until StreamID != 0; verify it ignores + // the first packet (StreamID 0) and accepts the second (StreamID 9). + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + queueDecode(p, + resultCallPacket(), // releaseStream res + resultCallPacket(), // FCPublish res + createStreamRes(0), // ignored + createStreamRes(9), // accepted + onStatusPacket("NetStream.Publish.Start"), // final publish onStatus + ) + + if err := c.publish(context.Background(), p, "stream"); err != nil { + t.Fatalf("publish: %+v", err) + } +} + +func TestRtmpClientToBackend_Publish_SkipsNonOnStatus(t *testing.T) { + // publish() loops past onFCPublish (a CallPacket whose CommandName != onStatus) + // until it sees onStatus(NetStream.Publish.Start). + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + onFC := rtmp.NewCallPacket() + onFC.CommandName = "onFCPublish" + queueDecode(p, + resultCallPacket(), + resultCallPacket(), + createStreamRes(1), + onFC, // skipped: not onStatus + onStatusPacket("NetStream.Publish.Start"), + ) + + if err := c.publish(context.Background(), p, "stream"); err != nil { + t.Fatalf("publish: %+v", err) + } +} + +func TestRtmpClientToBackend_Publish_OnStatusArgsNotObject(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + bad := rtmp.NewCallPacket() + bad.CommandName = "onStatus" + bad.Args = rtmp.NewAmf0String("not-an-object") + queueDecode(p, + resultCallPacket(), + resultCallPacket(), + createStreamRes(1), + bad, + ) + + err := c.publish(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "args not object") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Publish_OnStatusMissingCode(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + bad := rtmp.NewCallPacket() + bad.CommandName = "onStatus" + bad.Args = rtmp.NewAmf0Object() // empty: no "code" + queueDecode(p, + resultCallPacket(), + resultCallPacket(), + createStreamRes(1), + bad, + ) + + err := c.publish(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "code not string") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Publish_OnStatusWrongCode(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypePublisher) + queueDecode(p, + resultCallPacket(), + resultCallPacket(), + createStreamRes(1), + onStatusPacket("NetStream.Publish.Failed"), + ) + + err := c.publish(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "NetStream.Publish.Start") { + t.Fatalf("unexpected error %v", err) + } +} + +// --------------------------------------------------------------------------- +// play() in isolation +// --------------------------------------------------------------------------- + +func TestRtmpClientToBackend_Play_HappyPath(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypeViewer) + queueDecode(p, + createStreamRes(3), + onStatusPacket("NetStream.Play.Start"), + ) + + if err := c.play(context.Background(), p, "stream"); err != nil { + t.Fatalf("play: %+v", err) + } + // One WritePacket each for: createStream and play. + if got := p.WritePacketCallCount(); got != 2 { + t.Fatalf("WritePacket called %d times, want 2", got) + } +} + +func TestRtmpClientToBackend_Play_CreateStreamWriteError(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypeViewer) + p.WritePacketReturns(errors.New("boom")) + + err := c.play(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "createStream") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Play_CreateStreamExpectError(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypeViewer) + p.ReadMessageReturns(nil, errors.New("read fail")) + + err := c.play(context.Background(), p, "stream") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "createStream res") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpClientToBackend_Play_CreateStreamSkipsZeroID(t *testing.T) { + c, p := newIsolatedBackend(t, RTMPClientTypeViewer) + queueDecode(p, + createStreamRes(0), // skipped + createStreamRes(5), + onStatusPacket("NetStream.Play.Start"), + ) + + if err := c.play(context.Background(), p, "stream"); err != nil { + t.Fatalf("play: %+v", err) + } +} + +func TestRtmpClientToBackend_Play_FiltersUntilPlayStart(t *testing.T) { + // play() ignores onStatus packets whose code is not NetStream.Play.Start + // (e.g. the proxy sees a NetStream.Play.Reset first). + c, p := newIsolatedBackend(t, RTMPClientTypeViewer) + queueDecode(p, + createStreamRes(1), + onStatusPacket("NetStream.Play.Reset"), // skipped + onStatusPacket("NetStream.Play.Start"), + ) + + if err := c.play(context.Background(), p, "stream"); err != nil { + t.Fatalf("play: %+v", err) + } +} + +// --------------------------------------------------------------------------- +// rtmpConnection: fakes, fixture, and packet builders +// --------------------------------------------------------------------------- + +// fakeNetConn is a net.Conn replacement for serve(), which takes net.Conn. +// Read/Write are no-ops because every protocol call is intercepted by the +// fake handshake/protocol; RemoteAddr/Close are called directly by serve. +type fakeNetConn struct { + closed atomic.Bool +} + +func (c *fakeNetConn) Read(p []byte) (int, error) { return 0, io.EOF } +func (c *fakeNetConn) Write(p []byte) (int, error) { return len(p), nil } +func (c *fakeNetConn) Close() error { c.closed.Store(true); return nil } +func (c *fakeNetConn) LocalAddr() net.Addr { return fakeAddr{} } +func (c *fakeNetConn) RemoteAddr() net.Addr { return fakeAddr{} } +func (c *fakeNetConn) SetDeadline(time.Time) error { return nil } +func (c *fakeNetConn) SetReadDeadline(time.Time) error { return nil } +func (c *fakeNetConn) SetWriteDeadline(time.Time) error { + return nil +} + +type fakeAddr struct{} + +func (fakeAddr) Network() string { return "fake" } +func (fakeAddr) String() string { return "fake-addr" } + +// connFixture bundles the fakes plus an rtmpConnection wired against them. +// Tests configure the fakes, then call rc.serve(ctx, conn). +// +// The injected newBackend always returns a "terminating" backend whose +// inner load balancer fails Pick. This drives serve() far enough to call +// newBackend (so we can assert clientType), but Connect then fails fast +// so the test does not need to drive the proxy goroutines. +type connFixture struct { + netConn *fakeNetConn + clientHs *rtmpfakes.FakeHandshake + clientProto *rtmpfakes.FakeProtocol + lb *lbfakes.FakeOriginLoadBalancer + + backendCalls atomic.Int32 + backendClientType atomic.Value // RTMPClientType + + rc *rtmpConnection +} + +func newConnFixture() *connFixture { + f := &connFixture{ + netConn: &fakeNetConn{}, + clientHs: &rtmpfakes.FakeHandshake{}, + clientProto: &rtmpfakes.FakeProtocol{}, + lb: &lbfakes.FakeOriginLoadBalancer{}, + } + // Default: protocol.ReadMessage returns a fresh empty Message so + // ExpectPacket can proceed to DecodeMessage. DecodeMessage must be + // queued per test via queueDecode. + readMessageOK(f.clientProto) + + f.rc = newRTMPConnection(func(c *rtmpConnection) { + c.loadBalancer = f.lb + c.newHandshake = func() rtmp.Handshake { return f.clientHs } + c.newProtocol = func(rw io.ReadWriter) rtmp.Protocol { return f.clientProto } + c.newBackend = func(clientType RTMPClientType) *rtmpClientToBackend { + f.backendCalls.Add(1) + f.backendClientType.Store(clientType) + // Terminating backend: inner LB.Pick fails, so backend.Connect + // returns an error wrapped by serve() as "connect backend". + terminateLb := &lbfakes.FakeOriginLoadBalancer{} + terminateLb.PickReturns(nil, errors.New("test terminate")) + return newRTMPClientToBackend(func(b *rtmpClientToBackend) { + b.typ = clientType + b.loadBalancer = terminateLb + }) + } + }) + return f +} + +// connectReqPacket builds a ConnectAppPacket with the given tcUrl. +func connectReqPacket(tcUrl string) *rtmp.ConnectAppPacket { + p := rtmp.NewConnectAppPacket() + p.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl)) + return p +} + +// publishReqPacket builds a PublishPacket with the given stream name. +func publishReqPacket(streamName string) *rtmp.PublishPacket { + p := rtmp.NewPublishPacket() + p.StreamName = rtmp.NewAmf0String(streamName) + return p +} + +// playReqPacket builds a PlayPacket with the given stream name. +func playReqPacket(streamName string) *rtmp.PlayPacket { + p := rtmp.NewPlayPacket() + p.StreamName = rtmp.NewAmf0String(streamName) + return p +} + +// rtmp.CallPacket's CommandName field uses an unexported amf0String, which +// only accepts untyped string literals. The three identify-loop branches +// each get their own helper. +func createStreamCallPacket() *rtmp.CallPacket { + p := rtmp.NewCallPacket() + p.CommandName = "createStream" + return p +} + +func releaseStreamCallPacket() *rtmp.CallPacket { + p := rtmp.NewCallPacket() + p.CommandName = "releaseStream" + return p +} + +func getStreamLengthCallPacket() *rtmp.CallPacket { + p := rtmp.NewCallPacket() + p.CommandName = "getStreamLength" + return p +} + +// --------------------------------------------------------------------------- +// rtmpConnection: constructor & defaults +// --------------------------------------------------------------------------- + +func TestRtmpConnection_NewSetsDefaults(t *testing.T) { + c := newRTMPConnection() + if c.newHandshake == nil { + t.Fatal("newHandshake should default to a non-nil factory") + } + if c.newProtocol == nil { + t.Fatal("newProtocol should default to a non-nil factory") + } + if c.newBackend == nil { + t.Fatal("newBackend should default to a non-nil factory") + } + // Defaults are real factories — call them to confirm they return + // non-nil concrete values. + if hs := c.newHandshake(); hs == nil { + t.Fatal("default newHandshake returned nil") + } + if p := c.newProtocol(&fakeConn{}); p == nil { + t.Fatal("default newProtocol returned nil") + } +} + +func TestRtmpConnection_DefaultNewBackendWiresFields(t *testing.T) { + lbInst := &lbfakes.FakeOriginLoadBalancer{} + c := newRTMPConnection(func(c *rtmpConnection) { + c.loadBalancer = lbInst + }) + + pub := c.newBackend(RTMPClientTypePublisher) + if pub.typ != RTMPClientTypePublisher { + t.Fatalf("publisher backend typ=%v, want %v", pub.typ, RTMPClientTypePublisher) + } + if pub.loadBalancer != lbInst { + t.Fatal("publisher backend should reuse the connection's load balancer") + } + + view := c.newBackend(RTMPClientTypeViewer) + if view.typ != RTMPClientTypeViewer { + t.Fatalf("viewer backend typ=%v, want %v", view.typ, RTMPClientTypeViewer) + } + if view.loadBalancer != lbInst { + t.Fatal("viewer backend should reuse the connection's load balancer") + } +} + +func TestRtmpConnection_OptionOverridesNewBackend(t *testing.T) { + var called atomic.Int32 + override := func(clientType RTMPClientType) *rtmpClientToBackend { + called.Add(1) + return newRTMPClientToBackend() + } + c := newRTMPConnection(func(c *rtmpConnection) { c.newBackend = override }) + + _ = c.newBackend(RTMPClientTypePublisher) + if got := called.Load(); got != 1 { + t.Fatalf("override newBackend called %d times, want 1", got) + } +} + +// --------------------------------------------------------------------------- +// rtmpConnection.serve: handshake error paths +// --------------------------------------------------------------------------- + +func TestRtmpConnection_Serve_HandshakeReadC0Error(t *testing.T) { + f := newConnFixture() + f.clientHs.ReadC0S0Returns(nil, errors.New("boom")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "read c0") { + t.Fatalf("unexpected error %v", err) + } + if f.backendCalls.Load() != 0 { + t.Fatal("newBackend should not be called on handshake failure") + } +} + +func TestRtmpConnection_Serve_HandshakeWriteC0Error(t *testing.T) { + f := newConnFixture() + f.clientHs.WriteC0S0Returns(errors.New("boom")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + // The write-c0 branch is wrapped as "write s1" in serve() (typo in + // production, but the test pins the current behavior). + if !strings.Contains(err.Error(), "write s1") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpConnection_Serve_HandshakeReadC2Error(t *testing.T) { + f := newConnFixture() + f.clientHs.ReadC2S2Returns(nil, errors.New("boom")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "read c2") { + t.Fatalf("unexpected error %v", err) + } +} + +// --------------------------------------------------------------------------- +// rtmpConnection.serve: protocol error paths +// --------------------------------------------------------------------------- + +func TestRtmpConnection_Serve_ExpectConnectReqError(t *testing.T) { + f := newConnFixture() + // Fail ReadMessage so ExpectPacket returns immediately. + f.clientProto.ReadMessageStub = nil + f.clientProto.ReadMessageReturns(nil, errors.New("read fail")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "expect connect req") { + t.Fatalf("unexpected error %v", err) + } + if f.backendCalls.Load() != 0 { + t.Fatal("newBackend should not be called when connect req fails") + } +} + +func TestRtmpConnection_Serve_WriteAckSizeError(t *testing.T) { + f := newConnFixture() + queueDecode(f.clientProto, connectReqPacket("rtmp://1.2.3.4/live")) + // First WritePacket is the WindowAcknowledgementSize. + f.clientProto.WritePacketReturnsOnCall(0, errors.New("boom")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "write set ack size") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpConnection_Serve_WriteConnectResError(t *testing.T) { + f := newConnFixture() + queueDecode(f.clientProto, connectReqPacket("rtmp://1.2.3.4/live")) + // Third WritePacket is the ConnectAppResPacket; ack and chunk-size precede it. + f.clientProto.WritePacketReturnsOnCall(2, errors.New("boom")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "write connect res") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRtmpConnection_Serve_ExpectIdentifyReqError(t *testing.T) { + f := newConnFixture() + // Connect req decodes fine, then the next ReadMessage fails. + queueDecode(f.clientProto, connectReqPacket("rtmp://1.2.3.4/live")) + var reads atomic.Int32 + f.clientProto.ReadMessageStub = func(ctx context.Context) (rtmp.Message, error) { + if reads.Add(1) >= 2 { + return nil, errors.New("read fail") + } + return rtmp.NewMessage(), nil + } + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "expect identify req") { + t.Fatalf("unexpected error %v", err) + } +} + +// --------------------------------------------------------------------------- +// rtmpConnection.serve: identify-loop branches +// --------------------------------------------------------------------------- + +func TestRtmpConnection_Serve_IdentifyCreateStreamThenPublisher(t *testing.T) { + f := newConnFixture() + queueDecode(f.clientProto, + connectReqPacket("rtmp://1.2.3.4/live"), + createStreamCallPacket(), + publishReqPacket("stream"), + ) + + err := f.rc.serve(context.Background(), f.netConn) + // Reaches backend.Connect, which fails via the terminating LB. + if err == nil || !strings.Contains(err.Error(), "connect backend") { + t.Fatalf("expected connect backend error, got %v", err) + } + // WritePacket calls: ack, chunk, connectRes, createStreamRes, onFCPublish. + if got := f.clientProto.WritePacketCallCount(); got != 5 { + t.Fatalf("WritePacket called %d times, want 5", got) + } + if v := f.backendClientType.Load(); v != RTMPClientTypePublisher { + t.Fatalf("backend clientType=%v, want publisher", v) + } +} + +func TestRtmpConnection_Serve_IdentifyDefaultCallThenViewer(t *testing.T) { + // A generic CallPacket (e.g. releaseStream) must be acknowledged with + // a _result reply before the identify loop sees the Play packet. + f := newConnFixture() + queueDecode(f.clientProto, + connectReqPacket("rtmp://1.2.3.4/live"), + releaseStreamCallPacket(), + playReqPacket("stream"), + ) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil || !strings.Contains(err.Error(), "connect backend") { + t.Fatalf("expected connect backend error, got %v", err) + } + // WritePacket calls: ack, chunk, connectRes, _result, onStatus(play). + if got := f.clientProto.WritePacketCallCount(); got != 5 { + t.Fatalf("WritePacket called %d times, want 5", got) + } + if v := f.backendClientType.Load(); v != RTMPClientTypeViewer { + t.Fatalf("backend clientType=%v, want viewer", v) + } +} + +func TestRtmpConnection_Serve_IdentifyGetStreamLengthSkipsResponse(t *testing.T) { + // getStreamLength is ignored — no response written, no error, the loop + // just reads the next packet (the Publish). + f := newConnFixture() + queueDecode(f.clientProto, + connectReqPacket("rtmp://1.2.3.4/live"), + getStreamLengthCallPacket(), + publishReqPacket("stream"), + ) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil || !strings.Contains(err.Error(), "connect backend") { + t.Fatalf("expected connect backend error, got %v", err) + } + // WritePacket calls: ack, chunk, connectRes, onFCPublish. + // getStreamLength contributes nothing because the switch falls through. + if got := f.clientProto.WritePacketCallCount(); got != 4 { + t.Fatalf("WritePacket called %d times, want 4", got) + } +} + +func TestRtmpConnection_Serve_IdentifyResponseWriteError(t *testing.T) { + f := newConnFixture() + queueDecode(f.clientProto, + connectReqPacket("rtmp://1.2.3.4/live"), + createStreamCallPacket(), + ) + // First three WritePacket calls (ack, chunk, connectRes) succeed; + // the fourth (createStream response) fails. + f.clientProto.WritePacketReturnsOnCall(3, errors.New("boom")) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "write identify res") { + t.Fatalf("unexpected error %v", err) + } +} + +// --------------------------------------------------------------------------- +// rtmpConnection.serve: newBackend invocation contract +// --------------------------------------------------------------------------- + +func TestRtmpConnection_Serve_PublisherInvokesNewBackend(t *testing.T) { + // A direct Publish (no createStream beforehand) still drives the + // identify loop to set clientType=Publisher and call newBackend. + f := newConnFixture() + queueDecode(f.clientProto, + connectReqPacket("rtmp://1.2.3.4/live"), + publishReqPacket("stream"), + ) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil || !strings.Contains(err.Error(), "connect backend") { + t.Fatalf("expected connect backend error, got %v", err) + } + if got := f.backendCalls.Load(); got != 1 { + t.Fatalf("newBackend called %d times, want 1", got) + } + if v := f.backendClientType.Load(); v != RTMPClientTypePublisher { + t.Fatalf("backend clientType=%v, want publisher", v) + } +} + +func TestRtmpConnection_Serve_ViewerInvokesNewBackend(t *testing.T) { + f := newConnFixture() + queueDecode(f.clientProto, + connectReqPacket("rtmp://1.2.3.4/live"), + playReqPacket("stream"), + ) + + err := f.rc.serve(context.Background(), f.netConn) + if err == nil || !strings.Contains(err.Error(), "connect backend") { + t.Fatalf("expected connect backend error, got %v", err) + } + if got := f.backendCalls.Load(); got != 1 { + t.Fatalf("newBackend called %d times, want 1", got) + } + if v := f.backendClientType.Load(); v != RTMPClientTypeViewer { + t.Fatalf("backend clientType=%v, want viewer", v) + } +} + +// --------------------------------------------------------------------------- +// rtmpProxyServer: fakes and fixture +// --------------------------------------------------------------------------- + +// fakeListener is a net.Listener whose Accept returns connections pushed via +// push() and unblocks Accept with a "use of closed network connection" error +// on Close. The error message satisfies utils.IsClosedNetworkError so the +// accept loop exits via the graceful branch. +type fakeListener struct { + conns chan net.Conn + closed atomic.Bool +} + +func newFakeListener() *fakeListener { + return &fakeListener{conns: make(chan net.Conn, 4)} +} + +func (l *fakeListener) push(c net.Conn) { l.conns <- c } + +func (l *fakeListener) Accept() (net.Conn, error) { + c, ok := <-l.conns + if !ok { + return nil, errors.New("use of closed network connection") + } + return c, nil +} + +func (l *fakeListener) Close() error { + if l.closed.CompareAndSwap(false, true) { + close(l.conns) + } + return nil +} + +func (l *fakeListener) Addr() net.Addr { return fakeAddr{} } + +// proxyFixture bundles the fakes plus an rtmpProxyServer wired against them. +// The injected newConnection returns a connection whose handshake fails on +// ReadC0S0, so serve() returns fast without needing to drive the full RTMP +// protocol. Tests can assert how many connections were dispatched via +// newConnCalls and how the listen() option was invoked via listenCalls/ +// listenAddr. +type proxyFixture struct { + env *envfakes.FakeProxyEnvironment + lb *lbfakes.FakeOriginLoadBalancer + listener *fakeListener + listenCalls atomic.Int32 + listenAddr atomic.Value // string + newConnCalls atomic.Int32 + serveDone chan struct{} + server *rtmpProxyServer +} + +func newProxyFixture() *proxyFixture { + f := &proxyFixture{ + env: &envfakes.FakeProxyEnvironment{}, + lb: &lbfakes.FakeOriginLoadBalancer{}, + listener: newFakeListener(), + serveDone: make(chan struct{}, 16), + } + f.env.RtmpServerReturns("1935") + + srv := NewRTMPProxyServer(f.env, f.lb, func(v *rtmpProxyServer) { + v.listen = func(ctx context.Context, addr string) (net.Listener, error) { + f.listenCalls.Add(1) + f.listenAddr.Store(addr) + return f.listener, nil + } + v.newConnection = func() *rtmpConnection { + f.newConnCalls.Add(1) + hs := &rtmpfakes.FakeHandshake{} + hs.ReadC0S0Returns(nil, errors.New("test terminate")) + return newRTMPConnection(func(c *rtmpConnection) { + // Signal when serve() actually enters the handshake step so + // tests can sync on "per-conn goroutine has started". + c.newHandshake = func() rtmp.Handshake { + f.serveDone <- struct{}{} + return hs + } + }) + } + }) + f.server = srv.(*rtmpProxyServer) + return f +} + +// --------------------------------------------------------------------------- +// rtmpProxyServer: constructor & defaults +// --------------------------------------------------------------------------- + +func TestRTMPProxyServer_NewSetsDefaults(t *testing.T) { + srv := NewRTMPProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + v := srv.(*rtmpProxyServer) + if v.listen == nil { + t.Fatal("listen should default to a non-nil factory") + } + if v.newConnection == nil { + t.Fatal("newConnection should default to a non-nil factory") + } + // Default newConnection returns a wired-up rtmpConnection that reuses + // the server's load balancer. + rc := v.newConnection() + if rc == nil { + t.Fatal("default newConnection returned nil") + } + if rc.loadBalancer != v.loadBalancer { + t.Fatal("default newConnection should propagate the server's loadBalancer") + } +} + +func TestRTMPProxyServer_NewAppliesOptions(t *testing.T) { + var listenCalls, newConnCalls atomic.Int32 + srv := NewRTMPProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + func(v *rtmpProxyServer) { + v.listen = func(ctx context.Context, addr string) (net.Listener, error) { + listenCalls.Add(1) + return nil, errors.New("unused") + } + v.newConnection = func() *rtmpConnection { + newConnCalls.Add(1) + return newRTMPConnection() + } + }, + ) + v := srv.(*rtmpProxyServer) + _, _ = v.listen(context.Background(), ":0") + _ = v.newConnection() + if got := listenCalls.Load(); got != 1 { + t.Fatalf("custom listen called %d times, want 1", got) + } + if got := newConnCalls.Load(); got != 1 { + t.Fatalf("custom newConnection called %d times, want 1", got) + } +} + +// --------------------------------------------------------------------------- +// rtmpProxyServer.Close +// --------------------------------------------------------------------------- + +func TestRTMPProxyServer_Close_NoListener(t *testing.T) { + // Close before Run must not panic, must not hang, and must not error. + srv := NewRTMPProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + done := make(chan error, 1) + go func() { done <- srv.Close() }() + select { + case err := <-done: + if err != nil { + t.Fatalf("Close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Close hung with no listener") + } +} + +// --------------------------------------------------------------------------- +// rtmpProxyServer.Run: listen and endpoint normalization +// --------------------------------------------------------------------------- + +func TestRTMPProxyServer_Run_ListenError(t *testing.T) { + envFake := &envfakes.FakeProxyEnvironment{} + envFake.RtmpServerReturns("1935") + srv := NewRTMPProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *rtmpProxyServer) { + v.listen = func(ctx context.Context, addr string) (net.Listener, error) { + return nil, errors.New("permission denied") + } + }) + + err := srv.Run(context.Background()) + if err == nil { + t.Fatal("expected error from Run when listen fails") + } + if !strings.Contains(err.Error(), "listen rtmp addr") { + t.Fatalf("unexpected error %v", err) + } +} + +func TestRTMPProxyServer_Run_EndpointWithoutColon(t *testing.T) { + // A bare port like "1935" must be normalized to ":1935" before reaching listen(). + f := newProxyFixture() + f.env.RtmpServerReturns("1935") + + if err := f.server.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + defer f.server.Close() + + if got := f.listenAddr.Load(); got != ":1935" { + t.Fatalf("listen addr=%v, want :1935", got) + } +} + +func TestRTMPProxyServer_Run_EndpointWithColon(t *testing.T) { + // An endpoint that already contains ":" must be passed through unchanged. + f := newProxyFixture() + f.env.RtmpServerReturns("127.0.0.1:1935") + + if err := f.server.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + defer f.server.Close() + + if got := f.listenAddr.Load(); got != "127.0.0.1:1935" { + t.Fatalf("listen addr=%v, want 127.0.0.1:1935", got) + } +} + +// --------------------------------------------------------------------------- +// rtmpProxyServer.Run: accept loop +// --------------------------------------------------------------------------- + +func TestRTMPProxyServer_Run_AcceptInvokesNewConnection(t *testing.T) { + f := newProxyFixture() + if err := f.server.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + + conn := &fakeNetConn{} + f.listener.push(conn) + + // Wait for the per-conn goroutine to start (newHandshake is observed). + select { + case <-f.serveDone: + case <-time.After(2 * time.Second): + t.Fatal("newConnection was not invoked for accepted conn") + } + + if got := f.newConnCalls.Load(); got != 1 { + t.Fatalf("newConnection called %d times, want 1", got) + } + + // Close shuts the listener and drains the accept goroutine cleanly. + if err := f.server.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + // The accepted conn should have been closed by the per-conn goroutine's + // defer once serve() returned. + if !conn.closed.Load() { + t.Fatal("accepted conn was not closed after serve returned") + } +} + +func TestRTMPProxyServer_Run_CloseShutsDownAcceptLoop(t *testing.T) { + // Start Run with an idle listener (no queued conns). Accept blocks. Close + // must unblock it and let Run/Close return cleanly via the closed-network + // branch in the accept loop. + f := newProxyFixture() + if err := f.server.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + + done := make(chan error, 1) + go func() { done <- f.server.Close() }() + select { + case err := <-done: + if err != nil { + t.Fatalf("Close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Close hung — accept loop did not exit on listener close") + } + + if f.newConnCalls.Load() != 0 { + t.Fatal("newConnection should not be called when no conn was accepted") + } +} diff --git a/internal/server/srt.go b/internal/proxy/srt.go similarity index 83% rename from internal/server/srt.go rename to internal/proxy/srt.go index 0da23f51f..2ecb97696 100644 --- a/internal/server/srt.go +++ b/internal/proxy/srt.go @@ -1,14 +1,16 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package server +package proxy import ( "bytes" "context" "encoding/binary" "fmt" + "io" "net" + "strconv" "strings" stdSync "sync" "time" @@ -21,14 +23,17 @@ import ( "srsx/internal/utils" ) -// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to +// srsSRTProxyServer is the proxy for SRS server via SRT. It will figure out which backend server to // proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the // backend server. -type srsSRTServer struct { +type srsSRTProxyServer struct { // The environment interface. environment env.ProxyEnvironment - // The UDP listener for SRT server. - listener *net.UDPConn + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer + // The UDP listener for SRT server. Stored as net.PacketConn so tests + // can inject a fake listener via listenUDP. + listener net.PacketConn // The SRT connections, identify by the socket ID. sockets sync.Map[uint32, *SRTConnection] @@ -37,13 +42,28 @@ type srsSRTServer struct { // The wait group for server. wg stdSync.WaitGroup + + // listenUDP opens the UDP listener for the SRT server. Defaults to a real + // net.ListenUDP on the resolved endpoint; tests may override via a functional + // option to supply a fake listener. + listenUDP func(ctx context.Context, endpoint string) (net.PacketConn, error) } -func NewSRSSRTServer(environment env.ProxyEnvironment, opts ...func(*srsSRTServer)) *srsSRTServer { - v := &srsSRTServer{ - environment: environment, - start: time.Now(), - sockets: sync.NewMap[uint32, *SRTConnection](), +func NewSRSSRTProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*srsSRTProxyServer)) *srsSRTProxyServer { + v := &srsSRTProxyServer{ + environment: environment, + loadBalancer: loadBalancer, + start: time.Now(), + sockets: sync.NewMap[uint32, *SRTConnection](), + } + + // Default listenUDP: resolve the endpoint and open a real UDP socket. + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return nil, errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + return net.ListenUDP("udp", saddr) } for _, opt := range opts { @@ -52,33 +72,28 @@ func NewSRSSRTServer(environment env.ProxyEnvironment, opts ...func(*srsSRTServe return v } -func (v *srsSRTServer) Close() error { +func (v *srsSRTProxyServer) Close() error { if v.listener != nil { - v.listener.Close() + _ = v.listener.Close() } v.wg.Wait() return nil } -func (v *srsSRTServer) Run(ctx context.Context) error { +func (v *srsSRTProxyServer) Run(ctx context.Context) error { // Parse address to listen. endpoint := v.environment.SRTServer() if !strings.Contains(endpoint, ":") { endpoint = ":" + endpoint } - saddr, err := net.ResolveUDPAddr("udp", endpoint) + listener, err := v.listenUDP(ctx, endpoint) if err != nil { - return errors.Wrapf(err, "resolve udp addr %v", endpoint) - } - - listener, err := net.ListenUDP("udp", saddr) - if err != nil { - return errors.Wrapf(err, "listen udp %v", saddr) + return errors.Wrapf(err, "listen udp %v", endpoint) } v.listener = listener - logger.Debug(ctx, "SRT server listen at %v", saddr) + logger.Debug(ctx, "SRT server listen at %v", listener.LocalAddr()) // Consume all messages from UDP media transport. v.wg.Add(1) @@ -87,7 +102,7 @@ func (v *srsSRTServer) Run(ctx context.Context) error { for ctx.Err() == nil { buf := make([]byte, 4096) - n, caddr, err := v.listener.ReadFromUDP(buf) + n, caddr, err := v.listener.ReadFrom(buf) if err != nil { // If context is canceled or connection is closed, exit gracefully without logging error. if ctx.Err() != nil || utils.IsClosedNetworkError(err) { @@ -109,7 +124,7 @@ func (v *srsSRTServer) Run(ctx context.Context) error { return nil } -func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { +func (v *srsSRTProxyServer) handleClientUDP(ctx context.Context, addr net.Addr, data []byte) error { socketID := utils.SrtParseSocketID(data) var pkt *SRTHandshakePacket @@ -127,6 +142,7 @@ func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, d conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { c.ctx = logger.WithContext(ctx) c.listenerUDP, c.socketID = v.listener, socketID + c.loadBalancer = v.loadBalancer c.start = v.start })) @@ -158,14 +174,18 @@ func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, d type SRTConnection struct { // The stream context for SRT connection. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The current socket ID. socketID uint32 - // The UDP connection proxy to backend. - backendUDP *net.UDPConn - // The listener UDP connection, used to send messages to client. - listenerUDP *net.UDPConn + // The UDP connection proxy to backend. Stored as io.ReadWriteCloser so tests + // can inject a fake connection by overriding dialBackendUDP. + backendUDP io.ReadWriteCloser + // The listener UDP connection, used to send messages to client. Stored as + // net.PacketConn so tests can inject a fake listener. + listenerUDP net.PacketConn // Listener start time. start time.Time @@ -175,17 +195,29 @@ type SRTConnection struct { handshake1 *SRTHandshakePacket handshake2 *SRTHandshakePacket handshake3 *SRTHandshakePacket + + // dialBackendUDP opens a UDP connection to a backend SRS server. Defaults to a real + // UDP dial; tests may override via a functional option to supply a fake connection. + dialBackendUDP func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) } func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { v := &SRTConnection{} + + // Default dial: a real UDP connection to the backend. Uses Dialer.DialContext + // so ctx cancellation/deadline aborts DNS resolution (UDP itself has no handshake). + v.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + var d net.Dialer + return d.DialContext(ctx, "udp", net.JoinHostPort(ip, strconv.Itoa(port))) + } + for _, opt := range opts { opt(v) } return v } -func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { +func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr net.Addr, data []byte) (uint32, error) { ctx := v.ctx // If not handshake, try to proxy to backend directly. @@ -208,7 +240,7 @@ func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, return v.socketID, nil } -func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { +func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr net.Addr, data []byte) error { // Handle handshake 0 and 1 messages. if pkt.SynCookie == 0 { // Save handshake 0 packet. @@ -238,7 +270,7 @@ func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePa if b, err := v.handshake1.MarshalBinary(); err != nil { return errors.Wrapf(err, "marshal handshake 1") - } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + } else if _, err = v.listenerUDP.WriteTo(b, addr); err != nil { return errors.Wrapf(err, "write handshake 1") } @@ -303,15 +335,17 @@ func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePa } logger.Debug(ctx, "Proxy got handshake 3: %v", handshake3p) - // Response handshake 3 to client. - v.handshake3 = &*handshake3p + // Response handshake 3 to client. Copy so rewriting the cookie below does + // not mutate the struct just decoded from the backend. + handshake3c := *handshake3p + v.handshake3 = &handshake3c v.handshake3.SynCookie = v.handshake1.SynCookie v.socketID = handshake3p.SRTSocketID logger.Debug(ctx, "Handshake 3: %v", v.handshake3) if b, err := v.handshake3.MarshalBinary(); err != nil { return errors.Wrapf(err, "marshal handshake 3") - } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + } else if _, err = v.listenerUDP.WriteTo(b, addr); err != nil { return errors.Wrapf(err, "write handshake 3") } @@ -325,7 +359,7 @@ func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePa logger.Warn(ctx, "read from backend failed, err=%v", err) return } - if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { + if _, err = v.listenerUDP.WriteTo(b[:nn], addr); err != nil { // TODO: If backend server closed unexpectedly, we should notice the stream to quit. logger.Warn(ctx, "write to client failed, err=%v", err) return @@ -356,7 +390,7 @@ func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) err } // Pick a backend SRS server to proxy the SRT stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -373,12 +407,11 @@ func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) err // Connect to backend SRS server via UDP client. // TODO: FIXME: Support close the connection when timeout or client disconnected. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} - if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { - return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) - } else { - v.backendUDP = backendUDP + backendUDP, err := v.dialBackendUDP(ctx, backend.IP, int(udpPort)) + if err != nil { + return errors.Wrapf(err, "dial udp to %v:%v of %v for %v", backend.IP, udpPort, backend, streamURL) } + v.backendUDP = backendUDP return nil } diff --git a/internal/proxy/srt_test.go b/internal/proxy/srt_test.go new file mode 100644 index 000000000..8c4d42f84 --- /dev/null +++ b/internal/proxy/srt_test.go @@ -0,0 +1,987 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" + "srsx/internal/logger" +) + +// encodeSRTStreamIDExt builds an SRT extension block carrying the given stream +// id as extension type 0x05. The wire format places the type and length (in +// 4-byte words) as big-endian uint16s, followed by the payload with each +// 4-byte word stored in little-endian byte order — the inverse of what +// SRTHandshakePacket.StreamID does on read. +func encodeSRTStreamIDExt(sid string) []byte { + padded := []byte(sid) + if rem := len(padded) % 4; rem != 0 { + padded = append(padded, make([]byte, 4-rem)...) + } + + swapped := make([]byte, len(padded)) + for i := 0; i < len(padded); i += 4 { + swapped[i+0] = padded[i+3] + swapped[i+1] = padded[i+2] + swapped[i+2] = padded[i+1] + swapped[i+3] = padded[i+0] + } + + hdr := make([]byte, 4) + binary.BigEndian.PutUint16(hdr[0:], 0x05) + binary.BigEndian.PutUint16(hdr[2:], uint16(len(padded)/4)) + return append(hdr, swapped...) +} + +func TestSRTHandshakePacket_FlagPredicates(t *testing.T) { + cases := []struct { + name string + flag uint8 + ctype uint16 + stype uint16 + isData bool + isControl bool + isHandshake bool + }{ + {"data-packet", 0x00, 0, 0, true, false, false}, + {"handshake", 0x80, 0, 0, false, true, true}, + {"control-not-handshake-by-ctype", 0x80, 1, 0, false, true, false}, + {"control-not-handshake-by-stype", 0x80, 0, 1, false, true, false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := &SRTHandshakePacket{ControlFlag: c.flag, ControlType: c.ctype, SubType: c.stype} + if got := p.IsData(); got != c.isData { + t.Fatalf("IsData=%v, want %v", got, c.isData) + } + if got := p.IsControl(); got != c.isControl { + t.Fatalf("IsControl=%v, want %v", got, c.isControl) + } + if got := p.IsHandshake(); got != c.isHandshake { + t.Fatalf("IsHandshake=%v, want %v", got, c.isHandshake) + } + }) + } +} + +func TestSRTHandshakePacket_String_ContainsKeyFields(t *testing.T) { + p := &SRTHandshakePacket{ + ControlFlag: 0x80, + SocketID: 0xdeadbeef, + SRTSocketID: 0xcafebabe, + PeerIP: net.ParseIP("1.2.3.4"), + ExtraData: []byte{0, 1, 2, 3, 4}, + } + s := p.String() + for _, want := range []string{"Control=true", "SocketID=3735928559", "SRTSocketID=3405691582", "Peer=16B", "Extra=5B"} { + if !strings.Contains(s, want) { + t.Fatalf("String()=%q missing %q", s, want) + } + } +} + +func TestSRTHandshakePacket_UnmarshalBinary_ShortBuffers(t *testing.T) { + if err := (&SRTHandshakePacket{}).UnmarshalBinary([]byte{0x80}); err == nil { + t.Fatal("expected error for <4 byte buffer") + } + if err := (&SRTHandshakePacket{}).UnmarshalBinary(make([]byte, 32)); err == nil { + t.Fatal("expected error for <64 byte buffer") + } +} + +func TestSRTHandshakePacket_UnmarshalBinary_ParsesControlBits(t *testing.T) { + b := make([]byte, 64) + // First 16 bits: top bit = control flag (0x80), bottom 15 bits = ControlType (0x1234). + binary.BigEndian.PutUint16(b[0:], 0x8000|0x1234) + binary.BigEndian.PutUint16(b[2:], 0x5678) // SubType. + + p := &SRTHandshakePacket{} + if err := p.UnmarshalBinary(b); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if p.ControlFlag != 0x80 { + t.Fatalf("ControlFlag=0x%02x, want 0x80", p.ControlFlag) + } + if p.ControlType != 0x1234 { + t.Fatalf("ControlType=0x%04x, want 0x1234", p.ControlType) + } + if p.SubType != 0x5678 { + t.Fatalf("SubType=0x%04x, want 0x5678", p.SubType) + } +} + +func TestSRTHandshakePacket_UnmarshalBinary_PeerIPByteReversed(t *testing.T) { + b := make([]byte, 64) + // Wire bytes 48..51 are stored in reverse order; the parser flips them back + // to produce IPv4(b[51], b[50], b[49], b[48]). + b[48] = 4 + b[49] = 3 + b[50] = 2 + b[51] = 1 + + p := &SRTHandshakePacket{} + if err := p.UnmarshalBinary(b); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if want := net.ParseIP("1.2.3.4"); !p.PeerIP.Equal(want) { + t.Fatalf("PeerIP=%v, want %v", p.PeerIP, want) + } +} + +func TestSRTHandshakePacket_MarshalBinary_Layout(t *testing.T) { + p := &SRTHandshakePacket{ + ControlFlag: 0x80, + ControlType: 0x1234, + SubType: 0x5678, + AdditionalInfo: 0x11111111, + Timestamp: 0x22222222, + SocketID: 0x33333333, + Version: 5, + EncryptionField: 2, + ExtensionField: 0x4A17, + InitSequence: 0x44444444, + MTU: 1500, + FlowWindow: 8192, + HandshakeType: 1, + SRTSocketID: 0x55555555, + SynCookie: 0x66666666, + PeerIP: net.ParseIP("10.20.30.40"), + ExtraData: []byte{0xaa, 0xbb}, + } + + b, err := p.MarshalBinary() + if err != nil { + t.Fatalf("marshal: %v", err) + } + if got, want := len(b), 64+len(p.ExtraData); got != want { + t.Fatalf("len=%d, want %d", got, want) + } + if got := binary.BigEndian.Uint16(b[0:]); got != 0x8000|0x1234 { + t.Fatalf("word0=0x%04x, want 0x9234", got) + } + if got := binary.BigEndian.Uint16(b[2:]); got != 0x5678 { + t.Fatalf("SubType=0x%04x, want 0x5678", got) + } + // PeerIP is laid out in reversed octet order on the wire. + if b[48] != 40 || b[49] != 30 || b[50] != 20 || b[51] != 10 { + t.Fatalf("PeerIP bytes=[%d %d %d %d], want [40 30 20 10]", b[48], b[49], b[50], b[51]) + } + if b[64] != 0xaa || b[65] != 0xbb { + t.Fatalf("ExtraData not copied at offset 64") + } +} + +func TestSRTHandshakePacket_Roundtrip(t *testing.T) { + orig := &SRTHandshakePacket{ + ControlFlag: 0x80, + ControlType: 0x0001, + SubType: 0x0002, + AdditionalInfo: 0xa1a1a1a1, + Timestamp: 0xb2b2b2b2, + SocketID: 0xc3c3c3c3, + Version: 5, + EncryptionField: 0, + ExtensionField: 0x4A17, + InitSequence: 0xd4d4d4d4, + MTU: 1500, + FlowWindow: 8192, + HandshakeType: 1, + SRTSocketID: 0xe5e5e5e5, + SynCookie: 0xf6f6f6f6, + PeerIP: net.ParseIP("192.168.1.42"), + ExtraData: encodeSRTStreamIDExt("#!::r=live/stream"), + } + + b, err := orig.MarshalBinary() + if err != nil { + t.Fatalf("marshal: %v", err) + } + + got := &SRTHandshakePacket{} + if err := got.UnmarshalBinary(b); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.ControlFlag != orig.ControlFlag || + got.ControlType != orig.ControlType || + got.SubType != orig.SubType || + got.AdditionalInfo != orig.AdditionalInfo || + got.Timestamp != orig.Timestamp || + got.SocketID != orig.SocketID || + got.Version != orig.Version || + got.EncryptionField != orig.EncryptionField || + got.ExtensionField != orig.ExtensionField || + got.InitSequence != orig.InitSequence || + got.MTU != orig.MTU || + got.FlowWindow != orig.FlowWindow || + got.HandshakeType != orig.HandshakeType || + got.SRTSocketID != orig.SRTSocketID || + got.SynCookie != orig.SynCookie { + t.Fatalf("scalar field mismatch\n got=%+v\nwant=%+v", got, orig) + } + if !got.PeerIP.Equal(orig.PeerIP) { + t.Fatalf("PeerIP=%v, want %v", got.PeerIP, orig.PeerIP) + } + if sid, err := got.StreamID(); err != nil { + t.Fatalf("StreamID: %v", err) + } else if sid != "#!::r=live/stream" { + t.Fatalf("StreamID=%q, want %q", sid, "#!::r=live/stream") + } +} + +func TestSRTHandshakePacket_StreamID(t *testing.T) { + t.Run("single-extension-padded", func(t *testing.T) { + p := &SRTHandshakePacket{ExtraData: encodeSRTStreamIDExt("abc")} + sid, err := p.StreamID() + if err != nil { + t.Fatalf("StreamID: %v", err) + } + if sid != "abc" { + t.Fatalf("StreamID=%q, want %q", sid, "abc") + } + }) + + t.Run("multi-word-payload", func(t *testing.T) { + p := &SRTHandshakePacket{ExtraData: encodeSRTStreamIDExt("abcdefgh")} + sid, err := p.StreamID() + if err != nil { + t.Fatalf("StreamID: %v", err) + } + if sid != "abcdefgh" { + t.Fatalf("StreamID=%q, want %q", sid, "abcdefgh") + } + }) + + t.Run("skip-other-extensions", func(t *testing.T) { + // First a non-0x05 extension of size 1 word, then the real stream id. + other := []byte{0x00, 0x01, 0x00, 0x01, 0xde, 0xad, 0xbe, 0xef} + p := &SRTHandshakePacket{ExtraData: append(other, encodeSRTStreamIDExt("live/stream")...)} + sid, err := p.StreamID() + if err != nil { + t.Fatalf("StreamID: %v", err) + } + if sid != "live/stream" { + t.Fatalf("StreamID=%q, want %q", sid, "live/stream") + } + }) + + t.Run("trims-trailing-nuls", func(t *testing.T) { + // "ab" → padded to "ab\x00\x00", wire-swapped to {0,0,'b','a'}, then + // parsed back to "ab\x00\x00" and trimmed to "ab". + p := &SRTHandshakePacket{ExtraData: encodeSRTStreamIDExt("ab")} + sid, err := p.StreamID() + if err != nil { + t.Fatalf("StreamID: %v", err) + } + if sid != "ab" { + t.Fatalf("StreamID=%q, want %q", sid, "ab") + } + }) + + t.Run("empty-extra-returns-error", func(t *testing.T) { + p := &SRTHandshakePacket{} + if _, err := p.StreamID(); err == nil { + t.Fatal("expected error for empty ExtraData") + } + }) + + t.Run("declared-size-exceeds-buffer", func(t *testing.T) { + // Extension type 0x05 claims 4 words (16 bytes) but only 4 bytes follow. + p := &SRTHandshakePacket{ExtraData: []byte{0x00, 0x05, 0x00, 0x04, 0xaa, 0xbb, 0xcc, 0xdd}} + if _, err := p.StreamID(); err == nil { + t.Fatal("expected error when declared size exceeds buffer") + } + }) + + t.Run("only-non-streamid-extension-returns-error", func(t *testing.T) { + // One full extension that's not type 0x05; walker advances and then + // runs out of bytes for the next header → error. + p := &SRTHandshakePacket{ExtraData: []byte{0x00, 0x01, 0x00, 0x01, 0xde, 0xad, 0xbe, 0xef}} + if _, err := p.StreamID(); err == nil { + t.Fatal("expected error when no stream id extension is present") + } + }) +} + +// --------------------------------------------------------------------------- +// SRTConnection: fakes, fixture, and tests +// --------------------------------------------------------------------------- + +// newHandshake0 builds a client INDUCTION handshake packet (SynCookie == 0). +func newHandshake0(srtSocketID uint32) *SRTHandshakePacket { + return &SRTHandshakePacket{ + ControlFlag: 0x80, + ControlType: 0, + SubType: 0, + MTU: 1500, + FlowWindow: 8192, + HandshakeType: 1, + Version: 4, + InitSequence: 0xdeadbeef, + SRTSocketID: srtSocketID, + PeerIP: net.ParseIP("127.0.0.1"), + } +} + +// newHandshake2 builds a client CONCLUSION handshake packet carrying the given +// stream id (SynCookie must be non-zero so it enters the handshake-2 branch). +func newHandshake2(srtSocketID uint32, cookie uint32, streamID string) *SRTHandshakePacket { + return &SRTHandshakePacket{ + ControlFlag: 0x80, + ControlType: 0, + SubType: 0, + Version: 5, + HandshakeType: 0xFFFFFFFF, // CONCLUSION + SRTSocketID: srtSocketID, + SynCookie: cookie, + PeerIP: net.ParseIP("127.0.0.1"), + ExtraData: encodeSRTStreamIDExt(streamID), + } +} + +// marshalOrFatal marshals a handshake packet; fails the test on error. +func marshalOrFatal(t *testing.T, p *SRTHandshakePacket) []byte { + t.Helper() + b, err := p.MarshalBinary() + if err != nil { + t.Fatalf("marshal: %v", err) + } + return b +} + +// srtConnFixture wires an SRTConnection with fakes for the load balancer, +// listener, and backend dial seam. +type srtConnFixture struct { + conn *SRTConnection + lb *lbfakes.FakeOriginLoadBalancer + listener *fakePacketConn + backend *fakeBackendUDP + dialErr error + dialIP string + dialPort int +} + +func newSRTConnFixture() *srtConnFixture { + f := &srtConnFixture{ + lb: &lbfakes.FakeOriginLoadBalancer{}, + listener: newFakePacketConn(), + backend: newFakeBackendUDP(), + } + f.conn = NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(context.Background()) + c.loadBalancer = f.lb + c.listenerUDP = f.listener + c.start = time.Now() + c.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + f.dialIP, f.dialPort = ip, port + if f.dialErr != nil { + return nil, f.dialErr + } + return f.backend, nil + } + }) + return f +} + +func TestNewSRTConnection(t *testing.T) { + t.Run("defaults dialBackendUDP", func(t *testing.T) { + c := NewSRTConnection() + if c.dialBackendUDP == nil { + t.Fatal("expected dialBackendUDP to be defaulted") + } + }) + + t.Run("applies functional options", func(t *testing.T) { + c := NewSRTConnection(func(c *SRTConnection) { + c.socketID = 0xabc + }) + if c.socketID != 0xabc { + t.Fatalf("socketID=%x, want 0xabc", c.socketID) + } + }) + + t.Run("options override default dialBackendUDP", func(t *testing.T) { + called := false + dial := func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) { + called = true + return nil, nil + } + c := NewSRTConnection(func(c *SRTConnection) { c.dialBackendUDP = dial }) + _, _ = c.dialBackendUDP(context.Background(), "", 0) + if !called { + t.Fatal("expected overridden dialBackendUDP to be invoked") + } + }) +} + +func TestSRTConnection_HandlePacket_NoHandshake(t *testing.T) { + t.Run("noop when backendUDP not set", func(t *testing.T) { + f := newSRTConnFixture() + f.conn.socketID = 42 + + sid, err := f.conn.HandlePacket(nil, &net.UDPAddr{}, []byte("payload")) + if err != nil { + t.Fatalf("unexpected err=%v", err) + } + if sid != 42 { + t.Fatalf("socketID=%d, want 42", sid) + } + }) + + t.Run("writes data to backend", func(t *testing.T) { + f := newSRTConnFixture() + f.conn.backendUDP = f.backend + f.conn.socketID = 7 + + sid, err := f.conn.HandlePacket(nil, &net.UDPAddr{}, []byte("payload")) + if err != nil { + t.Fatalf("unexpected err=%v", err) + } + if sid != 7 { + t.Fatalf("socketID=%d, want 7", sid) + } + + select { + case got := <-f.backend.writes: + if string(got) != "payload" { + t.Fatalf("backend got %q, want %q", got, "payload") + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for backend write") + } + }) + + t.Run("propagates backend write error", func(t *testing.T) { + f := newSRTConnFixture() + f.conn.backendUDP = f.backend + f.backend.writeErr = errors.New("write-fail") + + _, err := f.conn.HandlePacket(nil, &net.UDPAddr{}, []byte("payload")) + if err == nil || !strings.Contains(err.Error(), "write-fail") { + t.Fatalf("expected write-fail err, got %v", err) + } + }) +} + +func TestSRTConnection_HandleHandshake_Step0(t *testing.T) { + t.Run("replies handshake 1 with proxy cookie", func(t *testing.T) { + f := newSRTConnFixture() + client := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 9000} + + hs0 := newHandshake0(0x11111111) + if _, err := f.conn.HandlePacket(hs0, client, marshalOrFatal(t, hs0)); err != nil { + t.Fatalf("HandlePacket err=%v", err) + } + + if f.conn.handshake0 != hs0 { + t.Fatal("handshake0 was not saved on the connection") + } + if f.conn.handshake1 == nil { + t.Fatal("handshake1 was not built") + } + // Proxy always replies INDUCTION with its own fixed cookie and the + // SRT magic ExtensionField, per the RFC induction message format. + if f.conn.handshake1.SynCookie != 0x418d5e4e { + t.Fatalf("handshake1.SynCookie=0x%08x, want 0x418d5e4e", f.conn.handshake1.SynCookie) + } + if f.conn.handshake1.ExtensionField != 0x4A17 { + t.Fatalf("handshake1.ExtensionField=0x%04x, want 0x4A17", f.conn.handshake1.ExtensionField) + } + + select { + case got := <-f.listener.writes: + if got.addr != client { + t.Fatalf("listener got addr=%v, want %v", got.addr, client) + } + parsed := &SRTHandshakePacket{} + if err := parsed.UnmarshalBinary(got.data); err != nil { + t.Fatalf("unmarshal listener write: %v", err) + } + if parsed.SynCookie != 0x418d5e4e { + t.Fatalf("on-wire SynCookie=0x%08x, want 0x418d5e4e", parsed.SynCookie) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for listener write") + } + }) + + t.Run("listener write error is propagated", func(t *testing.T) { + f := newSRTConnFixture() + f.listener.writeErr = errors.New("listen-write-fail") + + hs0 := newHandshake0(0x11111111) + _, err := f.conn.HandlePacket(hs0, &net.UDPAddr{}, marshalOrFatal(t, hs0)) + if err == nil || !strings.Contains(err.Error(), "listen-write-fail") { + t.Fatalf("expected propagated listener err, got %v", err) + } + }) +} + +func TestSRTConnection_HandleHandshake_Step2_StreamIDError(t *testing.T) { + f := newSRTConnFixture() + // Cookie != 0 puts us on the handshake-2 path; no 0x05 extension means + // StreamID() returns an error before we ever touch the load balancer. + pkt := &SRTHandshakePacket{ + ControlFlag: 0x80, + HandshakeType: 0xFFFFFFFF, + SRTSocketID: 1, + SynCookie: 0x418d5e4e, + PeerIP: net.ParseIP("127.0.0.1"), + } + _, err := f.conn.HandlePacket(pkt, &net.UDPAddr{}, marshalOrFatal(t, pkt)) + if err == nil || !strings.Contains(err.Error(), "parse stream id") { + t.Fatalf("expected parse-stream-id err, got %v", err) + } + if f.lb.PickCallCount() != 0 { + t.Fatal("expected Pick not to be called when stream id parse fails") + } +} + +func TestSRTConnection_HandleHandshake_Step2_FullFlow(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", SRT: []string{"20080"}}, nil) + client := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 9000} + + // Step 0 first, to populate handshake0 and the proxy's handshake1 (cookie + // 0x418d5e4e). The listener write for hs1 is drained so it does not block + // later assertions. + hs0 := newHandshake0(0x11111111) + if _, err := f.conn.HandlePacket(hs0, client, marshalOrFatal(t, hs0)); err != nil { + t.Fatalf("hs0 HandlePacket err=%v", err) + } + <-f.listener.writes + + // Pre-feed backend's hs1 (with its own cookie) and hs3 (with its own + // socket id) so the synchronous Reads inside handleHandshake unblock. + const backendCookie uint32 = 0x12345678 + const backendSocketID uint32 = 0xabcd1234 + f.backend.reads <- marshalOrFatal(t, &SRTHandshakePacket{ + ControlFlag: 0x80, SynCookie: backendCookie, PeerIP: net.ParseIP("127.0.0.1"), + }) + f.backend.reads <- marshalOrFatal(t, &SRTHandshakePacket{ + ControlFlag: 0x80, SRTSocketID: backendSocketID, SynCookie: backendCookie, PeerIP: net.ParseIP("127.0.0.1"), + }) + + hs2 := newHandshake2(0x11111111, 0x418d5e4e, "#!::r=live/stream") + sid, err := f.conn.HandlePacket(hs2, client, marshalOrFatal(t, hs2)) + if err != nil { + t.Fatalf("hs2 HandlePacket err=%v", err) + } + if sid != backendSocketID { + t.Fatalf("returned socketID=0x%08x, want 0x%08x", sid, backendSocketID) + } + if f.conn.socketID != backendSocketID { + t.Fatalf("conn.socketID=0x%08x, want 0x%08x", f.conn.socketID, backendSocketID) + } + if f.dialIP != "127.0.0.1" || f.dialPort != 20080 { + t.Fatalf("dial got ip=%q port=%d, want 127.0.0.1:20080", f.dialIP, f.dialPort) + } + + // First backend write is the raw hs0 from the client; second is hs2 with + // the cookie rewritten to the backend's value (not the proxy's). + got0 := drainBackendWrite(t, f.backend) + parsed0 := &SRTHandshakePacket{} + if err := parsed0.UnmarshalBinary(got0); err != nil { + t.Fatalf("unmarshal hs0 sent to backend: %v", err) + } + if parsed0.SynCookie != 0 { + t.Fatalf("hs0 forwarded with SynCookie=0x%08x, want 0", parsed0.SynCookie) + } + + got2 := drainBackendWrite(t, f.backend) + parsed2 := &SRTHandshakePacket{} + if err := parsed2.UnmarshalBinary(got2); err != nil { + t.Fatalf("unmarshal hs2 sent to backend: %v", err) + } + if parsed2.SynCookie != backendCookie { + t.Fatalf("hs2 to backend SynCookie=0x%08x, want 0x%08x", parsed2.SynCookie, backendCookie) + } + + // hs3 to the client must carry the proxy's cookie, not the backend's. + got3 := drainListenerWrite(t, f.listener, client) + parsed3 := &SRTHandshakePacket{} + if err := parsed3.UnmarshalBinary(got3); err != nil { + t.Fatalf("unmarshal hs3 sent to client: %v", err) + } + if parsed3.SynCookie != 0x418d5e4e { + t.Fatalf("hs3 to client SynCookie=0x%08x, want 0x418d5e4e", parsed3.SynCookie) + } + if parsed3.SRTSocketID != backendSocketID { + t.Fatalf("hs3 to client SRTSocketID=0x%08x, want 0x%08x", parsed3.SRTSocketID, backendSocketID) + } + + // Cleanly terminate the background backend→client forwarder goroutine. + _ = f.backend.Close() +} + +func drainBackendWrite(t *testing.T, b *fakeBackendUDP) []byte { + t.Helper() + select { + case got := <-b.writes: + return got + case <-time.After(time.Second): + t.Fatal("timeout waiting for backend write") + return nil + } +} + +func drainListenerWrite(t *testing.T, l *fakePacketConn, wantAddr net.Addr) []byte { + t.Helper() + select { + case got := <-l.writes: + if got.addr != wantAddr { + t.Fatalf("listener addr=%v, want %v", got.addr, wantAddr) + } + return got.data + case <-time.After(time.Second): + t.Fatal("timeout waiting for listener write") + return nil + } +} + +func TestSRTConnection_ConnectBackend(t *testing.T) { + t.Run("noop when already connected", func(t *testing.T) { + f := newSRTConnFixture() + f.conn.backendUDP = f.backend + if err := f.conn.connectBackend(context.Background(), "#!::r=live/stream"); err != nil { + t.Fatalf("unexpected err=%v", err) + } + if f.lb.PickCallCount() != 0 { + t.Fatal("expected Pick not to be called when already connected") + } + }) + + t.Run("propagates ParseSRTStreamID error", func(t *testing.T) { + f := newSRTConnFixture() + err := f.conn.connectBackend(context.Background(), "no-resource-key") + if err == nil || !strings.Contains(err.Error(), "parse stream id") { + t.Fatalf("expected parse-stream-id err, got %v", err) + } + }) + + t.Run("propagates Pick error", func(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(nil, errors.New("pick-fail")) + err := f.conn.connectBackend(context.Background(), "#!::r=live/stream") + if err == nil || !strings.Contains(err.Error(), "pick-fail") { + t.Fatalf("expected pick err, got %v", err) + } + }) + + t.Run("errors when backend has no SRT endpoints", func(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil) + err := f.conn.connectBackend(context.Background(), "#!::r=live/stream") + if err == nil || !strings.Contains(err.Error(), "no udp server") { + t.Fatalf("expected no-udp-server err, got %v", err) + } + }) + + t.Run("propagates ParseListenEndpoint error", func(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", SRT: []string{"not-a-port"}}, nil) + err := f.conn.connectBackend(context.Background(), "#!::r=live/stream") + if err == nil || !strings.Contains(err.Error(), "parse udp port") { + t.Fatalf("expected parse-udp-port err, got %v", err) + } + }) + + t.Run("propagates dial error", func(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", SRT: []string{"20080"}}, nil) + f.dialErr = errors.New("dial-fail") + err := f.conn.connectBackend(context.Background(), "#!::r=live/stream") + if err == nil || !strings.Contains(err.Error(), "dial-fail") { + t.Fatalf("expected dial err, got %v", err) + } + if f.conn.backendUDP != nil { + t.Fatal("backendUDP should remain nil on dial failure") + } + }) + + t.Run("success sets backendUDP and forwards ip/port", func(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.5", SRT: []string{"20080"}}, nil) + if err := f.conn.connectBackend(context.Background(), "#!::r=live/stream"); err != nil { + t.Fatalf("unexpected err=%v", err) + } + if f.conn.backendUDP != f.backend { + t.Fatal("backendUDP not set to dialed connection") + } + if f.dialIP != "10.0.0.5" || f.dialPort != 20080 { + t.Fatalf("dial got ip=%q port=%d, want 10.0.0.5:20080", f.dialIP, f.dialPort) + } + }) + + t.Run("defaults host to localhost when stream id has no h=", func(t *testing.T) { + f := newSRTConnFixture() + f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.5", SRT: []string{"20080"}}, nil) + if err := f.conn.connectBackend(context.Background(), "#!::r=live/stream"); err != nil { + t.Fatalf("unexpected err=%v", err) + } + // Pick is called with a stream URL built from "srt://localhost/live/stream"; + // BuildStreamURL normalizes hostnames without a "." to __defaultVhost__. + _, gotURL := f.lb.PickArgsForCall(0) + if !strings.Contains(gotURL, "__defaultVhost__") { + t.Fatalf("Pick streamURL=%q, want default-vhost form", gotURL) + } + }) +} + +// --------------------------------------------------------------------------- +// srsSRTProxyServer: fixture and tests +// --------------------------------------------------------------------------- + +// srtServerFixture wires a srsSRTProxyServer with fake env, lb, and listener. +// The default listenUDP returns the fixture's blocking listener so tests can +// drive Run() through it; tests that exercise handleClientUDP directly can +// instead set v.listener to f.listener without ever calling Run(). +type srtServerFixture struct { + env *envfakes.FakeProxyEnvironment + lb *lbfakes.FakeOriginLoadBalancer + listener *blockingUDPListener + server *srsSRTProxyServer +} + +func newSRTServerFixture() *srtServerFixture { + f := &srtServerFixture{ + env: &envfakes.FakeProxyEnvironment{}, + lb: &lbfakes.FakeOriginLoadBalancer{}, + listener: newBlockingUDPListener(), + } + f.env.SRTServerReturns("20080") + f.server = NewSRSSRTProxyServer(f.env, f.lb, func(v *srsSRTProxyServer) { + v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + return f.listener, nil + } + }) + return f +} + +func TestNewSRSSRTProxyServer_SetsDefaults(t *testing.T) { + v := NewSRSSRTProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + if v.listenUDP == nil { + t.Fatal("listenUDP should default to a non-nil factory") + } + if v.start.IsZero() { + t.Fatal("start should be initialized to time.Now()") + } +} + +func TestNewSRSSRTProxyServer_AppliesOptions(t *testing.T) { + called := false + listenUDP := func(ctx context.Context, endpoint string) (net.PacketConn, error) { + called = true + return nil, errors.New("test") + } + v := NewSRSSRTProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}, + func(s *srsSRTProxyServer) { s.listenUDP = listenUDP }) + _, _ = v.listenUDP(context.Background(), "") + if !called { + t.Fatal("expected overridden listenUDP to be invoked") + } +} + +func TestSRSSRTProxyServer_Close_NilListener(t *testing.T) { + // Close before Run must not panic, must not hang, and must not error. + v := NewSRSSRTProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{}) + done := make(chan error, 1) + go func() { done <- v.Close() }() + select { + case err := <-done: + if err != nil { + t.Fatalf("Close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Close hung with no listener") + } +} + +func TestSRSSRTProxyServer_Run_ListenError(t *testing.T) { + envFake := &envfakes.FakeProxyEnvironment{} + envFake.SRTServerReturns("20080") + v := NewSRSSRTProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(s *srsSRTProxyServer) { + s.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + return nil, errors.New("permission denied") + } + }) + + err := v.Run(context.Background()) + if err == nil || !strings.Contains(err.Error(), "listen udp") { + t.Fatalf("expected listen-udp err, got %v", err) + } +} + +func TestSRSSRTProxyServer_Run_EndpointWithoutColon(t *testing.T) { + // A bare port like "20080" must be normalized to ":20080". + envFake := &envfakes.FakeProxyEnvironment{} + envFake.SRTServerReturns("20080") + listener := newBlockingUDPListener() + var captured atomic.Value + v := NewSRSSRTProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(s *srsSRTProxyServer) { + s.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + captured.Store(endpoint) + return listener, nil + } + }) + + if err := v.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + defer v.Close() + + if got := captured.Load(); got != ":20080" { + t.Fatalf("listenUDP endpoint=%v, want :20080", got) + } +} + +func TestSRSSRTProxyServer_Run_EndpointWithColon(t *testing.T) { + envFake := &envfakes.FakeProxyEnvironment{} + envFake.SRTServerReturns("127.0.0.1:20080") + listener := newBlockingUDPListener() + var captured atomic.Value + v := NewSRSSRTProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(s *srsSRTProxyServer) { + s.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) { + captured.Store(endpoint) + return listener, nil + } + }) + + if err := v.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + defer v.Close() + + if got := captured.Load(); got != "127.0.0.1:20080" { + t.Fatalf("listenUDP endpoint=%v, want 127.0.0.1:20080", got) + } +} + +func TestSRSSRTProxyServer_Run_CloseStopsReadLoop(t *testing.T) { + // Start Run with an idle listener (no packets queued). The read goroutine + // blocks in ReadFrom. Close must unblock it via the "closed network + // connection" error and allow the wait group to drain. + f := newSRTServerFixture() + if err := f.server.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + + done := make(chan error, 1) + go func() { done <- f.server.Close() }() + select { + case err := <-done: + if err != nil { + t.Fatalf("Close: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Close hung — read loop did not exit on listener close") + } +} + +// --------------------------------------------------------------------------- +// srsSRTProxyServer.handleClientUDP — routing only +// --------------------------------------------------------------------------- + +// buildNonHandshakeUDPPayload assembles a UDP payload whose first 4 bytes do +// NOT match the SRT handshake magic (so utils.SrtIsHandshake returns false) +// but whose destination socket ID at offset 12..15 equals the given id. +func buildNonHandshakeUDPPayload(destSocketID uint32, tail []byte) []byte { + out := make([]byte, 16+len(tail)) + // data[0]=0x00 — top bit clear, so SrtIsHandshake is false. + binary.BigEndian.PutUint32(out[12:16], destSocketID) + copy(out[16:], tail) + return out +} + +func TestSRSSRTProxyServer_HandleClientUDP_RoutesNonHandshakeToExistingConn(t *testing.T) { + f := newSRTServerFixture() + // handleClientUDP wires v.listener into newly-created connections, but for + // this test the existing conn already has its own backend, so v.listener is + // only relevant to satisfy the LoadOrStore path (and never read from). + f.server.listener = f.listener + + backend := newFakeBackendUDP() + existing := NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(context.Background()) + c.backendUDP = backend + c.socketID = 0x12345678 + }) + f.server.sockets.Store(0x12345678, existing) + + payload := buildNonHandshakeUDPPayload(0x12345678, []byte("media-bytes")) + if err := f.server.handleClientUDP(context.Background(), &net.UDPAddr{}, payload); err != nil { + t.Fatalf("handleClientUDP err=%v", err) + } + + select { + case got := <-backend.writes: + // The full datagram is forwarded as-is. + if string(got) != string(payload) { + t.Fatalf("backend got %q, want %q", got, payload) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for backend write") + } +} + +func TestSRSSRTProxyServer_HandleClientUDP_HandshakeCreatesConnection(t *testing.T) { + f := newSRTServerFixture() + f.server.listener = f.listener + + const srtSocketID uint32 = 0xaabbccdd + hs0 := newHandshake0(srtSocketID) + data := marshalOrFatal(t, hs0) + // hs0 has SocketID(dest)=0 on the wire, so handleClientUDP must fall back + // to pkt.SRTSocketID to key the sockets map. + + client := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 9000} + if err := f.server.handleClientUDP(context.Background(), client, data); err != nil { + t.Fatalf("handleClientUDP err=%v", err) + } + + if _, ok := f.server.sockets.Load(srtSocketID); !ok { + t.Fatalf("expected sockets map to have entry under 0x%08x", srtSocketID) + } + + // hs1 reply must have been written back to the client via the listener. + select { + case got := <-f.listener.writes: + if got.addr != client { + t.Fatalf("listener addr=%v, want %v", got.addr, client) + } + parsed := &SRTHandshakePacket{} + if err := parsed.UnmarshalBinary(got.data); err != nil { + t.Fatalf("unmarshal hs1: %v", err) + } + if parsed.SynCookie != 0x418d5e4e { + t.Fatalf("hs1 SynCookie=0x%08x, want 0x418d5e4e", parsed.SynCookie) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for hs1 listener write") + } +} + +func TestSRSSRTProxyServer_HandleClientUDP_BadHandshakeUnmarshalError(t *testing.T) { + f := newSRTServerFixture() + f.server.listener = f.listener + + // First 4 bytes match the SRT handshake magic so SrtIsHandshake returns + // true, but the buffer is shorter than 64 bytes so UnmarshalBinary errors. + bad := []byte{0x80, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04} + err := f.server.handleClientUDP(context.Background(), &net.UDPAddr{}, bad) + if err == nil || !strings.Contains(err.Error(), "Invalid packet length") { + t.Fatalf("expected unmarshal err, got %v", err) + } +} diff --git a/internal/redisclient/gen.go b/internal/redisclient/gen.go new file mode 100644 index 000000000..5ce43b7be --- /dev/null +++ b/internal/redisclient/gen.go @@ -0,0 +1,6 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package redisclient + +//go:generate go tool counterfeiter -o redisclientfakes/fake_redis_client.go . RedisClient diff --git a/internal/redisclient/redisclient.go b/internal/redisclient/redisclient.go new file mode 100644 index 000000000..78c85d976 --- /dev/null +++ b/internal/redisclient/redisclient.go @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package redisclient + +import ( + "context" + "time" + + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ + "github.com/go-redis/redis/v8" +) + +// RedisClient is the subset of *redis.Client methods used by callers in this +// codebase. Declared as an interface so tests can substitute a fake without +// standing up a real Redis server. *redis.Client satisfies this interface +// directly. +type RedisClient interface { + Ping(ctx context.Context) *redis.StatusCmd + Get(ctx context.Context, key string) *redis.StringCmd + Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd + String() string +} + +// New connects to a Redis server at addr (host:port) with the given password +// and database index. Returns a RedisClient satisfied by *redis.Client. +func New(addr, password string, db int) RedisClient { + return redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + DB: db, + }) +} diff --git a/internal/redisclient/redisclientfakes/fake_redis_client.go b/internal/redisclient/redisclientfakes/fake_redis_client.go new file mode 100644 index 000000000..1ed9c03bd --- /dev/null +++ b/internal/redisclient/redisclientfakes/fake_redis_client.go @@ -0,0 +1,327 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package redisclientfakes + +import ( + "context" + "srsx/internal/redisclient" + "sync" + "time" + + redis "github.com/go-redis/redis/v8" +) + +type FakeRedisClient struct { + GetStub func(context.Context, string) *redis.StringCmd + getMutex sync.RWMutex + getArgsForCall []struct { + arg1 context.Context + arg2 string + } + getReturns struct { + result1 *redis.StringCmd + } + getReturnsOnCall map[int]struct { + result1 *redis.StringCmd + } + PingStub func(context.Context) *redis.StatusCmd + pingMutex sync.RWMutex + pingArgsForCall []struct { + arg1 context.Context + } + pingReturns struct { + result1 *redis.StatusCmd + } + pingReturnsOnCall map[int]struct { + result1 *redis.StatusCmd + } + SetStub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd + setMutex sync.RWMutex + setArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 interface{} + arg4 time.Duration + } + setReturns struct { + result1 *redis.StatusCmd + } + setReturnsOnCall map[int]struct { + result1 *redis.StatusCmd + } + StringStub func() string + stringMutex sync.RWMutex + stringArgsForCall []struct { + } + stringReturns struct { + result1 string + } + stringReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRedisClient) Get(arg1 context.Context, arg2 string) *redis.StringCmd { + fake.getMutex.Lock() + ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)] + fake.getArgsForCall = append(fake.getArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.GetStub + fakeReturns := fake.getReturns + fake.recordInvocation("Get", []interface{}{arg1, arg2}) + fake.getMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) GetCallCount() int { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + return len(fake.getArgsForCall) +} + +func (fake *FakeRedisClient) GetCalls(stub func(context.Context, string) *redis.StringCmd) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = stub +} + +func (fake *FakeRedisClient) GetArgsForCall(i int) (context.Context, string) { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + argsForCall := fake.getArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRedisClient) GetReturns(result1 *redis.StringCmd) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + fake.getReturns = struct { + result1 *redis.StringCmd + }{result1} +} + +func (fake *FakeRedisClient) GetReturnsOnCall(i int, result1 *redis.StringCmd) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + if fake.getReturnsOnCall == nil { + fake.getReturnsOnCall = make(map[int]struct { + result1 *redis.StringCmd + }) + } + fake.getReturnsOnCall[i] = struct { + result1 *redis.StringCmd + }{result1} +} + +func (fake *FakeRedisClient) Ping(arg1 context.Context) *redis.StatusCmd { + fake.pingMutex.Lock() + ret, specificReturn := fake.pingReturnsOnCall[len(fake.pingArgsForCall)] + fake.pingArgsForCall = append(fake.pingArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.PingStub + fakeReturns := fake.pingReturns + fake.recordInvocation("Ping", []interface{}{arg1}) + fake.pingMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) PingCallCount() int { + fake.pingMutex.RLock() + defer fake.pingMutex.RUnlock() + return len(fake.pingArgsForCall) +} + +func (fake *FakeRedisClient) PingCalls(stub func(context.Context) *redis.StatusCmd) { + fake.pingMutex.Lock() + defer fake.pingMutex.Unlock() + fake.PingStub = stub +} + +func (fake *FakeRedisClient) PingArgsForCall(i int) context.Context { + fake.pingMutex.RLock() + defer fake.pingMutex.RUnlock() + argsForCall := fake.pingArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRedisClient) PingReturns(result1 *redis.StatusCmd) { + fake.pingMutex.Lock() + defer fake.pingMutex.Unlock() + fake.PingStub = nil + fake.pingReturns = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) PingReturnsOnCall(i int, result1 *redis.StatusCmd) { + fake.pingMutex.Lock() + defer fake.pingMutex.Unlock() + fake.PingStub = nil + if fake.pingReturnsOnCall == nil { + fake.pingReturnsOnCall = make(map[int]struct { + result1 *redis.StatusCmd + }) + } + fake.pingReturnsOnCall[i] = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) Set(arg1 context.Context, arg2 string, arg3 interface{}, arg4 time.Duration) *redis.StatusCmd { + fake.setMutex.Lock() + ret, specificReturn := fake.setReturnsOnCall[len(fake.setArgsForCall)] + fake.setArgsForCall = append(fake.setArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 interface{} + arg4 time.Duration + }{arg1, arg2, arg3, arg4}) + stub := fake.SetStub + fakeReturns := fake.setReturns + fake.recordInvocation("Set", []interface{}{arg1, arg2, arg3, arg4}) + fake.setMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) SetCallCount() int { + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + return len(fake.setArgsForCall) +} + +func (fake *FakeRedisClient) SetCalls(stub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = stub +} + +func (fake *FakeRedisClient) SetArgsForCall(i int) (context.Context, string, interface{}, time.Duration) { + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + argsForCall := fake.setArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeRedisClient) SetReturns(result1 *redis.StatusCmd) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = nil + fake.setReturns = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) SetReturnsOnCall(i int, result1 *redis.StatusCmd) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = nil + if fake.setReturnsOnCall == nil { + fake.setReturnsOnCall = make(map[int]struct { + result1 *redis.StatusCmd + }) + } + fake.setReturnsOnCall[i] = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) String() string { + fake.stringMutex.Lock() + ret, specificReturn := fake.stringReturnsOnCall[len(fake.stringArgsForCall)] + fake.stringArgsForCall = append(fake.stringArgsForCall, struct { + }{}) + stub := fake.StringStub + fakeReturns := fake.stringReturns + fake.recordInvocation("String", []interface{}{}) + fake.stringMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) StringCallCount() int { + fake.stringMutex.RLock() + defer fake.stringMutex.RUnlock() + return len(fake.stringArgsForCall) +} + +func (fake *FakeRedisClient) StringCalls(stub func() string) { + fake.stringMutex.Lock() + defer fake.stringMutex.Unlock() + fake.StringStub = stub +} + +func (fake *FakeRedisClient) StringReturns(result1 string) { + fake.stringMutex.Lock() + defer fake.stringMutex.Unlock() + fake.StringStub = nil + fake.stringReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeRedisClient) StringReturnsOnCall(i int, result1 string) { + fake.stringMutex.Lock() + defer fake.stringMutex.Unlock() + fake.StringStub = nil + if fake.stringReturnsOnCall == nil { + fake.stringReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.stringReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeRedisClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRedisClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ redisclient.RedisClient = new(FakeRedisClient) diff --git a/internal/rtmp/amf0.go b/internal/rtmp/amf0.go index 7fd2c7a3d..e316f1b8f 100644 --- a/internal/rtmp/amf0.go +++ b/internal/rtmp/amf0.go @@ -90,7 +90,9 @@ type amf0Buffer interface { Write(p []byte) (n int, err error) } -var createBuffer = func() amf0Buffer { +// defaultBufFactory is the production amf0Buffer factory. Tests override the +// per-instance bufFactory field on amf0ObjectBase instead of swapping a global. +func defaultBufFactory() amf0Buffer { return &bytes.Buffer{} } @@ -399,6 +401,10 @@ type amf0Property struct { type amf0ObjectBase struct { properties []*amf0Property lock sync.Mutex + // bufFactory creates the amf0Buffer used by MarshalBinary. Held as a + // per-instance field (not a package global) so concurrent tests can each + // install their own buggy buffers without racing on shared state. + bufFactory func() amf0Buffer } func (v *amf0ObjectBase) Size() int { @@ -562,6 +568,7 @@ func NewAmf0Object() Amf0Object { func newAmf0Object() *amf0Object { v := &amf0Object{} v.properties = []*amf0Property{} + v.bufFactory = defaultBufFactory return v } @@ -600,7 +607,7 @@ func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { } func (v *amf0Object) MarshalBinary() (data []byte, err error) { - b := createBuffer() + b := v.bufFactory() if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { return nil, errors.Wrap(err, "marshal") @@ -640,6 +647,7 @@ func NewAmf0EcmaArray() Amf0EcmaArray { func newAmf0EcmaArray() *amf0EcmaArray { v := &amf0EcmaArray{} v.properties = []*amf0Property{} + v.bufFactory = defaultBufFactory return v } @@ -678,7 +686,7 @@ func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { } func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { - b := createBuffer() + b := v.bufFactory() if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { return nil, errors.Wrap(err, "marshal") @@ -717,6 +725,7 @@ type amf0StrictArray struct { func NewAmf0StrictArray() Amf0StrictArray { v := &amf0StrictArray{} v.properties = []*amf0Property{} + v.bufFactory = defaultBufFactory return v } @@ -759,7 +768,7 @@ func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { } func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { - b := createBuffer() + b := v.bufFactory() if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { return nil, errors.Wrap(err, "marshal") diff --git a/internal/rtmp/amf0_test.go b/internal/rtmp/amf0_test.go index a2c240360..da102e7b1 100644 --- a/internal/rtmp/amf0_test.go +++ b/internal/rtmp/amf0_test.go @@ -436,10 +436,21 @@ func (v *errorAmf0Any) amf0Marker() amf0Marker { return amf0MarkerNumber } -func TestAmf0MarshalErrors(t *testing.T) { - originalCreateBuffer := createBuffer - defer func() { createBuffer = originalCreateBuffer }() +// setBufFactory replaces the bufFactory on whichever amf0 object-like type +// underlies v. Concurrent tests can use this safely because each value carries +// its own factory. +func setBufFactory(v Amf0Any, fn func() amf0Buffer) { + switch v := v.(type) { + case *amf0Object: + v.bufFactory = fn + case *amf0EcmaArray: + v.bufFactory = fn + case *amf0StrictArray: + v.bufFactory = fn + } +} +func TestAmf0MarshalErrors(t *testing.T) { for _, tt := range []struct { name string make func() Amf0Any @@ -449,15 +460,16 @@ func TestAmf0MarshalErrors(t *testing.T) { {"strict-array", func() Amf0Any { return NewAmf0StrictArray() }}, } { t.Run(tt.name+" write-byte", func(t *testing.T) { - createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} } - if _, err := tt.make().MarshalBinary(); err == nil { + value := tt.make() + setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} }) + if _, err := value.MarshalBinary(); err == nil { t.Fatal("MarshalBinary() should fail") } }) t.Run(tt.name+" write-prop", func(t *testing.T) { - createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} } value := tt.make() + setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} }) switch v := value.(type) { case Amf0Object: v.Set("name", NewAmf0String("stream")) @@ -473,7 +485,6 @@ func TestAmf0MarshalErrors(t *testing.T) { }) } - createBuffer = originalCreateBuffer for _, tt := range []struct { name string make func() Amf0Any diff --git a/internal/signal/signal.go b/internal/signal/signal.go index b8930480b..a23b36a0f 100644 --- a/internal/signal/signal.go +++ b/internal/signal/signal.go @@ -15,15 +15,26 @@ import ( "srsx/internal/logger" ) -// Indirections so tests can substitute signal delivery and process exit. -var ( - signalNotify = signal.Notify - osExit = os.Exit -) +// Handler installs OS signal handlers and the force-quit timer. The notify +// and exit indirections are struct fields (not package globals) so concurrent +// tests can each construct a handler with their own fakes without racing on +// shared state. +type Handler struct { + notify func(c chan<- os.Signal, sig ...os.Signal) + exit func(code int) +} -func InstallSignals(ctx context.Context, cancel context.CancelFunc) { +// NewHandler returns a Handler wired to the real OS implementations. +func NewHandler() *Handler { + return &Handler{ + notify: signal.Notify, + exit: os.Exit, + } +} + +func (h *Handler) InstallSignals(ctx context.Context, cancel context.CancelFunc) { sc := make(chan os.Signal, 1) - signalNotify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + h.notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) go func() { for s := range sc { @@ -33,7 +44,7 @@ func InstallSignals(ctx context.Context, cancel context.CancelFunc) { }() } -func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error { +func (h *Handler) InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error { var forceTimeout time.Duration timeoutStr := environment.ForceQuitTimeout() if t, err := time.ParseDuration(timeoutStr); err != nil { @@ -46,7 +57,7 @@ func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) err <-ctx.Done() time.Sleep(forceTimeout) logger.Warn(ctx, "Force to exit by timeout") - osExit(1) + h.exit(1) }() return nil } diff --git a/internal/signal/signal_test.go b/internal/signal/signal_test.go index ea3fac252..fc1471dd6 100644 --- a/internal/signal/signal_test.go +++ b/internal/signal/signal_test.go @@ -16,59 +16,60 @@ import ( "srsx/internal/env/envfakes" ) -// swapNotify replaces signalNotify with a capturing fake and returns a getter -// for the channel registered by the code under test plus a restore func. -func swapNotify(t *testing.T) (func() chan<- os.Signal, func()) { - t.Helper() - orig := signalNotify +// captureNotify returns a Handler whose notify field records the channel +// passed by the code under test, plus a getter that retrieves it. +func captureNotify() (*Handler, func() chan<- os.Signal) { var ( mu sync.Mutex ch chan<- os.Signal ) - signalNotify = func(c chan<- os.Signal, _ ...os.Signal) { - mu.Lock() - defer mu.Unlock() - ch = c - } - return func() chan<- os.Signal { + h := &Handler{ + notify: func(c chan<- os.Signal, _ ...os.Signal) { mu.Lock() defer mu.Unlock() - return ch - }, func() { - signalNotify = orig - } + ch = c + }, + exit: os.Exit, + } + return h, func() chan<- os.Signal { + mu.Lock() + defer mu.Unlock() + return ch + } } -func swapExit(t *testing.T) (*int32, chan int, func()) { - t.Helper() - orig := osExit +// captureExit returns a Handler whose exit field records the code and never +// returns, plus a flag and channel that observe the call. +func captureExit() (*Handler, *int32, chan int) { var called int32 done := make(chan int, 1) - osExit = func(code int) { - atomic.StoreInt32(&called, 1) - select { - case done <- code: - default: - } - // Block to mimic os.Exit never returning; the goroutine holding us - // here is abandoned when the test ends. - select {} + h := &Handler{ + notify: func(chan<- os.Signal, ...os.Signal) {}, + exit: func(code int) { + atomic.StoreInt32(&called, 1) + select { + case done <- code: + default: + } + // Block to mimic os.Exit never returning; the goroutine holding us + // here is abandoned when the test ends. + select {} + }, } - return &called, done, func() { osExit = orig } + return h, &called, done } func TestInstallSignals_CancelsOnSignal(t *testing.T) { - getCh, restore := swapNotify(t) - defer restore() + h, getCh := captureNotify() ctx, cancel := context.WithCancel(t.Context()) defer cancel() - InstallSignals(ctx, cancel) + h.InstallSignals(ctx, cancel) ch := getCh() if ch == nil { - t.Fatal("signalNotify was not called") + t.Fatal("notify was not called") } ch <- syscall.SIGINT @@ -80,13 +81,12 @@ func TestInstallSignals_CancelsOnSignal(t *testing.T) { } func TestInstallSignals_HandlesRepeatedSignals(t *testing.T) { - getCh, restore := swapNotify(t) - defer restore() + h, getCh := captureNotify() ctx, cancel := context.WithCancel(t.Context()) defer cancel() - InstallSignals(ctx, cancel) + h.InstallSignals(ctx, cancel) ch := getCh() // Multiple signals must not panic; cancel() is idempotent. @@ -105,7 +105,7 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) { fakeEnv := &envfakes.FakeProxyEnvironment{} fakeEnv.ForceQuitTimeoutReturns("not-a-duration") - err := InstallForceQuit(t.Context(), fakeEnv) + err := NewHandler().InstallForceQuit(t.Context(), fakeEnv) if err == nil { t.Fatal("want error for bad duration") } @@ -118,20 +118,19 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) { } func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) { - called, done, restore := swapExit(t) - defer restore() + h, called, done := captureExit() fakeEnv := &envfakes.FakeProxyEnvironment{} fakeEnv.ForceQuitTimeoutReturns("1ms") ctx, cancel := context.WithCancel(t.Context()) - if err := InstallForceQuit(ctx, fakeEnv); err != nil { + if err := h.InstallForceQuit(ctx, fakeEnv); err != nil { t.Fatalf("unexpected err: %v", err) } // Before cancel, the goroutine is blocked and exit must not fire. if atomic.LoadInt32(called) != 0 { - t.Fatal("osExit called before ctx cancel") + t.Fatal("exit called before ctx cancel") } cancel() @@ -141,30 +140,39 @@ func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) { t.Fatalf("exit code = %d, want 1", code) } case <-time.After(time.Second): - t.Fatal("osExit not called after cancel + timeout") + t.Fatal("exit not called after cancel + timeout") } } func TestInstallForceQuit_WaitsForCancelBeforeSleeping(t *testing.T) { - called, done, restore := swapExit(t) - defer restore() + h, called, done := captureExit() fakeEnv := &envfakes.FakeProxyEnvironment{} fakeEnv.ForceQuitTimeoutReturns("10ms") - // Intentionally use a never-canceled context and leak the goroutine: - // if we canceled at test end, the goroutine would wake and race with - // restore() writing osExit. - if err := InstallForceQuit(context.Background(), fakeEnv); err != nil { + // Intentionally use a never-canceled context and leak the goroutine: the + // handler's exit closure is owned by this test instance, so leaving the + // goroutine alive doesn't race other tests. + if err := h.InstallForceQuit(context.Background(), fakeEnv); err != nil { t.Fatalf("unexpected err: %v", err) } select { case <-done: - t.Fatal("osExit fired without ctx cancel") + t.Fatal("exit fired without ctx cancel") case <-time.After(30 * time.Millisecond): } if atomic.LoadInt32(called) != 0 { - t.Fatal("osExit called unexpectedly") + t.Fatal("exit called unexpectedly") + } +} + +func TestNewHandler_UsesRealOSDefaults(t *testing.T) { + h := NewHandler() + if h.notify == nil { + t.Error("notify default not set") + } + if h.exit == nil { + t.Error("exit default not set") } } diff --git a/internal/version/version.go b/internal/version/version.go index f71511505..6e54527b8 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -15,7 +15,7 @@ func VersionMinor() int { } func VersionRevision() int { - return 147 + return 148 } func Version() string { diff --git a/memory b/memory deleted file mode 120000 index 20495cb7b..000000000 --- a/memory +++ /dev/null @@ -1 +0,0 @@ -.openclaw/memory \ No newline at end of file diff --git a/.openclaw/memory/2026-02-06.md b/memory/2026-02-06.md similarity index 100% rename from .openclaw/memory/2026-02-06.md rename to memory/2026-02-06.md diff --git a/.openclaw/memory/srs-codebase-map.md b/memory/srs-codebase-map.md similarity index 89% rename from .openclaw/memory/srs-codebase-map.md rename to memory/srs-codebase-map.md index 0beae9cae..6f2c9d018 100644 --- a/.openclaw/memory/srs-codebase-map.md +++ b/memory/srs-codebase-map.md @@ -215,11 +215,13 @@ The next-generation server (`cmd/` + `internal/`) is written in Go and maintaine `internal/bootstrap` — Server startup and lifecycle orchestration. Sets up logging context, signal handlers, loads environment, installs force-quit timer, optionally starts pprof, initializes the load balancer (memory or Redis based on `PROXY_LOAD_BALANCER_TYPE`), then starts all six servers sequentially (RTMP, WebRTC, HTTP API, SRT, System API, HTTP Stream) and blocks until context is cancelled. Deferred `Close()` on each server ensures graceful shutdown. -`internal/server` — Proxy server implementations. Each server accepts client connections, parses just enough of the protocol to extract the stream URL, picks a backend via the load balancer, and proxies traffic bidirectionally. Contains five proxy servers: (1) **RTMP proxy** (`rtmp.go`) — TCP listener, simple handshake, parses connect/publish/play to get stream URL, bidirectional RTMP message copying, stateless. (2) **HTTP stream proxy** (`http.go`) — serves static files, proxies HTTP-FLV/TS via reverse-proxy, proxies HLS m3u8 with `spbhid` rewriting so TS segment requests route to the same backend. (3) **WebRTC proxy** (`rtc.go`) — two-phase: WHIP/WHEP signaling (SDP rewrite to replace backend UDP port with proxy's) + UDP media transport (identifies connections by STUN ufrag, supports address migration), stateful. (4) **SRT proxy** (`srt.go`) — intercepts SRT 4-step handshake locally, parses stream ID on handshake 2, replays full handshake with backend, then proxies UDP bidirectionally, stateful per-connection. (5) **HTTP API + System API** (`api.go`) — HTTP API delegates WHIP/WHEP to WebRTC server; System API provides `/api/v1/srs/register` where backend SRS C++ servers register themselves so the load balancer knows about them. +`internal/proxy` — Proxy server implementations. Each server accepts client connections, parses just enough of the protocol to extract the stream URL, picks a backend via the load balancer, and proxies traffic bidirectionally. Contains five proxy servers: (1) **RTMP proxy** (`rtmp.go`) — TCP listener, simple handshake, parses connect/publish/play to get stream URL, bidirectional RTMP message copying, stateless. (2) **HTTP stream proxy** (`http.go`) — serves static files, proxies HTTP-FLV/TS via reverse-proxy, proxies HLS m3u8 with `spbhid` rewriting so TS segment requests route to the same backend. (3) **WebRTC proxy** (`rtc.go`) — two-phase: WHIP/WHEP signaling (SDP rewrite to replace backend UDP port with proxy's) + UDP media transport (identifies connections by STUN ufrag, supports address migration), stateful. (4) **SRT proxy** (`srt.go`) — intercepts SRT 4-step handshake locally, parses stream ID on handshake 2, replays full handshake with backend, then proxies UDP bidirectionally, stateful per-connection. (5) **HTTP API + System API** (`api.go`) — HTTP API delegates WHIP/WHEP to WebRTC server; System API provides `/api/v1/srs/register` where backend SRS C++ servers register themselves so the load balancer knows about them. `internal/rtmp` — RTMP protocol implementation (parsing, not proxying). Full RTMP chunk stream and message protocol: simple handshake (C0/C1/C2), chunk stream reader/writer with all four format types, extended timestamp, message reassembly from chunks. Defines all RTMP message types, chunk stream IDs, and command names. Packet types include ConnectApp, CreateStream, Publish, Play, Call, SetChunkSize, WindowAcknowledgementSize, SetPeerBandwidth, UserControl. Uses Go generics (`ExpectPacket[T]`) to read until a specific packet type arrives. Also includes full AMF0 encoder/decoder supporting Number, Boolean, String, Object, Null, Undefined, EcmaArray, StrictArray, Date, LongString — with ordered key-value maps, auto-type-discovery, and safe type converters. -`internal/lb` — Load balancer abstraction and two implementations. Defines `SRSLoadBalancer` interface and core types in `lb.go` (Initialize, Update, Pick, HLS/WebRTC state management) and `SRSServer` struct representing a backend origin (IP, listen endpoints for RTMP/HTTP/API/SRT/RTC, heartbeat tracking). **Memory LB** (`mem.go`) — in-memory using `sync.Map`, sticky random pick per stream URL, single-proxy deployment. **Redis LB** (`redis.go`) — Redis-backed shared state with TTL-based expiration, enables multi-proxy horizontal scaling behind a network load balancer. Also includes a debug helper (`debug.go`) that creates a fake backend from env vars when `PROXY_DEFAULT_BACKEND_ENABLED=on` for development without real SRS registration. +`internal/lb` — Load balancer abstraction and two implementations. Defines `OriginLoadBalancer` interface and core types in `lb.go` (Initialize, Update, Pick, HLS/WebRTC state management) and `OriginServer` struct representing a backend origin (IP, listen endpoints for RTMP/HTTP/API/SRT/RTC, heartbeat tracking). **Memory LB** (`mem.go`) — in-memory using `sync.Map`, sticky random pick per stream URL, single-proxy deployment. **Redis LB** (`redis.go`) — Redis-backed shared state with TTL-based expiration, enables multi-proxy horizontal scaling behind a network load balancer. Also includes a debug helper (`debug.go`) that creates a fake backend from env vars when `PROXY_DEFAULT_BACKEND_ENABLED=on` for development without real SRS registration. + +`internal/redisclient` — Thin Redis client abstraction. Defines a minimal `RedisClient` interface (`Ping`/`Get`/`Set`/`String`) satisfied by `*redis.Client` from `github.com/go-redis/redis/v8`, plus a `New(addr, password, db)` constructor. Used by `internal/lb/redis.go`. `internal/logger` — Structured logging with context IDs. Four log levels: Debug/Info (stdout), Warn/Error (stderr). Emits JSON via `log/slog` with `pid` and `cid` attributes. Each connection/request gets a unique 7-char hex context ID for log correlation, stored in `context.Context`. @@ -346,12 +348,15 @@ How to verify SRS works correctly. - Reconnecting Load Test - Janus -`.openclaw/skills/srs-develop/scripts/` — Go proxy verification scripts: +`.openclaw/skills/srs-develop/scripts/` — Go proxy verification and setup scripts: - `proxy-utest.sh` — Runs Go proxy unit tests with optional coverage. - `proxy-e2e-test.sh` — Single-origin RTMP proxy E2E test. - `proxy-e2e-cluster-test.sh` — Multi-origin memory load-balancer E2E test. - `proxy-e2e-redis-test.sh` — Multi-proxy Redis load-balancer E2E test. - `proxy-e2e-transmux-test.sh` — RTMP publish through proxy, then verify RTMP, HTTP-FLV, HLS, and WebRTC playback. +- `proxy-e2e-srt-test.sh` — SRT publish through proxy, then verify SRT, RTMP, HTTP-FLV, and HLS playback (WebRTC WHEP is a placeholder). +- `proxy-e2e-whip-test.sh` — WHIP (WebRTC) publish through proxy, then verify RTMP, HTTP-FLV, and HLS playback via the origin's `rtc_to_rtmp` bridge (WebRTC WHEP is a placeholder). +- `setup-ffmpeg-with-whip.sh` — macOS-only: build ffmpeg from source into `~/.local/` with WHIP (openssl DTLS) and SRT support; auto-invoked by `proxy-e2e-srt-test.sh` and `proxy-e2e-whip-test.sh` when no suitable ffmpeg is found. **Summary: The Key Differences** diff --git a/.openclaw/memory/srs-coroutines.md b/memory/srs-coroutines.md similarity index 100% rename from .openclaw/memory/srs-coroutines.md rename to memory/srs-coroutines.md diff --git a/.openclaw/memory/srs-overview.md b/memory/srs-overview.md similarity index 100% rename from .openclaw/memory/srs-overview.md rename to memory/srs-overview.md diff --git a/.openclaw/skills/srs-develop/SKILL.md b/skills/srs-develop/SKILL.md similarity index 92% rename from .openclaw/skills/srs-develop/SKILL.md rename to skills/srs-develop/SKILL.md index 4b444521a..6602e1fe0 100644 --- a/.openclaw/skills/srs-develop/SKILL.md +++ b/skills/srs-develop/SKILL.md @@ -156,10 +156,18 @@ Only after the user confirms the routing do you proceed to Step 2. ``` bash scripts/proxy-e2e-redis-test.sh ``` - - RTMP transmuxing test (starts proxy + one SRS origin, publishes RTMP, verifies RTMP/HTTP-FLV/HLS playback, and verifies WebRTC WHEP playback when `PROXY_TRANSMUX_TEST_RTC=on`): + - RTMP transmuxing test (starts proxy + one SRS origin, publishes RTMP, verifies RTMP/HTTP-FLV/HLS playback; WebRTC WHEP is a placeholder): ``` bash scripts/proxy-e2e-transmux-test.sh ``` + - SRT proxy + transmuxing test (starts proxy + one SRS origin, publishes SRT, verifies SRT/RTMP/HTTP-FLV/HLS playback; WebRTC WHEP is a placeholder). Requires an ffmpeg built with libsrt; the script auto-runs `scripts/setup-ffmpeg-with-whip.sh` to build one into `~/.local/` if no SRT-capable ffmpeg is found: + ``` + bash scripts/proxy-e2e-srt-test.sh + ``` + - WHIP proxy + transmuxing test (starts proxy + one SRS origin, publishes WebRTC via WHIP, verifies RTMP/HTTP-FLV/HLS playback; WebRTC WHEP is a placeholder). Requires an ffmpeg with the `whip` muxer (built with `--enable-openssl`); the script auto-runs `scripts/setup-ffmpeg-with-whip.sh` if no suitable ffmpeg is found: + ``` + bash scripts/proxy-e2e-whip-test.sh + ``` 5. If any tests fail, fix the issues and re-run until all tests pass. All script paths are relative to this skill's directory. diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh b/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh similarity index 96% rename from .openclaw/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh rename to skills/srs-develop/scripts/proxy-e2e-cluster-test.sh index eba2f5077..3721d8ac1 100755 --- a/.openclaw/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh @@ -5,11 +5,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-redis-test.sh b/skills/srs-develop/scripts/proxy-e2e-redis-test.sh similarity index 92% rename from .openclaw/skills/srs-develop/scripts/proxy-e2e-redis-test.sh rename to skills/srs-develop/scripts/proxy-e2e-redis-test.sh index 148a81d18..fee4a0b94 100755 --- a/.openclaw/skills/srs-develop/scripts/proxy-e2e-redis-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-redis-test.sh @@ -6,11 +6,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi @@ -38,7 +43,12 @@ PYTHON_BIN="${PYTHON_BIN:-python3}" SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" SRS_BINARY="$WORKSPACE/trunk/objs/srs" -TEST_STREAM_URL="__defaultVhost__/live/livestream" +# Randomize per run so each invocation uses unique Redis keys and never shares +# state with sibling E2E tests or a developer's local proxy that publishes to +# "live/livestream". +STREAM_NAME="redis$(date +%s)" +STREAM_PATH="live/$STREAM_NAME" +TEST_STREAM_URL="__defaultVhost__/$STREAM_PATH" # PIDs to clean up on exit. PROXY_A_PID="" @@ -279,7 +289,7 @@ echo "SRS origin started and registered in Redis." # --- Step 6: Publish RTMP stream to proxy A --- echo "=== Step 6: Publishing RTMP stream to proxy A ===" ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \ - "rtmp://localhost:$PROXY_A_RTMP_PORT/live/livestream" >/tmp/srs-ffmpeg-redis-e2e.log 2>&1 & + "rtmp://localhost:$PROXY_A_RTMP_PORT/$STREAM_PATH" >/tmp/srs-ffmpeg-redis-e2e.log 2>&1 & FFMPEG_PID=$! echo "FFmpeg publisher PID: $FFMPEG_PID" @@ -296,7 +306,7 @@ echo "Stream publishing through proxy A." # --- Step 7: Verify RTMP playback through proxy B --- echo "=== Step 7: Verifying RTMP playback through proxy B ===" PROBE_OUTPUT=$(ffprobe -v error -show_streams \ - "rtmp://localhost:$PROXY_B_RTMP_PORT/live/livestream" 2>&1 || true) + "rtmp://localhost:$PROXY_B_RTMP_PORT/$STREAM_PATH" 2>&1 || true) if echo "$PROBE_OUTPUT" | grep -q "codec_type=video"; then echo "PASS: Video stream detected through proxy B." diff --git a/skills/srs-develop/scripts/proxy-e2e-srt-test.sh b/skills/srs-develop/scripts/proxy-e2e-srt-test.sh new file mode 100755 index 000000000..38832e746 --- /dev/null +++ b/skills/srs-develop/scripts/proxy-e2e-srt-test.sh @@ -0,0 +1,295 @@ +#!/bin/bash +# E2E test for SRT proxy: starts proxy + SRS origin, publishes an SRT stream +# through the proxy, then verifies playback through the proxy in every form +# the origin can transmux from SRT (srt_to_rtmp + rtmp_to_rtc): +# - SRT play (passthrough) +# - RTMP play (via srt_to_rtmp on origin) +# - HTTP-FLV (HTTP remux of the bridged RTMP) +# - HLS (m3u8 + TS segments) +# - WebRTC WHEP (placeholder only, not actually verified here) +set -e + +SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done + +if [[ ! -f "$WORKSPACE/go.mod" ]]; then + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 + exit 1 +fi + +# Proxy ports — same layout as proxy-e2e-test.sh / proxy-e2e-transmux-test.sh. +PROXY_RTMP_PORT=11935 +PROXY_HTTP_API_PORT=11985 +PROXY_HTTP_SERVER_PORT=18080 +PROXY_WEBRTC_PORT=18000 +PROXY_SRT_PORT=20080 +PROXY_SYSTEM_API_PORT=12025 + +# Origin ports (from origin1-for-proxy.conf). +ORIGIN_RTMP_PORT=19351 +ORIGIN_HTTP_PORT=8081 +ORIGIN_API_PORT=19851 +ORIGIN_RTC_PORT=8001 +ORIGIN_SRT_PORT=10081 + +SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" +SRS_BINARY="$WORKSPACE/trunk/objs/srs" +# Randomize per run so each invocation starts from clean origin state (HLS +# segments, RTMP source, proxy stream registry) and never shares state with +# sibling E2E tests that publish to "live/livestream". +STREAM_URL="live/srt$(date +%s)" + +# SRT streamid format used by SRS: "#!::r=/,m=publish|request". +# @see trunk/3rdparty/srs-docs/doc/srt.md and internal/proxy/srt.go. +SRT_PUBLISH_URL="srt://localhost:$PROXY_SRT_PORT?streamid=#!::r=$STREAM_URL,m=publish" +SRT_PLAY_URL="srt://localhost:$PROXY_SRT_PORT?streamid=#!::r=$STREAM_URL,m=request" + +# PIDs to clean up on exit. +PROXY_PID="" +ORIGIN_PID="" +FFMPEG_PID="" + +cleanup() { + echo "" + echo "=== Cleaning up ===" + for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + done + echo "Cleanup done." +} +trap cleanup EXIT + +probe_has_audio_video() { + local name="$1" + local url="$2" + + echo "Verifying $name playback: $url" + local output + output=$("$FFPROBE_BIN" -v error -show_streams "$url" 2>&1 || true) + + if echo "$output" | grep -q "codec_type=video"; then + echo "PASS: $name video stream detected." + else + echo "FAIL: $name no video stream detected." >&2 + echo "ffprobe output:" >&2 + echo "$output" >&2 + exit 1 + fi + + if echo "$output" | grep -q "codec_type=audio"; then + echo "PASS: $name audio stream detected." + else + echo "FAIL: $name no audio stream detected." >&2 + echo "ffprobe output:" >&2 + echo "$output" >&2 + exit 1 + fi +} + +wait_for_hls_playlist() { + local url="$1" + local deadline=45 + + echo "Waiting for HLS playlist to be generated (up to ${deadline}s): $url" + for ((i = 1; i <= deadline; i++)); do + if curl -fsS "$url" 2>/dev/null | grep -q "#EXTM3U"; then + echo "HLS playlist is ready." + return + fi + sleep 1 + done + + echo "FAIL: HLS playlist was not generated in ${deadline}s." >&2 + echo "Last HLS response:" >&2 + curl -v "$url" 2>&1 || true + exit 1 +} + +echo "=== E2E SRT Proxy Test ===" +echo "Workspace: $WORKSPACE" +echo "Stream: $STREAM_URL" +echo "" + +# --- Pre-checks --- +if [[ ! -f "$SOURCE_FLV" ]]; then + echo "Error: test source not found: $SOURCE_FLV" >&2 + exit 1 +fi +if ! command -v curl &>/dev/null; then + echo "Error: curl not found in PATH" >&2 + exit 1 +fi + +# SRT URLs need libsrt compiled into ffmpeg/ffprobe. The default Homebrew +# ffmpeg formula does NOT include libsrt. Resolution order: +# 1. Use ffmpeg/ffprobe from PATH if they support SRT. +# 2. Otherwise, use ~/.local/bin/ffmpeg/ffprobe if previously built there. +# 3. Otherwise, build from source via setup-ffmpeg-with-whip.sh (installs +# into ~/.local/) and use the freshly built binaries. +ffmpeg_has_srt() { + local bin="$1" + [[ -x "$bin" ]] && "$bin" -hide_banner -protocols 2>/dev/null | grep -qw srt +} + +resolve_ffmpeg() { + local sys_ffmpeg sys_ffprobe local_ffmpeg local_ffprobe + sys_ffmpeg="$(command -v ffmpeg || true)" + sys_ffprobe="$(command -v ffprobe || true)" + local_ffmpeg="$HOME/.local/bin/ffmpeg" + local_ffprobe="$HOME/.local/bin/ffprobe" + + if [[ -n "$sys_ffprobe" ]] && ffmpeg_has_srt "$sys_ffmpeg"; then + FFMPEG_BIN="$sys_ffmpeg" + FFPROBE_BIN="$sys_ffprobe" + return 0 + fi + if [[ -x "$local_ffprobe" ]] && ffmpeg_has_srt "$local_ffmpeg"; then + FFMPEG_BIN="$local_ffmpeg" + FFPROBE_BIN="$local_ffprobe" + return 0 + fi + return 1 +} + +if ! resolve_ffmpeg; then + echo "No ffmpeg with SRT support found on PATH or in ~/.local/bin." + echo "Building ffmpeg from source via setup-ffmpeg-with-whip.sh — this can take several minutes." + bash "$SCRIPT_DIR/setup-ffmpeg-with-whip.sh" + FFMPEG_BIN="$HOME/.local/bin/ffmpeg" + FFPROBE_BIN="$HOME/.local/bin/ffprobe" + if ! ffmpeg_has_srt "$FFMPEG_BIN"; then + echo "Error: ffmpeg still lacks SRT support after running setup-ffmpeg-with-whip.sh." >&2 + exit 1 + fi + if [[ ! -x "$FFPROBE_BIN" ]]; then + echo "Error: ffprobe missing at $FFPROBE_BIN after running setup-ffmpeg-with-whip.sh." >&2 + exit 1 + fi +fi +echo "ffmpeg : $FFMPEG_BIN" +echo "ffprobe: $FFPROBE_BIN" + +# --- Step 0: Clean up stale state --- +rm -f "$WORKSPACE/trunk/objs/origin1.pid" +ALL_PORTS="$PROXY_RTMP_PORT $PROXY_HTTP_API_PORT $PROXY_HTTP_SERVER_PORT $PROXY_WEBRTC_PORT $PROXY_SRT_PORT $PROXY_SYSTEM_API_PORT $ORIGIN_RTMP_PORT $ORIGIN_HTTP_PORT $ORIGIN_API_PORT $ORIGIN_RTC_PORT $ORIGIN_SRT_PORT" +for port in $ALL_PORTS; do + lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true +done +sleep 1 + +# --- Step 1: Build proxy --- +echo "=== Step 1: Building proxy ===" +cd "$WORKSPACE" +make -s 2>&1 +echo "Proxy built: $WORKSPACE/bin/srs-proxy" + +# --- Step 2: Build SRS origin (if not already built) --- +if [[ ! -f "$SRS_BINARY" ]]; then + echo "=== Step 2: Building SRS origin ===" + cd "$WORKSPACE/trunk" + ./configure && make 2>&1 | tail -3 + echo "SRS origin built: $SRS_BINARY" +else + echo "=== Step 2: SRS origin already built ===" +fi + +# --- Step 3: Start proxy --- +echo "=== Step 3: Starting proxy (SRT :$PROXY_SRT_PORT, System API :$PROXY_SYSTEM_API_PORT) ===" +cd "$WORKSPACE" +env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \ + PROXY_HTTP_API=$PROXY_HTTP_API_PORT \ + PROXY_HTTP_SERVER=$PROXY_HTTP_SERVER_PORT \ + PROXY_WEBRTC_SERVER=$PROXY_WEBRTC_PORT \ + PROXY_SRT_SERVER=$PROXY_SRT_PORT \ + PROXY_SYSTEM_API=$PROXY_SYSTEM_API_PORT \ + PROXY_LOAD_BALANCER_TYPE=memory \ + ./bin/srs-proxy >/tmp/srs-proxy-srt-e2e.log 2>&1 & +PROXY_PID=$! +echo "Proxy PID: $PROXY_PID" +sleep 1 + +if ! kill -0 "$PROXY_PID" 2>/dev/null; then + echo "Error: proxy failed to start. Logs:" >&2 + cat /tmp/srs-proxy-srt-e2e.log >&2 + exit 1 +fi +echo "Proxy started." + +# --- Step 4: Start SRS origin --- +echo "=== Step 4: Starting SRS origin ===" +ulimit -n 10000 2>/dev/null || true +cd "$WORKSPACE/trunk" +./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-srt-e2e.log 2>&1 & +ORIGIN_PID=$! +echo "SRS origin PID: $ORIGIN_PID" + +# Wait for SRS to start and register with proxy (heartbeat interval is 9s). +echo "Waiting for SRS origin to register with proxy (up to 15s)..." +sleep 12 + +if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then + echo "Error: SRS origin failed to start. Logs:" >&2 + cat /tmp/srs-origin-srt-e2e.log >&2 + exit 1 +fi +echo "SRS origin started and registered." + +# --- Step 5: Publish SRT stream --- +echo "=== Step 5: Publishing SRT stream to proxy ===" +echo "Publish URL: $SRT_PUBLISH_URL" +"$FFMPEG_BIN" -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f mpegts \ + "$SRT_PUBLISH_URL" >/tmp/srs-ffmpeg-srt-e2e.log 2>&1 & +FFMPEG_PID=$! +echo "FFmpeg publisher PID: $FFMPEG_PID" + +# Wait for the SRT handshake, the proxy<->backend bridge, and the +# origin's srt_to_rtmp pipeline to spin up the derived streams. +sleep 5 + +if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then + echo "Error: FFmpeg publisher failed. Logs:" >&2 + cat /tmp/srs-ffmpeg-srt-e2e.log >&2 + exit 1 +fi +echo "Stream publishing." + +# --- Step 6: Verify SRT playback (passthrough) --- +echo "=== Step 6: Verifying SRT playback via proxy ===" +probe_has_audio_video "SRT" "$SRT_PLAY_URL" + +# --- Step 7: Verify RTMP playback (srt_to_rtmp) --- +echo "=== Step 7: Verifying RTMP playback via proxy ===" +probe_has_audio_video "RTMP" "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" + +# --- Step 8: Verify HTTP-FLV playback --- +echo "=== Step 8: Verifying HTTP-FLV playback via proxy ===" +probe_has_audio_video "HTTP-FLV" "http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.flv" + +# --- Step 9: Verify HLS playback --- +echo "=== Step 9: Verifying HLS playback via proxy ===" +HLS_URL="http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.m3u8" +wait_for_hls_playlist "$HLS_URL" +probe_has_audio_video "HLS" "$HLS_URL" + +# --- Step 10: WebRTC WHEP playback (placeholder) --- +echo "=== Step 10: WebRTC WHEP playback (placeholder) ===" +echo "SKIP: WebRTC WHEP playback is not verified by this script." +echo " The origin has rtmp_to_rtc enabled, so SRT->RTMP->RTC should work end-to-end," +echo " but actual playback verification is intentionally left as a TODO here." + +echo "" +echo "=== E2E SRT Proxy Test PASSED ===" diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-test.sh b/skills/srs-develop/scripts/proxy-e2e-test.sh similarity index 86% rename from .openclaw/skills/srs-develop/scripts/proxy-e2e-test.sh rename to skills/srs-develop/scripts/proxy-e2e-test.sh index 093a731e5..294ca0cf8 100755 --- a/.openclaw/skills/srs-develop/scripts/proxy-e2e-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-test.sh @@ -4,11 +4,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi @@ -24,6 +29,10 @@ PROXY_SYSTEM_API_PORT=12025 SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" SRS_BINARY="$WORKSPACE/trunk/objs/srs" ORIGIN_CONF="$WORKSPACE/trunk/conf/origin1-for-proxy.conf" +# Randomize per run so each invocation starts from clean origin state (HLS +# segments, RTMP source, proxy stream registry) and never shares state with +# sibling E2E tests that publish to "live/livestream". +STREAM_URL="live/rtmp$(date +%s)" # PIDs to clean up on exit. PROXY_PID="" @@ -143,7 +152,7 @@ echo "SRS origin started and registered." # --- Step 5: Publish RTMP stream --- echo "=== Step 5: Publishing RTMP stream to proxy ===" ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \ - "rtmp://localhost:$PROXY_RTMP_PORT/live/livestream" >/tmp/srs-ffmpeg-e2e.log 2>&1 & + "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" >/tmp/srs-ffmpeg-e2e.log 2>&1 & FFMPEG_PID=$! echo "FFmpeg publisher PID: $FFMPEG_PID" @@ -160,7 +169,7 @@ echo "Stream publishing." # --- Step 6: Verify RTMP playback --- echo "=== Step 6: Verifying RTMP playback via proxy ===" PROBE_OUTPUT=$(ffprobe -v error -show_streams \ - "rtmp://localhost:$PROXY_RTMP_PORT/live/livestream" 2>&1 || true) + "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" 2>&1 || true) if echo "$PROBE_OUTPUT" | grep -q "codec_type=video"; then echo "PASS: Video stream detected." diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh b/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh similarity index 72% rename from .openclaw/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh rename to skills/srs-develop/scripts/proxy-e2e-transmux-test.sh index 523284e98..df38e4489 100755 --- a/.openclaw/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh @@ -1,16 +1,21 @@ #!/bin/bash # E2E test for RTMP-to-multiple-protocol transmuxing through the proxy: # starts one proxy with memory load balancer + one SRS origin, publishes one -# RTMP stream, then verifies RTMP, HTTP-FLV, HLS, and optional WebRTC WHEP -# playback through the proxy. +# RTMP stream, then verifies RTMP, HTTP-FLV, and HLS playback through the +# proxy. WebRTC WHEP verification is intentionally a placeholder (not run). set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi @@ -32,12 +37,10 @@ ORIGIN_SRT_PORT=10081 SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" SRS_BINARY="$WORKSPACE/trunk/objs/srs" -SRS_TEST_BINARY="$WORKSPACE/trunk/3rdparty/srs-bench/objs/srs_test" -STREAM_URL="live/livestream" - -# WebRTC requires the srs-bench regression test binary. Keep it enabled by -# default because it exercises the proxy WHEP API path; set to "off" to skip. -PROXY_TRANSMUX_TEST_RTC="${PROXY_TRANSMUX_TEST_RTC:-on}" +# Randomize per run so each invocation starts from clean origin state (HLS +# segments, RTMP source, proxy stream registry) and never shares state with +# sibling E2E tests that publish to "live/livestream". +STREAM_URL="live/transmux$(date +%s)" # PIDs to clean up on exit. PROXY_PID="" @@ -155,18 +158,8 @@ else echo "=== Step 2: SRS origin already built ===" fi -# --- Step 3: Build WebRTC regression tool (if enabled and needed) --- -if [[ "$PROXY_TRANSMUX_TEST_RTC" == "on" && ! -x "$SRS_TEST_BINARY" ]]; then - echo "=== Step 3: Building WebRTC regression tool ===" - cd "$WORKSPACE/trunk/3rdparty/srs-bench" - make ./objs/srs_test - echo "WebRTC regression tool built: $SRS_TEST_BINARY" -else - echo "=== Step 3: WebRTC regression tool build skipped ===" -fi - -# --- Step 4: Start proxy --- -echo "=== Step 4: Starting proxy (memory LB) ===" +# --- Step 3: Start proxy --- +echo "=== Step 3: Starting proxy (memory LB) ===" cd "$WORKSPACE" env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \ PROXY_HTTP_API=$PROXY_HTTP_API_PORT \ @@ -187,8 +180,8 @@ if ! kill -0 "$PROXY_PID" 2>/dev/null; then fi echo "Proxy started." -# --- Step 5: Start SRS origin --- -echo "=== Step 5: Starting SRS origin ===" +# --- Step 4: Start SRS origin --- +echo "=== Step 4: Starting SRS origin ===" ulimit -n 10000 2>/dev/null || true cd "$WORKSPACE/trunk" ./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-transmux-e2e.log 2>&1 & @@ -206,8 +199,8 @@ if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then fi echo "SRS origin started and registered." -# --- Step 6: Publish RTMP stream --- -echo "=== Step 6: Publishing RTMP stream to proxy ===" +# --- Step 5: Publish RTMP stream --- +echo "=== Step 5: Publishing RTMP stream to proxy ===" ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \ "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" >/tmp/srs-ffmpeg-transmux-e2e.log 2>&1 & FFMPEG_PID=$! @@ -223,38 +216,25 @@ if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then fi echo "Stream publishing." -# --- Step 7: Verify RTMP playback --- -echo "=== Step 7: Verifying RTMP playback via proxy ===" +# --- Step 6: Verify RTMP playback --- +echo "=== Step 6: Verifying RTMP playback via proxy ===" probe_has_audio_video "RTMP" "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" -# --- Step 8: Verify HTTP-FLV playback --- -echo "=== Step 8: Verifying HTTP-FLV playback via proxy ===" +# --- Step 7: Verify HTTP-FLV playback --- +echo "=== Step 7: Verifying HTTP-FLV playback via proxy ===" probe_has_audio_video "HTTP-FLV" "http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.flv" -# --- Step 9: Verify HLS playback --- -echo "=== Step 9: Verifying HLS playback via proxy ===" +# --- Step 8: Verify HLS playback --- +echo "=== Step 8: Verifying HLS playback via proxy ===" HLS_URL="http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.m3u8" wait_for_hls_playlist "$HLS_URL" probe_has_audio_video "HLS" "$HLS_URL" -# --- Step 10: Verify WebRTC WHEP signaling via proxy --- -echo "=== Step 10: Verifying WebRTC WHEP signaling via proxy ===" -if [[ "$PROXY_TRANSMUX_TEST_RTC" == "on" ]]; then - if [[ ! -x "$SRS_TEST_BINARY" ]]; then - echo "FAIL: WebRTC regression tool not found: $SRS_TEST_BINARY" >&2 - exit 1 - fi - - cd "$WORKSPACE/trunk/3rdparty/srs-bench" - "$SRS_TEST_BINARY" \ - -test.run '^TestBugfix2371_RTMP2RTC_PlayWithNack$' \ - -srs-server "127.0.0.1:$PROXY_HTTP_API_PORT" \ - -srs-stream "/$STREAM_URL" \ - -srs-timeout 10000 - echo "PASS: WebRTC WHEP signaling succeeded." -else - echo "SKIP: WebRTC WHEP test disabled by PROXY_TRANSMUX_TEST_RTC=$PROXY_TRANSMUX_TEST_RTC." -fi +# --- Step 9: WebRTC WHEP playback (placeholder) --- +echo "=== Step 9: WebRTC WHEP playback (placeholder) ===" +echo "SKIP: WebRTC WHEP playback is not verified by this script." +echo " The origin has rtmp_to_rtc enabled, so RTMP->RTC should work end-to-end," +echo " but actual playback verification is intentionally left as a TODO here." echo "" echo "NOTE: RTSP is not tested here because the Go proxy currently has no RTSP listener." diff --git a/skills/srs-develop/scripts/proxy-e2e-whip-test.sh b/skills/srs-develop/scripts/proxy-e2e-whip-test.sh new file mode 100755 index 000000000..f0961c641 --- /dev/null +++ b/skills/srs-develop/scripts/proxy-e2e-whip-test.sh @@ -0,0 +1,350 @@ +#!/bin/bash +# E2E test for WHIP proxy: starts proxy + SRS origin, publishes a WebRTC +# stream via WHIP through the proxy, then verifies playback through the proxy +# in every form the origin can transmux from WebRTC (rtc_to_rtmp + http_remux +# + hls): +# - RTMP play (via rtc_to_rtmp on origin) +# - HTTP-FLV (HTTP remux of the bridged RTMP) +# - HLS (m3u8 + TS segments) +# - WebRTC WHEP (placeholder only, not actually verified here) +set -e + +SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done + +if [[ ! -f "$WORKSPACE/go.mod" ]]; then + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 + exit 1 +fi + +# Proxy ports — same layout as proxy-e2e-srt-test.sh / proxy-e2e-transmux-test.sh. +PROXY_RTMP_PORT=11935 +PROXY_HTTP_API_PORT=11985 +PROXY_HTTP_SERVER_PORT=18080 +PROXY_WEBRTC_PORT=18000 +PROXY_SRT_PORT=20080 +PROXY_SYSTEM_API_PORT=12025 + +# Origin ports (from origin1-for-proxy.conf). +ORIGIN_RTMP_PORT=19351 +ORIGIN_HTTP_PORT=8081 +ORIGIN_API_PORT=19851 +ORIGIN_RTC_PORT=8001 +ORIGIN_SRT_PORT=10081 + +SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" +SRS_BINARY="$WORKSPACE/trunk/objs/srs" +# Randomize the stream name per run so each test starts from a clean origin +# state (HLS segments, RTMP source, proxy stream registry) and never shares +# state with sibling E2E tests that publish to "live/livestream". +STREAM_NAME="whip$(date +%s)" +STREAM_URL="live/$STREAM_NAME" + +# WHIP endpoint exposed by the proxy. The proxy parses ?app=&stream= via +# utils.ConvertURLToStreamURL, then forwards the SDP exchange to the backend +# SRS origin. @see internal/proxy/api.go and internal/proxy/rtc.go. +WHIP_PUBLISH_URL="http://localhost:$PROXY_HTTP_API_PORT/rtc/v1/whip/?app=live&stream=$STREAM_NAME" + +# Make the SRS origin advertise a host candidate that loops back through the +# proxy. The proxy rewrites only the port in the SDP answer (origin RTC port +# -> proxy WebRTC port), so the candidate IP must already be reachable for +# the publisher; 127.0.0.1 works for all-local E2E. +ORIGIN_CANDIDATE="127.0.0.1" + +# PIDs to clean up on exit. +PROXY_PID="" +ORIGIN_PID="" +FFMPEG_PID="" + +cleanup() { + echo "" + echo "=== Cleaning up ===" + for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + done + echo "Cleanup done." +} +trap cleanup EXIT + +probe_has_audio_video() { + local name="$1" + local url="$2" + + echo "Verifying $name playback: $url" + local output + output=$("$FFPROBE_BIN" -v error -show_streams "$url" 2>&1 || true) + + if echo "$output" | grep -q "codec_type=video"; then + echo "PASS: $name video stream detected." + else + echo "FAIL: $name no video stream detected." >&2 + echo "ffprobe output:" >&2 + echo "$output" >&2 + exit 1 + fi + + if echo "$output" | grep -q "codec_type=audio"; then + echo "PASS: $name audio stream detected." + else + echo "FAIL: $name no audio stream detected." >&2 + echo "ffprobe output:" >&2 + echo "$output" >&2 + exit 1 + fi +} + +wait_for_hls_playlist() { + local url="$1" + local deadline=60 + + echo "Waiting for HLS playlist to be generated (up to ${deadline}s): $url" + for ((i = 1; i <= deadline; i++)); do + if curl -fsS "$url" 2>/dev/null | grep -q "#EXTM3U"; then + echo "HLS playlist is ready." + return + fi + sleep 1 + done + + echo "FAIL: HLS playlist was not generated in ${deadline}s." >&2 + echo "Last HLS response:" >&2 + curl -v "$url" 2>&1 || true + exit 1 +} + +first_hls_segment() { + local url="$1" + + curl -fsS "$url" 2>/dev/null | awk ' + /^[[:space:]]*$/ { next } + /^#/ { next } + { print; exit } + ' +} + +wait_for_hls_to_skip_first_segment() { + local url="$1" + local deadline=60 + local first_segment current_segment output + + first_segment="$(first_hls_segment "$url")" + if [[ -z "$first_segment" ]]; then + echo "FAIL: HLS playlist has no media segment: $url" >&2 + curl -fsS "$url" 2>&1 || true + exit 1 + fi + + echo "Waiting for HLS to skip the first possibly incomplete segment (up to ${deadline}s): $first_segment" + for ((i = 1; i <= deadline; i++)); do + current_segment="$(first_hls_segment "$url")" + if [[ -n "$current_segment" && "$current_segment" != "$first_segment" ]]; then + output=$("$FFPROBE_BIN" -v error -show_streams "$url" 2>&1 || true) + if echo "$output" | grep -q "codec_type=video" && echo "$output" | grep -q "codec_type=audio"; then + echo "HLS first segment advanced and audio/video is ready: $current_segment" + return + fi + fi + sleep 1 + done + + echo "FAIL: HLS did not skip the first segment and expose audio/video in ${deadline}s." >&2 + echo "Last HLS response:" >&2 + curl -fsS "$url" 2>&1 || true + echo "Last ffprobe output:" >&2 + echo "$output" >&2 + exit 1 +} + +echo "=== E2E WHIP Proxy Test ===" +echo "Workspace: $WORKSPACE" +echo "Stream: $STREAM_URL" +echo "" + +# --- Pre-checks --- +if [[ ! -f "$SOURCE_FLV" ]]; then + echo "Error: test source not found: $SOURCE_FLV" >&2 + exit 1 +fi +if ! command -v curl &>/dev/null; then + echo "Error: curl not found in PATH" >&2 + exit 1 +fi + +# WHIP needs an ffmpeg with the `whip` muxer (added in ffmpeg 7.1, requires +# --enable-openssl at build time for DTLS-SRTP). Neither vanilla brew nor the +# homebrew-ffmpeg tap enable it. Resolution order: +# 1. Use ffmpeg/ffprobe from PATH if they include the whip muxer. +# 2. Otherwise, use ~/.local/bin/ffmpeg/ffprobe if previously built there. +# 3. Otherwise, build from source via setup-ffmpeg-with-whip.sh (installs +# into ~/.local/) and use the freshly built binaries. +ffmpeg_has_whip() { + local bin="$1" + [[ -x "$bin" ]] && "$bin" -hide_banner -muxers 2>/dev/null | grep -qw whip +} + +resolve_ffmpeg() { + local sys_ffmpeg sys_ffprobe local_ffmpeg local_ffprobe + sys_ffmpeg="$(command -v ffmpeg || true)" + sys_ffprobe="$(command -v ffprobe || true)" + local_ffmpeg="$HOME/.local/bin/ffmpeg" + local_ffprobe="$HOME/.local/bin/ffprobe" + + if [[ -n "$sys_ffprobe" ]] && ffmpeg_has_whip "$sys_ffmpeg"; then + FFMPEG_BIN="$sys_ffmpeg" + FFPROBE_BIN="$sys_ffprobe" + return 0 + fi + if [[ -x "$local_ffprobe" ]] && ffmpeg_has_whip "$local_ffmpeg"; then + FFMPEG_BIN="$local_ffmpeg" + FFPROBE_BIN="$local_ffprobe" + return 0 + fi + return 1 +} + +if ! resolve_ffmpeg; then + echo "No ffmpeg with WHIP muxer found on PATH or in ~/.local/bin." + echo "Building ffmpeg from source via setup-ffmpeg-with-whip.sh — this can take several minutes." + bash "$SCRIPT_DIR/setup-ffmpeg-with-whip.sh" + FFMPEG_BIN="$HOME/.local/bin/ffmpeg" + FFPROBE_BIN="$HOME/.local/bin/ffprobe" + if ! ffmpeg_has_whip "$FFMPEG_BIN"; then + echo "Error: ffmpeg still lacks WHIP muxer after running setup-ffmpeg-with-whip.sh." >&2 + exit 1 + fi + if [[ ! -x "$FFPROBE_BIN" ]]; then + echo "Error: ffprobe missing at $FFPROBE_BIN after running setup-ffmpeg-with-whip.sh." >&2 + exit 1 + fi +fi +echo "ffmpeg : $FFMPEG_BIN" +echo "ffprobe: $FFPROBE_BIN" + +# --- Step 0: Clean up stale state --- +rm -f "$WORKSPACE/trunk/objs/origin1.pid" +ALL_PORTS="$PROXY_RTMP_PORT $PROXY_HTTP_API_PORT $PROXY_HTTP_SERVER_PORT $PROXY_WEBRTC_PORT $PROXY_SRT_PORT $PROXY_SYSTEM_API_PORT $ORIGIN_RTMP_PORT $ORIGIN_HTTP_PORT $ORIGIN_API_PORT $ORIGIN_RTC_PORT $ORIGIN_SRT_PORT" +for port in $ALL_PORTS; do + lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true +done +sleep 1 + +# --- Step 1: Build proxy --- +echo "=== Step 1: Building proxy ===" +cd "$WORKSPACE" +make -s 2>&1 +echo "Proxy built: $WORKSPACE/bin/srs-proxy" + +# --- Step 2: Build SRS origin (if not already built) --- +if [[ ! -f "$SRS_BINARY" ]]; then + echo "=== Step 2: Building SRS origin ===" + cd "$WORKSPACE/trunk" + ./configure && make 2>&1 | tail -3 + echo "SRS origin built: $SRS_BINARY" +else + echo "=== Step 2: SRS origin already built ===" +fi + +# --- Step 3: Start proxy --- +echo "=== Step 3: Starting proxy (HTTP API :$PROXY_HTTP_API_PORT, WebRTC :$PROXY_WEBRTC_PORT) ===" +cd "$WORKSPACE" +env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \ + PROXY_HTTP_API=$PROXY_HTTP_API_PORT \ + PROXY_HTTP_SERVER=$PROXY_HTTP_SERVER_PORT \ + PROXY_WEBRTC_SERVER=$PROXY_WEBRTC_PORT \ + PROXY_SRT_SERVER=$PROXY_SRT_PORT \ + PROXY_SYSTEM_API=$PROXY_SYSTEM_API_PORT \ + PROXY_LOAD_BALANCER_TYPE=memory \ + ./bin/srs-proxy >/tmp/srs-proxy-whip-e2e.log 2>&1 & +PROXY_PID=$! +echo "Proxy PID: $PROXY_PID" +sleep 1 + +if ! kill -0 "$PROXY_PID" 2>/dev/null; then + echo "Error: proxy failed to start. Logs:" >&2 + cat /tmp/srs-proxy-whip-e2e.log >&2 + exit 1 +fi +echo "Proxy started." + +# --- Step 4: Start SRS origin (with CANDIDATE=$ORIGIN_CANDIDATE for WebRTC) --- +echo "=== Step 4: Starting SRS origin (CANDIDATE=$ORIGIN_CANDIDATE) ===" +ulimit -n 10000 2>/dev/null || true +cd "$WORKSPACE/trunk" +env CANDIDATE="$ORIGIN_CANDIDATE" \ + ./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-whip-e2e.log 2>&1 & +ORIGIN_PID=$! +echo "SRS origin PID: $ORIGIN_PID" + +# Wait for SRS to start and register with proxy (heartbeat interval is 9s). +echo "Waiting for SRS origin to register with proxy (up to 15s)..." +sleep 12 + +if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then + echo "Error: SRS origin failed to start. Logs:" >&2 + cat /tmp/srs-origin-whip-e2e.log >&2 + exit 1 +fi +echo "SRS origin started and registered." + +# --- Step 5: Publish WHIP stream --- +# WebRTC requires H.264 (baseline-friendly) + Opus. source.flv is H.264 High +# profile + AAC, so transcode video to baseline and audio to Opus. Use +# zerolatency/ultrafast so the encoder keeps up with -re. +echo "=== Step 5: Publishing WHIP stream to proxy ===" +echo "Publish URL: $WHIP_PUBLISH_URL" +"$FFMPEG_BIN" -stream_loop -1 -re -i "$SOURCE_FLV" \ + -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p \ + -tune zerolatency -preset ultrafast \ + -c:a libopus -ar 48000 -ac 2 \ + -f whip "$WHIP_PUBLISH_URL" >/tmp/srs-ffmpeg-whip-e2e.log 2>&1 & +FFMPEG_PID=$! +echo "FFmpeg publisher PID: $FFMPEG_PID" + +# Wait for WHIP SDP exchange + DTLS-SRTP handshake + the origin's +# rtc_to_rtmp pipeline to spin up the bridged RTMP stream. +sleep 8 + +if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then + echo "Error: FFmpeg WHIP publisher failed. Logs:" >&2 + cat /tmp/srs-ffmpeg-whip-e2e.log >&2 + exit 1 +fi +echo "Stream publishing." + +# --- Step 6: Verify RTMP playback (rtc_to_rtmp) --- +echo "=== Step 6: Verifying RTMP playback via proxy ===" +probe_has_audio_video "RTMP" "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" + +# --- Step 7: Verify HTTP-FLV playback --- +echo "=== Step 7: Verifying HTTP-FLV playback via proxy ===" +probe_has_audio_video "HTTP-FLV" "http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.flv" + +# --- Step 8: Verify HLS playback --- +echo "=== Step 8: Verifying HLS playback via proxy ===" +HLS_URL="http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.m3u8" +wait_for_hls_playlist "$HLS_URL" +wait_for_hls_to_skip_first_segment "$HLS_URL" +probe_has_audio_video "HLS" "$HLS_URL" + +# --- Step 9: WebRTC WHEP playback (placeholder) --- +echo "=== Step 9: WebRTC WHEP playback (placeholder) ===" +echo "SKIP: WebRTC WHEP playback is not verified by this script." +echo " The origin has rtmp_to_rtc enabled, so WHIP->RTMP->RTC should work end-to-end," +echo " but actual playback verification is intentionally left as a TODO here." + +echo "" +echo "=== E2E WHIP Proxy Test PASSED ===" diff --git a/.openclaw/skills/srs-develop/scripts/proxy-utest.sh b/skills/srs-develop/scripts/proxy-utest.sh similarity index 75% rename from .openclaw/skills/srs-develop/scripts/proxy-utest.sh rename to skills/srs-develop/scripts/proxy-utest.sh index 9e030d99f..52b6f590c 100755 --- a/.openclaw/skills/srs-develop/scripts/proxy-utest.sh +++ b/skills/srs-develop/scripts/proxy-utest.sh @@ -27,11 +27,16 @@ for arg in "$@"; do done SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/setup-ffmpeg-with-whip.sh b/skills/srs-develop/scripts/setup-ffmpeg-with-whip.sh new file mode 100755 index 000000000..bd50eff14 --- /dev/null +++ b/skills/srs-develop/scripts/setup-ffmpeg-with-whip.sh @@ -0,0 +1,232 @@ +#!/bin/bash +# Build ffmpeg from source with the codecs and protocols the SRS proxy E2E +# tests need — in particular WHIP (WebRTC-HTTP Ingestion Protocol), which the +# default Homebrew formulas (vanilla + homebrew-ffmpeg tap) do not enable. +# +# Modelled on the ossrs/dev-docker ubuntu20 base images: +# https://github.com/ossrs/dev-docker/blob/ubuntu20/Dockerfile.base +# https://github.com/ossrs/dev-docker/blob/ubuntu20/Dockerfile.base2 +# https://github.com/ossrs/dev-docker/blob/ubuntu20/Dockerfile.base3 +# The Dockerfiles describe *which* libraries and configure flags to enable; +# on macOS we install those deps via Homebrew (shared libs) instead of +# rebuilding each one from tarballs. ffmpeg itself is built from source so we +# can pass --enable-libsrtp / --enable-openssl, which neither Homebrew formula +# turns on. +# +# Output: +# ~/.local/src/ffmpeg (git clone) +# ~/.local/bin/{ffmpeg,ffprobe,ffplay} +# ~/.local/lib, ~/.local/share, ... (ffmpeg --prefix tree) +# +# Re-running is safe: deps already installed are skipped, the git clone is +# reused (fetch + checkout), and ffmpeg is rebuilt incrementally. +set -e + +FFMPEG_TAG="${FFMPEG_TAG:-n8.1.1}" +PREFIX="${PREFIX:-$HOME/.local}" +SRC_DIR="${SRC_DIR:-$HOME/.local/src/ffmpeg}" +JOBS="${JOBS:-$(sysctl -n hw.ncpu 2>/dev/null || echo 4)}" + +REQUIRED_BREW_PKGS=( + # build toolchain (Homebrew renamed pkg-config → pkgconf in 2024) + pkgconf nasm + # WHIP needs DTLS (openssl). ffmpeg's WHIP muxer uses ffmpeg's *internal* + # SRTP implementation (srtp_protocol_select="rtp_protocol srtp" in configure) + # — no external libsrtp dependency, so no extra brew package here. + openssl@3 + # SRT + srt + # video codecs (matches Dockerfile.base / .base2) + x264 x265 libvpx + # audio codecs (fdk-aac requires --enable-nonfree below) + fdk-aac lame opus + # subtitle / font stack (matches Dockerfile.base3) + freetype fontconfig harfbuzz fribidi libass +) + +log() { printf '\n=== %s ===\n' "$*"; } + +# --- Pre-checks --------------------------------------------------------------- + +if [[ "$(uname -s)" != "Darwin" ]]; then + echo "Error: this script targets macOS. For Linux, follow Dockerfile.base/2/3 directly." >&2 + exit 1 +fi + +if [[ "$(id -u)" -eq 0 ]]; then + echo "Error: do not run as root. Homebrew refuses, and the prefix is your \$HOME." >&2 + exit 1 +fi + +if ! command -v brew &>/dev/null; then + echo "Error: Homebrew not found. Install from https://brew.sh and retry." >&2 + exit 1 +fi + +if ! command -v git &>/dev/null; then + echo "Error: git not found in PATH." >&2 + exit 1 +fi + +log "Configuration" +echo "FFmpeg tag : $FFMPEG_TAG" +echo "Prefix : $PREFIX" +echo "Source dir : $SRC_DIR" +echo "Parallel : $JOBS jobs" + +# --- Step 1: Install Homebrew deps ------------------------------------------- + +log "Step 1: Installing Homebrew dependencies" +INSTALLED="$(brew list --formula 2>/dev/null || true)" +TO_INSTALL=() +for pkg in "${REQUIRED_BREW_PKGS[@]}"; do + if echo "$INSTALLED" | grep -qx "$pkg"; then + echo " ok $pkg" + else + echo " miss $pkg" + TO_INSTALL+=("$pkg") + fi +done + +if [[ ${#TO_INSTALL[@]} -gt 0 ]]; then + echo "" + echo "Installing missing packages: ${TO_INSTALL[*]}" + brew install "${TO_INSTALL[@]}" +else + echo "" + echo "All required Homebrew packages already installed." +fi + +# --- Step 2: Clone or refresh ffmpeg source ---------------------------------- + +log "Step 2: Fetching ffmpeg source at $FFMPEG_TAG" +mkdir -p "$(dirname "$SRC_DIR")" +if [[ ! -d "$SRC_DIR/.git" ]]; then + # Shallow clone of just the tag we want — full history is ~600 MB and + # takes minutes to index; this drops to ~80 MB and seconds. + git clone --depth 1 --branch "$FFMPEG_TAG" \ + https://github.com/FFmpeg/FFmpeg.git "$SRC_DIR" +else + cd "$SRC_DIR" + git fetch --depth 1 --tags --quiet origin "$FFMPEG_TAG" +fi +cd "$SRC_DIR" +git checkout --quiet "$FFMPEG_TAG" +echo "Checked out: $(git describe --tags --always)" + +# --- Step 3: Configure -------------------------------------------------------- + +log "Step 3: Configuring ffmpeg" + +# openssl@3 is keg-only in Homebrew, so its .pc files are not on the default +# PKG_CONFIG_PATH. Other deps (lame in particular) ship no .pc at all, so +# also pass --extra-cflags / --extra-ldflags pointing at the Homebrew prefix +# (works for both Apple Silicon /opt/homebrew and Intel /usr/local). +BREW_PREFIX="$(brew --prefix)" +OPENSSL_PREFIX="$(brew --prefix openssl@3)" +export PKG_CONFIG_PATH="$OPENSSL_PREFIX/lib/pkgconfig:${PKG_CONFIG_PATH:-}" +echo "PKG_CONFIG_PATH=$PKG_CONFIG_PATH" + +# --enable-nonfree : required by --enable-libfdk-aac +# --enable-gpl : required by --enable-libx264 / --enable-libx265 / --enable-libass +# --enable-version3 : allow (L)GPLv3 components alongside GPL/nonfree +# --enable-openssl : DTLS backend; the WHIP muxer needs this for the +# DTLS-SRTP handshake. ffmpeg's WHIP muxer uses ffmpeg's +# *internal* SRTP (libavformat/srtp.c), not external +# libsrtp — so no --enable-libsrtp flag exists/is needed. +# --enable-libsrt : SRT protocol (publish/play) +./configure \ + --prefix="$PREFIX" \ + --extra-cflags="-I$BREW_PREFIX/include" \ + --extra-ldflags="-L$BREW_PREFIX/lib" \ + --enable-gpl \ + --enable-nonfree \ + --enable-version3 \ + --enable-openssl \ + --enable-libsrt \ + --enable-libx264 \ + --enable-libx265 \ + --enable-libvpx \ + --enable-libfdk-aac \ + --enable-libmp3lame \ + --enable-libopus \ + --enable-libass \ + --enable-libfreetype \ + --enable-libfontconfig \ + --enable-libharfbuzz \ + --enable-libfribidi \ + --enable-videotoolbox \ + --enable-audiotoolbox \ + --disable-debug + +# --- Step 4: Build and install ----------------------------------------------- + +log "Step 4: Building ffmpeg ($JOBS jobs)" +make -j"$JOBS" + +log "Step 5: Installing to $PREFIX" +make install + +# --- Step 5: Verify ----------------------------------------------------------- + +log "Step 6: Verifying installed binary" +FFMPEG_BIN="$PREFIX/bin/ffmpeg" +if [[ ! -x "$FFMPEG_BIN" ]]; then + echo "Error: $FFMPEG_BIN missing after install." >&2 + exit 1 +fi + +echo "" +"$FFMPEG_BIN" -version | head -2 +echo "" + +if "$FFMPEG_BIN" -hide_banner -muxers 2>/dev/null | grep -qw whip; then + echo "PASS: WHIP muxer is available." +else + echo "FAIL: WHIP muxer is NOT in the installed ffmpeg." >&2 + exit 1 +fi + +if "$FFMPEG_BIN" -hide_banner -protocols 2>/dev/null | grep -qw srt; then + echo "PASS: SRT protocol is available." +else + echo "FAIL: SRT protocol is NOT in the installed ffmpeg." >&2 + exit 1 +fi + +if "$FFMPEG_BIN" -hide_banner -codecs 2>/dev/null | grep -E "libx264|libx265" | grep -q libx264; then + echo "PASS: libx264 encoder is available." +else + echo "FAIL: libx264 encoder missing." >&2 + exit 1 +fi + +if "$FFMPEG_BIN" -hide_banner -codecs 2>/dev/null | grep -q libx265; then + echo "PASS: libx265 encoder is available." +else + echo "FAIL: libx265 encoder missing." >&2 + exit 1 +fi + +echo "" +log "Done" +echo "Binary: $FFMPEG_BIN" +echo "Version: $("$FFMPEG_BIN" -version | head -1)" +echo "" + +# Warn if PATH won't pick up the new binary. +if ! echo ":$PATH:" | grep -q ":$PREFIX/bin:"; then + echo "NOTE: $PREFIX/bin is not on your PATH." + echo " Add this to ~/.zshrc (or your shell rc):" + echo " export PATH=\"\$HOME/.local/bin:\$PATH\"" + echo " Then 'which ffmpeg' should resolve to $FFMPEG_BIN." +else + RESOLVED="$(command -v ffmpeg || true)" + if [[ "$RESOLVED" == "$FFMPEG_BIN" ]]; then + echo "PATH check: 'ffmpeg' resolves to $RESOLVED ✓" + else + echo "NOTE: $PREFIX/bin is on PATH but 'ffmpeg' currently resolves to:" + echo " $RESOLVED" + echo " Reorder PATH so $PREFIX/bin comes before $(dirname "$RESOLVED")." + fi +fi diff --git a/.openclaw/skills/srs-support/.gitignore b/skills/srs-support/.gitignore similarity index 100% rename from .openclaw/skills/srs-support/.gitignore rename to skills/srs-support/.gitignore diff --git a/.openclaw/skills/srs-support/SKILL.md b/skills/srs-support/SKILL.md similarity index 100% rename from .openclaw/skills/srs-support/SKILL.md rename to skills/srs-support/SKILL.md diff --git a/.openclaw/skills/srs-support/evals/evals.json b/skills/srs-support/evals/evals.json similarity index 100% rename from .openclaw/skills/srs-support/evals/evals.json rename to skills/srs-support/evals/evals.json diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index ae3c9bc0c..b8f383319 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2026-05-17, Merge [#4675](https://github.com/ossrs/srs/pull/4675): Proxy: Refactor for testability; add SRT/WHIP E2E and unit tests. v7.0.148 (#4675) * v7.0, 2026-05-02, Merge [#4672](https://github.com/ossrs/srs/pull/4672): Proxy: Refactor server APIs and expand RTMP test coverage. v7.0.147 (#4672) * v7.0, 2026-04-28, Merge [#4670](https://github.com/ossrs/srs/pull/4670): Proxy: Refine logger and environment APIs. v7.0.146 (#4670) * v7.0, 2026-04-23, Merge [#4667](https://github.com/ossrs/srs/pull/4667): Proxy: Refactor internal/errors and internal/sync, and add unit tests across internal/*. v7.0.145 (#4667) diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index 55b2d2e04..50d1acabc 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 147 +#define VERSION_REVISION 148 #endif