Claude: Split lb interfaces, extract redisclient, drop race-prone globals.

- Split OriginLoadBalancer into OriginService / HLSService / RTCService;
  the original interface now embeds the three role interfaces. Generate
  counterfeiter fakes for all four.
- Extract internal/redisclient: RedisClient interface + New() factory.
  internal/lb/redis.go no longer imports github.com/go-redis/redis/v8.
- Add unit tests for lb.go (OriginServer.ID/String/Format/NewOriginServer)
  and for the full memory + redis load balancers.
- Replace package-level test seams (memoryKeepaliveInterval, newRedisClient,
  redisKeepaliveInterval, signal.signalNotify/osExit, rtmp.createBuffer) with
  per-instance struct fields so concurrent tests can't race on them.
- Promote signal.InstallSignals / InstallForceQuit onto a new signal.Handler
  type; update bootstrap to construct one.
- Move rtmp createBuffer onto amf0ObjectBase as bufFactory; the three AMF0
  marshalers and their tests use the per-instance factory.
- Make proxy test scripts locate the workspace by walking up to go.mod
  instead of brittle '../../../..' counting (symlink-aware).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
winlin 2026-05-16 11:11:18 -04:00
parent f45bf30b46
commit 3060bf8e7c
26 changed files with 2811 additions and 116 deletions

View File

@ -33,7 +33,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error {
// Install signals.
ctx, cancel := context.WithCancel(ctx)
signal.InstallSignals(ctx, cancel)
signal.NewHandler().InstallSignals(ctx, cancel)
// Run the main loop, ignore the user cancel error.
err := b.run(ctx)
@ -58,7 +58,7 @@ func (b *proxyBootstrap) run(ctx context.Context) error {
// When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur
// because the main thread exits after the context is cancelled. However, sometimes the main thread
// may be blocked for some reason, so a forced exit is necessary to ensure the program terminates.
if err := signal.InstallForceQuit(ctx, environment); err != nil {
if err := signal.NewHandler().InstallForceQuit(ctx, environment); err != nil {
return errors.Wrapf(err, "install force quit")
}

9
internal/lb/gen.go Normal file
View File

@ -0,0 +1,9 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package lb
//go:generate go tool counterfeiter -o lbfakes/fake_origin_load_balancer.go . OriginLoadBalancer
//go:generate go tool counterfeiter -o lbfakes/fake_origin_service.go . OriginService
//go:generate go tool counterfeiter -o lbfakes/fake_hls_service.go . HLSService
//go:generate go tool counterfeiter -o lbfakes/fake_rtc_service.go . RTCService

View File

@ -109,20 +109,35 @@ type RTCConnection interface {
GetUfrag() string
}
// OriginLoadBalancer is the interface to load balance the SRS servers.
type OriginLoadBalancer interface {
// Initialize the load balancer.
Initialize(ctx context.Context) error
// OriginService is the interface for origin-server registry and stream routing.
type OriginService interface {
// Update records the latest registration or heartbeat for an origin server.
Update(ctx context.Context, server *OriginServer) error
// Pick a backend server for the specified stream URL.
Pick(ctx context.Context, streamURL string) (*OriginServer, error)
}
// HLSService is the interface for HLS session state, indexed by stream URL and SPBHID.
type HLSService interface {
// Load or store the HLS streaming for the specified stream URL.
LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error)
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error)
}
// RTCService is the interface for WebRTC session state, indexed by stream URL and ICE ufrag.
type RTCService interface {
// Store the WebRTC streaming for the specified stream URL.
StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error)
}
// OriginLoadBalancer is the interface to load balance the SRS servers.
type OriginLoadBalancer interface {
OriginService
HLSService
RTCService
// Initialize the load balancer.
Initialize(ctx context.Context) error
}

141
internal/lb/lb_test.go Normal file
View File

@ -0,0 +1,141 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package lb
import (
"fmt"
"strings"
"testing"
"time"
)
func TestOriginServerID(t *testing.T) {
for _, tt := range []struct {
name string
v *OriginServer
want string
}{
{"populated", &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1234"}, "srv-svc-1234"},
{"empty", &OriginServer{}, "--"},
} {
t.Run(tt.name, func(t *testing.T) {
if got := tt.v.ID(); got != tt.want {
t.Fatalf("ID()=%q, want %q", got, tt.want)
}
})
}
}
func TestOriginServerString(t *testing.T) {
// String() routes through Format with the %v default branch.
v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"}
got := v.String()
if want := "SRS ip=1.2.3.4, id=srv-svc-p"; got != want {
t.Fatalf("String()=%q, want %q", got, want)
}
}
func TestOriginServerFormat_ShortVerbs(t *testing.T) {
v := &OriginServer{IP: "10.0.0.1", ServerID: "srv", ServiceID: "svc", PID: "9"}
want := "SRS ip=10.0.0.1, id=srv-svc-9"
for _, verb := range []string{"%v", "%s"} {
got := fmt.Sprintf(verb, v)
if got != want {
t.Fatalf("Sprintf(%q)=%q, want %q", verb, got, want)
}
}
}
func TestOriginServerFormat_PlusVerbsAllFields(t *testing.T) {
ts := time.Date(2026, 5, 16, 10, 30, 45, 123_000_000, time.UTC)
v := &OriginServer{
IP: "10.0.0.1", DeviceID: "dev1",
ServerID: "srv", ServiceID: "svc", PID: "9",
RTMP: []string{":1935", ":1936"},
HTTP: []string{":8080"},
API: []string{":1985"},
SRT: []string{":10080"},
RTC: []string{":8000"},
UpdatedAt: ts,
}
for _, verb := range []string{"%+v", "%+s"} {
got := fmt.Sprintf(verb, v)
for _, sub := range []string{
"SRS ip=10.0.0.1",
"id=srv-svc-9",
"pid=9, server=srv, service=svc",
"device=dev1",
"rtmp=[:1935,:1936]",
"http=[:8080]",
"api=[:1985]",
"srt=[:10080]",
"rtc=[:8000]",
"update=2026-05-16 10:30:45.123",
} {
if !strings.Contains(got, sub) {
t.Fatalf("Sprintf(%q)=%q missing %q", verb, got, sub)
}
}
}
}
func TestOriginServerFormat_PlusVerbMinimal(t *testing.T) {
// Plus verb with no optional fields populated exercises the false
// branches of every "if len(X) > 0 / X != \"\"" guard in Format.
v := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "9"}
got := fmt.Sprintf("%+v", v)
if !strings.Contains(got, "pid=9, server=srv, service=svc") {
t.Fatalf("%%+v output %q missing core ids", got)
}
if !strings.Contains(got, "update=") {
t.Fatalf("%%+v output %q missing update timestamp", got)
}
for _, sub := range []string{"device=", "rtmp=", "http=", "api=", "srt=", "rtc="} {
if strings.Contains(got, sub) {
t.Fatalf("%%+v output %q should not contain %q for an empty field", got, sub)
}
}
}
func TestOriginServerFormat_OtherVerb(t *testing.T) {
// A non-v/s verb falls through to the default branch, which recursively
// formats with %v and appends ", fmt=%<verb>".
v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"}
got := fmt.Sprintf("%d", v)
want := "SRS ip=1.2.3.4, id=srv-svc-p, fmt=%d"
if got != want {
t.Fatalf("%%d output %q, want %q", got, want)
}
}
func TestNewOriginServer(t *testing.T) {
t.Run("no opts", func(t *testing.T) {
v := NewOriginServer()
if v == nil {
t.Fatal("NewOriginServer() returned nil")
}
if v.IP != "" || v.DeviceID != "" || v.ServerID != "" || v.ServiceID != "" || v.PID != "" {
t.Fatalf("expected zero value, got %+v", v)
}
if len(v.RTMP)+len(v.HTTP)+len(v.API)+len(v.SRT)+len(v.RTC) != 0 {
t.Fatalf("expected empty endpoints, got %+v", v)
}
if !v.UpdatedAt.IsZero() {
t.Fatalf("expected zero UpdatedAt, got %v", v.UpdatedAt)
}
})
t.Run("with opts", func(t *testing.T) {
v := NewOriginServer(
func(s *OriginServer) { s.IP = "9.9.9.9" },
func(s *OriginServer) { s.ServerID = "abc" },
func(s *OriginServer) { s.RTMP = []string{":1935"} },
)
if v.IP != "9.9.9.9" || v.ServerID != "abc" || len(v.RTMP) != 1 || v.RTMP[0] != ":1935" {
t.Fatalf("opts not applied: got %+v", v)
}
})
}

View File

@ -0,0 +1,197 @@
// Code generated by counterfeiter. DO NOT EDIT.
package lbfakes
import (
"context"
"srsx/internal/lb"
"sync"
)
type FakeHLSService struct {
LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error)
loadHLSBySPBHIDMutex sync.RWMutex
loadHLSBySPBHIDArgsForCall []struct {
arg1 context.Context
arg2 string
}
loadHLSBySPBHIDReturns struct {
result1 lb.HLSPlayStream
result2 error
}
loadHLSBySPBHIDReturnsOnCall map[int]struct {
result1 lb.HLSPlayStream
result2 error
}
LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)
loadOrStoreHLSMutex sync.RWMutex
loadOrStoreHLSArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 lb.HLSPlayStream
}
loadOrStoreHLSReturns struct {
result1 lb.HLSPlayStream
result2 error
}
loadOrStoreHLSReturnsOnCall map[int]struct {
result1 lb.HLSPlayStream
result2 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeHLSService) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) {
fake.loadHLSBySPBHIDMutex.Lock()
ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)]
fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.LoadHLSBySPBHIDStub
fakeReturns := fake.loadHLSBySPBHIDReturns
fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2})
fake.loadHLSBySPBHIDMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeHLSService) LoadHLSBySPBHIDCallCount() int {
fake.loadHLSBySPBHIDMutex.RLock()
defer fake.loadHLSBySPBHIDMutex.RUnlock()
return len(fake.loadHLSBySPBHIDArgsForCall)
}
func (fake *FakeHLSService) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) {
fake.loadHLSBySPBHIDMutex.Lock()
defer fake.loadHLSBySPBHIDMutex.Unlock()
fake.LoadHLSBySPBHIDStub = stub
}
func (fake *FakeHLSService) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) {
fake.loadHLSBySPBHIDMutex.RLock()
defer fake.loadHLSBySPBHIDMutex.RUnlock()
argsForCall := fake.loadHLSBySPBHIDArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeHLSService) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) {
fake.loadHLSBySPBHIDMutex.Lock()
defer fake.loadHLSBySPBHIDMutex.Unlock()
fake.LoadHLSBySPBHIDStub = nil
fake.loadHLSBySPBHIDReturns = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeHLSService) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
fake.loadHLSBySPBHIDMutex.Lock()
defer fake.loadHLSBySPBHIDMutex.Unlock()
fake.LoadHLSBySPBHIDStub = nil
if fake.loadHLSBySPBHIDReturnsOnCall == nil {
fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct {
result1 lb.HLSPlayStream
result2 error
})
}
fake.loadHLSBySPBHIDReturnsOnCall[i] = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeHLSService) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) {
fake.loadOrStoreHLSMutex.Lock()
ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)]
fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 lb.HLSPlayStream
}{arg1, arg2, arg3})
stub := fake.LoadOrStoreHLSStub
fakeReturns := fake.loadOrStoreHLSReturns
fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3})
fake.loadOrStoreHLSMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeHLSService) LoadOrStoreHLSCallCount() int {
fake.loadOrStoreHLSMutex.RLock()
defer fake.loadOrStoreHLSMutex.RUnlock()
return len(fake.loadOrStoreHLSArgsForCall)
}
func (fake *FakeHLSService) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) {
fake.loadOrStoreHLSMutex.Lock()
defer fake.loadOrStoreHLSMutex.Unlock()
fake.LoadOrStoreHLSStub = stub
}
func (fake *FakeHLSService) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) {
fake.loadOrStoreHLSMutex.RLock()
defer fake.loadOrStoreHLSMutex.RUnlock()
argsForCall := fake.loadOrStoreHLSArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
}
func (fake *FakeHLSService) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) {
fake.loadOrStoreHLSMutex.Lock()
defer fake.loadOrStoreHLSMutex.Unlock()
fake.LoadOrStoreHLSStub = nil
fake.loadOrStoreHLSReturns = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeHLSService) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
fake.loadOrStoreHLSMutex.Lock()
defer fake.loadOrStoreHLSMutex.Unlock()
fake.LoadOrStoreHLSStub = nil
if fake.loadOrStoreHLSReturnsOnCall == nil {
fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct {
result1 lb.HLSPlayStream
result2 error
})
}
fake.loadOrStoreHLSReturnsOnCall[i] = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeHLSService) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeHLSService) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ lb.HLSService = new(FakeHLSService)

