// Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT package lb import ( "context" "strings" "testing" "time" "srsx/internal/env/envfakes" ) // stubHLS is a minimal HLSPlayStream for testing. type stubHLS struct { spbhid string } func (s *stubHLS) GetSPBHID() string { return s.spbhid } func (s *stubHLS) Initialize(ctx context.Context) HLSPlayStream { return s } // stubRTC is a minimal RTCConnection for testing. type stubRTC struct { ufrag string } func (s *stubRTC) GetUfrag() string { return s.ufrag } // newMem returns a fresh in-memory load balancer with a default fake env. func newMem() *memoryLoadBalancer { env := &envfakes.FakeProxyEnvironment{} return NewMemoryLoadBalancer(env).(*memoryLoadBalancer) } func TestNewMemoryLoadBalancer(t *testing.T) { env := &envfakes.FakeProxyEnvironment{} lb := NewMemoryLoadBalancer(env) if lb == nil { t.Fatal("NewMemoryLoadBalancer returned nil") } } func TestMemLB_Initialize_DefaultBackendDisabled(t *testing.T) { env := &envfakes.FakeProxyEnvironment{} env.DefaultBackendEnabledReturns("off") lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) if err := lb.Initialize(context.Background()); err != nil { t.Fatalf("Initialize: %v", err) } // No server stored when disabled. count := 0 lb.servers.Range(func(string, *OriginServer) bool { count++; return true }) if count != 0 { t.Fatalf("expected 0 servers, got %d", count) } } func TestMemLB_Initialize_DefaultBackendError(t *testing.T) { env := &envfakes.FakeProxyEnvironment{} env.DefaultBackendEnabledReturns("on") env.DefaultBackendIPReturns("") // triggers "empty default backend ip" lb := NewMemoryLoadBalancer(env) err := lb.Initialize(context.Background()) if err == nil || !strings.Contains(err.Error(), "initialize default SRS") { t.Fatalf("expected wrapped error, got %v", err) } } func TestMemLB_Initialize_KeepaliveTick(t *testing.T) { env := &envfakes.FakeProxyEnvironment{} env.DefaultBackendEnabledReturns("on") env.DefaultBackendIPReturns("1.2.3.4") env.DefaultBackendRTMPReturns(":1935") ctx, cancel := context.WithCancel(context.Background()) defer cancel() lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) // Shorten the keep-alive interval on this instance only so concurrent // tests don't race on shared state. lb.keepaliveInterval = time.Millisecond if err := lb.Initialize(ctx); err != nil { t.Fatalf("Initialize: %v", err) } // Find the server and watch UpdatedAt advance after a keep-alive tick. var s *OriginServer lb.servers.Range(func(_ string, v *OriginServer) bool { s = v; return false }) if s == nil { t.Fatal("expected server stored") } first := s.UpdatedAt // Wait long enough for several ticks (interval is 1ms, server.UpdatedAt // is set to time.Now() inside NewDefaultOriginServerForDebugging on each // Update? — actually Update only stores the server pointer, so UpdatedAt // won't change. The goroutine still hits the tick branch though, which // is all we need for coverage). time.Sleep(20 * time.Millisecond) cancel() time.Sleep(10 * time.Millisecond) _ = first } func TestMemLB_Initialize_DefaultBackendSuccess(t *testing.T) { env := &envfakes.FakeProxyEnvironment{} env.DefaultBackendEnabledReturns("on") env.DefaultBackendIPReturns("1.2.3.4") env.DefaultBackendRTMPReturns(":1935") ctx, cancel := context.WithCancel(context.Background()) defer cancel() lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) if err := lb.Initialize(ctx); err != nil { t.Fatalf("Initialize: %v", err) } count := 0 lb.servers.Range(func(string, *OriginServer) bool { count++; return true }) if count != 1 { t.Fatalf("expected 1 server stored, got %d", count) } // Cancel and give the keep-alive goroutine a moment to exit cleanly. cancel() time.Sleep(20 * time.Millisecond) } func TestMemLB_Update(t *testing.T) { lb := newMem() s := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1"} if err := lb.Update(context.Background(), s); err != nil { t.Fatalf("Update: %v", err) } got, ok := lb.servers.Load(s.ID()) if !ok || got != s { t.Fatalf("Update did not store the server: got=%v ok=%v", got, ok) } } func TestMemLB_Pick_NoServers(t *testing.T) { lb := newMem() _, err := lb.Pick(context.Background(), "url1") if err == nil || !strings.Contains(err.Error(), "no server available") { t.Fatalf("expected no-server error, got %v", err) } } func TestMemLB_Pick_AliveServer_Sticky(t *testing.T) { lb := newMem() s := &OriginServer{ServerID: "a", PID: "1", UpdatedAt: time.Now()} _ = lb.Update(context.Background(), s) got, err := lb.Pick(context.Background(), "url1") if err != nil { t.Fatalf("Pick: %v", err) } if got != s { t.Fatalf("Pick returned %v, want %v", got, s) } // Second pick for the same URL returns the same server (sticky branch). got2, err := lb.Pick(context.Background(), "url1") if err != nil { t.Fatalf("Pick second: %v", err) } if got2 != got { t.Fatalf("second Pick returned %v, want %v (sticky)", got2, got) } } func TestMemLB_Pick_OnlyDeadServers_Fallback(t *testing.T) { lb := newMem() // UpdatedAt long past => not alive. Tests the fallback "use all servers" branch. s := &OriginServer{ ServerID: "a", PID: "1", UpdatedAt: time.Now().Add(-2 * ServerAliveDuration), } _ = lb.Update(context.Background(), s) got, err := lb.Pick(context.Background(), "url1") if err != nil { t.Fatalf("Pick: %v", err) } if got != s { t.Fatalf("expected dead-server fallback to return %v, got %v", s, got) } } func TestMemLB_LoadHLSBySPBHID_NotFound(t *testing.T) { lb := newMem() _, err := lb.LoadHLSBySPBHID(context.Background(), "missing") if err == nil || !strings.Contains(err.Error(), "no HLS streaming") { t.Fatalf("expected error, got %v", err) } } func TestMemLB_LoadOrStoreHLS_New(t *testing.T) { lb := newMem() s := &stubHLS{spbhid: "abc"} got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s) if err != nil { t.Fatalf("LoadOrStoreHLS: %v", err) } if got != s { t.Fatalf("LoadOrStoreHLS returned %v, want %v", got, s) } // Lookup via SPBHID works (dual-index write). bySPBHID, err := lb.LoadHLSBySPBHID(context.Background(), "abc") if err != nil { t.Fatalf("LoadHLSBySPBHID: %v", err) } if bySPBHID != s { t.Fatalf("LoadHLSBySPBHID returned %v, want %v", bySPBHID, s) } } func TestMemLB_LoadOrStoreHLS_Existing(t *testing.T) { lb := newMem() s1 := &stubHLS{spbhid: "first"} s2 := &stubHLS{spbhid: "second"} _, _ = lb.LoadOrStoreHLS(context.Background(), "url1", s1) got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s2) if err != nil { t.Fatalf("LoadOrStoreHLS: %v", err) } if got != s1 { t.Fatalf("expected existing s1, got %v", got) } // SPBHID 'second' (from the rejected s2) maps to the existing s1. bySPBHID, _ := lb.LoadHLSBySPBHID(context.Background(), "second") if bySPBHID != s1 { t.Fatalf("expected SPBHID 'second' to map to s1, got %v", bySPBHID) } } func TestMemLB_StoreWebRTC_And_Load(t *testing.T) { lb := newMem() s := &stubRTC{ufrag: "ufrg1"} if err := lb.StoreWebRTC(context.Background(), "url1", s); err != nil { t.Fatalf("StoreWebRTC: %v", err) } got, err := lb.LoadWebRTCByUfrag(context.Background(), "ufrg1") if err != nil { t.Fatalf("LoadWebRTCByUfrag: %v", err) } if got != s { t.Fatalf("got %v, want %v", got, s) } } func TestMemLB_LoadWebRTCByUfrag_NotFound(t *testing.T) { lb := newMem() _, err := lb.LoadWebRTCByUfrag(context.Background(), "missing") if err == nil || !strings.Contains(err.Error(), "no WebRTC streaming") { t.Fatalf("expected error, got %v", err) } }