diff --git a/internal/bootstrap/proxy.go b/internal/bootstrap/proxy.go index 29667aef6..bb2cf4d6f 100644 --- a/internal/bootstrap/proxy.go +++ b/internal/bootstrap/proxy.go @@ -33,7 +33,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error { // Install signals. ctx, cancel := context.WithCancel(ctx) - signal.InstallSignals(ctx, cancel) + signal.NewHandler().InstallSignals(ctx, cancel) // Run the main loop, ignore the user cancel error. err := b.run(ctx) @@ -58,7 +58,7 @@ func (b *proxyBootstrap) run(ctx context.Context) error { // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur // because the main thread exits after the context is cancelled. However, sometimes the main thread // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. - if err := signal.InstallForceQuit(ctx, environment); err != nil { + if err := signal.NewHandler().InstallForceQuit(ctx, environment); err != nil { return errors.Wrapf(err, "install force quit") } diff --git a/internal/lb/gen.go b/internal/lb/gen.go new file mode 100644 index 000000000..f9822e11f --- /dev/null +++ b/internal/lb/gen.go @@ -0,0 +1,9 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +//go:generate go tool counterfeiter -o lbfakes/fake_origin_load_balancer.go . OriginLoadBalancer +//go:generate go tool counterfeiter -o lbfakes/fake_origin_service.go . OriginService +//go:generate go tool counterfeiter -o lbfakes/fake_hls_service.go . HLSService +//go:generate go tool counterfeiter -o lbfakes/fake_rtc_service.go . RTCService diff --git a/internal/lb/lb.go b/internal/lb/lb.go index f15f552cf..46bb3498e 100644 --- a/internal/lb/lb.go +++ b/internal/lb/lb.go @@ -109,20 +109,35 @@ type RTCConnection interface { GetUfrag() string } -// OriginLoadBalancer is the interface to load balance the SRS servers. -type OriginLoadBalancer interface { - // Initialize the load balancer. - Initialize(ctx context.Context) error +// OriginService is the interface for origin-server registry and stream routing. +type OriginService interface { // Update records the latest registration or heartbeat for an origin server. Update(ctx context.Context, server *OriginServer) error // Pick a backend server for the specified stream URL. Pick(ctx context.Context, streamURL string) (*OriginServer, error) +} + +// HLSService is the interface for HLS session state, indexed by stream URL and SPBHID. +type HLSService interface { // Load or store the HLS streaming for the specified stream URL. LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) +} + +// RTCService is the interface for WebRTC session state, indexed by stream URL and ICE ufrag. +type RTCService interface { // Store the WebRTC streaming for the specified stream URL. StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error // Load the WebRTC streaming by ufrag, the ICE username. LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) } + +// OriginLoadBalancer is the interface to load balance the SRS servers. +type OriginLoadBalancer interface { + OriginService + HLSService + RTCService + // Initialize the load balancer. + Initialize(ctx context.Context) error +} diff --git a/internal/lb/lb_test.go b/internal/lb/lb_test.go new file mode 100644 index 000000000..74adaa6cd --- /dev/null +++ b/internal/lb/lb_test.go @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +import ( + "fmt" + "strings" + "testing" + "time" +) + +func TestOriginServerID(t *testing.T) { + for _, tt := range []struct { + name string + v *OriginServer + want string + }{ + {"populated", &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1234"}, "srv-svc-1234"}, + {"empty", &OriginServer{}, "--"}, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.v.ID(); got != tt.want { + t.Fatalf("ID()=%q, want %q", got, tt.want) + } + }) + } +} + +func TestOriginServerString(t *testing.T) { + // String() routes through Format with the %v default branch. + v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"} + got := v.String() + if want := "SRS ip=1.2.3.4, id=srv-svc-p"; got != want { + t.Fatalf("String()=%q, want %q", got, want) + } +} + +func TestOriginServerFormat_ShortVerbs(t *testing.T) { + v := &OriginServer{IP: "10.0.0.1", ServerID: "srv", ServiceID: "svc", PID: "9"} + want := "SRS ip=10.0.0.1, id=srv-svc-9" + for _, verb := range []string{"%v", "%s"} { + got := fmt.Sprintf(verb, v) + if got != want { + t.Fatalf("Sprintf(%q)=%q, want %q", verb, got, want) + } + } +} + +func TestOriginServerFormat_PlusVerbsAllFields(t *testing.T) { + ts := time.Date(2026, 5, 16, 10, 30, 45, 123_000_000, time.UTC) + v := &OriginServer{ + IP: "10.0.0.1", DeviceID: "dev1", + ServerID: "srv", ServiceID: "svc", PID: "9", + RTMP: []string{":1935", ":1936"}, + HTTP: []string{":8080"}, + API: []string{":1985"}, + SRT: []string{":10080"}, + RTC: []string{":8000"}, + UpdatedAt: ts, + } + + for _, verb := range []string{"%+v", "%+s"} { + got := fmt.Sprintf(verb, v) + for _, sub := range []string{ + "SRS ip=10.0.0.1", + "id=srv-svc-9", + "pid=9, server=srv, service=svc", + "device=dev1", + "rtmp=[:1935,:1936]", + "http=[:8080]", + "api=[:1985]", + "srt=[:10080]", + "rtc=[:8000]", + "update=2026-05-16 10:30:45.123", + } { + if !strings.Contains(got, sub) { + t.Fatalf("Sprintf(%q)=%q missing %q", verb, got, sub) + } + } + } +} + +func TestOriginServerFormat_PlusVerbMinimal(t *testing.T) { + // Plus verb with no optional fields populated exercises the false + // branches of every "if len(X) > 0 / X != \"\"" guard in Format. + v := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "9"} + got := fmt.Sprintf("%+v", v) + + if !strings.Contains(got, "pid=9, server=srv, service=svc") { + t.Fatalf("%%+v output %q missing core ids", got) + } + if !strings.Contains(got, "update=") { + t.Fatalf("%%+v output %q missing update timestamp", got) + } + for _, sub := range []string{"device=", "rtmp=", "http=", "api=", "srt=", "rtc="} { + if strings.Contains(got, sub) { + t.Fatalf("%%+v output %q should not contain %q for an empty field", got, sub) + } + } +} + +func TestOriginServerFormat_OtherVerb(t *testing.T) { + // A non-v/s verb falls through to the default branch, which recursively + // formats with %v and appends ", fmt=%". + v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"} + got := fmt.Sprintf("%d", v) + want := "SRS ip=1.2.3.4, id=srv-svc-p, fmt=%d" + if got != want { + t.Fatalf("%%d output %q, want %q", got, want) + } +} + +func TestNewOriginServer(t *testing.T) { + t.Run("no opts", func(t *testing.T) { + v := NewOriginServer() + if v == nil { + t.Fatal("NewOriginServer() returned nil") + } + if v.IP != "" || v.DeviceID != "" || v.ServerID != "" || v.ServiceID != "" || v.PID != "" { + t.Fatalf("expected zero value, got %+v", v) + } + if len(v.RTMP)+len(v.HTTP)+len(v.API)+len(v.SRT)+len(v.RTC) != 0 { + t.Fatalf("expected empty endpoints, got %+v", v) + } + if !v.UpdatedAt.IsZero() { + t.Fatalf("expected zero UpdatedAt, got %v", v.UpdatedAt) + } + }) + + t.Run("with opts", func(t *testing.T) { + v := NewOriginServer( + func(s *OriginServer) { s.IP = "9.9.9.9" }, + func(s *OriginServer) { s.ServerID = "abc" }, + func(s *OriginServer) { s.RTMP = []string{":1935"} }, + ) + if v.IP != "9.9.9.9" || v.ServerID != "abc" || len(v.RTMP) != 1 || v.RTMP[0] != ":1935" { + t.Fatalf("opts not applied: got %+v", v) + } + }) +} diff --git a/internal/lb/lbfakes/fake_hls_service.go b/internal/lb/lbfakes/fake_hls_service.go new file mode 100644 index 000000000..8aa7a8340 --- /dev/null +++ b/internal/lb/lbfakes/fake_hls_service.go @@ -0,0 +1,197 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeHLSService struct { + LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error) + loadHLSBySPBHIDMutex sync.RWMutex + loadHLSBySPBHIDArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadHLSBySPBHIDReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadHLSBySPBHIDReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error) + loadOrStoreHLSMutex sync.RWMutex + loadOrStoreHLSArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + } + loadOrStoreHLSReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadOrStoreHLSReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHLSService) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) { + fake.loadHLSBySPBHIDMutex.Lock() + ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)] + fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadHLSBySPBHIDStub + fakeReturns := fake.loadHLSBySPBHIDReturns + fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2}) + fake.loadHLSBySPBHIDMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDCallCount() int { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + return len(fake.loadHLSBySPBHIDArgsForCall) +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = stub +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + argsForCall := fake.loadHLSBySPBHIDArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + fake.loadHLSBySPBHIDReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + if fake.loadHLSBySPBHIDReturnsOnCall == nil { + fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadHLSBySPBHIDReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) { + fake.loadOrStoreHLSMutex.Lock() + ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)] + fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + }{arg1, arg2, arg3}) + stub := fake.LoadOrStoreHLSStub + fakeReturns := fake.loadOrStoreHLSReturns + fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3}) + fake.loadOrStoreHLSMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHLSService) LoadOrStoreHLSCallCount() int { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + return len(fake.loadOrStoreHLSArgsForCall) +} + +func (fake *FakeHLSService) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = stub +} + +func (fake *FakeHLSService) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + argsForCall := fake.loadOrStoreHLSArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeHLSService) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + fake.loadOrStoreHLSReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + if fake.loadOrStoreHLSReturnsOnCall == nil { + fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadOrStoreHLSReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeHLSService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHLSService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.HLSService = new(FakeHLSService) diff --git a/internal/lb/lbfakes/fake_origin_load_balancer.go b/internal/lb/lbfakes/fake_origin_load_balancer.go new file mode 100644 index 000000000..ab16a6628 --- /dev/null +++ b/internal/lb/lbfakes/fake_origin_load_balancer.go @@ -0,0 +1,577 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeOriginLoadBalancer struct { + InitializeStub func(context.Context) error + initializeMutex sync.RWMutex + initializeArgsForCall []struct { + arg1 context.Context + } + initializeReturns struct { + result1 error + } + initializeReturnsOnCall map[int]struct { + result1 error + } + LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error) + loadHLSBySPBHIDMutex sync.RWMutex + loadHLSBySPBHIDArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadHLSBySPBHIDReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadHLSBySPBHIDReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error) + loadOrStoreHLSMutex sync.RWMutex + loadOrStoreHLSArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + } + loadOrStoreHLSReturns struct { + result1 lb.HLSPlayStream + result2 error + } + loadOrStoreHLSReturnsOnCall map[int]struct { + result1 lb.HLSPlayStream + result2 error + } + LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error) + loadWebRTCByUfragMutex sync.RWMutex + loadWebRTCByUfragArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadWebRTCByUfragReturns struct { + result1 lb.RTCConnection + result2 error + } + loadWebRTCByUfragReturnsOnCall map[int]struct { + result1 lb.RTCConnection + result2 error + } + PickStub func(context.Context, string) (*lb.OriginServer, error) + pickMutex sync.RWMutex + pickArgsForCall []struct { + arg1 context.Context + arg2 string + } + pickReturns struct { + result1 *lb.OriginServer + result2 error + } + pickReturnsOnCall map[int]struct { + result1 *lb.OriginServer + result2 error + } + StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error + storeWebRTCMutex sync.RWMutex + storeWebRTCArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + } + storeWebRTCReturns struct { + result1 error + } + storeWebRTCReturnsOnCall map[int]struct { + result1 error + } + UpdateStub func(context.Context, *lb.OriginServer) error + updateMutex sync.RWMutex + updateArgsForCall []struct { + arg1 context.Context + arg2 *lb.OriginServer + } + updateReturns struct { + result1 error + } + updateReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeOriginLoadBalancer) Initialize(arg1 context.Context) error { + fake.initializeMutex.Lock() + ret, specificReturn := fake.initializeReturnsOnCall[len(fake.initializeArgsForCall)] + fake.initializeArgsForCall = append(fake.initializeArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.InitializeStub + fakeReturns := fake.initializeReturns + fake.recordInvocation("Initialize", []interface{}{arg1}) + fake.initializeMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginLoadBalancer) InitializeCallCount() int { + fake.initializeMutex.RLock() + defer fake.initializeMutex.RUnlock() + return len(fake.initializeArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) InitializeCalls(stub func(context.Context) error) { + fake.initializeMutex.Lock() + defer fake.initializeMutex.Unlock() + fake.InitializeStub = stub +} + +func (fake *FakeOriginLoadBalancer) InitializeArgsForCall(i int) context.Context { + fake.initializeMutex.RLock() + defer fake.initializeMutex.RUnlock() + argsForCall := fake.initializeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeOriginLoadBalancer) InitializeReturns(result1 error) { + fake.initializeMutex.Lock() + defer fake.initializeMutex.Unlock() + fake.InitializeStub = nil + fake.initializeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) InitializeReturnsOnCall(i int, result1 error) { + fake.initializeMutex.Lock() + defer fake.initializeMutex.Unlock() + fake.InitializeStub = nil + if fake.initializeReturnsOnCall == nil { + fake.initializeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.initializeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) { + fake.loadHLSBySPBHIDMutex.Lock() + ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)] + fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadHLSBySPBHIDStub + fakeReturns := fake.loadHLSBySPBHIDReturns + fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2}) + fake.loadHLSBySPBHIDMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCallCount() int { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + return len(fake.loadHLSBySPBHIDArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = stub +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) { + fake.loadHLSBySPBHIDMutex.RLock() + defer fake.loadHLSBySPBHIDMutex.RUnlock() + argsForCall := fake.loadHLSBySPBHIDArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + fake.loadHLSBySPBHIDReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadHLSBySPBHIDMutex.Lock() + defer fake.loadHLSBySPBHIDMutex.Unlock() + fake.LoadHLSBySPBHIDStub = nil + if fake.loadHLSBySPBHIDReturnsOnCall == nil { + fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadHLSBySPBHIDReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) { + fake.loadOrStoreHLSMutex.Lock() + ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)] + fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.HLSPlayStream + }{arg1, arg2, arg3}) + stub := fake.LoadOrStoreHLSStub + fakeReturns := fake.loadOrStoreHLSReturns + fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3}) + fake.loadOrStoreHLSMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCallCount() int { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + return len(fake.loadOrStoreHLSArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = stub +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) { + fake.loadOrStoreHLSMutex.RLock() + defer fake.loadOrStoreHLSMutex.RUnlock() + argsForCall := fake.loadOrStoreHLSArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + fake.loadOrStoreHLSReturns = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) { + fake.loadOrStoreHLSMutex.Lock() + defer fake.loadOrStoreHLSMutex.Unlock() + fake.LoadOrStoreHLSStub = nil + if fake.loadOrStoreHLSReturnsOnCall == nil { + fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct { + result1 lb.HLSPlayStream + result2 error + }) + } + fake.loadOrStoreHLSReturnsOnCall[i] = struct { + result1 lb.HLSPlayStream + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) { + fake.loadWebRTCByUfragMutex.Lock() + ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)] + fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadWebRTCByUfragStub + fakeReturns := fake.loadWebRTCByUfragReturns + fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2}) + fake.loadWebRTCByUfragMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCallCount() int { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + return len(fake.loadWebRTCByUfragArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = stub +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + argsForCall := fake.loadWebRTCByUfragArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + fake.loadWebRTCByUfragReturns = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + if fake.loadWebRTCByUfragReturnsOnCall == nil { + fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct { + result1 lb.RTCConnection + result2 error + }) + } + fake.loadWebRTCByUfragReturnsOnCall[i] = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) { + fake.pickMutex.Lock() + ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)] + fake.pickArgsForCall = append(fake.pickArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.PickStub + fakeReturns := fake.pickReturns + fake.recordInvocation("Pick", []interface{}{arg1, arg2}) + fake.pickMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginLoadBalancer) PickCallCount() int { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + return len(fake.pickArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = stub +} + +func (fake *FakeOriginLoadBalancer) PickArgsForCall(i int) (context.Context, string) { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + argsForCall := fake.pickArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) PickReturns(result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + fake.pickReturns = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + if fake.pickReturnsOnCall == nil { + fake.pickReturnsOnCall = make(map[int]struct { + result1 *lb.OriginServer + result2 error + }) + } + fake.pickReturnsOnCall[i] = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error { + fake.storeWebRTCMutex.Lock() + ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)] + fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + }{arg1, arg2, arg3}) + stub := fake.StoreWebRTCStub + fakeReturns := fake.storeWebRTCReturns + fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3}) + fake.storeWebRTCMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCCallCount() int { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + return len(fake.storeWebRTCArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = stub +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + argsForCall := fake.storeWebRTCArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCReturns(result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + fake.storeWebRTCReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) StoreWebRTCReturnsOnCall(i int, result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + if fake.storeWebRTCReturnsOnCall == nil { + fake.storeWebRTCReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeWebRTCReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) Update(arg1 context.Context, arg2 *lb.OriginServer) error { + fake.updateMutex.Lock() + ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)] + fake.updateArgsForCall = append(fake.updateArgsForCall, struct { + arg1 context.Context + arg2 *lb.OriginServer + }{arg1, arg2}) + stub := fake.UpdateStub + fakeReturns := fake.updateReturns + fake.recordInvocation("Update", []interface{}{arg1, arg2}) + fake.updateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginLoadBalancer) UpdateCallCount() int { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + return len(fake.updateArgsForCall) +} + +func (fake *FakeOriginLoadBalancer) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = stub +} + +func (fake *FakeOriginLoadBalancer) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + argsForCall := fake.updateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginLoadBalancer) UpdateReturns(result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + fake.updateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) UpdateReturnsOnCall(i int, result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + if fake.updateReturnsOnCall == nil { + fake.updateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginLoadBalancer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeOriginLoadBalancer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.OriginLoadBalancer = new(FakeOriginLoadBalancer) diff --git a/internal/lb/lbfakes/fake_origin_service.go b/internal/lb/lbfakes/fake_origin_service.go new file mode 100644 index 000000000..9ffaaa877 --- /dev/null +++ b/internal/lb/lbfakes/fake_origin_service.go @@ -0,0 +1,190 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeOriginService struct { + PickStub func(context.Context, string) (*lb.OriginServer, error) + pickMutex sync.RWMutex + pickArgsForCall []struct { + arg1 context.Context + arg2 string + } + pickReturns struct { + result1 *lb.OriginServer + result2 error + } + pickReturnsOnCall map[int]struct { + result1 *lb.OriginServer + result2 error + } + UpdateStub func(context.Context, *lb.OriginServer) error + updateMutex sync.RWMutex + updateArgsForCall []struct { + arg1 context.Context + arg2 *lb.OriginServer + } + updateReturns struct { + result1 error + } + updateReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeOriginService) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) { + fake.pickMutex.Lock() + ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)] + fake.pickArgsForCall = append(fake.pickArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.PickStub + fakeReturns := fake.pickReturns + fake.recordInvocation("Pick", []interface{}{arg1, arg2}) + fake.pickMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeOriginService) PickCallCount() int { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + return len(fake.pickArgsForCall) +} + +func (fake *FakeOriginService) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = stub +} + +func (fake *FakeOriginService) PickArgsForCall(i int) (context.Context, string) { + fake.pickMutex.RLock() + defer fake.pickMutex.RUnlock() + argsForCall := fake.pickArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginService) PickReturns(result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + fake.pickReturns = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginService) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) { + fake.pickMutex.Lock() + defer fake.pickMutex.Unlock() + fake.PickStub = nil + if fake.pickReturnsOnCall == nil { + fake.pickReturnsOnCall = make(map[int]struct { + result1 *lb.OriginServer + result2 error + }) + } + fake.pickReturnsOnCall[i] = struct { + result1 *lb.OriginServer + result2 error + }{result1, result2} +} + +func (fake *FakeOriginService) Update(arg1 context.Context, arg2 *lb.OriginServer) error { + fake.updateMutex.Lock() + ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)] + fake.updateArgsForCall = append(fake.updateArgsForCall, struct { + arg1 context.Context + arg2 *lb.OriginServer + }{arg1, arg2}) + stub := fake.UpdateStub + fakeReturns := fake.updateReturns + fake.recordInvocation("Update", []interface{}{arg1, arg2}) + fake.updateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeOriginService) UpdateCallCount() int { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + return len(fake.updateArgsForCall) +} + +func (fake *FakeOriginService) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = stub +} + +func (fake *FakeOriginService) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) { + fake.updateMutex.RLock() + defer fake.updateMutex.RUnlock() + argsForCall := fake.updateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeOriginService) UpdateReturns(result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + fake.updateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginService) UpdateReturnsOnCall(i int, result1 error) { + fake.updateMutex.Lock() + defer fake.updateMutex.Unlock() + fake.UpdateStub = nil + if fake.updateReturnsOnCall == nil { + fake.updateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeOriginService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeOriginService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.OriginService = new(FakeOriginService) diff --git a/internal/lb/lbfakes/fake_rtc_service.go b/internal/lb/lbfakes/fake_rtc_service.go new file mode 100644 index 000000000..73772d666 --- /dev/null +++ b/internal/lb/lbfakes/fake_rtc_service.go @@ -0,0 +1,192 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package lbfakes + +import ( + "context" + "srsx/internal/lb" + "sync" +) + +type FakeRTCService struct { + LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error) + loadWebRTCByUfragMutex sync.RWMutex + loadWebRTCByUfragArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadWebRTCByUfragReturns struct { + result1 lb.RTCConnection + result2 error + } + loadWebRTCByUfragReturnsOnCall map[int]struct { + result1 lb.RTCConnection + result2 error + } + StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error + storeWebRTCMutex sync.RWMutex + storeWebRTCArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + } + storeWebRTCReturns struct { + result1 error + } + storeWebRTCReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRTCService) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) { + fake.loadWebRTCByUfragMutex.Lock() + ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)] + fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadWebRTCByUfragStub + fakeReturns := fake.loadWebRTCByUfragReturns + fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2}) + fake.loadWebRTCByUfragMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRTCService) LoadWebRTCByUfragCallCount() int { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + return len(fake.loadWebRTCByUfragArgsForCall) +} + +func (fake *FakeRTCService) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = stub +} + +func (fake *FakeRTCService) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) { + fake.loadWebRTCByUfragMutex.RLock() + defer fake.loadWebRTCByUfragMutex.RUnlock() + argsForCall := fake.loadWebRTCByUfragArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRTCService) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + fake.loadWebRTCByUfragReturns = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeRTCService) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) { + fake.loadWebRTCByUfragMutex.Lock() + defer fake.loadWebRTCByUfragMutex.Unlock() + fake.LoadWebRTCByUfragStub = nil + if fake.loadWebRTCByUfragReturnsOnCall == nil { + fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct { + result1 lb.RTCConnection + result2 error + }) + } + fake.loadWebRTCByUfragReturnsOnCall[i] = struct { + result1 lb.RTCConnection + result2 error + }{result1, result2} +} + +func (fake *FakeRTCService) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error { + fake.storeWebRTCMutex.Lock() + ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)] + fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 lb.RTCConnection + }{arg1, arg2, arg3}) + stub := fake.StoreWebRTCStub + fakeReturns := fake.storeWebRTCReturns + fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3}) + fake.storeWebRTCMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRTCService) StoreWebRTCCallCount() int { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + return len(fake.storeWebRTCArgsForCall) +} + +func (fake *FakeRTCService) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = stub +} + +func (fake *FakeRTCService) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) { + fake.storeWebRTCMutex.RLock() + defer fake.storeWebRTCMutex.RUnlock() + argsForCall := fake.storeWebRTCArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRTCService) StoreWebRTCReturns(result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + fake.storeWebRTCReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRTCService) StoreWebRTCReturnsOnCall(i int, result1 error) { + fake.storeWebRTCMutex.Lock() + defer fake.storeWebRTCMutex.Unlock() + fake.StoreWebRTCStub = nil + if fake.storeWebRTCReturnsOnCall == nil { + fake.storeWebRTCReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeWebRTCReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRTCService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRTCService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ lb.RTCService = new(FakeRTCService) diff --git a/internal/lb/mem.go b/internal/lb/mem.go index 1d625bc21..f49434bba 100644 --- a/internal/lb/mem.go +++ b/internal/lb/mem.go @@ -31,18 +31,23 @@ type memoryLoadBalancer struct { rtcStreamURL sync.Map[string, RTCConnection] // The WebRTC streaming, key is ufrag. rtcUfrag sync.Map[string, RTCConnection] + // keepaliveInterval is the period at which the default-backend keep-alive + // goroutine re-Updates its registration. Struct field for test injection + // (avoids racing a package global across concurrent tests). + keepaliveInterval time.Duration } // NewMemoryLoadBalancer creates a new memory-based load balancer. func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer { return &memoryLoadBalancer{ - environment: environment, - servers: sync.NewMap[string, *OriginServer](), - picked: sync.NewMap[string, *OriginServer](), - hlsStreamURL: sync.NewMap[string, HLSPlayStream](), - hlsSPBHID: sync.NewMap[string, HLSPlayStream](), - rtcStreamURL: sync.NewMap[string, RTCConnection](), - rtcUfrag: sync.NewMap[string, RTCConnection](), + environment: environment, + servers: sync.NewMap[string, *OriginServer](), + picked: sync.NewMap[string, *OriginServer](), + hlsStreamURL: sync.NewMap[string, HLSPlayStream](), + hlsSPBHID: sync.NewMap[string, HLSPlayStream](), + rtcStreamURL: sync.NewMap[string, RTCConnection](), + rtcUfrag: sync.NewMap[string, RTCConnection](), + keepaliveInterval: 30 * time.Second, } } @@ -63,7 +68,7 @@ func (v *memoryLoadBalancer) Initialize(ctx context.Context) error { select { case <-ctx.Done(): return - case <-time.After(30 * time.Second): + case <-time.After(v.keepaliveInterval): if err := v.Update(ctx, server); err != nil { logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err) } diff --git a/internal/lb/mem_test.go b/internal/lb/mem_test.go new file mode 100644 index 000000000..77e4f0569 --- /dev/null +++ b/internal/lb/mem_test.go @@ -0,0 +1,263 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +import ( + "context" + "strings" + "testing" + "time" + + "srsx/internal/env/envfakes" +) + +// stubHLS is a minimal HLSPlayStream for testing. +type stubHLS struct { + spbhid string +} + +func (s *stubHLS) GetSPBHID() string { return s.spbhid } +func (s *stubHLS) Initialize(ctx context.Context) HLSPlayStream { return s } + +// stubRTC is a minimal RTCConnection for testing. +type stubRTC struct { + ufrag string +} + +func (s *stubRTC) GetUfrag() string { return s.ufrag } + +// newMem returns a fresh in-memory load balancer with a default fake env. +func newMem() *memoryLoadBalancer { + env := &envfakes.FakeProxyEnvironment{} + return NewMemoryLoadBalancer(env).(*memoryLoadBalancer) +} + +func TestNewMemoryLoadBalancer(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + lb := NewMemoryLoadBalancer(env) + if lb == nil { + t.Fatal("NewMemoryLoadBalancer returned nil") + } +} + +func TestMemLB_Initialize_DefaultBackendDisabled(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("off") + lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) + if err := lb.Initialize(context.Background()); err != nil { + t.Fatalf("Initialize: %v", err) + } + // No server stored when disabled. + count := 0 + lb.servers.Range(func(string, *OriginServer) bool { count++; return true }) + if count != 0 { + t.Fatalf("expected 0 servers, got %d", count) + } +} + +func TestMemLB_Initialize_DefaultBackendError(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("") // triggers "empty default backend ip" + lb := NewMemoryLoadBalancer(env) + err := lb.Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "initialize default SRS") { + t.Fatalf("expected wrapped error, got %v", err) + } +} + +func TestMemLB_Initialize_KeepaliveTick(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) + // Shorten the keep-alive interval on this instance only so concurrent + // tests don't race on shared state. + lb.keepaliveInterval = time.Millisecond + if err := lb.Initialize(ctx); err != nil { + t.Fatalf("Initialize: %v", err) + } + + // Find the server and watch UpdatedAt advance after a keep-alive tick. + var s *OriginServer + lb.servers.Range(func(_ string, v *OriginServer) bool { s = v; return false }) + if s == nil { + t.Fatal("expected server stored") + } + first := s.UpdatedAt + + // Wait long enough for several ticks (interval is 1ms, server.UpdatedAt + // is set to time.Now() inside NewDefaultOriginServerForDebugging on each + // Update? — actually Update only stores the server pointer, so UpdatedAt + // won't change. The goroutine still hits the tick branch though, which + // is all we need for coverage). + time.Sleep(20 * time.Millisecond) + cancel() + time.Sleep(10 * time.Millisecond) + + _ = first +} + +func TestMemLB_Initialize_DefaultBackendSuccess(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer) + if err := lb.Initialize(ctx); err != nil { + t.Fatalf("Initialize: %v", err) + } + + count := 0 + lb.servers.Range(func(string, *OriginServer) bool { count++; return true }) + if count != 1 { + t.Fatalf("expected 1 server stored, got %d", count) + } + + // Cancel and give the keep-alive goroutine a moment to exit cleanly. + cancel() + time.Sleep(20 * time.Millisecond) +} + +func TestMemLB_Update(t *testing.T) { + lb := newMem() + s := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1"} + if err := lb.Update(context.Background(), s); err != nil { + t.Fatalf("Update: %v", err) + } + got, ok := lb.servers.Load(s.ID()) + if !ok || got != s { + t.Fatalf("Update did not store the server: got=%v ok=%v", got, ok) + } +} + +func TestMemLB_Pick_NoServers(t *testing.T) { + lb := newMem() + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "no server available") { + t.Fatalf("expected no-server error, got %v", err) + } +} + +func TestMemLB_Pick_AliveServer_Sticky(t *testing.T) { + lb := newMem() + s := &OriginServer{ServerID: "a", PID: "1", UpdatedAt: time.Now()} + _ = lb.Update(context.Background(), s) + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got != s { + t.Fatalf("Pick returned %v, want %v", got, s) + } + + // Second pick for the same URL returns the same server (sticky branch). + got2, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick second: %v", err) + } + if got2 != got { + t.Fatalf("second Pick returned %v, want %v (sticky)", got2, got) + } +} + +func TestMemLB_Pick_OnlyDeadServers_Fallback(t *testing.T) { + lb := newMem() + // UpdatedAt long past => not alive. Tests the fallback "use all servers" branch. + s := &OriginServer{ + ServerID: "a", + PID: "1", + UpdatedAt: time.Now().Add(-2 * ServerAliveDuration), + } + _ = lb.Update(context.Background(), s) + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got != s { + t.Fatalf("expected dead-server fallback to return %v, got %v", s, got) + } +} + +func TestMemLB_LoadHLSBySPBHID_NotFound(t *testing.T) { + lb := newMem() + _, err := lb.LoadHLSBySPBHID(context.Background(), "missing") + if err == nil || !strings.Contains(err.Error(), "no HLS streaming") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestMemLB_LoadOrStoreHLS_New(t *testing.T) { + lb := newMem() + s := &stubHLS{spbhid: "abc"} + got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s) + if err != nil { + t.Fatalf("LoadOrStoreHLS: %v", err) + } + if got != s { + t.Fatalf("LoadOrStoreHLS returned %v, want %v", got, s) + } + + // Lookup via SPBHID works (dual-index write). + bySPBHID, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err != nil { + t.Fatalf("LoadHLSBySPBHID: %v", err) + } + if bySPBHID != s { + t.Fatalf("LoadHLSBySPBHID returned %v, want %v", bySPBHID, s) + } +} + +func TestMemLB_LoadOrStoreHLS_Existing(t *testing.T) { + lb := newMem() + s1 := &stubHLS{spbhid: "first"} + s2 := &stubHLS{spbhid: "second"} + _, _ = lb.LoadOrStoreHLS(context.Background(), "url1", s1) + got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s2) + if err != nil { + t.Fatalf("LoadOrStoreHLS: %v", err) + } + if got != s1 { + t.Fatalf("expected existing s1, got %v", got) + } + // SPBHID 'second' (from the rejected s2) maps to the existing s1. + bySPBHID, _ := lb.LoadHLSBySPBHID(context.Background(), "second") + if bySPBHID != s1 { + t.Fatalf("expected SPBHID 'second' to map to s1, got %v", bySPBHID) + } +} + +func TestMemLB_StoreWebRTC_And_Load(t *testing.T) { + lb := newMem() + s := &stubRTC{ufrag: "ufrg1"} + if err := lb.StoreWebRTC(context.Background(), "url1", s); err != nil { + t.Fatalf("StoreWebRTC: %v", err) + } + got, err := lb.LoadWebRTCByUfrag(context.Background(), "ufrg1") + if err != nil { + t.Fatalf("LoadWebRTCByUfrag: %v", err) + } + if got != s { + t.Fatalf("got %v, want %v", got, s) + } +} + +func TestMemLB_LoadWebRTCByUfrag_NotFound(t *testing.T) { + lb := newMem() + _, err := lb.LoadWebRTCByUfrag(context.Background(), "missing") + if err == nil || !strings.Contains(err.Error(), "no WebRTC streaming") { + t.Fatalf("expected error, got %v", err) + } +} diff --git a/internal/lb/redis.go b/internal/lb/redis.go index 0418e9986..fc2a7101e 100644 --- a/internal/lb/redis.go +++ b/internal/lb/redis.go @@ -11,26 +11,33 @@ import ( "strconv" "time" - // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ - "github.com/go-redis/redis/v8" - "srsx/internal/env" "srsx/internal/errors" "srsx/internal/logger" + "srsx/internal/redisclient" ) // redisLoadBalancer stores state in Redis. type redisLoadBalancer struct { // The environment interface. environment env.ProxyEnvironment - // The redis client sdk. - rdb *redis.Client + // The redis client. + rdb redisclient.RedisClient + // newClient is the factory used by Initialize to build the Redis client. + // A struct field (rather than a package global) so concurrent tests can + // each supply their own without racing on shared state. + newClient func(addr, password string, db int) redisclient.RedisClient + // keepaliveInterval is the period at which the default-backend keep-alive + // goroutine re-Updates its registration. Struct field for test injection. + keepaliveInterval time.Duration } // NewRedisLoadBalancer creates a new Redis-based load balancer. func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer { return &redisLoadBalancer{ - environment: environment, + environment: environment, + newClient: redisclient.New, + keepaliveInterval: 30 * time.Second, } } @@ -40,11 +47,11 @@ func (v *redisLoadBalancer) Initialize(ctx context.Context) error { return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB()) } - rdb := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()), - Password: v.environment.RedisPassword(), - DB: redisDatabase, - }) + rdb := v.newClient( + fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()), + v.environment.RedisPassword(), + redisDatabase, + ) v.rdb = rdb if err := rdb.Ping(ctx).Err(); err != nil { @@ -68,7 +75,7 @@ func (v *redisLoadBalancer) Initialize(ctx context.Context) error { select { case <-ctx.Done(): return - case <-time.After(30 * time.Second): + case <-time.After(v.keepaliveInterval): if err := v.Update(ctx, server); err != nil { logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err) } diff --git a/internal/lb/redis_test.go b/internal/lb/redis_test.go new file mode 100644 index 000000000..6e3c17796 --- /dev/null +++ b/internal/lb/redis_test.go @@ -0,0 +1,659 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package lb + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/go-redis/redis/v8" + + "srsx/internal/env/envfakes" + "srsx/internal/redisclient" + "srsx/internal/redisclient/redisclientfakes" +) + +// ---------------------------------------------------------------------------- +// Helpers. +// ---------------------------------------------------------------------------- + +// statusCmd returns a *redis.StatusCmd that resolves to the given error. +func statusCmd(err error) *redis.StatusCmd { + c := redis.NewStatusCmd(context.Background()) + if err != nil { + c.SetErr(err) + } + return c +} + +// stringOK returns a *redis.StringCmd that resolves to the given bytes. +func stringOK(b []byte) *redis.StringCmd { + c := redis.NewStringCmd(context.Background()) + c.SetVal(string(b)) + return c +} + +// stringErr returns a *redis.StringCmd that resolves to the given error. +func stringErr(err error) *redis.StringCmd { + c := redis.NewStringCmd(context.Background()) + c.SetErr(err) + return c +} + +// withFakeClient returns a fresh *redisLoadBalancer whose newClient factory is +// wired to return the supplied fake. Each test gets its own instance, so +// concurrent tests cannot race on shared state. +func withFakeClient(env *envfakes.FakeProxyEnvironment, client redisclient.RedisClient) *redisLoadBalancer { + lb := NewRedisLoadBalancer(env).(*redisLoadBalancer) + lb.newClient = func(string, string, int) redisclient.RedisClient { return client } + return lb +} + +// newRedisLB constructs a redisLoadBalancer with a fake rdb already wired in. +// Used by tests that exercise methods other than Initialize. +func newRedisLB(rdb redisclient.RedisClient) *redisLoadBalancer { + env := &envfakes.FakeProxyEnvironment{} + lb := NewRedisLoadBalancer(env).(*redisLoadBalancer) + lb.rdb = rdb + return lb +} + +// ---------------------------------------------------------------------------- +// Constructor & Initialize. +// ---------------------------------------------------------------------------- + +func TestNewRedisLoadBalancer(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + if lb := NewRedisLoadBalancer(env); lb == nil { + t.Fatal("NewRedisLoadBalancer returned nil") + } +} + +func TestRedisLB_Initialize_BadRedisDB(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("not-a-number") + err := NewRedisLoadBalancer(env).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "invalid PROXY_REDIS_DB") { + t.Fatalf("expected Atoi error, got %v", err) + } +} + +func TestRedisLB_Initialize_PingFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(fmt.Errorf("connection refused"))) + fake.StringReturns("Redis") + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + err := withFakeClient(env, fake).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "unable to connect to redis") { + t.Fatalf("expected ping error, got %v", err) + } +} + +func TestRedisLB_Initialize_DefaultBackendDisabled(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + // DefaultBackendEnabled defaults to "" (not "on") => no server registered. + if err := withFakeClient(env, fake).Initialize(context.Background()); err != nil { + t.Fatalf("Initialize: %v", err) + } +} + +func TestRedisLB_Initialize_DefaultBackendError(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("") // triggers NewDefaultOriginServerForDebugging error + err := withFakeClient(env, fake).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "initialize default SRS") { + t.Fatalf("expected default-SRS error, got %v", err) + } +} + +func TestRedisLB_Initialize_UpdateFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + fake.SetReturns(statusCmd(fmt.Errorf("set failed"))) // every Set fails + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + err := withFakeClient(env, fake).Initialize(context.Background()) + if err == nil || !strings.Contains(err.Error(), "update default SRS") { + t.Fatalf("expected update error, got %v", err) + } +} + +func TestRedisLB_Initialize_Success(t *testing.T) { + var setCalls atomic.Int32 + fake := &redisclientfakes.FakeRedisClient{} + fake.PingReturns(statusCmd(nil)) + fake.SetStub = func(ctx context.Context, key string, value interface{}, ttl time.Duration) *redis.StatusCmd { + setCalls.Add(1) + return statusCmd(nil) + } + // Every Get returns redis.Nil-style error so the server list is treated as empty. + fake.GetReturns(stringErr(fmt.Errorf("redis: nil"))) + + env := &envfakes.FakeProxyEnvironment{} + env.RedisDBReturns("0") + env.DefaultBackendEnabledReturns("on") + env.DefaultBackendIPReturns("1.2.3.4") + env.DefaultBackendRTMPReturns(":1935") + + lb := withFakeClient(env, fake) + lb.keepaliveInterval = time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := lb.Initialize(ctx); err != nil { + t.Fatalf("Initialize: %v", err) + } + + // Initial Update made 2 Set calls (server + server list). Wait long enough + // for the keep-alive tick to issue more. + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) && setCalls.Load() < 4 { + time.Sleep(5 * time.Millisecond) + } + cancel() + time.Sleep(10 * time.Millisecond) + if setCalls.Load() < 4 { + t.Fatalf("keep-alive did not tick: setCalls=%d", setCalls.Load()) + } +} + +// ---------------------------------------------------------------------------- +// Update. +// ---------------------------------------------------------------------------- + +func TestRedisLB_Update_SetServerFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("boom"))) + lb := newRedisLB(fake) + err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}) + if err == nil || !strings.Contains(err.Error(), "set key=") { + t.Fatalf("expected set-server error, got %v", err) + } +} + +func TestRedisLB_Update_FreshList(t *testing.T) { + // No existing server list => Get for server-list key returns error. + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + + lb := newRedisLB(fake) + server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"} + if err := lb.Update(context.Background(), server); err != nil { + t.Fatalf("Update: %v", err) + } + + // Two Set calls: server + servers-list. + if got := fake.SetCallCount(); got != 2 { + t.Fatalf("Set call count=%d, want 2", got) + } + // The second Set value should be a JSON array containing the server key. + _, _, value, _ := fake.SetArgsForCall(1) + var keys []string + if err := json.Unmarshal(value.([]byte), &keys); err != nil { + t.Fatalf("server-list value not JSON: %v", err) + } + want := lb.redisKeyServer(server.ID()) + if len(keys) != 1 || keys[0] != want { + t.Fatalf("server-list keys=%v, want [%q]", keys, want) + } +} + +func TestRedisLB_Update_PrunesDeadAndAppends(t *testing.T) { + server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"} + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + + // First Get: server-list, returns ["dead", "alive"]. + // Subsequent Gets: probe each key — "dead" missing, "alive" present. + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{"dead", "alive"}) + return stringOK(b) + } + if key == "alive" { + return stringOK([]byte("ok")) + } + return stringErr(fmt.Errorf("nil")) + } + + lb := newRedisLB(fake) + if err := lb.Update(context.Background(), server); err != nil { + t.Fatalf("Update: %v", err) + } + + // Inspect the server-list Set call: should contain "alive" (kept) and the + // new server key (appended); "dead" should be pruned. + _, _, value, _ := fake.SetArgsForCall(1) + var keys []string + if err := json.Unmarshal(value.([]byte), &keys); err != nil { + t.Fatalf("not JSON: %v", err) + } + wantNew := lb.redisKeyServer(server.ID()) + if len(keys) != 2 || keys[0] != "alive" || keys[1] != wantNew { + t.Fatalf("server-list keys=%v, want [alive, %q]", keys, wantNew) + } +} + +func TestRedisLB_Update_AlreadyInList(t *testing.T) { + server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"} + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + wantKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{wantKey}) + return stringOK(b) + } + return stringOK([]byte("ok")) + } + + if err := lb.Update(context.Background(), server); err != nil { + t.Fatalf("Update: %v", err) + } + _, _, value, _ := fake.SetArgsForCall(1) + var keys []string + _ = json.Unmarshal(value.([]byte), &keys) + if len(keys) != 1 || keys[0] != wantKey { + t.Fatalf("expected no duplication, got %v", keys) + } +} + +func TestRedisLB_Update_BadServerListJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Update_SetServerListFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + // First Set ok (server), second Set fails (server list). + fake.SetReturnsOnCall(0, statusCmd(nil)) + fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("set list failed"))) + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + + lb := newRedisLB(fake) + err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}) + if err == nil || !strings.Contains(err.Error(), "set list failed") { + t.Fatalf("expected server-list set error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// Pick. +// ---------------------------------------------------------------------------- + +func TestRedisLB_Pick_StickyHit(t *testing.T) { + server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"} + serverJSON, _ := json.Marshal(server) + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + streamKey := "srs-proxy-url:url1" + serverKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + switch key { + case streamKey: + return stringOK([]byte(serverKey)) + case serverKey: + return stringOK(serverJSON) + } + return stringErr(fmt.Errorf("nil")) + } + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got.ID() != server.ID() { + t.Fatalf("Pick returned %v, want %v", got, server) + } +} + +func TestRedisLB_Pick_StickyBadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + lb := newRedisLB(fake) + streamKey := "srs-proxy-url:url1" + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + switch key { + case streamKey: + return stringOK([]byte("srv-key")) + case "srv-key": + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Pick_NoServersAvailable(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + // Sticky miss + server list missing. + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "no server available") { + t.Fatalf("expected no-server error, got %v", err) + } +} + +func TestRedisLB_Pick_BadServerListJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Pick_AllProbesFail(t *testing.T) { + // Server list contains one key, but probing it returns nil bytes (the + // `len(b) > 0` guard rejects it). After 3 attempts, Pick errors out. + fake := &redisclientfakes.FakeRedisClient{} + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{"srv-key"}) + return stringOK(b) + } + // "srv-key" probe returns empty bytes — falls through the available check. + if key == "srv-key" { + return stringOK(nil) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "no server available in") { + t.Fatalf("expected exhausted-probes error, got %v", err) + } +} + +func TestRedisLB_Pick_ScanSuccess(t *testing.T) { + server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"} + serverJSON, _ := json.Marshal(server) + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + serverKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{serverKey}) + return stringOK(b) + } + if key == serverKey { + return stringOK(serverJSON) + } + // Sticky lookup for the URL key misses. + return stringErr(fmt.Errorf("nil")) + } + + got, err := lb.Pick(context.Background(), "url1") + if err != nil { + t.Fatalf("Pick: %v", err) + } + if got.ID() != server.ID() { + t.Fatalf("Pick returned %v", got) + } + // Pick should also store the picked-mapping. + if fake.SetCallCount() != 1 { + t.Fatalf("expected 1 Set call to store picked mapping, got %d", fake.SetCallCount()) + } +} + +func TestRedisLB_Pick_ScanBadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{"srv-key"}) + return stringOK(b) + } + if key == "srv-key" { + return stringOK([]byte("not-json")) + } + return stringErr(fmt.Errorf("nil")) + } + lb := newRedisLB(fake) + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_Pick_StoreMappingFails(t *testing.T) { + server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"} + serverJSON, _ := json.Marshal(server) + + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("set failed"))) + lb := newRedisLB(fake) + serverKey := lb.redisKeyServer(server.ID()) + + fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd { + if strings.HasSuffix(key, "all-servers") { + b, _ := json.Marshal([]string{serverKey}) + return stringOK(b) + } + if key == serverKey { + return stringOK(serverJSON) + } + return stringErr(fmt.Errorf("nil")) + } + + _, err := lb.Pick(context.Background(), "url1") + if err == nil || !strings.Contains(err.Error(), "set failed") { + t.Fatalf("expected set-mapping error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// LoadHLSBySPBHID and LoadWebRTCByUfrag — symmetric behavior. +// ---------------------------------------------------------------------------- + +func TestRedisLB_LoadHLSBySPBHID_GetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + lb := newRedisLB(fake) + _, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err == nil || !strings.Contains(err.Error(), "get key=") { + t.Fatalf("expected get error, got %v", err) + } +} + +func TestRedisLB_LoadHLSBySPBHID_BadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte("not-json"))) + lb := newRedisLB(fake) + _, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_LoadHLSBySPBHID_InterfaceLimitation(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`))) + lb := newRedisLB(fake) + _, err := lb.LoadHLSBySPBHID(context.Background(), "abc") + if err == nil || !strings.Contains(err.Error(), "cannot deserialize") { + t.Fatalf("expected interface limitation error, got %v", err) + } +} + +func TestRedisLB_LoadWebRTCByUfrag_GetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringErr(fmt.Errorf("nil"))) + lb := newRedisLB(fake) + _, err := lb.LoadWebRTCByUfrag(context.Background(), "u") + if err == nil || !strings.Contains(err.Error(), "get key=") { + t.Fatalf("expected get error, got %v", err) + } +} + +func TestRedisLB_LoadWebRTCByUfrag_BadJSON(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte("not-json"))) + lb := newRedisLB(fake) + _, err := lb.LoadWebRTCByUfrag(context.Background(), "u") + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("expected unmarshal error, got %v", err) + } +} + +func TestRedisLB_LoadWebRTCByUfrag_InterfaceLimitation(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`))) + lb := newRedisLB(fake) + _, err := lb.LoadWebRTCByUfrag(context.Background(), "u") + if err == nil || !strings.Contains(err.Error(), "cannot deserialize") { + t.Fatalf("expected interface limitation error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// LoadOrStoreHLS and StoreWebRTC. +// ---------------------------------------------------------------------------- + +func TestRedisLB_LoadOrStoreHLS_Success(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + + hls := &stubHLS{spbhid: "abc"} + got, err := lb.LoadOrStoreHLS(context.Background(), "url1", hls) + if err != nil { + t.Fatalf("LoadOrStoreHLS: %v", err) + } + if got != hls { + t.Fatalf("got %v, want input back", got) + } + if fake.SetCallCount() != 2 { + t.Fatalf("expected 2 Set calls (URL + SPBHID), got %d", fake.SetCallCount()) + } +} + +func TestRedisLB_LoadOrStoreHLS_FirstSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("boom"))) + lb := newRedisLB(fake) + _, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"}) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestRedisLB_LoadOrStoreHLS_SecondSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturnsOnCall(0, statusCmd(nil)) + fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom"))) + lb := newRedisLB(fake) + _, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"}) + if err == nil || !strings.Contains(err.Error(), "second boom") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestRedisLB_StoreWebRTC_Success(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(nil)) + lb := newRedisLB(fake) + if err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}); err != nil { + t.Fatalf("StoreWebRTC: %v", err) + } + if fake.SetCallCount() != 2 { + t.Fatalf("expected 2 Set calls (URL + Ufrag), got %d", fake.SetCallCount()) + } +} + +func TestRedisLB_StoreWebRTC_FirstSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturns(statusCmd(fmt.Errorf("boom"))) + lb := newRedisLB(fake) + err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected error, got %v", err) + } +} + +func TestRedisLB_StoreWebRTC_SecondSetFails(t *testing.T) { + fake := &redisclientfakes.FakeRedisClient{} + fake.SetReturnsOnCall(0, statusCmd(nil)) + fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom"))) + lb := newRedisLB(fake) + err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}) + if err == nil || !strings.Contains(err.Error(), "second boom") { + t.Fatalf("expected error, got %v", err) + } +} + +// ---------------------------------------------------------------------------- +// Key helpers. +// ---------------------------------------------------------------------------- + +func TestRedisLB_KeyHelpers(t *testing.T) { + lb := &redisLoadBalancer{} + for _, tt := range []struct { + got, want string + }{ + {lb.redisKeyUfrag("u"), "srs-proxy-ufrag:u"}, + {lb.redisKeyRTC("url"), "srs-proxy-rtc:url"}, + {lb.redisKeySPBHID("s"), "srs-proxy-spbhid:s"}, + {lb.redisKeyHLS("url"), "srs-proxy-hls:url"}, + {lb.redisKeyServer("id"), "srs-proxy-server:id"}, + {lb.redisKeyServers(), "srs-proxy-all-servers"}, + } { + if tt.got != tt.want { + t.Errorf("got %q, want %q", tt.got, tt.want) + } + } +} diff --git a/internal/redisclient/gen.go b/internal/redisclient/gen.go new file mode 100644 index 000000000..5ce43b7be --- /dev/null +++ b/internal/redisclient/gen.go @@ -0,0 +1,6 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package redisclient + +//go:generate go tool counterfeiter -o redisclientfakes/fake_redis_client.go . RedisClient diff --git a/internal/redisclient/redisclient.go b/internal/redisclient/redisclient.go new file mode 100644 index 000000000..78c85d976 --- /dev/null +++ b/internal/redisclient/redisclient.go @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package redisclient + +import ( + "context" + "time" + + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ + "github.com/go-redis/redis/v8" +) + +// RedisClient is the subset of *redis.Client methods used by callers in this +// codebase. Declared as an interface so tests can substitute a fake without +// standing up a real Redis server. *redis.Client satisfies this interface +// directly. +type RedisClient interface { + Ping(ctx context.Context) *redis.StatusCmd + Get(ctx context.Context, key string) *redis.StringCmd + Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd + String() string +} + +// New connects to a Redis server at addr (host:port) with the given password +// and database index. Returns a RedisClient satisfied by *redis.Client. +func New(addr, password string, db int) RedisClient { + return redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + DB: db, + }) +} diff --git a/internal/redisclient/redisclientfakes/fake_redis_client.go b/internal/redisclient/redisclientfakes/fake_redis_client.go new file mode 100644 index 000000000..1ed9c03bd --- /dev/null +++ b/internal/redisclient/redisclientfakes/fake_redis_client.go @@ -0,0 +1,327 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package redisclientfakes + +import ( + "context" + "srsx/internal/redisclient" + "sync" + "time" + + redis "github.com/go-redis/redis/v8" +) + +type FakeRedisClient struct { + GetStub func(context.Context, string) *redis.StringCmd + getMutex sync.RWMutex + getArgsForCall []struct { + arg1 context.Context + arg2 string + } + getReturns struct { + result1 *redis.StringCmd + } + getReturnsOnCall map[int]struct { + result1 *redis.StringCmd + } + PingStub func(context.Context) *redis.StatusCmd + pingMutex sync.RWMutex + pingArgsForCall []struct { + arg1 context.Context + } + pingReturns struct { + result1 *redis.StatusCmd + } + pingReturnsOnCall map[int]struct { + result1 *redis.StatusCmd + } + SetStub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd + setMutex sync.RWMutex + setArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 interface{} + arg4 time.Duration + } + setReturns struct { + result1 *redis.StatusCmd + } + setReturnsOnCall map[int]struct { + result1 *redis.StatusCmd + } + StringStub func() string + stringMutex sync.RWMutex + stringArgsForCall []struct { + } + stringReturns struct { + result1 string + } + stringReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRedisClient) Get(arg1 context.Context, arg2 string) *redis.StringCmd { + fake.getMutex.Lock() + ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)] + fake.getArgsForCall = append(fake.getArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.GetStub + fakeReturns := fake.getReturns + fake.recordInvocation("Get", []interface{}{arg1, arg2}) + fake.getMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) GetCallCount() int { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + return len(fake.getArgsForCall) +} + +func (fake *FakeRedisClient) GetCalls(stub func(context.Context, string) *redis.StringCmd) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = stub +} + +func (fake *FakeRedisClient) GetArgsForCall(i int) (context.Context, string) { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + argsForCall := fake.getArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRedisClient) GetReturns(result1 *redis.StringCmd) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + fake.getReturns = struct { + result1 *redis.StringCmd + }{result1} +} + +func (fake *FakeRedisClient) GetReturnsOnCall(i int, result1 *redis.StringCmd) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + if fake.getReturnsOnCall == nil { + fake.getReturnsOnCall = make(map[int]struct { + result1 *redis.StringCmd + }) + } + fake.getReturnsOnCall[i] = struct { + result1 *redis.StringCmd + }{result1} +} + +func (fake *FakeRedisClient) Ping(arg1 context.Context) *redis.StatusCmd { + fake.pingMutex.Lock() + ret, specificReturn := fake.pingReturnsOnCall[len(fake.pingArgsForCall)] + fake.pingArgsForCall = append(fake.pingArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.PingStub + fakeReturns := fake.pingReturns + fake.recordInvocation("Ping", []interface{}{arg1}) + fake.pingMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) PingCallCount() int { + fake.pingMutex.RLock() + defer fake.pingMutex.RUnlock() + return len(fake.pingArgsForCall) +} + +func (fake *FakeRedisClient) PingCalls(stub func(context.Context) *redis.StatusCmd) { + fake.pingMutex.Lock() + defer fake.pingMutex.Unlock() + fake.PingStub = stub +} + +func (fake *FakeRedisClient) PingArgsForCall(i int) context.Context { + fake.pingMutex.RLock() + defer fake.pingMutex.RUnlock() + argsForCall := fake.pingArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRedisClient) PingReturns(result1 *redis.StatusCmd) { + fake.pingMutex.Lock() + defer fake.pingMutex.Unlock() + fake.PingStub = nil + fake.pingReturns = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) PingReturnsOnCall(i int, result1 *redis.StatusCmd) { + fake.pingMutex.Lock() + defer fake.pingMutex.Unlock() + fake.PingStub = nil + if fake.pingReturnsOnCall == nil { + fake.pingReturnsOnCall = make(map[int]struct { + result1 *redis.StatusCmd + }) + } + fake.pingReturnsOnCall[i] = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) Set(arg1 context.Context, arg2 string, arg3 interface{}, arg4 time.Duration) *redis.StatusCmd { + fake.setMutex.Lock() + ret, specificReturn := fake.setReturnsOnCall[len(fake.setArgsForCall)] + fake.setArgsForCall = append(fake.setArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 interface{} + arg4 time.Duration + }{arg1, arg2, arg3, arg4}) + stub := fake.SetStub + fakeReturns := fake.setReturns + fake.recordInvocation("Set", []interface{}{arg1, arg2, arg3, arg4}) + fake.setMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) SetCallCount() int { + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + return len(fake.setArgsForCall) +} + +func (fake *FakeRedisClient) SetCalls(stub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = stub +} + +func (fake *FakeRedisClient) SetArgsForCall(i int) (context.Context, string, interface{}, time.Duration) { + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + argsForCall := fake.setArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeRedisClient) SetReturns(result1 *redis.StatusCmd) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = nil + fake.setReturns = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) SetReturnsOnCall(i int, result1 *redis.StatusCmd) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = nil + if fake.setReturnsOnCall == nil { + fake.setReturnsOnCall = make(map[int]struct { + result1 *redis.StatusCmd + }) + } + fake.setReturnsOnCall[i] = struct { + result1 *redis.StatusCmd + }{result1} +} + +func (fake *FakeRedisClient) String() string { + fake.stringMutex.Lock() + ret, specificReturn := fake.stringReturnsOnCall[len(fake.stringArgsForCall)] + fake.stringArgsForCall = append(fake.stringArgsForCall, struct { + }{}) + stub := fake.StringStub + fakeReturns := fake.stringReturns + fake.recordInvocation("String", []interface{}{}) + fake.stringMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRedisClient) StringCallCount() int { + fake.stringMutex.RLock() + defer fake.stringMutex.RUnlock() + return len(fake.stringArgsForCall) +} + +func (fake *FakeRedisClient) StringCalls(stub func() string) { + fake.stringMutex.Lock() + defer fake.stringMutex.Unlock() + fake.StringStub = stub +} + +func (fake *FakeRedisClient) StringReturns(result1 string) { + fake.stringMutex.Lock() + defer fake.stringMutex.Unlock() + fake.StringStub = nil + fake.stringReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeRedisClient) StringReturnsOnCall(i int, result1 string) { + fake.stringMutex.Lock() + defer fake.stringMutex.Unlock() + fake.StringStub = nil + if fake.stringReturnsOnCall == nil { + fake.stringReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.stringReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeRedisClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRedisClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ redisclient.RedisClient = new(FakeRedisClient) diff --git a/internal/rtmp/amf0.go b/internal/rtmp/amf0.go index 7fd2c7a3d..e316f1b8f 100644 --- a/internal/rtmp/amf0.go +++ b/internal/rtmp/amf0.go @@ -90,7 +90,9 @@ type amf0Buffer interface { Write(p []byte) (n int, err error) } -var createBuffer = func() amf0Buffer { +// defaultBufFactory is the production amf0Buffer factory. Tests override the +// per-instance bufFactory field on amf0ObjectBase instead of swapping a global. +func defaultBufFactory() amf0Buffer { return &bytes.Buffer{} } @@ -399,6 +401,10 @@ type amf0Property struct { type amf0ObjectBase struct { properties []*amf0Property lock sync.Mutex + // bufFactory creates the amf0Buffer used by MarshalBinary. Held as a + // per-instance field (not a package global) so concurrent tests can each + // install their own buggy buffers without racing on shared state. + bufFactory func() amf0Buffer } func (v *amf0ObjectBase) Size() int { @@ -562,6 +568,7 @@ func NewAmf0Object() Amf0Object { func newAmf0Object() *amf0Object { v := &amf0Object{} v.properties = []*amf0Property{} + v.bufFactory = defaultBufFactory return v } @@ -600,7 +607,7 @@ func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { } func (v *amf0Object) MarshalBinary() (data []byte, err error) { - b := createBuffer() + b := v.bufFactory() if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { return nil, errors.Wrap(err, "marshal") @@ -640,6 +647,7 @@ func NewAmf0EcmaArray() Amf0EcmaArray { func newAmf0EcmaArray() *amf0EcmaArray { v := &amf0EcmaArray{} v.properties = []*amf0Property{} + v.bufFactory = defaultBufFactory return v } @@ -678,7 +686,7 @@ func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { } func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { - b := createBuffer() + b := v.bufFactory() if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { return nil, errors.Wrap(err, "marshal") @@ -717,6 +725,7 @@ type amf0StrictArray struct { func NewAmf0StrictArray() Amf0StrictArray { v := &amf0StrictArray{} v.properties = []*amf0Property{} + v.bufFactory = defaultBufFactory return v } @@ -759,7 +768,7 @@ func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { } func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { - b := createBuffer() + b := v.bufFactory() if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { return nil, errors.Wrap(err, "marshal") diff --git a/internal/rtmp/amf0_test.go b/internal/rtmp/amf0_test.go index a2c240360..da102e7b1 100644 --- a/internal/rtmp/amf0_test.go +++ b/internal/rtmp/amf0_test.go @@ -436,10 +436,21 @@ func (v *errorAmf0Any) amf0Marker() amf0Marker { return amf0MarkerNumber } -func TestAmf0MarshalErrors(t *testing.T) { - originalCreateBuffer := createBuffer - defer func() { createBuffer = originalCreateBuffer }() +// setBufFactory replaces the bufFactory on whichever amf0 object-like type +// underlies v. Concurrent tests can use this safely because each value carries +// its own factory. +func setBufFactory(v Amf0Any, fn func() amf0Buffer) { + switch v := v.(type) { + case *amf0Object: + v.bufFactory = fn + case *amf0EcmaArray: + v.bufFactory = fn + case *amf0StrictArray: + v.bufFactory = fn + } +} +func TestAmf0MarshalErrors(t *testing.T) { for _, tt := range []struct { name string make func() Amf0Any @@ -449,15 +460,16 @@ func TestAmf0MarshalErrors(t *testing.T) { {"strict-array", func() Amf0Any { return NewAmf0StrictArray() }}, } { t.Run(tt.name+" write-byte", func(t *testing.T) { - createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} } - if _, err := tt.make().MarshalBinary(); err == nil { + value := tt.make() + setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} }) + if _, err := value.MarshalBinary(); err == nil { t.Fatal("MarshalBinary() should fail") } }) t.Run(tt.name+" write-prop", func(t *testing.T) { - createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} } value := tt.make() + setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} }) switch v := value.(type) { case Amf0Object: v.Set("name", NewAmf0String("stream")) @@ -473,7 +485,6 @@ func TestAmf0MarshalErrors(t *testing.T) { }) } - createBuffer = originalCreateBuffer for _, tt := range []struct { name string make func() Amf0Any diff --git a/internal/signal/signal.go b/internal/signal/signal.go index b8930480b..a23b36a0f 100644 --- a/internal/signal/signal.go +++ b/internal/signal/signal.go @@ -15,15 +15,26 @@ import ( "srsx/internal/logger" ) -// Indirections so tests can substitute signal delivery and process exit. -var ( - signalNotify = signal.Notify - osExit = os.Exit -) +// Handler installs OS signal handlers and the force-quit timer. The notify +// and exit indirections are struct fields (not package globals) so concurrent +// tests can each construct a handler with their own fakes without racing on +// shared state. +type Handler struct { + notify func(c chan<- os.Signal, sig ...os.Signal) + exit func(code int) +} -func InstallSignals(ctx context.Context, cancel context.CancelFunc) { +// NewHandler returns a Handler wired to the real OS implementations. +func NewHandler() *Handler { + return &Handler{ + notify: signal.Notify, + exit: os.Exit, + } +} + +func (h *Handler) InstallSignals(ctx context.Context, cancel context.CancelFunc) { sc := make(chan os.Signal, 1) - signalNotify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + h.notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) go func() { for s := range sc { @@ -33,7 +44,7 @@ func InstallSignals(ctx context.Context, cancel context.CancelFunc) { }() } -func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error { +func (h *Handler) InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error { var forceTimeout time.Duration timeoutStr := environment.ForceQuitTimeout() if t, err := time.ParseDuration(timeoutStr); err != nil { @@ -46,7 +57,7 @@ func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) err <-ctx.Done() time.Sleep(forceTimeout) logger.Warn(ctx, "Force to exit by timeout") - osExit(1) + h.exit(1) }() return nil } diff --git a/internal/signal/signal_test.go b/internal/signal/signal_test.go index ea3fac252..fc1471dd6 100644 --- a/internal/signal/signal_test.go +++ b/internal/signal/signal_test.go @@ -16,59 +16,60 @@ import ( "srsx/internal/env/envfakes" ) -// swapNotify replaces signalNotify with a capturing fake and returns a getter -// for the channel registered by the code under test plus a restore func. -func swapNotify(t *testing.T) (func() chan<- os.Signal, func()) { - t.Helper() - orig := signalNotify +// captureNotify returns a Handler whose notify field records the channel +// passed by the code under test, plus a getter that retrieves it. +func captureNotify() (*Handler, func() chan<- os.Signal) { var ( mu sync.Mutex ch chan<- os.Signal ) - signalNotify = func(c chan<- os.Signal, _ ...os.Signal) { - mu.Lock() - defer mu.Unlock() - ch = c - } - return func() chan<- os.Signal { + h := &Handler{ + notify: func(c chan<- os.Signal, _ ...os.Signal) { mu.Lock() defer mu.Unlock() - return ch - }, func() { - signalNotify = orig - } + ch = c + }, + exit: os.Exit, + } + return h, func() chan<- os.Signal { + mu.Lock() + defer mu.Unlock() + return ch + } } -func swapExit(t *testing.T) (*int32, chan int, func()) { - t.Helper() - orig := osExit +// captureExit returns a Handler whose exit field records the code and never +// returns, plus a flag and channel that observe the call. +func captureExit() (*Handler, *int32, chan int) { var called int32 done := make(chan int, 1) - osExit = func(code int) { - atomic.StoreInt32(&called, 1) - select { - case done <- code: - default: - } - // Block to mimic os.Exit never returning; the goroutine holding us - // here is abandoned when the test ends. - select {} + h := &Handler{ + notify: func(chan<- os.Signal, ...os.Signal) {}, + exit: func(code int) { + atomic.StoreInt32(&called, 1) + select { + case done <- code: + default: + } + // Block to mimic os.Exit never returning; the goroutine holding us + // here is abandoned when the test ends. + select {} + }, } - return &called, done, func() { osExit = orig } + return h, &called, done } func TestInstallSignals_CancelsOnSignal(t *testing.T) { - getCh, restore := swapNotify(t) - defer restore() + h, getCh := captureNotify() ctx, cancel := context.WithCancel(t.Context()) defer cancel() - InstallSignals(ctx, cancel) + h.InstallSignals(ctx, cancel) ch := getCh() if ch == nil { - t.Fatal("signalNotify was not called") + t.Fatal("notify was not called") } ch <- syscall.SIGINT @@ -80,13 +81,12 @@ func TestInstallSignals_CancelsOnSignal(t *testing.T) { } func TestInstallSignals_HandlesRepeatedSignals(t *testing.T) { - getCh, restore := swapNotify(t) - defer restore() + h, getCh := captureNotify() ctx, cancel := context.WithCancel(t.Context()) defer cancel() - InstallSignals(ctx, cancel) + h.InstallSignals(ctx, cancel) ch := getCh() // Multiple signals must not panic; cancel() is idempotent. @@ -105,7 +105,7 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) { fakeEnv := &envfakes.FakeProxyEnvironment{} fakeEnv.ForceQuitTimeoutReturns("not-a-duration") - err := InstallForceQuit(t.Context(), fakeEnv) + err := NewHandler().InstallForceQuit(t.Context(), fakeEnv) if err == nil { t.Fatal("want error for bad duration") } @@ -118,20 +118,19 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) { } func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) { - called, done, restore := swapExit(t) - defer restore() + h, called, done := captureExit() fakeEnv := &envfakes.FakeProxyEnvironment{} fakeEnv.ForceQuitTimeoutReturns("1ms") ctx, cancel := context.WithCancel(t.Context()) - if err := InstallForceQuit(ctx, fakeEnv); err != nil { + if err := h.InstallForceQuit(ctx, fakeEnv); err != nil { t.Fatalf("unexpected err: %v", err) } // Before cancel, the goroutine is blocked and exit must not fire. if atomic.LoadInt32(called) != 0 { - t.Fatal("osExit called before ctx cancel") + t.Fatal("exit called before ctx cancel") } cancel() @@ -141,30 +140,39 @@ func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) { t.Fatalf("exit code = %d, want 1", code) } case <-time.After(time.Second): - t.Fatal("osExit not called after cancel + timeout") + t.Fatal("exit not called after cancel + timeout") } } func TestInstallForceQuit_WaitsForCancelBeforeSleeping(t *testing.T) { - called, done, restore := swapExit(t) - defer restore() + h, called, done := captureExit() fakeEnv := &envfakes.FakeProxyEnvironment{} fakeEnv.ForceQuitTimeoutReturns("10ms") - // Intentionally use a never-canceled context and leak the goroutine: - // if we canceled at test end, the goroutine would wake and race with - // restore() writing osExit. - if err := InstallForceQuit(context.Background(), fakeEnv); err != nil { + // Intentionally use a never-canceled context and leak the goroutine: the + // handler's exit closure is owned by this test instance, so leaving the + // goroutine alive doesn't race other tests. + if err := h.InstallForceQuit(context.Background(), fakeEnv); err != nil { t.Fatalf("unexpected err: %v", err) } select { case <-done: - t.Fatal("osExit fired without ctx cancel") + t.Fatal("exit fired without ctx cancel") case <-time.After(30 * time.Millisecond): } if atomic.LoadInt32(called) != 0 { - t.Fatal("osExit called unexpectedly") + t.Fatal("exit called unexpectedly") + } +} + +func TestNewHandler_UsesRealOSDefaults(t *testing.T) { + h := NewHandler() + if h.notify == nil { + t.Error("notify default not set") + } + if h.exit == nil { + t.Error("exit default not set") } } diff --git a/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh b/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh index eba2f5077..3721d8ac1 100755 --- a/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh @@ -5,11 +5,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/proxy-e2e-redis-test.sh b/skills/srs-develop/scripts/proxy-e2e-redis-test.sh index e5e54f77b..fee4a0b94 100755 --- a/skills/srs-develop/scripts/proxy-e2e-redis-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-redis-test.sh @@ -6,11 +6,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/proxy-e2e-srt-test.sh b/skills/srs-develop/scripts/proxy-e2e-srt-test.sh index 030894b47..38832e746 100755 --- a/skills/srs-develop/scripts/proxy-e2e-srt-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-srt-test.sh @@ -10,11 +10,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/proxy-e2e-test.sh b/skills/srs-develop/scripts/proxy-e2e-test.sh index 61f25608e..294ca0cf8 100755 --- a/skills/srs-develop/scripts/proxy-e2e-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-test.sh @@ -4,11 +4,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh b/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh index 0f9069857..df38e4489 100755 --- a/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh @@ -6,11 +6,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/proxy-e2e-whip-test.sh b/skills/srs-develop/scripts/proxy-e2e-whip-test.sh index 0127c6f58..f0961c641 100755 --- a/skills/srs-develop/scripts/proxy-e2e-whip-test.sh +++ b/skills/srs-develop/scripts/proxy-e2e-whip-test.sh @@ -10,11 +10,16 @@ set -e SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi diff --git a/skills/srs-develop/scripts/proxy-utest.sh b/skills/srs-develop/scripts/proxy-utest.sh index 9e030d99f..52b6f590c 100755 --- a/skills/srs-develop/scripts/proxy-utest.sh +++ b/skills/srs-develop/scripts/proxy-utest.sh @@ -27,11 +27,16 @@ for arg in "$@"; do done SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" -# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs -WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" +# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.." +# counting when the skills directory is reached via a symlink (which changes +# the symbolic vs. physical depth). +WORKSPACE="$SCRIPT_DIR" +while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do + WORKSPACE="$(dirname "$WORKSPACE")" +done if [[ ! -f "$WORKSPACE/go.mod" ]]; then - echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2 exit 1 fi