View File

@ -0,0 +1,577 @@
// Code generated by counterfeiter. DO NOT EDIT.
package lbfakes
import (
"context"
"srsx/internal/lb"
"sync"
)
type FakeOriginLoadBalancer struct {
InitializeStub func(context.Context) error
initializeMutex sync.RWMutex
initializeArgsForCall []struct {
arg1 context.Context
}
initializeReturns struct {
result1 error
}
initializeReturnsOnCall map[int]struct {
result1 error
}
LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error)
loadHLSBySPBHIDMutex sync.RWMutex
loadHLSBySPBHIDArgsForCall []struct {
arg1 context.Context
arg2 string
}
loadHLSBySPBHIDReturns struct {
result1 lb.HLSPlayStream
result2 error
}
loadHLSBySPBHIDReturnsOnCall map[int]struct {
result1 lb.HLSPlayStream
result2 error
}
LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)
loadOrStoreHLSMutex sync.RWMutex
loadOrStoreHLSArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 lb.HLSPlayStream
}
loadOrStoreHLSReturns struct {
result1 lb.HLSPlayStream
result2 error
}
loadOrStoreHLSReturnsOnCall map[int]struct {
result1 lb.HLSPlayStream
result2 error
}
LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error)
loadWebRTCByUfragMutex sync.RWMutex
loadWebRTCByUfragArgsForCall []struct {
arg1 context.Context
arg2 string
}
loadWebRTCByUfragReturns struct {
result1 lb.RTCConnection
result2 error
}
loadWebRTCByUfragReturnsOnCall map[int]struct {
result1 lb.RTCConnection
result2 error
}
PickStub func(context.Context, string) (*lb.OriginServer, error)
pickMutex sync.RWMutex
pickArgsForCall []struct {
arg1 context.Context
arg2 string
}
pickReturns struct {
result1 *lb.OriginServer
result2 error
}
pickReturnsOnCall map[int]struct {
result1 *lb.OriginServer
result2 error
}
StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error
storeWebRTCMutex sync.RWMutex
storeWebRTCArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 lb.RTCConnection
}
storeWebRTCReturns struct {
result1 error
}
storeWebRTCReturnsOnCall map[int]struct {
result1 error
}
UpdateStub func(context.Context, *lb.OriginServer) error
updateMutex sync.RWMutex
updateArgsForCall []struct {
arg1 context.Context
arg2 *lb.OriginServer
}
updateReturns struct {
result1 error
}
updateReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeOriginLoadBalancer) Initialize(arg1 context.Context) error {
fake.initializeMutex.Lock()
ret, specificReturn := fake.initializeReturnsOnCall[len(fake.initializeArgsForCall)]
fake.initializeArgsForCall = append(fake.initializeArgsForCall, struct {
arg1 context.Context
}{arg1})
stub := fake.InitializeStub
fakeReturns := fake.initializeReturns
fake.recordInvocation("Initialize", []interface{}{arg1})
fake.initializeMutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeOriginLoadBalancer) InitializeCallCount() int {
fake.initializeMutex.RLock()
defer fake.initializeMutex.RUnlock()
return len(fake.initializeArgsForCall)
}
func (fake *FakeOriginLoadBalancer) InitializeCalls(stub func(context.Context) error) {
fake.initializeMutex.Lock()
defer fake.initializeMutex.Unlock()
fake.InitializeStub = stub
}
func (fake *FakeOriginLoadBalancer) InitializeArgsForCall(i int) context.Context {
fake.initializeMutex.RLock()
defer fake.initializeMutex.RUnlock()
argsForCall := fake.initializeArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeOriginLoadBalancer) InitializeReturns(result1 error) {
fake.initializeMutex.Lock()
defer fake.initializeMutex.Unlock()
fake.InitializeStub = nil
fake.initializeReturns = struct {
result1 error
}{result1}
}
func (fake *FakeOriginLoadBalancer) InitializeReturnsOnCall(i int, result1 error) {
fake.initializeMutex.Lock()
defer fake.initializeMutex.Unlock()
fake.InitializeStub = nil
if fake.initializeReturnsOnCall == nil {
fake.initializeReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.initializeReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) {
fake.loadHLSBySPBHIDMutex.Lock()
ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)]
fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.LoadHLSBySPBHIDStub
fakeReturns := fake.loadHLSBySPBHIDReturns
fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2})
fake.loadHLSBySPBHIDMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCallCount() int {
fake.loadHLSBySPBHIDMutex.RLock()
defer fake.loadHLSBySPBHIDMutex.RUnlock()
return len(fake.loadHLSBySPBHIDArgsForCall)
}
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) {
fake.loadHLSBySPBHIDMutex.Lock()
defer fake.loadHLSBySPBHIDMutex.Unlock()
fake.LoadHLSBySPBHIDStub = stub
}
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) {
fake.loadHLSBySPBHIDMutex.RLock()
defer fake.loadHLSBySPBHIDMutex.RUnlock()
argsForCall := fake.loadHLSBySPBHIDArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) {
fake.loadHLSBySPBHIDMutex.Lock()
defer fake.loadHLSBySPBHIDMutex.Unlock()
fake.LoadHLSBySPBHIDStub = nil
fake.loadHLSBySPBHIDReturns = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
fake.loadHLSBySPBHIDMutex.Lock()
defer fake.loadHLSBySPBHIDMutex.Unlock()
fake.LoadHLSBySPBHIDStub = nil
if fake.loadHLSBySPBHIDReturnsOnCall == nil {
fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct {
result1 lb.HLSPlayStream
result2 error
})
}
fake.loadHLSBySPBHIDReturnsOnCall[i] = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) {
fake.loadOrStoreHLSMutex.Lock()
ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)]
fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 lb.HLSPlayStream
}{arg1, arg2, arg3})
stub := fake.LoadOrStoreHLSStub
fakeReturns := fake.loadOrStoreHLSReturns
fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3})
fake.loadOrStoreHLSMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCallCount() int {
fake.loadOrStoreHLSMutex.RLock()
defer fake.loadOrStoreHLSMutex.RUnlock()
return len(fake.loadOrStoreHLSArgsForCall)
}
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) {
fake.loadOrStoreHLSMutex.Lock()
defer fake.loadOrStoreHLSMutex.Unlock()
fake.LoadOrStoreHLSStub = stub
}
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) {
fake.loadOrStoreHLSMutex.RLock()
defer fake.loadOrStoreHLSMutex.RUnlock()
argsForCall := fake.loadOrStoreHLSArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
}
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) {
fake.loadOrStoreHLSMutex.Lock()
defer fake.loadOrStoreHLSMutex.Unlock()
fake.LoadOrStoreHLSStub = nil
fake.loadOrStoreHLSReturns = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
fake.loadOrStoreHLSMutex.Lock()
defer fake.loadOrStoreHLSMutex.Unlock()
fake.LoadOrStoreHLSStub = nil
if fake.loadOrStoreHLSReturnsOnCall == nil {
fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct {
result1 lb.HLSPlayStream
result2 error
})
}
fake.loadOrStoreHLSReturnsOnCall[i] = struct {
result1 lb.HLSPlayStream
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) {
fake.loadWebRTCByUfragMutex.Lock()
ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)]
fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.LoadWebRTCByUfragStub
fakeReturns := fake.loadWebRTCByUfragReturns
fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2})
fake.loadWebRTCByUfragMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCallCount() int {
fake.loadWebRTCByUfragMutex.RLock()
defer fake.loadWebRTCByUfragMutex.RUnlock()
return len(fake.loadWebRTCByUfragArgsForCall)
}
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) {
fake.loadWebRTCByUfragMutex.Lock()
defer fake.loadWebRTCByUfragMutex.Unlock()
fake.LoadWebRTCByUfragStub = stub
}
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) {
fake.loadWebRTCByUfragMutex.RLock()
defer fake.loadWebRTCByUfragMutex.RUnlock()
argsForCall := fake.loadWebRTCByUfragArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) {
fake.loadWebRTCByUfragMutex.Lock()
defer fake.loadWebRTCByUfragMutex.Unlock()
fake.LoadWebRTCByUfragStub = nil
fake.loadWebRTCByUfragReturns = struct {
result1 lb.RTCConnection
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) {
fake.loadWebRTCByUfragMutex.Lock()
defer fake.loadWebRTCByUfragMutex.Unlock()
fake.LoadWebRTCByUfragStub = nil
if fake.loadWebRTCByUfragReturnsOnCall == nil {
fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct {
result1 lb.RTCConnection
result2 error
})
}
fake.loadWebRTCByUfragReturnsOnCall[i] = struct {
result1 lb.RTCConnection
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) {
fake.pickMutex.Lock()
ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)]
fake.pickArgsForCall = append(fake.pickArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.PickStub
fakeReturns := fake.pickReturns
fake.recordInvocation("Pick", []interface{}{arg1, arg2})
fake.pickMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeOriginLoadBalancer) PickCallCount() int {
fake.pickMutex.RLock()
defer fake.pickMutex.RUnlock()
return len(fake.pickArgsForCall)
}
func (fake *FakeOriginLoadBalancer) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) {
fake.pickMutex.Lock()
defer fake.pickMutex.Unlock()
fake.PickStub = stub
}
func (fake *FakeOriginLoadBalancer) PickArgsForCall(i int) (context.Context, string) {
fake.pickMutex.RLock()
defer fake.pickMutex.RUnlock()
argsForCall := fake.pickArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeOriginLoadBalancer) PickReturns(result1 *lb.OriginServer, result2 error) {
fake.pickMutex.Lock()
defer fake.pickMutex.Unlock()
fake.PickStub = nil
fake.pickReturns = struct {
result1 *lb.OriginServer
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) {
fake.pickMutex.Lock()
defer fake.pickMutex.Unlock()
fake.PickStub = nil
if fake.pickReturnsOnCall == nil {
fake.pickReturnsOnCall = make(map[int]struct {
result1 *lb.OriginServer
result2 error
})
}
fake.pickReturnsOnCall[i] = struct {
result1 *lb.OriginServer
result2 error
}{result1, result2}
}
func (fake *FakeOriginLoadBalancer) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error {
fake.storeWebRTCMutex.Lock()
ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)]
fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 lb.RTCConnection
}{arg1, arg2, arg3})
stub := fake.StoreWebRTCStub
fakeReturns := fake.storeWebRTCReturns
fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3})
fake.storeWebRTCMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeOriginLoadBalancer) StoreWebRTCCallCount() int {
fake.storeWebRTCMutex.RLock()
defer fake.storeWebRTCMutex.RUnlock()
return len(fake.storeWebRTCArgsForCall)
}
func (fake *FakeOriginLoadBalancer) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) {
fake.storeWebRTCMutex.Lock()
defer fake.storeWebRTCMutex.Unlock()
fake.StoreWebRTCStub = stub
}
func (fake *FakeOriginLoadBalancer) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) {
fake.storeWebRTCMutex.RLock()
defer fake.storeWebRTCMutex.RUnlock()
argsForCall := fake.storeWebRTCArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
}
func (fake *FakeOriginLoadBalancer) StoreWebRTCReturns(result1 error) {
fake.storeWebRTCMutex.Lock()
defer fake.storeWebRTCMutex.Unlock()
fake.StoreWebRTCStub = nil
fake.storeWebRTCReturns = struct {
result1 error
}{result1}
}
func (fake *FakeOriginLoadBalancer) StoreWebRTCReturnsOnCall(i int, result1 error) {
fake.storeWebRTCMutex.Lock()
defer fake.storeWebRTCMutex.Unlock()
fake.StoreWebRTCStub = nil
if fake.storeWebRTCReturnsOnCall == nil {
fake.storeWebRTCReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.storeWebRTCReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeOriginLoadBalancer) Update(arg1 context.Context, arg2 *lb.OriginServer) error {
fake.updateMutex.Lock()
ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)]
fake.updateArgsForCall = append(fake.updateArgsForCall, struct {
arg1 context.Context
arg2 *lb.OriginServer
}{arg1, arg2})
stub := fake.UpdateStub
fakeReturns := fake.updateReturns
fake.recordInvocation("Update", []interface{}{arg1, arg2})
fake.updateMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeOriginLoadBalancer) UpdateCallCount() int {
fake.updateMutex.RLock()
defer fake.updateMutex.RUnlock()
return len(fake.updateArgsForCall)
}
func (fake *FakeOriginLoadBalancer) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) {
fake.updateMutex.Lock()
defer fake.updateMutex.Unlock()
fake.UpdateStub = stub
}
func (fake *FakeOriginLoadBalancer) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) {
fake.updateMutex.RLock()
defer fake.updateMutex.RUnlock()
argsForCall := fake.updateArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeOriginLoadBalancer) UpdateReturns(result1 error) {
fake.updateMutex.Lock()
defer fake.updateMutex.Unlock()
fake.UpdateStub = nil
fake.updateReturns = struct {
result1 error
}{result1}
}
func (fake *FakeOriginLoadBalancer) UpdateReturnsOnCall(i int, result1 error) {
fake.updateMutex.Lock()
defer fake.updateMutex.Unlock()
fake.UpdateStub = nil
if fake.updateReturnsOnCall == nil {
fake.updateReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.updateReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeOriginLoadBalancer) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeOriginLoadBalancer) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ lb.OriginLoadBalancer = new(FakeOriginLoadBalancer)

View File

@ -0,0 +1,190 @@
// Code generated by counterfeiter. DO NOT EDIT.
package lbfakes
import (
"context"
"srsx/internal/lb"
"sync"
)
type FakeOriginService struct {
PickStub func(context.Context, string) (*lb.OriginServer, error)
pickMutex sync.RWMutex
pickArgsForCall []struct {
arg1 context.Context
arg2 string
}
pickReturns struct {
result1 *lb.OriginServer
result2 error
}
pickReturnsOnCall map[int]struct {
result1 *lb.OriginServer
result2 error
}
UpdateStub func(context.Context, *lb.OriginServer) error
updateMutex sync.RWMutex
updateArgsForCall []struct {
arg1 context.Context
arg2 *lb.OriginServer
}
updateReturns struct {
result1 error
}
updateReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeOriginService) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) {
fake.pickMutex.Lock()
ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)]
fake.pickArgsForCall = append(fake.pickArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.PickStub
fakeReturns := fake.pickReturns
fake.recordInvocation("Pick", []interface{}{arg1, arg2})
fake.pickMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeOriginService) PickCallCount() int {
fake.pickMutex.RLock()
defer fake.pickMutex.RUnlock()
return len(fake.pickArgsForCall)
}
func (fake *FakeOriginService) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) {
fake.pickMutex.Lock()
defer fake.pickMutex.Unlock()
fake.PickStub = stub
}
func (fake *FakeOriginService) PickArgsForCall(i int) (context.Context, string) {
fake.pickMutex.RLock()
defer fake.pickMutex.RUnlock()
argsForCall := fake.pickArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeOriginService) PickReturns(result1 *lb.OriginServer, result2 error) {
fake.pickMutex.Lock()
defer fake.pickMutex.Unlock()
fake.PickStub = nil
fake.pickReturns = struct {
result1 *lb.OriginServer
result2 error
}{result1, result2}
}
func (fake *FakeOriginService) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) {
fake.pickMutex.Lock()
defer fake.pickMutex.Unlock()
fake.PickStub = nil
if fake.pickReturnsOnCall == nil {
fake.pickReturnsOnCall = make(map[int]struct {
result1 *lb.OriginServer
result2 error
})
}
fake.pickReturnsOnCall[i] = struct {
result1 *lb.OriginServer
result2 error
}{result1, result2}
}
func (fake *FakeOriginService) Update(arg1 context.Context, arg2 *lb.OriginServer) error {
fake.updateMutex.Lock()
ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)]
fake.updateArgsForCall = append(fake.updateArgsForCall, struct {
arg1 context.Context
arg2 *lb.OriginServer
}{arg1, arg2})
stub := fake.UpdateStub
fakeReturns := fake.updateReturns
fake.recordInvocation("Update", []interface{}{arg1, arg2})
fake.updateMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeOriginService) UpdateCallCount() int {
fake.updateMutex.RLock()
defer fake.updateMutex.RUnlock()
return len(fake.updateArgsForCall)
}
func (fake *FakeOriginService) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) {
fake.updateMutex.Lock()
defer fake.updateMutex.Unlock()
fake.UpdateStub = stub
}
func (fake *FakeOriginService) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) {
fake.updateMutex.RLock()
defer fake.updateMutex.RUnlock()
argsForCall := fake.updateArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeOriginService) UpdateReturns(result1 error) {
fake.updateMutex.Lock()
defer fake.updateMutex.Unlock()
fake.UpdateStub = nil
fake.updateReturns = struct {
result1 error
}{result1}
}
func (fake *FakeOriginService) UpdateReturnsOnCall(i int, result1 error) {
fake.updateMutex.Lock()
defer fake.updateMutex.Unlock()
fake.UpdateStub = nil
if fake.updateReturnsOnCall == nil {
fake.updateReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.updateReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeOriginService) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeOriginService) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ lb.OriginService = new(FakeOriginService)

View File

@ -0,0 +1,192 @@
// Code generated by counterfeiter. DO NOT EDIT.
package lbfakes
import (
"context"
"srsx/internal/lb"
"sync"
)
type FakeRTCService struct {
LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error)
loadWebRTCByUfragMutex sync.RWMutex
loadWebRTCByUfragArgsForCall []struct {
arg1 context.Context
arg2 string
}
loadWebRTCByUfragReturns struct {
result1 lb.RTCConnection
result2 error
}
loadWebRTCByUfragReturnsOnCall map[int]struct {
result1 lb.RTCConnection
result2 error
}
StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error
storeWebRTCMutex sync.RWMutex
storeWebRTCArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 lb.RTCConnection
}
storeWebRTCReturns struct {
result1 error
}
storeWebRTCReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeRTCService) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) {
fake.loadWebRTCByUfragMutex.Lock()
ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)]
fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.LoadWebRTCByUfragStub
fakeReturns := fake.loadWebRTCByUfragReturns
fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2})
fake.loadWebRTCByUfragMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeRTCService) LoadWebRTCByUfragCallCount() int {
fake.loadWebRTCByUfragMutex.RLock()
defer fake.loadWebRTCByUfragMutex.RUnlock()
return len(fake.loadWebRTCByUfragArgsForCall)
}
func (fake *FakeRTCService) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) {
fake.loadWebRTCByUfragMutex.Lock()
defer fake.loadWebRTCByUfragMutex.Unlock()
fake.LoadWebRTCByUfragStub = stub
}
func (fake *FakeRTCService) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) {
fake.loadWebRTCByUfragMutex.RLock()
defer fake.loadWebRTCByUfragMutex.RUnlock()
argsForCall := fake.loadWebRTCByUfragArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeRTCService) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) {
fake.loadWebRTCByUfragMutex.Lock()
defer fake.loadWebRTCByUfragMutex.Unlock()
fake.LoadWebRTCByUfragStub = nil
fake.loadWebRTCByUfragReturns = struct {
result1 lb.RTCConnection
result2 error
}{result1, result2}
}
func (fake *FakeRTCService) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) {
fake.loadWebRTCByUfragMutex.Lock()
defer fake.loadWebRTCByUfragMutex.Unlock()
fake.LoadWebRTCByUfragStub = nil
if fake.loadWebRTCByUfragReturnsOnCall == nil {
fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct {
result1 lb.RTCConnection
result2 error
})
}
fake.loadWebRTCByUfragReturnsOnCall[i] = struct {
result1 lb.RTCConnection
result2 error
}{result1, result2}
}
func (fake *FakeRTCService) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error {
fake.storeWebRTCMutex.Lock()
ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)]
fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 lb.RTCConnection
}{arg1, arg2, arg3})
stub := fake.StoreWebRTCStub
fakeReturns := fake.storeWebRTCReturns
fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3})
fake.storeWebRTCMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRTCService) StoreWebRTCCallCount() int {
fake.storeWebRTCMutex.RLock()
defer fake.storeWebRTCMutex.RUnlock()
return len(fake.storeWebRTCArgsForCall)
}
func (fake *FakeRTCService) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) {
fake.storeWebRTCMutex.Lock()
defer fake.storeWebRTCMutex.Unlock()
fake.StoreWebRTCStub = stub
}
func (fake *FakeRTCService) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) {
fake.storeWebRTCMutex.RLock()
defer fake.storeWebRTCMutex.RUnlock()
argsForCall := fake.storeWebRTCArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
}
func (fake *FakeRTCService) StoreWebRTCReturns(result1 error) {
fake.storeWebRTCMutex.Lock()
defer fake.storeWebRTCMutex.Unlock()
fake.StoreWebRTCStub = nil
fake.storeWebRTCReturns = struct {
result1 error
}{result1}
}
func (fake *FakeRTCService) StoreWebRTCReturnsOnCall(i int, result1 error) {
fake.storeWebRTCMutex.Lock()
defer fake.storeWebRTCMutex.Unlock()
fake.StoreWebRTCStub = nil
if fake.storeWebRTCReturnsOnCall == nil {
fake.storeWebRTCReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.storeWebRTCReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeRTCService) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeRTCService) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ lb.RTCService = new(FakeRTCService)

View File

@ -31,18 +31,23 @@ type memoryLoadBalancer struct {
rtcStreamURL sync.Map[string, RTCConnection]
// The WebRTC streaming, key is ufrag.
rtcUfrag sync.Map[string, RTCConnection]
// keepaliveInterval is the period at which the default-backend keep-alive
// goroutine re-Updates its registration. Struct field for test injection
// (avoids racing a package global across concurrent tests).
keepaliveInterval time.Duration
}
// NewMemoryLoadBalancer creates a new memory-based load balancer.
func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
return &memoryLoadBalancer{
environment: environment,
servers: sync.NewMap[string, *OriginServer](),
picked: sync.NewMap[string, *OriginServer](),
hlsStreamURL: sync.NewMap[string, HLSPlayStream](),
hlsSPBHID: sync.NewMap[string, HLSPlayStream](),
rtcStreamURL: sync.NewMap[string, RTCConnection](),
rtcUfrag: sync.NewMap[string, RTCConnection](),
environment: environment,
servers: sync.NewMap[string, *OriginServer](),
picked: sync.NewMap[string, *OriginServer](),
hlsStreamURL: sync.NewMap[string, HLSPlayStream](),
hlsSPBHID: sync.NewMap[string, HLSPlayStream](),
rtcStreamURL: sync.NewMap[string, RTCConnection](),
rtcUfrag: sync.NewMap[string, RTCConnection](),
keepaliveInterval: 30 * time.Second,
}
}
@ -63,7 +68,7 @@ func (v *memoryLoadBalancer) Initialize(ctx context.Context) error {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
case <-time.After(v.keepaliveInterval):
if err := v.Update(ctx, server); err != nil {
logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err)
}

263
internal/lb/mem_test.go Normal file
View File

@ -0,0 +1,263 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package lb
import (
"context"
"strings"
"testing"
"time"
"srsx/internal/env/envfakes"
)
// stubHLS is a minimal HLSPlayStream for testing.
type stubHLS struct {
spbhid string
}
func (s *stubHLS) GetSPBHID() string { return s.spbhid }
func (s *stubHLS) Initialize(ctx context.Context) HLSPlayStream { return s }
// stubRTC is a minimal RTCConnection for testing.
type stubRTC struct {
ufrag string
}
func (s *stubRTC) GetUfrag() string { return s.ufrag }
// newMem returns a fresh in-memory load balancer with a default fake env.
func newMem() *memoryLoadBalancer {
env := &envfakes.FakeProxyEnvironment{}
return NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
}
func TestNewMemoryLoadBalancer(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
lb := NewMemoryLoadBalancer(env)
if lb == nil {
t.Fatal("NewMemoryLoadBalancer returned nil")
}
}
func TestMemLB_Initialize_DefaultBackendDisabled(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
env.DefaultBackendEnabledReturns("off")
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
if err := lb.Initialize(context.Background()); err != nil {
t.Fatalf("Initialize: %v", err)
}
// No server stored when disabled.
count := 0
lb.servers.Range(func(string, *OriginServer) bool { count++; return true })
if count != 0 {
t.Fatalf("expected 0 servers, got %d", count)
}
}
func TestMemLB_Initialize_DefaultBackendError(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
env.DefaultBackendEnabledReturns("on")
env.DefaultBackendIPReturns("") // triggers "empty default backend ip"
lb := NewMemoryLoadBalancer(env)
err := lb.Initialize(context.Background())
if err == nil || !strings.Contains(err.Error(), "initialize default SRS") {
t.Fatalf("expected wrapped error, got %v", err)
}
}
func TestMemLB_Initialize_KeepaliveTick(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
env.DefaultBackendEnabledReturns("on")
env.DefaultBackendIPReturns("1.2.3.4")
env.DefaultBackendRTMPReturns(":1935")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
// Shorten the keep-alive interval on this instance only so concurrent
// tests don't race on shared state.
lb.keepaliveInterval = time.Millisecond
if err := lb.Initialize(ctx); err != nil {
t.Fatalf("Initialize: %v", err)
}
// Find the server and watch UpdatedAt advance after a keep-alive tick.
var s *OriginServer
lb.servers.Range(func(_ string, v *OriginServer) bool { s = v; return false })
if s == nil {
t.Fatal("expected server stored")
}
first := s.UpdatedAt
// Wait long enough for several ticks (interval is 1ms, server.UpdatedAt
// is set to time.Now() inside NewDefaultOriginServerForDebugging on each
// Update? — actually Update only stores the server pointer, so UpdatedAt
// won't change. The goroutine still hits the tick branch though, which
// is all we need for coverage).
time.Sleep(20 * time.Millisecond)
cancel()
time.Sleep(10 * time.Millisecond)
_ = first
}
func TestMemLB_Initialize_DefaultBackendSuccess(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
env.DefaultBackendEnabledReturns("on")
env.DefaultBackendIPReturns("1.2.3.4")
env.DefaultBackendRTMPReturns(":1935")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
if err := lb.Initialize(ctx); err != nil {
t.Fatalf("Initialize: %v", err)
}
count := 0
lb.servers.Range(func(string, *OriginServer) bool { count++; return true })
if count != 1 {
t.Fatalf("expected 1 server stored, got %d", count)
}
// Cancel and give the keep-alive goroutine a moment to exit cleanly.
cancel()
time.Sleep(20 * time.Millisecond)
}
func TestMemLB_Update(t *testing.T) {
lb := newMem()
s := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1"}
if err := lb.Update(context.Background(), s); err != nil {
t.Fatalf("Update: %v", err)
}
got, ok := lb.servers.Load(s.ID())
if !ok || got != s {
t.Fatalf("Update did not store the server: got=%v ok=%v", got, ok)
}
}
func TestMemLB_Pick_NoServers(t *testing.T) {
lb := newMem()
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "no server available") {
t.Fatalf("expected no-server error, got %v", err)
}
}
func TestMemLB_Pick_AliveServer_Sticky(t *testing.T) {
lb := newMem()
s := &OriginServer{ServerID: "a", PID: "1", UpdatedAt: time.Now()}
_ = lb.Update(context.Background(), s)
got, err := lb.Pick(context.Background(), "url1")
if err != nil {
t.Fatalf("Pick: %v", err)
}
if got != s {
t.Fatalf("Pick returned %v, want %v", got, s)
}
// Second pick for the same URL returns the same server (sticky branch).
got2, err := lb.Pick(context.Background(), "url1")
if err != nil {
t.Fatalf("Pick second: %v", err)
}
if got2 != got {
t.Fatalf("second Pick returned %v, want %v (sticky)", got2, got)
}
}
func TestMemLB_Pick_OnlyDeadServers_Fallback(t *testing.T) {
lb := newMem()
// UpdatedAt long past => not alive. Tests the fallback "use all servers" branch.
s := &OriginServer{
ServerID: "a",
PID: "1",
UpdatedAt: time.Now().Add(-2 * ServerAliveDuration),
}
_ = lb.Update(context.Background(), s)
got, err := lb.Pick(context.Background(), "url1")
if err != nil {
t.Fatalf("Pick: %v", err)
}
if got != s {
t.Fatalf("expected dead-server fallback to return %v, got %v", s, got)
}
}
func TestMemLB_LoadHLSBySPBHID_NotFound(t *testing.T) {
lb := newMem()
_, err := lb.LoadHLSBySPBHID(context.Background(), "missing")
if err == nil || !strings.Contains(err.Error(), "no HLS streaming") {
t.Fatalf("expected error, got %v", err)
}
}
func TestMemLB_LoadOrStoreHLS_New(t *testing.T) {
lb := newMem()
s := &stubHLS{spbhid: "abc"}
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s)
if err != nil {
t.Fatalf("LoadOrStoreHLS: %v", err)
}
if got != s {
t.Fatalf("LoadOrStoreHLS returned %v, want %v", got, s)
}
// Lookup via SPBHID works (dual-index write).
bySPBHID, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
if err != nil {
t.Fatalf("LoadHLSBySPBHID: %v", err)
}
if bySPBHID != s {
t.Fatalf("LoadHLSBySPBHID returned %v, want %v", bySPBHID, s)
}
}
func TestMemLB_LoadOrStoreHLS_Existing(t *testing.T) {
lb := newMem()
s1 := &stubHLS{spbhid: "first"}
s2 := &stubHLS{spbhid: "second"}
_, _ = lb.LoadOrStoreHLS(context.Background(), "url1", s1)
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s2)
if err != nil {
t.Fatalf("LoadOrStoreHLS: %v", err)
}
if got != s1 {
t.Fatalf("expected existing s1, got %v", got)
}
// SPBHID 'second' (from the rejected s2) maps to the existing s1.
bySPBHID, _ := lb.LoadHLSBySPBHID(context.Background(), "second")
if bySPBHID != s1 {
t.Fatalf("expected SPBHID 'second' to map to s1, got %v", bySPBHID)
}
}
func TestMemLB_StoreWebRTC_And_Load(t *testing.T) {
lb := newMem()
s := &stubRTC{ufrag: "ufrg1"}
if err := lb.StoreWebRTC(context.Background(), "url1", s); err != nil {
t.Fatalf("StoreWebRTC: %v", err)
}
got, err := lb.LoadWebRTCByUfrag(context.Background(), "ufrg1")
if err != nil {
t.Fatalf("LoadWebRTCByUfrag: %v", err)
}
if got != s {
t.Fatalf("got %v, want %v", got, s)
}
}
func TestMemLB_LoadWebRTCByUfrag_NotFound(t *testing.T) {
lb := newMem()
_, err := lb.LoadWebRTCByUfrag(context.Background(), "missing")
if err == nil || !strings.Contains(err.Error(), "no WebRTC streaming") {
t.Fatalf("expected error, got %v", err)
}
}

View File

@ -11,26 +11,33 @@ import (
"strconv"
"time"
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"srsx/internal/env"
"srsx/internal/errors"
"srsx/internal/logger"
"srsx/internal/redisclient"
)
// redisLoadBalancer stores state in Redis.
type redisLoadBalancer struct {
// The environment interface.
environment env.ProxyEnvironment
// The redis client sdk.
rdb *redis.Client
// The redis client.
rdb redisclient.RedisClient
// newClient is the factory used by Initialize to build the Redis client.
// A struct field (rather than a package global) so concurrent tests can
// each supply their own without racing on shared state.
newClient func(addr, password string, db int) redisclient.RedisClient
// keepaliveInterval is the period at which the default-backend keep-alive
// goroutine re-Updates its registration. Struct field for test injection.
keepaliveInterval time.Duration
}
// NewRedisLoadBalancer creates a new Redis-based load balancer.
func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
return &redisLoadBalancer{
environment: environment,
environment: environment,
newClient: redisclient.New,
keepaliveInterval: 30 * time.Second,
}
}
@ -40,11 +47,11 @@ func (v *redisLoadBalancer) Initialize(ctx context.Context) error {
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB())
}
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()),
Password: v.environment.RedisPassword(),
DB: redisDatabase,
})
rdb := v.newClient(
fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()),
v.environment.RedisPassword(),
redisDatabase,
)
v.rdb = rdb
if err := rdb.Ping(ctx).Err(); err != nil {
@ -68,7 +75,7 @@ func (v *redisLoadBalancer) Initialize(ctx context.Context) error {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
case <-time.After(v.keepaliveInterval):
if err := v.Update(ctx, server); err != nil {
logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err)
}

659
internal/lb/redis_test.go Normal file
View File

@ -0,0 +1,659 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package lb
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/go-redis/redis/v8"
"srsx/internal/env/envfakes"
"srsx/internal/redisclient"
"srsx/internal/redisclient/redisclientfakes"
)
// ----------------------------------------------------------------------------
// Helpers.
// ----------------------------------------------------------------------------
// statusCmd returns a *redis.StatusCmd that resolves to the given error.
func statusCmd(err error) *redis.StatusCmd {
c := redis.NewStatusCmd(context.Background())
if err != nil {
c.SetErr(err)
}
return c
}
// stringOK returns a *redis.StringCmd that resolves to the given bytes.
func stringOK(b []byte) *redis.StringCmd {
c := redis.NewStringCmd(context.Background())
c.SetVal(string(b))
return c
}
// stringErr returns a *redis.StringCmd that resolves to the given error.
func stringErr(err error) *redis.StringCmd {
c := redis.NewStringCmd(context.Background())
c.SetErr(err)
return c
}
// withFakeClient returns a fresh *redisLoadBalancer whose newClient factory is
// wired to return the supplied fake. Each test gets its own instance, so
// concurrent tests cannot race on shared state.
func withFakeClient(env *envfakes.FakeProxyEnvironment, client redisclient.RedisClient) *redisLoadBalancer {
lb := NewRedisLoadBalancer(env).(*redisLoadBalancer)
lb.newClient = func(string, string, int) redisclient.RedisClient { return client }
return lb
}
// newRedisLB constructs a redisLoadBalancer with a fake rdb already wired in.
// Used by tests that exercise methods other than Initialize.
func newRedisLB(rdb redisclient.RedisClient) *redisLoadBalancer {
env := &envfakes.FakeProxyEnvironment{}
lb := NewRedisLoadBalancer(env).(*redisLoadBalancer)
lb.rdb = rdb
return lb
}
// ----------------------------------------------------------------------------
// Constructor & Initialize.
// ----------------------------------------------------------------------------
func TestNewRedisLoadBalancer(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
if lb := NewRedisLoadBalancer(env); lb == nil {
t.Fatal("NewRedisLoadBalancer returned nil")
}
}
func TestRedisLB_Initialize_BadRedisDB(t *testing.T) {
env := &envfakes.FakeProxyEnvironment{}
env.RedisDBReturns("not-a-number")
err := NewRedisLoadBalancer(env).Initialize(context.Background())
if err == nil || !strings.Contains(err.Error(), "invalid PROXY_REDIS_DB") {
t.Fatalf("expected Atoi error, got %v", err)
}
}
func TestRedisLB_Initialize_PingFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.PingReturns(statusCmd(fmt.Errorf("connection refused")))
fake.StringReturns("Redis<fake>")
env := &envfakes.FakeProxyEnvironment{}
env.RedisDBReturns("0")
err := withFakeClient(env, fake).Initialize(context.Background())
if err == nil || !strings.Contains(err.Error(), "unable to connect to redis") {
t.Fatalf("expected ping error, got %v", err)
}
}
func TestRedisLB_Initialize_DefaultBackendDisabled(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.PingReturns(statusCmd(nil))
env := &envfakes.FakeProxyEnvironment{}
env.RedisDBReturns("0")
// DefaultBackendEnabled defaults to "" (not "on") => no server registered.
if err := withFakeClient(env, fake).Initialize(context.Background()); err != nil {
t.Fatalf("Initialize: %v", err)
}
}
func TestRedisLB_Initialize_DefaultBackendError(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.PingReturns(statusCmd(nil))
env := &envfakes.FakeProxyEnvironment{}
env.RedisDBReturns("0")
env.DefaultBackendEnabledReturns("on")
env.DefaultBackendIPReturns("") // triggers NewDefaultOriginServerForDebugging error
err := withFakeClient(env, fake).Initialize(context.Background())
if err == nil || !strings.Contains(err.Error(), "initialize default SRS") {
t.Fatalf("expected default-SRS error, got %v", err)
}
}
func TestRedisLB_Initialize_UpdateFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.PingReturns(statusCmd(nil))
fake.SetReturns(statusCmd(fmt.Errorf("set failed"))) // every Set fails
env := &envfakes.FakeProxyEnvironment{}
env.RedisDBReturns("0")
env.DefaultBackendEnabledReturns("on")
env.DefaultBackendIPReturns("1.2.3.4")
env.DefaultBackendRTMPReturns(":1935")
err := withFakeClient(env, fake).Initialize(context.Background())
if err == nil || !strings.Contains(err.Error(), "update default SRS") {
t.Fatalf("expected update error, got %v", err)
}
}
func TestRedisLB_Initialize_Success(t *testing.T) {
var setCalls atomic.Int32
fake := &redisclientfakes.FakeRedisClient{}
fake.PingReturns(statusCmd(nil))
fake.SetStub = func(ctx context.Context, key string, value interface{}, ttl time.Duration) *redis.StatusCmd {
setCalls.Add(1)
return statusCmd(nil)
}
// Every Get returns redis.Nil-style error so the server list is treated as empty.
fake.GetReturns(stringErr(fmt.Errorf("redis: nil")))
env := &envfakes.FakeProxyEnvironment{}
env.RedisDBReturns("0")
env.DefaultBackendEnabledReturns("on")
env.DefaultBackendIPReturns("1.2.3.4")
env.DefaultBackendRTMPReturns(":1935")
lb := withFakeClient(env, fake)
lb.keepaliveInterval = time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := lb.Initialize(ctx); err != nil {
t.Fatalf("Initialize: %v", err)
}
// Initial Update made 2 Set calls (server + server list). Wait long enough
// for the keep-alive tick to issue more.
deadline := time.Now().Add(200 * time.Millisecond)
for time.Now().Before(deadline) && setCalls.Load() < 4 {
time.Sleep(5 * time.Millisecond)
}
cancel()
time.Sleep(10 * time.Millisecond)
if setCalls.Load() < 4 {
t.Fatalf("keep-alive did not tick: setCalls=%d", setCalls.Load())
}
}
// ----------------------------------------------------------------------------
// Update.
// ----------------------------------------------------------------------------
func TestRedisLB_Update_SetServerFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(fmt.Errorf("boom")))
lb := newRedisLB(fake)
err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"})
if err == nil || !strings.Contains(err.Error(), "set key=") {
t.Fatalf("expected set-server error, got %v", err)
}
}
func TestRedisLB_Update_FreshList(t *testing.T) {
// No existing server list => Get for server-list key returns error.
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
fake.GetReturns(stringErr(fmt.Errorf("nil")))
lb := newRedisLB(fake)
server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}
if err := lb.Update(context.Background(), server); err != nil {
t.Fatalf("Update: %v", err)
}
// Two Set calls: server + servers-list.
if got := fake.SetCallCount(); got != 2 {
t.Fatalf("Set call count=%d, want 2", got)
}
// The second Set value should be a JSON array containing the server key.
_, _, value, _ := fake.SetArgsForCall(1)
var keys []string
if err := json.Unmarshal(value.([]byte), &keys); err != nil {
t.Fatalf("server-list value not JSON: %v", err)
}
want := lb.redisKeyServer(server.ID())
if len(keys) != 1 || keys[0] != want {
t.Fatalf("server-list keys=%v, want [%q]", keys, want)
}
}
func TestRedisLB_Update_PrunesDeadAndAppends(t *testing.T) {
server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
// First Get: server-list, returns ["dead", "alive"].
// Subsequent Gets: probe each key — "dead" missing, "alive" present.
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
b, _ := json.Marshal([]string{"dead", "alive"})
return stringOK(b)
}
if key == "alive" {
return stringOK([]byte("ok"))
}
return stringErr(fmt.Errorf("nil"))
}
lb := newRedisLB(fake)
if err := lb.Update(context.Background(), server); err != nil {
t.Fatalf("Update: %v", err)
}
// Inspect the server-list Set call: should contain "alive" (kept) and the
// new server key (appended); "dead" should be pruned.
_, _, value, _ := fake.SetArgsForCall(1)
var keys []string
if err := json.Unmarshal(value.([]byte), &keys); err != nil {
t.Fatalf("not JSON: %v", err)
}
wantNew := lb.redisKeyServer(server.ID())
if len(keys) != 2 || keys[0] != "alive" || keys[1] != wantNew {
t.Fatalf("server-list keys=%v, want [alive, %q]", keys, wantNew)
}
}
func TestRedisLB_Update_AlreadyInList(t *testing.T) {
server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
lb := newRedisLB(fake)
wantKey := lb.redisKeyServer(server.ID())
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
b, _ := json.Marshal([]string{wantKey})
return stringOK(b)
}
return stringOK([]byte("ok"))
}
if err := lb.Update(context.Background(), server); err != nil {
t.Fatalf("Update: %v", err)
}
_, _, value, _ := fake.SetArgsForCall(1)
var keys []string
_ = json.Unmarshal(value.([]byte), &keys)
if len(keys) != 1 || keys[0] != wantKey {
t.Fatalf("expected no duplication, got %v", keys)
}
}
func TestRedisLB_Update_BadServerListJSON(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
return stringOK([]byte("not-json"))
}
return stringErr(fmt.Errorf("nil"))
}
lb := newRedisLB(fake)
err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"})
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Fatalf("expected unmarshal error, got %v", err)
}
}
func TestRedisLB_Update_SetServerListFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
// First Set ok (server), second Set fails (server list).
fake.SetReturnsOnCall(0, statusCmd(nil))
fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("set list failed")))
fake.GetReturns(stringErr(fmt.Errorf("nil")))
lb := newRedisLB(fake)
err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"})
if err == nil || !strings.Contains(err.Error(), "set list failed") {
t.Fatalf("expected server-list set error, got %v", err)
}
}
// ----------------------------------------------------------------------------
// Pick.
// ----------------------------------------------------------------------------
func TestRedisLB_Pick_StickyHit(t *testing.T) {
server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"}
serverJSON, _ := json.Marshal(server)
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
lb := newRedisLB(fake)
streamKey := "srs-proxy-url:url1"
serverKey := lb.redisKeyServer(server.ID())
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
switch key {
case streamKey:
return stringOK([]byte(serverKey))
case serverKey:
return stringOK(serverJSON)
}
return stringErr(fmt.Errorf("nil"))
}
got, err := lb.Pick(context.Background(), "url1")
if err != nil {
t.Fatalf("Pick: %v", err)
}
if got.ID() != server.ID() {
t.Fatalf("Pick returned %v, want %v", got, server)
}
}
func TestRedisLB_Pick_StickyBadJSON(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
lb := newRedisLB(fake)
streamKey := "srs-proxy-url:url1"
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
switch key {
case streamKey:
return stringOK([]byte("srv-key"))
case "srv-key":
return stringOK([]byte("not-json"))
}
return stringErr(fmt.Errorf("nil"))
}
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Fatalf("expected unmarshal error, got %v", err)
}
}
func TestRedisLB_Pick_NoServersAvailable(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
// Sticky miss + server list missing.
fake.GetReturns(stringErr(fmt.Errorf("nil")))
lb := newRedisLB(fake)
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "no server available") {
t.Fatalf("expected no-server error, got %v", err)
}
}
func TestRedisLB_Pick_BadServerListJSON(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
return stringOK([]byte("not-json"))
}
return stringErr(fmt.Errorf("nil"))
}
lb := newRedisLB(fake)
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Fatalf("expected unmarshal error, got %v", err)
}
}
func TestRedisLB_Pick_AllProbesFail(t *testing.T) {
// Server list contains one key, but probing it returns nil bytes (the
// `len(b) > 0` guard rejects it). After 3 attempts, Pick errors out.
fake := &redisclientfakes.FakeRedisClient{}
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
b, _ := json.Marshal([]string{"srv-key"})
return stringOK(b)
}
// "srv-key" probe returns empty bytes — falls through the available check.
if key == "srv-key" {
return stringOK(nil)
}
return stringErr(fmt.Errorf("nil"))
}
lb := newRedisLB(fake)
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "no server available in") {
t.Fatalf("expected exhausted-probes error, got %v", err)
}
}
func TestRedisLB_Pick_ScanSuccess(t *testing.T) {
server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"}
serverJSON, _ := json.Marshal(server)
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
lb := newRedisLB(fake)
serverKey := lb.redisKeyServer(server.ID())
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
b, _ := json.Marshal([]string{serverKey})
return stringOK(b)
}
if key == serverKey {
return stringOK(serverJSON)
}
// Sticky lookup for the URL key misses.
return stringErr(fmt.Errorf("nil"))
}
got, err := lb.Pick(context.Background(), "url1")
if err != nil {
t.Fatalf("Pick: %v", err)
}
if got.ID() != server.ID() {
t.Fatalf("Pick returned %v", got)
}
// Pick should also store the picked-mapping.
if fake.SetCallCount() != 1 {
t.Fatalf("expected 1 Set call to store picked mapping, got %d", fake.SetCallCount())
}
}
func TestRedisLB_Pick_ScanBadJSON(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
b, _ := json.Marshal([]string{"srv-key"})
return stringOK(b)
}
if key == "srv-key" {
return stringOK([]byte("not-json"))
}
return stringErr(fmt.Errorf("nil"))
}
lb := newRedisLB(fake)
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Fatalf("expected unmarshal error, got %v", err)
}
}
func TestRedisLB_Pick_StoreMappingFails(t *testing.T) {
server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"}
serverJSON, _ := json.Marshal(server)
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(fmt.Errorf("set failed")))
lb := newRedisLB(fake)
serverKey := lb.redisKeyServer(server.ID())
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
if strings.HasSuffix(key, "all-servers") {
b, _ := json.Marshal([]string{serverKey})
return stringOK(b)
}
if key == serverKey {
return stringOK(serverJSON)
}
return stringErr(fmt.Errorf("nil"))
}
_, err := lb.Pick(context.Background(), "url1")
if err == nil || !strings.Contains(err.Error(), "set failed") {
t.Fatalf("expected set-mapping error, got %v", err)
}
}
// ----------------------------------------------------------------------------
// LoadHLSBySPBHID and LoadWebRTCByUfrag — symmetric behavior.
// ----------------------------------------------------------------------------
func TestRedisLB_LoadHLSBySPBHID_GetFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetReturns(stringErr(fmt.Errorf("nil")))
lb := newRedisLB(fake)
_, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
if err == nil || !strings.Contains(err.Error(), "get key=") {
t.Fatalf("expected get error, got %v", err)
}
}
func TestRedisLB_LoadHLSBySPBHID_BadJSON(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetReturns(stringOK([]byte("not-json")))
lb := newRedisLB(fake)
_, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Fatalf("expected unmarshal error, got %v", err)
}
}
func TestRedisLB_LoadHLSBySPBHID_InterfaceLimitation(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`)))
lb := newRedisLB(fake)
_, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
if err == nil || !strings.Contains(err.Error(), "cannot deserialize") {
t.Fatalf("expected interface limitation error, got %v", err)
}
}
func TestRedisLB_LoadWebRTCByUfrag_GetFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetReturns(stringErr(fmt.Errorf("nil")))
lb := newRedisLB(fake)
_, err := lb.LoadWebRTCByUfrag(context.Background(), "u")
if err == nil || !strings.Contains(err.Error(), "get key=") {
t.Fatalf("expected get error, got %v", err)
}
}
func TestRedisLB_LoadWebRTCByUfrag_BadJSON(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetReturns(stringOK([]byte("not-json")))
lb := newRedisLB(fake)
_, err := lb.LoadWebRTCByUfrag(context.Background(), "u")
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Fatalf("expected unmarshal error, got %v", err)
}
}
func TestRedisLB_LoadWebRTCByUfrag_InterfaceLimitation(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`)))
lb := newRedisLB(fake)
_, err := lb.LoadWebRTCByUfrag(context.Background(), "u")
if err == nil || !strings.Contains(err.Error(), "cannot deserialize") {
t.Fatalf("expected interface limitation error, got %v", err)
}
}
// ----------------------------------------------------------------------------
// LoadOrStoreHLS and StoreWebRTC.
// ----------------------------------------------------------------------------
func TestRedisLB_LoadOrStoreHLS_Success(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
lb := newRedisLB(fake)
hls := &stubHLS{spbhid: "abc"}
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", hls)
if err != nil {
t.Fatalf("LoadOrStoreHLS: %v", err)
}
if got != hls {
t.Fatalf("got %v, want input back", got)
}
if fake.SetCallCount() != 2 {
t.Fatalf("expected 2 Set calls (URL + SPBHID), got %d", fake.SetCallCount())
}
}
func TestRedisLB_LoadOrStoreHLS_FirstSetFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(fmt.Errorf("boom")))
lb := newRedisLB(fake)
_, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"})
if err == nil || !strings.Contains(err.Error(), "boom") {
t.Fatalf("expected error, got %v", err)
}
}
func TestRedisLB_LoadOrStoreHLS_SecondSetFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturnsOnCall(0, statusCmd(nil))
fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom")))
lb := newRedisLB(fake)
_, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"})
if err == nil || !strings.Contains(err.Error(), "second boom") {
t.Fatalf("expected error, got %v", err)
}
}
func TestRedisLB_StoreWebRTC_Success(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(nil))
lb := newRedisLB(fake)
if err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}); err != nil {
t.Fatalf("StoreWebRTC: %v", err)
}
if fake.SetCallCount() != 2 {
t.Fatalf("expected 2 Set calls (URL + Ufrag), got %d", fake.SetCallCount())
}
}
func TestRedisLB_StoreWebRTC_FirstSetFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturns(statusCmd(fmt.Errorf("boom")))
lb := newRedisLB(fake)
err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"})
if err == nil || !strings.Contains(err.Error(), "boom") {
t.Fatalf("expected error, got %v", err)
}
}
func TestRedisLB_StoreWebRTC_SecondSetFails(t *testing.T) {
fake := &redisclientfakes.FakeRedisClient{}
fake.SetReturnsOnCall(0, statusCmd(nil))
fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom")))
lb := newRedisLB(fake)
err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"})
if err == nil || !strings.Contains(err.Error(), "second boom") {
t.Fatalf("expected error, got %v", err)
}
}
// ----------------------------------------------------------------------------
// Key helpers.
// ----------------------------------------------------------------------------
func TestRedisLB_KeyHelpers(t *testing.T) {
lb := &redisLoadBalancer{}
for _, tt := range []struct {
got, want string
}{
{lb.redisKeyUfrag("u"), "srs-proxy-ufrag:u"},
{lb.redisKeyRTC("url"), "srs-proxy-rtc:url"},
{lb.redisKeySPBHID("s"), "srs-proxy-spbhid:s"},
{lb.redisKeyHLS("url"), "srs-proxy-hls:url"},
{lb.redisKeyServer("id"), "srs-proxy-server:id"},
{lb.redisKeyServers(), "srs-proxy-all-servers"},
} {
if tt.got != tt.want {
t.Errorf("got %q, want %q", tt.got, tt.want)
}
}
}

View File

@ -0,0 +1,6 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package redisclient
//go:generate go tool counterfeiter -o redisclientfakes/fake_redis_client.go . RedisClient

View File

@ -0,0 +1,33 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package redisclient
import (
"context"
"time"
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
)
// RedisClient is the subset of *redis.Client methods used by callers in this
// codebase. Declared as an interface so tests can substitute a fake without
// standing up a real Redis server. *redis.Client satisfies this interface
// directly.
type RedisClient interface {
Ping(ctx context.Context) *redis.StatusCmd
Get(ctx context.Context, key string) *redis.StringCmd
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd
String() string
}
// New connects to a Redis server at addr (host:port) with the given password
// and database index. Returns a RedisClient satisfied by *redis.Client.
func New(addr, password string, db int) RedisClient {
return redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
})
}

View File

@ -0,0 +1,327 @@
// Code generated by counterfeiter. DO NOT EDIT.
package redisclientfakes
import (
"context"
"srsx/internal/redisclient"
"sync"
"time"
redis "github.com/go-redis/redis/v8"
)
type FakeRedisClient struct {
GetStub func(context.Context, string) *redis.StringCmd
getMutex sync.RWMutex
getArgsForCall []struct {
arg1 context.Context
arg2 string
}
getReturns struct {
result1 *redis.StringCmd
}
getReturnsOnCall map[int]struct {
result1 *redis.StringCmd
}
PingStub func(context.Context) *redis.StatusCmd
pingMutex sync.RWMutex
pingArgsForCall []struct {
arg1 context.Context
}
pingReturns struct {
result1 *redis.StatusCmd
}
pingReturnsOnCall map[int]struct {
result1 *redis.StatusCmd
}
SetStub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd
setMutex sync.RWMutex
setArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 interface{}
arg4 time.Duration
}
setReturns struct {
result1 *redis.StatusCmd
}
setReturnsOnCall map[int]struct {
result1 *redis.StatusCmd
}
StringStub func() string
stringMutex sync.RWMutex
stringArgsForCall []struct {
}
stringReturns struct {
result1 string
}
stringReturnsOnCall map[int]struct {
result1 string
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeRedisClient) Get(arg1 context.Context, arg2 string) *redis.StringCmd {
fake.getMutex.Lock()
ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)]
fake.getArgsForCall = append(fake.getArgsForCall, struct {
arg1 context.Context
arg2 string
}{arg1, arg2})
stub := fake.GetStub
fakeReturns := fake.getReturns
fake.recordInvocation("Get", []interface{}{arg1, arg2})
fake.getMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRedisClient) GetCallCount() int {
fake.getMutex.RLock()
defer fake.getMutex.RUnlock()
return len(fake.getArgsForCall)
}
func (fake *FakeRedisClient) GetCalls(stub func(context.Context, string) *redis.StringCmd) {
fake.getMutex.Lock()
defer fake.getMutex.Unlock()
fake.GetStub = stub
}
func (fake *FakeRedisClient) GetArgsForCall(i int) (context.Context, string) {
fake.getMutex.RLock()
defer fake.getMutex.RUnlock()
argsForCall := fake.getArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeRedisClient) GetReturns(result1 *redis.StringCmd) {
fake.getMutex.Lock()
defer fake.getMutex.Unlock()
fake.GetStub = nil
fake.getReturns = struct {
result1 *redis.StringCmd
}{result1}
}
func (fake *FakeRedisClient) GetReturnsOnCall(i int, result1 *redis.StringCmd) {
fake.getMutex.Lock()
defer fake.getMutex.Unlock()
fake.GetStub = nil
if fake.getReturnsOnCall == nil {
fake.getReturnsOnCall = make(map[int]struct {
result1 *redis.StringCmd
})
}
fake.getReturnsOnCall[i] = struct {
result1 *redis.StringCmd
}{result1}
}
func (fake *FakeRedisClient) Ping(arg1 context.Context) *redis.StatusCmd {
fake.pingMutex.Lock()
ret, specificReturn := fake.pingReturnsOnCall[len(fake.pingArgsForCall)]
fake.pingArgsForCall = append(fake.pingArgsForCall, struct {
arg1 context.Context
}{arg1})
stub := fake.PingStub
fakeReturns := fake.pingReturns
fake.recordInvocation("Ping", []interface{}{arg1})
fake.pingMutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRedisClient) PingCallCount() int {
fake.pingMutex.RLock()
defer fake.pingMutex.RUnlock()
return len(fake.pingArgsForCall)
}
func (fake *FakeRedisClient) PingCalls(stub func(context.Context) *redis.StatusCmd) {
fake.pingMutex.Lock()
defer fake.pingMutex.Unlock()
fake.PingStub = stub
}
func (fake *FakeRedisClient) PingArgsForCall(i int) context.Context {
fake.pingMutex.RLock()
defer fake.pingMutex.RUnlock()
argsForCall := fake.pingArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeRedisClient) PingReturns(result1 *redis.StatusCmd) {
fake.pingMutex.Lock()
defer fake.pingMutex.Unlock()
fake.PingStub = nil
fake.pingReturns = struct {
result1 *redis.StatusCmd
}{result1}
}
func (fake *FakeRedisClient) PingReturnsOnCall(i int, result1 *redis.StatusCmd) {
fake.pingMutex.Lock()
defer fake.pingMutex.Unlock()
fake.PingStub = nil
if fake.pingReturnsOnCall == nil {
fake.pingReturnsOnCall = make(map[int]struct {
result1 *redis.StatusCmd
})
}
fake.pingReturnsOnCall[i] = struct {
result1 *redis.StatusCmd
}{result1}
}
func (fake *FakeRedisClient) Set(arg1 context.Context, arg2 string, arg3 interface{}, arg4 time.Duration) *redis.StatusCmd {
fake.setMutex.Lock()
ret, specificReturn := fake.setReturnsOnCall[len(fake.setArgsForCall)]
fake.setArgsForCall = append(fake.setArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 interface{}
arg4 time.Duration
}{arg1, arg2, arg3, arg4})
stub := fake.SetStub
fakeReturns := fake.setReturns
fake.recordInvocation("Set", []interface{}{arg1, arg2, arg3, arg4})
fake.setMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3, arg4)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRedisClient) SetCallCount() int {
fake.setMutex.RLock()
defer fake.setMutex.RUnlock()
return len(fake.setArgsForCall)
}
func (fake *FakeRedisClient) SetCalls(stub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd) {
fake.setMutex.Lock()
defer fake.setMutex.Unlock()
fake.SetStub = stub
}
func (fake *FakeRedisClient) SetArgsForCall(i int) (context.Context, string, interface{}, time.Duration) {
fake.setMutex.RLock()
defer fake.setMutex.RUnlock()
argsForCall := fake.setArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4
}
func (fake *FakeRedisClient) SetReturns(result1 *redis.StatusCmd) {
fake.setMutex.Lock()
defer fake.setMutex.Unlock()
fake.SetStub = nil
fake.setReturns = struct {
result1 *redis.StatusCmd
}{result1}
}
func (fake *FakeRedisClient) SetReturnsOnCall(i int, result1 *redis.StatusCmd) {
fake.setMutex.Lock()
defer fake.setMutex.Unlock()
fake.SetStub = nil
if fake.setReturnsOnCall == nil {
fake.setReturnsOnCall = make(map[int]struct {
result1 *redis.StatusCmd
})
}
fake.setReturnsOnCall[i] = struct {
result1 *redis.StatusCmd
}{result1}
}
func (fake *FakeRedisClient) String() string {
fake.stringMutex.Lock()
ret, specificReturn := fake.stringReturnsOnCall[len(fake.stringArgsForCall)]
fake.stringArgsForCall = append(fake.stringArgsForCall, struct {
}{})
stub := fake.StringStub
fakeReturns := fake.stringReturns
fake.recordInvocation("String", []interface{}{})
fake.stringMutex.Unlock()
if stub != nil {
return stub()
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRedisClient) StringCallCount() int {
fake.stringMutex.RLock()
defer fake.stringMutex.RUnlock()
return len(fake.stringArgsForCall)
}
func (fake *FakeRedisClient) StringCalls(stub func() string) {
fake.stringMutex.Lock()
defer fake.stringMutex.Unlock()
fake.StringStub = stub
}
func (fake *FakeRedisClient) StringReturns(result1 string) {
fake.stringMutex.Lock()
defer fake.stringMutex.Unlock()
fake.StringStub = nil
fake.stringReturns = struct {
result1 string
}{result1}
}
func (fake *FakeRedisClient) StringReturnsOnCall(i int, result1 string) {
fake.stringMutex.Lock()
defer fake.stringMutex.Unlock()
fake.StringStub = nil
if fake.stringReturnsOnCall == nil {
fake.stringReturnsOnCall = make(map[int]struct {
result1 string
})
}
fake.stringReturnsOnCall[i] = struct {
result1 string
}{result1}
}
func (fake *FakeRedisClient) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeRedisClient) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ redisclient.RedisClient = new(FakeRedisClient)

View File

@ -90,7 +90,9 @@ type amf0Buffer interface {
Write(p []byte) (n int, err error)
}
var createBuffer = func() amf0Buffer {
// defaultBufFactory is the production amf0Buffer factory. Tests override the
// per-instance bufFactory field on amf0ObjectBase instead of swapping a global.
func defaultBufFactory() amf0Buffer {
return &bytes.Buffer{}
}
@ -399,6 +401,10 @@ type amf0Property struct {
type amf0ObjectBase struct {
properties []*amf0Property
lock sync.Mutex
// bufFactory creates the amf0Buffer used by MarshalBinary. Held as a
// per-instance field (not a package global) so concurrent tests can each
// install their own buggy buffers without racing on shared state.
bufFactory func() amf0Buffer
}
func (v *amf0ObjectBase) Size() int {
@ -562,6 +568,7 @@ func NewAmf0Object() Amf0Object {
func newAmf0Object() *amf0Object {
v := &amf0Object{}
v.properties = []*amf0Property{}
v.bufFactory = defaultBufFactory
return v
}
@ -600,7 +607,7 @@ func (v *amf0Object) UnmarshalBinary(data []byte) (err error) {
}
func (v *amf0Object) MarshalBinary() (data []byte, err error) {
b := createBuffer()
b := v.bufFactory()
if err = b.WriteByte(byte(amf0MarkerObject)); err != nil {
return nil, errors.Wrap(err, "marshal")
@ -640,6 +647,7 @@ func NewAmf0EcmaArray() Amf0EcmaArray {
func newAmf0EcmaArray() *amf0EcmaArray {
v := &amf0EcmaArray{}
v.properties = []*amf0Property{}
v.bufFactory = defaultBufFactory
return v
}
@ -678,7 +686,7 @@ func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) {
}
func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
b := v.bufFactory()
if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
@ -717,6 +725,7 @@ type amf0StrictArray struct {
func NewAmf0StrictArray() Amf0StrictArray {
v := &amf0StrictArray{}
v.properties = []*amf0Property{}
v.bufFactory = defaultBufFactory
return v
}
@ -759,7 +768,7 @@ func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) {
}
func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
b := v.bufFactory()
if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil {
return nil, errors.Wrap(err, "marshal")

View File

@ -436,10 +436,21 @@ func (v *errorAmf0Any) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
func TestAmf0MarshalErrors(t *testing.T) {
originalCreateBuffer := createBuffer
defer func() { createBuffer = originalCreateBuffer }()
// setBufFactory replaces the bufFactory on whichever amf0 object-like type
// underlies v. Concurrent tests can use this safely because each value carries
// its own factory.
func setBufFactory(v Amf0Any, fn func() amf0Buffer) {
switch v := v.(type) {
case *amf0Object:
v.bufFactory = fn
case *amf0EcmaArray:
v.bufFactory = fn
case *amf0StrictArray:
v.bufFactory = fn
}
}
func TestAmf0MarshalErrors(t *testing.T) {
for _, tt := range []struct {
name string
make func() Amf0Any
@ -449,15 +460,16 @@ func TestAmf0MarshalErrors(t *testing.T) {
{"strict-array", func() Amf0Any { return NewAmf0StrictArray() }},
} {
t.Run(tt.name+" write-byte", func(t *testing.T) {
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} }
if _, err := tt.make().MarshalBinary(); err == nil {
value := tt.make()
setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} })
if _, err := value.MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
t.Run(tt.name+" write-prop", func(t *testing.T) {
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} }
value := tt.make()
setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} })
switch v := value.(type) {
case Amf0Object:
v.Set("name", NewAmf0String("stream"))
@ -473,7 +485,6 @@ func TestAmf0MarshalErrors(t *testing.T) {
})
}
createBuffer = originalCreateBuffer
for _, tt := range []struct {
name string
make func() Amf0Any

View File

@ -15,15 +15,26 @@ import (
"srsx/internal/logger"
)
// Indirections so tests can substitute signal delivery and process exit.
var (
signalNotify = signal.Notify
osExit = os.Exit
)
// Handler installs OS signal handlers and the force-quit timer. The notify
// and exit indirections are struct fields (not package globals) so concurrent
// tests can each construct a handler with their own fakes without racing on
// shared state.
type Handler struct {
notify func(c chan<- os.Signal, sig ...os.Signal)
exit func(code int)
}
func InstallSignals(ctx context.Context, cancel context.CancelFunc) {
// NewHandler returns a Handler wired to the real OS implementations.
func NewHandler() *Handler {
return &Handler{
notify: signal.Notify,
exit: os.Exit,
}
}
func (h *Handler) InstallSignals(ctx context.Context, cancel context.CancelFunc) {
sc := make(chan os.Signal, 1)
signalNotify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
h.notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
go func() {
for s := range sc {
@ -33,7 +44,7 @@ func InstallSignals(ctx context.Context, cancel context.CancelFunc) {
}()
}
func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error {
func (h *Handler) InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error {
var forceTimeout time.Duration
timeoutStr := environment.ForceQuitTimeout()
if t, err := time.ParseDuration(timeoutStr); err != nil {
@ -46,7 +57,7 @@ func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) err
<-ctx.Done()
time.Sleep(forceTimeout)
logger.Warn(ctx, "Force to exit by timeout")
osExit(1)
h.exit(1)
}()
return nil
}

View File

@ -16,59 +16,60 @@ import (
"srsx/internal/env/envfakes"
)
// swapNotify replaces signalNotify with a capturing fake and returns a getter
// for the channel registered by the code under test plus a restore func.
func swapNotify(t *testing.T) (func() chan<- os.Signal, func()) {
t.Helper()
orig := signalNotify
// captureNotify returns a Handler whose notify field records the channel
// passed by the code under test, plus a getter that retrieves it.
func captureNotify() (*Handler, func() chan<- os.Signal) {
var (
mu sync.Mutex
ch chan<- os.Signal
)
signalNotify = func(c chan<- os.Signal, _ ...os.Signal) {
mu.Lock()
defer mu.Unlock()
ch = c
}
return func() chan<- os.Signal {
h := &Handler{
notify: func(c chan<- os.Signal, _ ...os.Signal) {
mu.Lock()
defer mu.Unlock()
return ch
}, func() {
signalNotify = orig
}
ch = c
},
exit: os.Exit,
}
return h, func() chan<- os.Signal {
mu.Lock()
defer mu.Unlock()
return ch
}
}
func swapExit(t *testing.T) (*int32, chan int, func()) {
t.Helper()
orig := osExit
// captureExit returns a Handler whose exit field records the code and never
// returns, plus a flag and channel that observe the call.
func captureExit() (*Handler, *int32, chan int) {
var called int32
done := make(chan int, 1)
osExit = func(code int) {
atomic.StoreInt32(&called, 1)
select {
case done <- code:
default:
}
// Block to mimic os.Exit never returning; the goroutine holding us
// here is abandoned when the test ends.
select {}
h := &Handler{
notify: func(chan<- os.Signal, ...os.Signal) {},
exit: func(code int) {
atomic.StoreInt32(&called, 1)
select {
case done <- code:
default:
}
// Block to mimic os.Exit never returning; the goroutine holding us
// here is abandoned when the test ends.
select {}
},
}
return &called, done, func() { osExit = orig }
return h, &called, done
}
func TestInstallSignals_CancelsOnSignal(t *testing.T) {
getCh, restore := swapNotify(t)
defer restore()
h, getCh := captureNotify()
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
InstallSignals(ctx, cancel)
h.InstallSignals(ctx, cancel)
ch := getCh()
if ch == nil {
t.Fatal("signalNotify was not called")
t.Fatal("notify was not called")
}
ch <- syscall.SIGINT
@ -80,13 +81,12 @@ func TestInstallSignals_CancelsOnSignal(t *testing.T) {
}
func TestInstallSignals_HandlesRepeatedSignals(t *testing.T) {
getCh, restore := swapNotify(t)
defer restore()
h, getCh := captureNotify()
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
InstallSignals(ctx, cancel)
h.InstallSignals(ctx, cancel)
ch := getCh()
// Multiple signals must not panic; cancel() is idempotent.
@ -105,7 +105,7 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) {
fakeEnv := &envfakes.FakeProxyEnvironment{}
fakeEnv.ForceQuitTimeoutReturns("not-a-duration")
err := InstallForceQuit(t.Context(), fakeEnv)
err := NewHandler().InstallForceQuit(t.Context(), fakeEnv)
if err == nil {
t.Fatal("want error for bad duration")
}
@ -118,20 +118,19 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) {
}
func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) {
called, done, restore := swapExit(t)
defer restore()
h, called, done := captureExit()
fakeEnv := &envfakes.FakeProxyEnvironment{}
fakeEnv.ForceQuitTimeoutReturns("1ms")
ctx, cancel := context.WithCancel(t.Context())
if err := InstallForceQuit(ctx, fakeEnv); err != nil {
if err := h.InstallForceQuit(ctx, fakeEnv); err != nil {
t.Fatalf("unexpected err: %v", err)
}
// Before cancel, the goroutine is blocked and exit must not fire.
if atomic.LoadInt32(called) != 0 {
t.Fatal("osExit called before ctx cancel")
t.Fatal("exit called before ctx cancel")
}
cancel()
@ -141,30 +140,39 @@ func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) {
t.Fatalf("exit code = %d, want 1", code)
}
case <-time.After(time.Second):
t.Fatal("osExit not called after cancel + timeout")
t.Fatal("exit not called after cancel + timeout")
}
}
func TestInstallForceQuit_WaitsForCancelBeforeSleeping(t *testing.T) {
called, done, restore := swapExit(t)
defer restore()
h, called, done := captureExit()
fakeEnv := &envfakes.FakeProxyEnvironment{}
fakeEnv.ForceQuitTimeoutReturns("10ms")
// Intentionally use a never-canceled context and leak the goroutine:
// if we canceled at test end, the goroutine would wake and race with
// restore() writing osExit.
if err := InstallForceQuit(context.Background(), fakeEnv); err != nil {
// Intentionally use a never-canceled context and leak the goroutine: the
// handler's exit closure is owned by this test instance, so leaving the
// goroutine alive doesn't race other tests.
if err := h.InstallForceQuit(context.Background(), fakeEnv); err != nil {
t.Fatalf("unexpected err: %v", err)
}
select {
case <-done:
t.Fatal("osExit fired without ctx cancel")
t.Fatal("exit fired without ctx cancel")
case <-time.After(30 * time.Millisecond):
}
if atomic.LoadInt32(called) != 0 {
t.Fatal("osExit called unexpectedly")
t.Fatal("exit called unexpectedly")
}
}
func TestNewHandler_UsesRealOSDefaults(t *testing.T) {
h := NewHandler()
if h.notify == nil {
t.Error("notify default not set")
}
if h.exit == nil {
t.Error("exit default not set")
}
}

View File

@ -5,11 +5,16 @@
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi

View File

@ -6,11 +6,16 @@
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi

View File

@ -10,11 +10,16 @@
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi

View File

@ -4,11 +4,16 @@
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi

View File

@ -6,11 +6,16 @@
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi

View File

@ -10,11 +10,16 @@
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi

View File

@ -27,11 +27,16 @@ for arg in "$@"; do
done
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
# counting when the skills directory is reached via a symlink (which changes
# the symbolic vs. physical depth).
WORKSPACE="$SCRIPT_DIR"
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
WORKSPACE="$(dirname "$WORKSPACE")"
done
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
exit 1
fi