Claude: Split lb interfaces, extract redisclient, drop race-prone globals.
- Split OriginLoadBalancer into OriginService / HLSService / RTCService; the original interface now embeds the three role interfaces. Generate counterfeiter fakes for all four. - Extract internal/redisclient: RedisClient interface + New() factory. internal/lb/redis.go no longer imports github.com/go-redis/redis/v8. - Add unit tests for lb.go (OriginServer.ID/String/Format/NewOriginServer) and for the full memory + redis load balancers. - Replace package-level test seams (memoryKeepaliveInterval, newRedisClient, redisKeepaliveInterval, signal.signalNotify/osExit, rtmp.createBuffer) with per-instance struct fields so concurrent tests can't race on them. - Promote signal.InstallSignals / InstallForceQuit onto a new signal.Handler type; update bootstrap to construct one. - Move rtmp createBuffer onto amf0ObjectBase as bufFactory; the three AMF0 marshalers and their tests use the per-instance factory. - Make proxy test scripts locate the workspace by walking up to go.mod instead of brittle '../../../..' counting (symlink-aware). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
f45bf30b46
commit
3060bf8e7c
|
|
@ -33,7 +33,7 @@ func (b *proxyBootstrap) Start(ctx context.Context) error {
|
|||
|
||||
// Install signals.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
signal.InstallSignals(ctx, cancel)
|
||||
signal.NewHandler().InstallSignals(ctx, cancel)
|
||||
|
||||
// Run the main loop, ignore the user cancel error.
|
||||
err := b.run(ctx)
|
||||
|
|
@ -58,7 +58,7 @@ func (b *proxyBootstrap) run(ctx context.Context) error {
|
|||
// When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur
|
||||
// because the main thread exits after the context is cancelled. However, sometimes the main thread
|
||||
// may be blocked for some reason, so a forced exit is necessary to ensure the program terminates.
|
||||
if err := signal.InstallForceQuit(ctx, environment); err != nil {
|
||||
if err := signal.NewHandler().InstallForceQuit(ctx, environment); err != nil {
|
||||
return errors.Wrapf(err, "install force quit")
|
||||
}
|
||||
|
||||
|
|
|
|||
9
internal/lb/gen.go
Normal file
9
internal/lb/gen.go
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package lb
|
||||
|
||||
//go:generate go tool counterfeiter -o lbfakes/fake_origin_load_balancer.go . OriginLoadBalancer
|
||||
//go:generate go tool counterfeiter -o lbfakes/fake_origin_service.go . OriginService
|
||||
//go:generate go tool counterfeiter -o lbfakes/fake_hls_service.go . HLSService
|
||||
//go:generate go tool counterfeiter -o lbfakes/fake_rtc_service.go . RTCService
|
||||
|
|
@ -109,20 +109,35 @@ type RTCConnection interface {
|
|||
GetUfrag() string
|
||||
}
|
||||
|
||||
// OriginLoadBalancer is the interface to load balance the SRS servers.
|
||||
type OriginLoadBalancer interface {
|
||||
// Initialize the load balancer.
|
||||
Initialize(ctx context.Context) error
|
||||
// OriginService is the interface for origin-server registry and stream routing.
|
||||
type OriginService interface {
|
||||
// Update records the latest registration or heartbeat for an origin server.
|
||||
Update(ctx context.Context, server *OriginServer) error
|
||||
// Pick a backend server for the specified stream URL.
|
||||
Pick(ctx context.Context, streamURL string) (*OriginServer, error)
|
||||
}
|
||||
|
||||
// HLSService is the interface for HLS session state, indexed by stream URL and SPBHID.
|
||||
type HLSService interface {
|
||||
// Load or store the HLS streaming for the specified stream URL.
|
||||
LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error)
|
||||
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
|
||||
LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error)
|
||||
}
|
||||
|
||||
// RTCService is the interface for WebRTC session state, indexed by stream URL and ICE ufrag.
|
||||
type RTCService interface {
|
||||
// Store the WebRTC streaming for the specified stream URL.
|
||||
StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error
|
||||
// Load the WebRTC streaming by ufrag, the ICE username.
|
||||
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error)
|
||||
}
|
||||
|
||||
// OriginLoadBalancer is the interface to load balance the SRS servers.
|
||||
type OriginLoadBalancer interface {
|
||||
OriginService
|
||||
HLSService
|
||||
RTCService
|
||||
// Initialize the load balancer.
|
||||
Initialize(ctx context.Context) error
|
||||
}
|
||||
|
|
|
|||
141
internal/lb/lb_test.go
Normal file
141
internal/lb/lb_test.go
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package lb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOriginServerID(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
v *OriginServer
|
||||
want string
|
||||
}{
|
||||
{"populated", &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1234"}, "srv-svc-1234"},
|
||||
{"empty", &OriginServer{}, "--"},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.v.ID(); got != tt.want {
|
||||
t.Fatalf("ID()=%q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginServerString(t *testing.T) {
|
||||
// String() routes through Format with the %v default branch.
|
||||
v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"}
|
||||
got := v.String()
|
||||
if want := "SRS ip=1.2.3.4, id=srv-svc-p"; got != want {
|
||||
t.Fatalf("String()=%q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginServerFormat_ShortVerbs(t *testing.T) {
|
||||
v := &OriginServer{IP: "10.0.0.1", ServerID: "srv", ServiceID: "svc", PID: "9"}
|
||||
want := "SRS ip=10.0.0.1, id=srv-svc-9"
|
||||
for _, verb := range []string{"%v", "%s"} {
|
||||
got := fmt.Sprintf(verb, v)
|
||||
if got != want {
|
||||
t.Fatalf("Sprintf(%q)=%q, want %q", verb, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginServerFormat_PlusVerbsAllFields(t *testing.T) {
|
||||
ts := time.Date(2026, 5, 16, 10, 30, 45, 123_000_000, time.UTC)
|
||||
v := &OriginServer{
|
||||
IP: "10.0.0.1", DeviceID: "dev1",
|
||||
ServerID: "srv", ServiceID: "svc", PID: "9",
|
||||
RTMP: []string{":1935", ":1936"},
|
||||
HTTP: []string{":8080"},
|
||||
API: []string{":1985"},
|
||||
SRT: []string{":10080"},
|
||||
RTC: []string{":8000"},
|
||||
UpdatedAt: ts,
|
||||
}
|
||||
|
||||
for _, verb := range []string{"%+v", "%+s"} {
|
||||
got := fmt.Sprintf(verb, v)
|
||||
for _, sub := range []string{
|
||||
"SRS ip=10.0.0.1",
|
||||
"id=srv-svc-9",
|
||||
"pid=9, server=srv, service=svc",
|
||||
"device=dev1",
|
||||
"rtmp=[:1935,:1936]",
|
||||
"http=[:8080]",
|
||||
"api=[:1985]",
|
||||
"srt=[:10080]",
|
||||
"rtc=[:8000]",
|
||||
"update=2026-05-16 10:30:45.123",
|
||||
} {
|
||||
if !strings.Contains(got, sub) {
|
||||
t.Fatalf("Sprintf(%q)=%q missing %q", verb, got, sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginServerFormat_PlusVerbMinimal(t *testing.T) {
|
||||
// Plus verb with no optional fields populated exercises the false
|
||||
// branches of every "if len(X) > 0 / X != \"\"" guard in Format.
|
||||
v := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "9"}
|
||||
got := fmt.Sprintf("%+v", v)
|
||||
|
||||
if !strings.Contains(got, "pid=9, server=srv, service=svc") {
|
||||
t.Fatalf("%%+v output %q missing core ids", got)
|
||||
}
|
||||
if !strings.Contains(got, "update=") {
|
||||
t.Fatalf("%%+v output %q missing update timestamp", got)
|
||||
}
|
||||
for _, sub := range []string{"device=", "rtmp=", "http=", "api=", "srt=", "rtc="} {
|
||||
if strings.Contains(got, sub) {
|
||||
t.Fatalf("%%+v output %q should not contain %q for an empty field", got, sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginServerFormat_OtherVerb(t *testing.T) {
|
||||
// A non-v/s verb falls through to the default branch, which recursively
|
||||
// formats with %v and appends ", fmt=%<verb>".
|
||||
v := &OriginServer{IP: "1.2.3.4", ServerID: "srv", ServiceID: "svc", PID: "p"}
|
||||
got := fmt.Sprintf("%d", v)
|
||||
want := "SRS ip=1.2.3.4, id=srv-svc-p, fmt=%d"
|
||||
if got != want {
|
||||
t.Fatalf("%%d output %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOriginServer(t *testing.T) {
|
||||
t.Run("no opts", func(t *testing.T) {
|
||||
v := NewOriginServer()
|
||||
if v == nil {
|
||||
t.Fatal("NewOriginServer() returned nil")
|
||||
}
|
||||
if v.IP != "" || v.DeviceID != "" || v.ServerID != "" || v.ServiceID != "" || v.PID != "" {
|
||||
t.Fatalf("expected zero value, got %+v", v)
|
||||
}
|
||||
if len(v.RTMP)+len(v.HTTP)+len(v.API)+len(v.SRT)+len(v.RTC) != 0 {
|
||||
t.Fatalf("expected empty endpoints, got %+v", v)
|
||||
}
|
||||
if !v.UpdatedAt.IsZero() {
|
||||
t.Fatalf("expected zero UpdatedAt, got %v", v.UpdatedAt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with opts", func(t *testing.T) {
|
||||
v := NewOriginServer(
|
||||
func(s *OriginServer) { s.IP = "9.9.9.9" },
|
||||
func(s *OriginServer) { s.ServerID = "abc" },
|
||||
func(s *OriginServer) { s.RTMP = []string{":1935"} },
|
||||
)
|
||||
if v.IP != "9.9.9.9" || v.ServerID != "abc" || len(v.RTMP) != 1 || v.RTMP[0] != ":1935" {
|
||||
t.Fatalf("opts not applied: got %+v", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
197
internal/lb/lbfakes/fake_hls_service.go
Normal file
197
internal/lb/lbfakes/fake_hls_service.go
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
// Code generated by counterfeiter. DO NOT EDIT.
|
||||
package lbfakes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"srsx/internal/lb"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FakeHLSService struct {
|
||||
LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error)
|
||||
loadHLSBySPBHIDMutex sync.RWMutex
|
||||
loadHLSBySPBHIDArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
loadHLSBySPBHIDReturns struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
loadHLSBySPBHIDReturnsOnCall map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)
|
||||
loadOrStoreHLSMutex sync.RWMutex
|
||||
loadOrStoreHLSArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.HLSPlayStream
|
||||
}
|
||||
loadOrStoreHLSReturns struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
loadOrStoreHLSReturnsOnCall map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
invocations map[string][][]interface{}
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)]
|
||||
fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.LoadHLSBySPBHIDStub
|
||||
fakeReturns := fake.loadHLSBySPBHIDReturns
|
||||
fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2})
|
||||
fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadHLSBySPBHIDCallCount() int {
|
||||
fake.loadHLSBySPBHIDMutex.RLock()
|
||||
defer fake.loadHLSBySPBHIDMutex.RUnlock()
|
||||
return len(fake.loadHLSBySPBHIDArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
defer fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
fake.LoadHLSBySPBHIDStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) {
|
||||
fake.loadHLSBySPBHIDMutex.RLock()
|
||||
defer fake.loadHLSBySPBHIDMutex.RUnlock()
|
||||
argsForCall := fake.loadHLSBySPBHIDArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
defer fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
fake.LoadHLSBySPBHIDStub = nil
|
||||
fake.loadHLSBySPBHIDReturns = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
defer fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
fake.LoadHLSBySPBHIDStub = nil
|
||||
if fake.loadHLSBySPBHIDReturnsOnCall == nil {
|
||||
fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.loadHLSBySPBHIDReturnsOnCall[i] = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)]
|
||||
fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.HLSPlayStream
|
||||
}{arg1, arg2, arg3})
|
||||
stub := fake.LoadOrStoreHLSStub
|
||||
fakeReturns := fake.loadOrStoreHLSReturns
|
||||
fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3})
|
||||
fake.loadOrStoreHLSMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2, arg3)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadOrStoreHLSCallCount() int {
|
||||
fake.loadOrStoreHLSMutex.RLock()
|
||||
defer fake.loadOrStoreHLSMutex.RUnlock()
|
||||
return len(fake.loadOrStoreHLSArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
defer fake.loadOrStoreHLSMutex.Unlock()
|
||||
fake.LoadOrStoreHLSStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) {
|
||||
fake.loadOrStoreHLSMutex.RLock()
|
||||
defer fake.loadOrStoreHLSMutex.RUnlock()
|
||||
argsForCall := fake.loadOrStoreHLSArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
defer fake.loadOrStoreHLSMutex.Unlock()
|
||||
fake.LoadOrStoreHLSStub = nil
|
||||
fake.loadOrStoreHLSReturns = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
defer fake.loadOrStoreHLSMutex.Unlock()
|
||||
fake.LoadOrStoreHLSStub = nil
|
||||
if fake.loadOrStoreHLSReturnsOnCall == nil {
|
||||
fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.loadOrStoreHLSReturnsOnCall[i] = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
copiedInvocations := map[string][][]interface{}{}
|
||||
for key, value := range fake.invocations {
|
||||
copiedInvocations[key] = value
|
||||
}
|
||||
return copiedInvocations
|
||||
}
|
||||
|
||||
func (fake *FakeHLSService) recordInvocation(key string, args []interface{}) {
|
||||
fake.invocationsMutex.Lock()
|
||||
defer fake.invocationsMutex.Unlock()
|
||||
if fake.invocations == nil {
|
||||
fake.invocations = map[string][][]interface{}{}
|
||||
}
|
||||
if fake.invocations[key] == nil {
|
||||
fake.invocations[key] = [][]interface{}{}
|
||||
}
|
||||
fake.invocations[key] = append(fake.invocations[key], args)
|
||||
}
|
||||
|
||||
var _ lb.HLSService = new(FakeHLSService)
|
||||
577
internal/lb/lbfakes/fake_origin_load_balancer.go
Normal file
577
internal/lb/lbfakes/fake_origin_load_balancer.go
Normal file
|
|
@ -0,0 +1,577 @@
|
|||
// Code generated by counterfeiter. DO NOT EDIT.
|
||||
package lbfakes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"srsx/internal/lb"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FakeOriginLoadBalancer struct {
|
||||
InitializeStub func(context.Context) error
|
||||
initializeMutex sync.RWMutex
|
||||
initializeArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
}
|
||||
initializeReturns struct {
|
||||
result1 error
|
||||
}
|
||||
initializeReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
LoadHLSBySPBHIDStub func(context.Context, string) (lb.HLSPlayStream, error)
|
||||
loadHLSBySPBHIDMutex sync.RWMutex
|
||||
loadHLSBySPBHIDArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
loadHLSBySPBHIDReturns struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
loadHLSBySPBHIDReturnsOnCall map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
LoadOrStoreHLSStub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)
|
||||
loadOrStoreHLSMutex sync.RWMutex
|
||||
loadOrStoreHLSArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.HLSPlayStream
|
||||
}
|
||||
loadOrStoreHLSReturns struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
loadOrStoreHLSReturnsOnCall map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}
|
||||
LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error)
|
||||
loadWebRTCByUfragMutex sync.RWMutex
|
||||
loadWebRTCByUfragArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
loadWebRTCByUfragReturns struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}
|
||||
loadWebRTCByUfragReturnsOnCall map[int]struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}
|
||||
PickStub func(context.Context, string) (*lb.OriginServer, error)
|
||||
pickMutex sync.RWMutex
|
||||
pickArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
pickReturns struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}
|
||||
pickReturnsOnCall map[int]struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}
|
||||
StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error
|
||||
storeWebRTCMutex sync.RWMutex
|
||||
storeWebRTCArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.RTCConnection
|
||||
}
|
||||
storeWebRTCReturns struct {
|
||||
result1 error
|
||||
}
|
||||
storeWebRTCReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
UpdateStub func(context.Context, *lb.OriginServer) error
|
||||
updateMutex sync.RWMutex
|
||||
updateArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 *lb.OriginServer
|
||||
}
|
||||
updateReturns struct {
|
||||
result1 error
|
||||
}
|
||||
updateReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
invocations map[string][][]interface{}
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) Initialize(arg1 context.Context) error {
|
||||
fake.initializeMutex.Lock()
|
||||
ret, specificReturn := fake.initializeReturnsOnCall[len(fake.initializeArgsForCall)]
|
||||
fake.initializeArgsForCall = append(fake.initializeArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
}{arg1})
|
||||
stub := fake.InitializeStub
|
||||
fakeReturns := fake.initializeReturns
|
||||
fake.recordInvocation("Initialize", []interface{}{arg1})
|
||||
fake.initializeMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) InitializeCallCount() int {
|
||||
fake.initializeMutex.RLock()
|
||||
defer fake.initializeMutex.RUnlock()
|
||||
return len(fake.initializeArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) InitializeCalls(stub func(context.Context) error) {
|
||||
fake.initializeMutex.Lock()
|
||||
defer fake.initializeMutex.Unlock()
|
||||
fake.InitializeStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) InitializeArgsForCall(i int) context.Context {
|
||||
fake.initializeMutex.RLock()
|
||||
defer fake.initializeMutex.RUnlock()
|
||||
argsForCall := fake.initializeArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) InitializeReturns(result1 error) {
|
||||
fake.initializeMutex.Lock()
|
||||
defer fake.initializeMutex.Unlock()
|
||||
fake.InitializeStub = nil
|
||||
fake.initializeReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) InitializeReturnsOnCall(i int, result1 error) {
|
||||
fake.initializeMutex.Lock()
|
||||
defer fake.initializeMutex.Unlock()
|
||||
fake.InitializeStub = nil
|
||||
if fake.initializeReturnsOnCall == nil {
|
||||
fake.initializeReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.initializeReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHID(arg1 context.Context, arg2 string) (lb.HLSPlayStream, error) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
ret, specificReturn := fake.loadHLSBySPBHIDReturnsOnCall[len(fake.loadHLSBySPBHIDArgsForCall)]
|
||||
fake.loadHLSBySPBHIDArgsForCall = append(fake.loadHLSBySPBHIDArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.LoadHLSBySPBHIDStub
|
||||
fakeReturns := fake.loadHLSBySPBHIDReturns
|
||||
fake.recordInvocation("LoadHLSBySPBHID", []interface{}{arg1, arg2})
|
||||
fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCallCount() int {
|
||||
fake.loadHLSBySPBHIDMutex.RLock()
|
||||
defer fake.loadHLSBySPBHIDMutex.RUnlock()
|
||||
return len(fake.loadHLSBySPBHIDArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDCalls(stub func(context.Context, string) (lb.HLSPlayStream, error)) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
defer fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
fake.LoadHLSBySPBHIDStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDArgsForCall(i int) (context.Context, string) {
|
||||
fake.loadHLSBySPBHIDMutex.RLock()
|
||||
defer fake.loadHLSBySPBHIDMutex.RUnlock()
|
||||
argsForCall := fake.loadHLSBySPBHIDArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturns(result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
defer fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
fake.LoadHLSBySPBHIDStub = nil
|
||||
fake.loadHLSBySPBHIDReturns = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadHLSBySPBHIDReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadHLSBySPBHIDMutex.Lock()
|
||||
defer fake.loadHLSBySPBHIDMutex.Unlock()
|
||||
fake.LoadHLSBySPBHIDStub = nil
|
||||
if fake.loadHLSBySPBHIDReturnsOnCall == nil {
|
||||
fake.loadHLSBySPBHIDReturnsOnCall = make(map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.loadHLSBySPBHIDReturnsOnCall[i] = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLS(arg1 context.Context, arg2 string, arg3 lb.HLSPlayStream) (lb.HLSPlayStream, error) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
ret, specificReturn := fake.loadOrStoreHLSReturnsOnCall[len(fake.loadOrStoreHLSArgsForCall)]
|
||||
fake.loadOrStoreHLSArgsForCall = append(fake.loadOrStoreHLSArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.HLSPlayStream
|
||||
}{arg1, arg2, arg3})
|
||||
stub := fake.LoadOrStoreHLSStub
|
||||
fakeReturns := fake.loadOrStoreHLSReturns
|
||||
fake.recordInvocation("LoadOrStoreHLS", []interface{}{arg1, arg2, arg3})
|
||||
fake.loadOrStoreHLSMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2, arg3)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCallCount() int {
|
||||
fake.loadOrStoreHLSMutex.RLock()
|
||||
defer fake.loadOrStoreHLSMutex.RUnlock()
|
||||
return len(fake.loadOrStoreHLSArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSCalls(stub func(context.Context, string, lb.HLSPlayStream) (lb.HLSPlayStream, error)) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
defer fake.loadOrStoreHLSMutex.Unlock()
|
||||
fake.LoadOrStoreHLSStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSArgsForCall(i int) (context.Context, string, lb.HLSPlayStream) {
|
||||
fake.loadOrStoreHLSMutex.RLock()
|
||||
defer fake.loadOrStoreHLSMutex.RUnlock()
|
||||
argsForCall := fake.loadOrStoreHLSArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturns(result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
defer fake.loadOrStoreHLSMutex.Unlock()
|
||||
fake.LoadOrStoreHLSStub = nil
|
||||
fake.loadOrStoreHLSReturns = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadOrStoreHLSReturnsOnCall(i int, result1 lb.HLSPlayStream, result2 error) {
|
||||
fake.loadOrStoreHLSMutex.Lock()
|
||||
defer fake.loadOrStoreHLSMutex.Unlock()
|
||||
fake.LoadOrStoreHLSStub = nil
|
||||
if fake.loadOrStoreHLSReturnsOnCall == nil {
|
||||
fake.loadOrStoreHLSReturnsOnCall = make(map[int]struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.loadOrStoreHLSReturnsOnCall[i] = struct {
|
||||
result1 lb.HLSPlayStream
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)]
|
||||
fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.LoadWebRTCByUfragStub
|
||||
fakeReturns := fake.loadWebRTCByUfragReturns
|
||||
fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2})
|
||||
fake.loadWebRTCByUfragMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCallCount() int {
|
||||
fake.loadWebRTCByUfragMutex.RLock()
|
||||
defer fake.loadWebRTCByUfragMutex.RUnlock()
|
||||
return len(fake.loadWebRTCByUfragArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
defer fake.loadWebRTCByUfragMutex.Unlock()
|
||||
fake.LoadWebRTCByUfragStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) {
|
||||
fake.loadWebRTCByUfragMutex.RLock()
|
||||
defer fake.loadWebRTCByUfragMutex.RUnlock()
|
||||
argsForCall := fake.loadWebRTCByUfragArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
defer fake.loadWebRTCByUfragMutex.Unlock()
|
||||
fake.LoadWebRTCByUfragStub = nil
|
||||
fake.loadWebRTCByUfragReturns = struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
defer fake.loadWebRTCByUfragMutex.Unlock()
|
||||
fake.LoadWebRTCByUfragStub = nil
|
||||
if fake.loadWebRTCByUfragReturnsOnCall == nil {
|
||||
fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.loadWebRTCByUfragReturnsOnCall[i] = struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) {
|
||||
fake.pickMutex.Lock()
|
||||
ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)]
|
||||
fake.pickArgsForCall = append(fake.pickArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.PickStub
|
||||
fakeReturns := fake.pickReturns
|
||||
fake.recordInvocation("Pick", []interface{}{arg1, arg2})
|
||||
fake.pickMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) PickCallCount() int {
|
||||
fake.pickMutex.RLock()
|
||||
defer fake.pickMutex.RUnlock()
|
||||
return len(fake.pickArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) {
|
||||
fake.pickMutex.Lock()
|
||||
defer fake.pickMutex.Unlock()
|
||||
fake.PickStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) PickArgsForCall(i int) (context.Context, string) {
|
||||
fake.pickMutex.RLock()
|
||||
defer fake.pickMutex.RUnlock()
|
||||
argsForCall := fake.pickArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) PickReturns(result1 *lb.OriginServer, result2 error) {
|
||||
fake.pickMutex.Lock()
|
||||
defer fake.pickMutex.Unlock()
|
||||
fake.PickStub = nil
|
||||
fake.pickReturns = struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) {
|
||||
fake.pickMutex.Lock()
|
||||
defer fake.pickMutex.Unlock()
|
||||
fake.PickStub = nil
|
||||
if fake.pickReturnsOnCall == nil {
|
||||
fake.pickReturnsOnCall = make(map[int]struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.pickReturnsOnCall[i] = struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)]
|
||||
fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.RTCConnection
|
||||
}{arg1, arg2, arg3})
|
||||
stub := fake.StoreWebRTCStub
|
||||
fakeReturns := fake.storeWebRTCReturns
|
||||
fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3})
|
||||
fake.storeWebRTCMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2, arg3)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) StoreWebRTCCallCount() int {
|
||||
fake.storeWebRTCMutex.RLock()
|
||||
defer fake.storeWebRTCMutex.RUnlock()
|
||||
return len(fake.storeWebRTCArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
defer fake.storeWebRTCMutex.Unlock()
|
||||
fake.StoreWebRTCStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) {
|
||||
fake.storeWebRTCMutex.RLock()
|
||||
defer fake.storeWebRTCMutex.RUnlock()
|
||||
argsForCall := fake.storeWebRTCArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) StoreWebRTCReturns(result1 error) {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
defer fake.storeWebRTCMutex.Unlock()
|
||||
fake.StoreWebRTCStub = nil
|
||||
fake.storeWebRTCReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) StoreWebRTCReturnsOnCall(i int, result1 error) {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
defer fake.storeWebRTCMutex.Unlock()
|
||||
fake.StoreWebRTCStub = nil
|
||||
if fake.storeWebRTCReturnsOnCall == nil {
|
||||
fake.storeWebRTCReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.storeWebRTCReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) Update(arg1 context.Context, arg2 *lb.OriginServer) error {
|
||||
fake.updateMutex.Lock()
|
||||
ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)]
|
||||
fake.updateArgsForCall = append(fake.updateArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 *lb.OriginServer
|
||||
}{arg1, arg2})
|
||||
stub := fake.UpdateStub
|
||||
fakeReturns := fake.updateReturns
|
||||
fake.recordInvocation("Update", []interface{}{arg1, arg2})
|
||||
fake.updateMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) UpdateCallCount() int {
|
||||
fake.updateMutex.RLock()
|
||||
defer fake.updateMutex.RUnlock()
|
||||
return len(fake.updateArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) {
|
||||
fake.updateMutex.Lock()
|
||||
defer fake.updateMutex.Unlock()
|
||||
fake.UpdateStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) {
|
||||
fake.updateMutex.RLock()
|
||||
defer fake.updateMutex.RUnlock()
|
||||
argsForCall := fake.updateArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) UpdateReturns(result1 error) {
|
||||
fake.updateMutex.Lock()
|
||||
defer fake.updateMutex.Unlock()
|
||||
fake.UpdateStub = nil
|
||||
fake.updateReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) UpdateReturnsOnCall(i int, result1 error) {
|
||||
fake.updateMutex.Lock()
|
||||
defer fake.updateMutex.Unlock()
|
||||
fake.UpdateStub = nil
|
||||
if fake.updateReturnsOnCall == nil {
|
||||
fake.updateReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.updateReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
copiedInvocations := map[string][][]interface{}{}
|
||||
for key, value := range fake.invocations {
|
||||
copiedInvocations[key] = value
|
||||
}
|
||||
return copiedInvocations
|
||||
}
|
||||
|
||||
func (fake *FakeOriginLoadBalancer) recordInvocation(key string, args []interface{}) {
|
||||
fake.invocationsMutex.Lock()
|
||||
defer fake.invocationsMutex.Unlock()
|
||||
if fake.invocations == nil {
|
||||
fake.invocations = map[string][][]interface{}{}
|
||||
}
|
||||
if fake.invocations[key] == nil {
|
||||
fake.invocations[key] = [][]interface{}{}
|
||||
}
|
||||
fake.invocations[key] = append(fake.invocations[key], args)
|
||||
}
|
||||
|
||||
var _ lb.OriginLoadBalancer = new(FakeOriginLoadBalancer)
|
||||
190
internal/lb/lbfakes/fake_origin_service.go
Normal file
190
internal/lb/lbfakes/fake_origin_service.go
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
// Code generated by counterfeiter. DO NOT EDIT.
|
||||
package lbfakes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"srsx/internal/lb"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FakeOriginService struct {
|
||||
PickStub func(context.Context, string) (*lb.OriginServer, error)
|
||||
pickMutex sync.RWMutex
|
||||
pickArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
pickReturns struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}
|
||||
pickReturnsOnCall map[int]struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}
|
||||
UpdateStub func(context.Context, *lb.OriginServer) error
|
||||
updateMutex sync.RWMutex
|
||||
updateArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 *lb.OriginServer
|
||||
}
|
||||
updateReturns struct {
|
||||
result1 error
|
||||
}
|
||||
updateReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
invocations map[string][][]interface{}
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) Pick(arg1 context.Context, arg2 string) (*lb.OriginServer, error) {
|
||||
fake.pickMutex.Lock()
|
||||
ret, specificReturn := fake.pickReturnsOnCall[len(fake.pickArgsForCall)]
|
||||
fake.pickArgsForCall = append(fake.pickArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.PickStub
|
||||
fakeReturns := fake.pickReturns
|
||||
fake.recordInvocation("Pick", []interface{}{arg1, arg2})
|
||||
fake.pickMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) PickCallCount() int {
|
||||
fake.pickMutex.RLock()
|
||||
defer fake.pickMutex.RUnlock()
|
||||
return len(fake.pickArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) PickCalls(stub func(context.Context, string) (*lb.OriginServer, error)) {
|
||||
fake.pickMutex.Lock()
|
||||
defer fake.pickMutex.Unlock()
|
||||
fake.PickStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) PickArgsForCall(i int) (context.Context, string) {
|
||||
fake.pickMutex.RLock()
|
||||
defer fake.pickMutex.RUnlock()
|
||||
argsForCall := fake.pickArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) PickReturns(result1 *lb.OriginServer, result2 error) {
|
||||
fake.pickMutex.Lock()
|
||||
defer fake.pickMutex.Unlock()
|
||||
fake.PickStub = nil
|
||||
fake.pickReturns = struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) PickReturnsOnCall(i int, result1 *lb.OriginServer, result2 error) {
|
||||
fake.pickMutex.Lock()
|
||||
defer fake.pickMutex.Unlock()
|
||||
fake.PickStub = nil
|
||||
if fake.pickReturnsOnCall == nil {
|
||||
fake.pickReturnsOnCall = make(map[int]struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.pickReturnsOnCall[i] = struct {
|
||||
result1 *lb.OriginServer
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) Update(arg1 context.Context, arg2 *lb.OriginServer) error {
|
||||
fake.updateMutex.Lock()
|
||||
ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)]
|
||||
fake.updateArgsForCall = append(fake.updateArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 *lb.OriginServer
|
||||
}{arg1, arg2})
|
||||
stub := fake.UpdateStub
|
||||
fakeReturns := fake.updateReturns
|
||||
fake.recordInvocation("Update", []interface{}{arg1, arg2})
|
||||
fake.updateMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) UpdateCallCount() int {
|
||||
fake.updateMutex.RLock()
|
||||
defer fake.updateMutex.RUnlock()
|
||||
return len(fake.updateArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) UpdateCalls(stub func(context.Context, *lb.OriginServer) error) {
|
||||
fake.updateMutex.Lock()
|
||||
defer fake.updateMutex.Unlock()
|
||||
fake.UpdateStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) UpdateArgsForCall(i int) (context.Context, *lb.OriginServer) {
|
||||
fake.updateMutex.RLock()
|
||||
defer fake.updateMutex.RUnlock()
|
||||
argsForCall := fake.updateArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) UpdateReturns(result1 error) {
|
||||
fake.updateMutex.Lock()
|
||||
defer fake.updateMutex.Unlock()
|
||||
fake.UpdateStub = nil
|
||||
fake.updateReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) UpdateReturnsOnCall(i int, result1 error) {
|
||||
fake.updateMutex.Lock()
|
||||
defer fake.updateMutex.Unlock()
|
||||
fake.UpdateStub = nil
|
||||
if fake.updateReturnsOnCall == nil {
|
||||
fake.updateReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.updateReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
copiedInvocations := map[string][][]interface{}{}
|
||||
for key, value := range fake.invocations {
|
||||
copiedInvocations[key] = value
|
||||
}
|
||||
return copiedInvocations
|
||||
}
|
||||
|
||||
func (fake *FakeOriginService) recordInvocation(key string, args []interface{}) {
|
||||
fake.invocationsMutex.Lock()
|
||||
defer fake.invocationsMutex.Unlock()
|
||||
if fake.invocations == nil {
|
||||
fake.invocations = map[string][][]interface{}{}
|
||||
}
|
||||
if fake.invocations[key] == nil {
|
||||
fake.invocations[key] = [][]interface{}{}
|
||||
}
|
||||
fake.invocations[key] = append(fake.invocations[key], args)
|
||||
}
|
||||
|
||||
var _ lb.OriginService = new(FakeOriginService)
|
||||
192
internal/lb/lbfakes/fake_rtc_service.go
Normal file
192
internal/lb/lbfakes/fake_rtc_service.go
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
// Code generated by counterfeiter. DO NOT EDIT.
|
||||
package lbfakes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"srsx/internal/lb"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FakeRTCService struct {
|
||||
LoadWebRTCByUfragStub func(context.Context, string) (lb.RTCConnection, error)
|
||||
loadWebRTCByUfragMutex sync.RWMutex
|
||||
loadWebRTCByUfragArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
loadWebRTCByUfragReturns struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}
|
||||
loadWebRTCByUfragReturnsOnCall map[int]struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}
|
||||
StoreWebRTCStub func(context.Context, string, lb.RTCConnection) error
|
||||
storeWebRTCMutex sync.RWMutex
|
||||
storeWebRTCArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.RTCConnection
|
||||
}
|
||||
storeWebRTCReturns struct {
|
||||
result1 error
|
||||
}
|
||||
storeWebRTCReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
invocations map[string][][]interface{}
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) LoadWebRTCByUfrag(arg1 context.Context, arg2 string) (lb.RTCConnection, error) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
ret, specificReturn := fake.loadWebRTCByUfragReturnsOnCall[len(fake.loadWebRTCByUfragArgsForCall)]
|
||||
fake.loadWebRTCByUfragArgsForCall = append(fake.loadWebRTCByUfragArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.LoadWebRTCByUfragStub
|
||||
fakeReturns := fake.loadWebRTCByUfragReturns
|
||||
fake.recordInvocation("LoadWebRTCByUfrag", []interface{}{arg1, arg2})
|
||||
fake.loadWebRTCByUfragMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1, ret.result2
|
||||
}
|
||||
return fakeReturns.result1, fakeReturns.result2
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) LoadWebRTCByUfragCallCount() int {
|
||||
fake.loadWebRTCByUfragMutex.RLock()
|
||||
defer fake.loadWebRTCByUfragMutex.RUnlock()
|
||||
return len(fake.loadWebRTCByUfragArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) LoadWebRTCByUfragCalls(stub func(context.Context, string) (lb.RTCConnection, error)) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
defer fake.loadWebRTCByUfragMutex.Unlock()
|
||||
fake.LoadWebRTCByUfragStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) LoadWebRTCByUfragArgsForCall(i int) (context.Context, string) {
|
||||
fake.loadWebRTCByUfragMutex.RLock()
|
||||
defer fake.loadWebRTCByUfragMutex.RUnlock()
|
||||
argsForCall := fake.loadWebRTCByUfragArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) LoadWebRTCByUfragReturns(result1 lb.RTCConnection, result2 error) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
defer fake.loadWebRTCByUfragMutex.Unlock()
|
||||
fake.LoadWebRTCByUfragStub = nil
|
||||
fake.loadWebRTCByUfragReturns = struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) LoadWebRTCByUfragReturnsOnCall(i int, result1 lb.RTCConnection, result2 error) {
|
||||
fake.loadWebRTCByUfragMutex.Lock()
|
||||
defer fake.loadWebRTCByUfragMutex.Unlock()
|
||||
fake.LoadWebRTCByUfragStub = nil
|
||||
if fake.loadWebRTCByUfragReturnsOnCall == nil {
|
||||
fake.loadWebRTCByUfragReturnsOnCall = make(map[int]struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
})
|
||||
}
|
||||
fake.loadWebRTCByUfragReturnsOnCall[i] = struct {
|
||||
result1 lb.RTCConnection
|
||||
result2 error
|
||||
}{result1, result2}
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) StoreWebRTC(arg1 context.Context, arg2 string, arg3 lb.RTCConnection) error {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
ret, specificReturn := fake.storeWebRTCReturnsOnCall[len(fake.storeWebRTCArgsForCall)]
|
||||
fake.storeWebRTCArgsForCall = append(fake.storeWebRTCArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 lb.RTCConnection
|
||||
}{arg1, arg2, arg3})
|
||||
stub := fake.StoreWebRTCStub
|
||||
fakeReturns := fake.storeWebRTCReturns
|
||||
fake.recordInvocation("StoreWebRTC", []interface{}{arg1, arg2, arg3})
|
||||
fake.storeWebRTCMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2, arg3)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) StoreWebRTCCallCount() int {
|
||||
fake.storeWebRTCMutex.RLock()
|
||||
defer fake.storeWebRTCMutex.RUnlock()
|
||||
return len(fake.storeWebRTCArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) StoreWebRTCCalls(stub func(context.Context, string, lb.RTCConnection) error) {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
defer fake.storeWebRTCMutex.Unlock()
|
||||
fake.StoreWebRTCStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) StoreWebRTCArgsForCall(i int) (context.Context, string, lb.RTCConnection) {
|
||||
fake.storeWebRTCMutex.RLock()
|
||||
defer fake.storeWebRTCMutex.RUnlock()
|
||||
argsForCall := fake.storeWebRTCArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) StoreWebRTCReturns(result1 error) {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
defer fake.storeWebRTCMutex.Unlock()
|
||||
fake.StoreWebRTCStub = nil
|
||||
fake.storeWebRTCReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) StoreWebRTCReturnsOnCall(i int, result1 error) {
|
||||
fake.storeWebRTCMutex.Lock()
|
||||
defer fake.storeWebRTCMutex.Unlock()
|
||||
fake.StoreWebRTCStub = nil
|
||||
if fake.storeWebRTCReturnsOnCall == nil {
|
||||
fake.storeWebRTCReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.storeWebRTCReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
copiedInvocations := map[string][][]interface{}{}
|
||||
for key, value := range fake.invocations {
|
||||
copiedInvocations[key] = value
|
||||
}
|
||||
return copiedInvocations
|
||||
}
|
||||
|
||||
func (fake *FakeRTCService) recordInvocation(key string, args []interface{}) {
|
||||
fake.invocationsMutex.Lock()
|
||||
defer fake.invocationsMutex.Unlock()
|
||||
if fake.invocations == nil {
|
||||
fake.invocations = map[string][][]interface{}{}
|
||||
}
|
||||
if fake.invocations[key] == nil {
|
||||
fake.invocations[key] = [][]interface{}{}
|
||||
}
|
||||
fake.invocations[key] = append(fake.invocations[key], args)
|
||||
}
|
||||
|
||||
var _ lb.RTCService = new(FakeRTCService)
|
||||
|
|
@ -31,18 +31,23 @@ type memoryLoadBalancer struct {
|
|||
rtcStreamURL sync.Map[string, RTCConnection]
|
||||
// The WebRTC streaming, key is ufrag.
|
||||
rtcUfrag sync.Map[string, RTCConnection]
|
||||
// keepaliveInterval is the period at which the default-backend keep-alive
|
||||
// goroutine re-Updates its registration. Struct field for test injection
|
||||
// (avoids racing a package global across concurrent tests).
|
||||
keepaliveInterval time.Duration
|
||||
}
|
||||
|
||||
// NewMemoryLoadBalancer creates a new memory-based load balancer.
|
||||
func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
|
||||
return &memoryLoadBalancer{
|
||||
environment: environment,
|
||||
servers: sync.NewMap[string, *OriginServer](),
|
||||
picked: sync.NewMap[string, *OriginServer](),
|
||||
hlsStreamURL: sync.NewMap[string, HLSPlayStream](),
|
||||
hlsSPBHID: sync.NewMap[string, HLSPlayStream](),
|
||||
rtcStreamURL: sync.NewMap[string, RTCConnection](),
|
||||
rtcUfrag: sync.NewMap[string, RTCConnection](),
|
||||
environment: environment,
|
||||
servers: sync.NewMap[string, *OriginServer](),
|
||||
picked: sync.NewMap[string, *OriginServer](),
|
||||
hlsStreamURL: sync.NewMap[string, HLSPlayStream](),
|
||||
hlsSPBHID: sync.NewMap[string, HLSPlayStream](),
|
||||
rtcStreamURL: sync.NewMap[string, RTCConnection](),
|
||||
rtcUfrag: sync.NewMap[string, RTCConnection](),
|
||||
keepaliveInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,7 +68,7 @@ func (v *memoryLoadBalancer) Initialize(ctx context.Context) error {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-time.After(v.keepaliveInterval):
|
||||
if err := v.Update(ctx, server); err != nil {
|
||||
logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err)
|
||||
}
|
||||
|
|
|
|||
263
internal/lb/mem_test.go
Normal file
263
internal/lb/mem_test.go
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package lb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"srsx/internal/env/envfakes"
|
||||
)
|
||||
|
||||
// stubHLS is a minimal HLSPlayStream for testing.
|
||||
type stubHLS struct {
|
||||
spbhid string
|
||||
}
|
||||
|
||||
func (s *stubHLS) GetSPBHID() string { return s.spbhid }
|
||||
func (s *stubHLS) Initialize(ctx context.Context) HLSPlayStream { return s }
|
||||
|
||||
// stubRTC is a minimal RTCConnection for testing.
|
||||
type stubRTC struct {
|
||||
ufrag string
|
||||
}
|
||||
|
||||
func (s *stubRTC) GetUfrag() string { return s.ufrag }
|
||||
|
||||
// newMem returns a fresh in-memory load balancer with a default fake env.
|
||||
func newMem() *memoryLoadBalancer {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
return NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
||||
}
|
||||
|
||||
func TestNewMemoryLoadBalancer(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
lb := NewMemoryLoadBalancer(env)
|
||||
if lb == nil {
|
||||
t.Fatal("NewMemoryLoadBalancer returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_Initialize_DefaultBackendDisabled(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.DefaultBackendEnabledReturns("off")
|
||||
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
||||
if err := lb.Initialize(context.Background()); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
// No server stored when disabled.
|
||||
count := 0
|
||||
lb.servers.Range(func(string, *OriginServer) bool { count++; return true })
|
||||
if count != 0 {
|
||||
t.Fatalf("expected 0 servers, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_Initialize_DefaultBackendError(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.DefaultBackendEnabledReturns("on")
|
||||
env.DefaultBackendIPReturns("") // triggers "empty default backend ip"
|
||||
lb := NewMemoryLoadBalancer(env)
|
||||
err := lb.Initialize(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "initialize default SRS") {
|
||||
t.Fatalf("expected wrapped error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_Initialize_KeepaliveTick(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.DefaultBackendEnabledReturns("on")
|
||||
env.DefaultBackendIPReturns("1.2.3.4")
|
||||
env.DefaultBackendRTMPReturns(":1935")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
||||
// Shorten the keep-alive interval on this instance only so concurrent
|
||||
// tests don't race on shared state.
|
||||
lb.keepaliveInterval = time.Millisecond
|
||||
if err := lb.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
// Find the server and watch UpdatedAt advance after a keep-alive tick.
|
||||
var s *OriginServer
|
||||
lb.servers.Range(func(_ string, v *OriginServer) bool { s = v; return false })
|
||||
if s == nil {
|
||||
t.Fatal("expected server stored")
|
||||
}
|
||||
first := s.UpdatedAt
|
||||
|
||||
// Wait long enough for several ticks (interval is 1ms, server.UpdatedAt
|
||||
// is set to time.Now() inside NewDefaultOriginServerForDebugging on each
|
||||
// Update? — actually Update only stores the server pointer, so UpdatedAt
|
||||
// won't change. The goroutine still hits the tick branch though, which
|
||||
// is all we need for coverage).
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
_ = first
|
||||
}
|
||||
|
||||
func TestMemLB_Initialize_DefaultBackendSuccess(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.DefaultBackendEnabledReturns("on")
|
||||
env.DefaultBackendIPReturns("1.2.3.4")
|
||||
env.DefaultBackendRTMPReturns(":1935")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
lb := NewMemoryLoadBalancer(env).(*memoryLoadBalancer)
|
||||
if err := lb.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
lb.servers.Range(func(string, *OriginServer) bool { count++; return true })
|
||||
if count != 1 {
|
||||
t.Fatalf("expected 1 server stored, got %d", count)
|
||||
}
|
||||
|
||||
// Cancel and give the keep-alive goroutine a moment to exit cleanly.
|
||||
cancel()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestMemLB_Update(t *testing.T) {
|
||||
lb := newMem()
|
||||
s := &OriginServer{ServerID: "srv", ServiceID: "svc", PID: "1"}
|
||||
if err := lb.Update(context.Background(), s); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
got, ok := lb.servers.Load(s.ID())
|
||||
if !ok || got != s {
|
||||
t.Fatalf("Update did not store the server: got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_Pick_NoServers(t *testing.T) {
|
||||
lb := newMem()
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "no server available") {
|
||||
t.Fatalf("expected no-server error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_Pick_AliveServer_Sticky(t *testing.T) {
|
||||
lb := newMem()
|
||||
s := &OriginServer{ServerID: "a", PID: "1", UpdatedAt: time.Now()}
|
||||
_ = lb.Update(context.Background(), s)
|
||||
|
||||
got, err := lb.Pick(context.Background(), "url1")
|
||||
if err != nil {
|
||||
t.Fatalf("Pick: %v", err)
|
||||
}
|
||||
if got != s {
|
||||
t.Fatalf("Pick returned %v, want %v", got, s)
|
||||
}
|
||||
|
||||
// Second pick for the same URL returns the same server (sticky branch).
|
||||
got2, err := lb.Pick(context.Background(), "url1")
|
||||
if err != nil {
|
||||
t.Fatalf("Pick second: %v", err)
|
||||
}
|
||||
if got2 != got {
|
||||
t.Fatalf("second Pick returned %v, want %v (sticky)", got2, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_Pick_OnlyDeadServers_Fallback(t *testing.T) {
|
||||
lb := newMem()
|
||||
// UpdatedAt long past => not alive. Tests the fallback "use all servers" branch.
|
||||
s := &OriginServer{
|
||||
ServerID: "a",
|
||||
PID: "1",
|
||||
UpdatedAt: time.Now().Add(-2 * ServerAliveDuration),
|
||||
}
|
||||
_ = lb.Update(context.Background(), s)
|
||||
|
||||
got, err := lb.Pick(context.Background(), "url1")
|
||||
if err != nil {
|
||||
t.Fatalf("Pick: %v", err)
|
||||
}
|
||||
if got != s {
|
||||
t.Fatalf("expected dead-server fallback to return %v, got %v", s, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_LoadHLSBySPBHID_NotFound(t *testing.T) {
|
||||
lb := newMem()
|
||||
_, err := lb.LoadHLSBySPBHID(context.Background(), "missing")
|
||||
if err == nil || !strings.Contains(err.Error(), "no HLS streaming") {
|
||||
t.Fatalf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_LoadOrStoreHLS_New(t *testing.T) {
|
||||
lb := newMem()
|
||||
s := &stubHLS{spbhid: "abc"}
|
||||
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadOrStoreHLS: %v", err)
|
||||
}
|
||||
if got != s {
|
||||
t.Fatalf("LoadOrStoreHLS returned %v, want %v", got, s)
|
||||
}
|
||||
|
||||
// Lookup via SPBHID works (dual-index write).
|
||||
bySPBHID, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadHLSBySPBHID: %v", err)
|
||||
}
|
||||
if bySPBHID != s {
|
||||
t.Fatalf("LoadHLSBySPBHID returned %v, want %v", bySPBHID, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_LoadOrStoreHLS_Existing(t *testing.T) {
|
||||
lb := newMem()
|
||||
s1 := &stubHLS{spbhid: "first"}
|
||||
s2 := &stubHLS{spbhid: "second"}
|
||||
_, _ = lb.LoadOrStoreHLS(context.Background(), "url1", s1)
|
||||
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", s2)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadOrStoreHLS: %v", err)
|
||||
}
|
||||
if got != s1 {
|
||||
t.Fatalf("expected existing s1, got %v", got)
|
||||
}
|
||||
// SPBHID 'second' (from the rejected s2) maps to the existing s1.
|
||||
bySPBHID, _ := lb.LoadHLSBySPBHID(context.Background(), "second")
|
||||
if bySPBHID != s1 {
|
||||
t.Fatalf("expected SPBHID 'second' to map to s1, got %v", bySPBHID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_StoreWebRTC_And_Load(t *testing.T) {
|
||||
lb := newMem()
|
||||
s := &stubRTC{ufrag: "ufrg1"}
|
||||
if err := lb.StoreWebRTC(context.Background(), "url1", s); err != nil {
|
||||
t.Fatalf("StoreWebRTC: %v", err)
|
||||
}
|
||||
got, err := lb.LoadWebRTCByUfrag(context.Background(), "ufrg1")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadWebRTCByUfrag: %v", err)
|
||||
}
|
||||
if got != s {
|
||||
t.Fatalf("got %v, want %v", got, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemLB_LoadWebRTCByUfrag_NotFound(t *testing.T) {
|
||||
lb := newMem()
|
||||
_, err := lb.LoadWebRTCByUfrag(context.Background(), "missing")
|
||||
if err == nil || !strings.Contains(err.Error(), "no WebRTC streaming") {
|
||||
t.Fatalf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -11,26 +11,33 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"srsx/internal/env"
|
||||
"srsx/internal/errors"
|
||||
"srsx/internal/logger"
|
||||
"srsx/internal/redisclient"
|
||||
)
|
||||
|
||||
// redisLoadBalancer stores state in Redis.
|
||||
type redisLoadBalancer struct {
|
||||
// The environment interface.
|
||||
environment env.ProxyEnvironment
|
||||
// The redis client sdk.
|
||||
rdb *redis.Client
|
||||
// The redis client.
|
||||
rdb redisclient.RedisClient
|
||||
// newClient is the factory used by Initialize to build the Redis client.
|
||||
// A struct field (rather than a package global) so concurrent tests can
|
||||
// each supply their own without racing on shared state.
|
||||
newClient func(addr, password string, db int) redisclient.RedisClient
|
||||
// keepaliveInterval is the period at which the default-backend keep-alive
|
||||
// goroutine re-Updates its registration. Struct field for test injection.
|
||||
keepaliveInterval time.Duration
|
||||
}
|
||||
|
||||
// NewRedisLoadBalancer creates a new Redis-based load balancer.
|
||||
func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
|
||||
return &redisLoadBalancer{
|
||||
environment: environment,
|
||||
environment: environment,
|
||||
newClient: redisclient.New,
|
||||
keepaliveInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -40,11 +47,11 @@ func (v *redisLoadBalancer) Initialize(ctx context.Context) error {
|
|||
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB())
|
||||
}
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()),
|
||||
Password: v.environment.RedisPassword(),
|
||||
DB: redisDatabase,
|
||||
})
|
||||
rdb := v.newClient(
|
||||
fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()),
|
||||
v.environment.RedisPassword(),
|
||||
redisDatabase,
|
||||
)
|
||||
v.rdb = rdb
|
||||
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
|
|
@ -68,7 +75,7 @@ func (v *redisLoadBalancer) Initialize(ctx context.Context) error {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-time.After(v.keepaliveInterval):
|
||||
if err := v.Update(ctx, server); err != nil {
|
||||
logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err)
|
||||
}
|
||||
|
|
|
|||
659
internal/lb/redis_test.go
Normal file
659
internal/lb/redis_test.go
Normal file
|
|
@ -0,0 +1,659 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package lb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"srsx/internal/env/envfakes"
|
||||
"srsx/internal/redisclient"
|
||||
"srsx/internal/redisclient/redisclientfakes"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// statusCmd returns a *redis.StatusCmd that resolves to the given error.
|
||||
func statusCmd(err error) *redis.StatusCmd {
|
||||
c := redis.NewStatusCmd(context.Background())
|
||||
if err != nil {
|
||||
c.SetErr(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// stringOK returns a *redis.StringCmd that resolves to the given bytes.
|
||||
func stringOK(b []byte) *redis.StringCmd {
|
||||
c := redis.NewStringCmd(context.Background())
|
||||
c.SetVal(string(b))
|
||||
return c
|
||||
}
|
||||
|
||||
// stringErr returns a *redis.StringCmd that resolves to the given error.
|
||||
func stringErr(err error) *redis.StringCmd {
|
||||
c := redis.NewStringCmd(context.Background())
|
||||
c.SetErr(err)
|
||||
return c
|
||||
}
|
||||
|
||||
// withFakeClient returns a fresh *redisLoadBalancer whose newClient factory is
|
||||
// wired to return the supplied fake. Each test gets its own instance, so
|
||||
// concurrent tests cannot race on shared state.
|
||||
func withFakeClient(env *envfakes.FakeProxyEnvironment, client redisclient.RedisClient) *redisLoadBalancer {
|
||||
lb := NewRedisLoadBalancer(env).(*redisLoadBalancer)
|
||||
lb.newClient = func(string, string, int) redisclient.RedisClient { return client }
|
||||
return lb
|
||||
}
|
||||
|
||||
// newRedisLB constructs a redisLoadBalancer with a fake rdb already wired in.
|
||||
// Used by tests that exercise methods other than Initialize.
|
||||
func newRedisLB(rdb redisclient.RedisClient) *redisLoadBalancer {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
lb := NewRedisLoadBalancer(env).(*redisLoadBalancer)
|
||||
lb.rdb = rdb
|
||||
return lb
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Constructor & Initialize.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestNewRedisLoadBalancer(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
if lb := NewRedisLoadBalancer(env); lb == nil {
|
||||
t.Fatal("NewRedisLoadBalancer returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Initialize_BadRedisDB(t *testing.T) {
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.RedisDBReturns("not-a-number")
|
||||
err := NewRedisLoadBalancer(env).Initialize(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid PROXY_REDIS_DB") {
|
||||
t.Fatalf("expected Atoi error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Initialize_PingFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.PingReturns(statusCmd(fmt.Errorf("connection refused")))
|
||||
fake.StringReturns("Redis<fake>")
|
||||
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.RedisDBReturns("0")
|
||||
err := withFakeClient(env, fake).Initialize(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "unable to connect to redis") {
|
||||
t.Fatalf("expected ping error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Initialize_DefaultBackendDisabled(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.PingReturns(statusCmd(nil))
|
||||
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.RedisDBReturns("0")
|
||||
// DefaultBackendEnabled defaults to "" (not "on") => no server registered.
|
||||
if err := withFakeClient(env, fake).Initialize(context.Background()); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Initialize_DefaultBackendError(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.PingReturns(statusCmd(nil))
|
||||
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.RedisDBReturns("0")
|
||||
env.DefaultBackendEnabledReturns("on")
|
||||
env.DefaultBackendIPReturns("") // triggers NewDefaultOriginServerForDebugging error
|
||||
err := withFakeClient(env, fake).Initialize(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "initialize default SRS") {
|
||||
t.Fatalf("expected default-SRS error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Initialize_UpdateFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.PingReturns(statusCmd(nil))
|
||||
fake.SetReturns(statusCmd(fmt.Errorf("set failed"))) // every Set fails
|
||||
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.RedisDBReturns("0")
|
||||
env.DefaultBackendEnabledReturns("on")
|
||||
env.DefaultBackendIPReturns("1.2.3.4")
|
||||
env.DefaultBackendRTMPReturns(":1935")
|
||||
err := withFakeClient(env, fake).Initialize(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "update default SRS") {
|
||||
t.Fatalf("expected update error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Initialize_Success(t *testing.T) {
|
||||
var setCalls atomic.Int32
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.PingReturns(statusCmd(nil))
|
||||
fake.SetStub = func(ctx context.Context, key string, value interface{}, ttl time.Duration) *redis.StatusCmd {
|
||||
setCalls.Add(1)
|
||||
return statusCmd(nil)
|
||||
}
|
||||
// Every Get returns redis.Nil-style error so the server list is treated as empty.
|
||||
fake.GetReturns(stringErr(fmt.Errorf("redis: nil")))
|
||||
|
||||
env := &envfakes.FakeProxyEnvironment{}
|
||||
env.RedisDBReturns("0")
|
||||
env.DefaultBackendEnabledReturns("on")
|
||||
env.DefaultBackendIPReturns("1.2.3.4")
|
||||
env.DefaultBackendRTMPReturns(":1935")
|
||||
|
||||
lb := withFakeClient(env, fake)
|
||||
lb.keepaliveInterval = time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := lb.Initialize(ctx); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
|
||||
// Initial Update made 2 Set calls (server + server list). Wait long enough
|
||||
// for the keep-alive tick to issue more.
|
||||
deadline := time.Now().Add(200 * time.Millisecond)
|
||||
for time.Now().Before(deadline) && setCalls.Load() < 4 {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
cancel()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if setCalls.Load() < 4 {
|
||||
t.Fatalf("keep-alive did not tick: setCalls=%d", setCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Update.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestRedisLB_Update_SetServerFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(fmt.Errorf("boom")))
|
||||
lb := newRedisLB(fake)
|
||||
err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"})
|
||||
if err == nil || !strings.Contains(err.Error(), "set key=") {
|
||||
t.Fatalf("expected set-server error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Update_FreshList(t *testing.T) {
|
||||
// No existing server list => Get for server-list key returns error.
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
fake.GetReturns(stringErr(fmt.Errorf("nil")))
|
||||
|
||||
lb := newRedisLB(fake)
|
||||
server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}
|
||||
if err := lb.Update(context.Background(), server); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
// Two Set calls: server + servers-list.
|
||||
if got := fake.SetCallCount(); got != 2 {
|
||||
t.Fatalf("Set call count=%d, want 2", got)
|
||||
}
|
||||
// The second Set value should be a JSON array containing the server key.
|
||||
_, _, value, _ := fake.SetArgsForCall(1)
|
||||
var keys []string
|
||||
if err := json.Unmarshal(value.([]byte), &keys); err != nil {
|
||||
t.Fatalf("server-list value not JSON: %v", err)
|
||||
}
|
||||
want := lb.redisKeyServer(server.ID())
|
||||
if len(keys) != 1 || keys[0] != want {
|
||||
t.Fatalf("server-list keys=%v, want [%q]", keys, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Update_PrunesDeadAndAppends(t *testing.T) {
|
||||
server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}
|
||||
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
|
||||
// First Get: server-list, returns ["dead", "alive"].
|
||||
// Subsequent Gets: probe each key — "dead" missing, "alive" present.
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
b, _ := json.Marshal([]string{"dead", "alive"})
|
||||
return stringOK(b)
|
||||
}
|
||||
if key == "alive" {
|
||||
return stringOK([]byte("ok"))
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
|
||||
lb := newRedisLB(fake)
|
||||
if err := lb.Update(context.Background(), server); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
// Inspect the server-list Set call: should contain "alive" (kept) and the
|
||||
// new server key (appended); "dead" should be pruned.
|
||||
_, _, value, _ := fake.SetArgsForCall(1)
|
||||
var keys []string
|
||||
if err := json.Unmarshal(value.([]byte), &keys); err != nil {
|
||||
t.Fatalf("not JSON: %v", err)
|
||||
}
|
||||
wantNew := lb.redisKeyServer(server.ID())
|
||||
if len(keys) != 2 || keys[0] != "alive" || keys[1] != wantNew {
|
||||
t.Fatalf("server-list keys=%v, want [alive, %q]", keys, wantNew)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Update_AlreadyInList(t *testing.T) {
|
||||
server := &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"}
|
||||
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
lb := newRedisLB(fake)
|
||||
wantKey := lb.redisKeyServer(server.ID())
|
||||
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
b, _ := json.Marshal([]string{wantKey})
|
||||
return stringOK(b)
|
||||
}
|
||||
return stringOK([]byte("ok"))
|
||||
}
|
||||
|
||||
if err := lb.Update(context.Background(), server); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
_, _, value, _ := fake.SetArgsForCall(1)
|
||||
var keys []string
|
||||
_ = json.Unmarshal(value.([]byte), &keys)
|
||||
if len(keys) != 1 || keys[0] != wantKey {
|
||||
t.Fatalf("expected no duplication, got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Update_BadServerListJSON(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
return stringOK([]byte("not-json"))
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
lb := newRedisLB(fake)
|
||||
err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"})
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Fatalf("expected unmarshal error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Update_SetServerListFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
// First Set ok (server), second Set fails (server list).
|
||||
fake.SetReturnsOnCall(0, statusCmd(nil))
|
||||
fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("set list failed")))
|
||||
fake.GetReturns(stringErr(fmt.Errorf("nil")))
|
||||
|
||||
lb := newRedisLB(fake)
|
||||
err := lb.Update(context.Background(), &OriginServer{ServerID: "s", ServiceID: "v", PID: "1"})
|
||||
if err == nil || !strings.Contains(err.Error(), "set list failed") {
|
||||
t.Fatalf("expected server-list set error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pick.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestRedisLB_Pick_StickyHit(t *testing.T) {
|
||||
server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"}
|
||||
serverJSON, _ := json.Marshal(server)
|
||||
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
lb := newRedisLB(fake)
|
||||
streamKey := "srs-proxy-url:url1"
|
||||
serverKey := lb.redisKeyServer(server.ID())
|
||||
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
switch key {
|
||||
case streamKey:
|
||||
return stringOK([]byte(serverKey))
|
||||
case serverKey:
|
||||
return stringOK(serverJSON)
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
|
||||
got, err := lb.Pick(context.Background(), "url1")
|
||||
if err != nil {
|
||||
t.Fatalf("Pick: %v", err)
|
||||
}
|
||||
if got.ID() != server.ID() {
|
||||
t.Fatalf("Pick returned %v, want %v", got, server)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_StickyBadJSON(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
lb := newRedisLB(fake)
|
||||
streamKey := "srs-proxy-url:url1"
|
||||
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
switch key {
|
||||
case streamKey:
|
||||
return stringOK([]byte("srv-key"))
|
||||
case "srv-key":
|
||||
return stringOK([]byte("not-json"))
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Fatalf("expected unmarshal error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_NoServersAvailable(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
// Sticky miss + server list missing.
|
||||
fake.GetReturns(stringErr(fmt.Errorf("nil")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "no server available") {
|
||||
t.Fatalf("expected no-server error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_BadServerListJSON(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
return stringOK([]byte("not-json"))
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Fatalf("expected unmarshal error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_AllProbesFail(t *testing.T) {
|
||||
// Server list contains one key, but probing it returns nil bytes (the
|
||||
// `len(b) > 0` guard rejects it). After 3 attempts, Pick errors out.
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
b, _ := json.Marshal([]string{"srv-key"})
|
||||
return stringOK(b)
|
||||
}
|
||||
// "srv-key" probe returns empty bytes — falls through the available check.
|
||||
if key == "srv-key" {
|
||||
return stringOK(nil)
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "no server available in") {
|
||||
t.Fatalf("expected exhausted-probes error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_ScanSuccess(t *testing.T) {
|
||||
server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"}
|
||||
serverJSON, _ := json.Marshal(server)
|
||||
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
lb := newRedisLB(fake)
|
||||
serverKey := lb.redisKeyServer(server.ID())
|
||||
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
b, _ := json.Marshal([]string{serverKey})
|
||||
return stringOK(b)
|
||||
}
|
||||
if key == serverKey {
|
||||
return stringOK(serverJSON)
|
||||
}
|
||||
// Sticky lookup for the URL key misses.
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
|
||||
got, err := lb.Pick(context.Background(), "url1")
|
||||
if err != nil {
|
||||
t.Fatalf("Pick: %v", err)
|
||||
}
|
||||
if got.ID() != server.ID() {
|
||||
t.Fatalf("Pick returned %v", got)
|
||||
}
|
||||
// Pick should also store the picked-mapping.
|
||||
if fake.SetCallCount() != 1 {
|
||||
t.Fatalf("expected 1 Set call to store picked mapping, got %d", fake.SetCallCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_ScanBadJSON(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
b, _ := json.Marshal([]string{"srv-key"})
|
||||
return stringOK(b)
|
||||
}
|
||||
if key == "srv-key" {
|
||||
return stringOK([]byte("not-json"))
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Fatalf("expected unmarshal error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_Pick_StoreMappingFails(t *testing.T) {
|
||||
server := &OriginServer{ServerID: "a", ServiceID: "b", PID: "1"}
|
||||
serverJSON, _ := json.Marshal(server)
|
||||
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(fmt.Errorf("set failed")))
|
||||
lb := newRedisLB(fake)
|
||||
serverKey := lb.redisKeyServer(server.ID())
|
||||
|
||||
fake.GetStub = func(ctx context.Context, key string) *redis.StringCmd {
|
||||
if strings.HasSuffix(key, "all-servers") {
|
||||
b, _ := json.Marshal([]string{serverKey})
|
||||
return stringOK(b)
|
||||
}
|
||||
if key == serverKey {
|
||||
return stringOK(serverJSON)
|
||||
}
|
||||
return stringErr(fmt.Errorf("nil"))
|
||||
}
|
||||
|
||||
_, err := lb.Pick(context.Background(), "url1")
|
||||
if err == nil || !strings.Contains(err.Error(), "set failed") {
|
||||
t.Fatalf("expected set-mapping error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// LoadHLSBySPBHID and LoadWebRTCByUfrag — symmetric behavior.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestRedisLB_LoadHLSBySPBHID_GetFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetReturns(stringErr(fmt.Errorf("nil")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "get key=") {
|
||||
t.Fatalf("expected get error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadHLSBySPBHID_BadJSON(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetReturns(stringOK([]byte("not-json")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Fatalf("expected unmarshal error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadHLSBySPBHID_InterfaceLimitation(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`)))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadHLSBySPBHID(context.Background(), "abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "cannot deserialize") {
|
||||
t.Fatalf("expected interface limitation error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadWebRTCByUfrag_GetFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetReturns(stringErr(fmt.Errorf("nil")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadWebRTCByUfrag(context.Background(), "u")
|
||||
if err == nil || !strings.Contains(err.Error(), "get key=") {
|
||||
t.Fatalf("expected get error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadWebRTCByUfrag_BadJSON(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetReturns(stringOK([]byte("not-json")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadWebRTCByUfrag(context.Background(), "u")
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Fatalf("expected unmarshal error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadWebRTCByUfrag_InterfaceLimitation(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.GetReturns(stringOK([]byte(`{"foo":"bar"}`)))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadWebRTCByUfrag(context.Background(), "u")
|
||||
if err == nil || !strings.Contains(err.Error(), "cannot deserialize") {
|
||||
t.Fatalf("expected interface limitation error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// LoadOrStoreHLS and StoreWebRTC.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestRedisLB_LoadOrStoreHLS_Success(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
lb := newRedisLB(fake)
|
||||
|
||||
hls := &stubHLS{spbhid: "abc"}
|
||||
got, err := lb.LoadOrStoreHLS(context.Background(), "url1", hls)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadOrStoreHLS: %v", err)
|
||||
}
|
||||
if got != hls {
|
||||
t.Fatalf("got %v, want input back", got)
|
||||
}
|
||||
if fake.SetCallCount() != 2 {
|
||||
t.Fatalf("expected 2 Set calls (URL + SPBHID), got %d", fake.SetCallCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadOrStoreHLS_FirstSetFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(fmt.Errorf("boom")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"})
|
||||
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||
t.Fatalf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_LoadOrStoreHLS_SecondSetFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturnsOnCall(0, statusCmd(nil))
|
||||
fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom")))
|
||||
lb := newRedisLB(fake)
|
||||
_, err := lb.LoadOrStoreHLS(context.Background(), "url1", &stubHLS{spbhid: "abc"})
|
||||
if err == nil || !strings.Contains(err.Error(), "second boom") {
|
||||
t.Fatalf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_StoreWebRTC_Success(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(nil))
|
||||
lb := newRedisLB(fake)
|
||||
if err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"}); err != nil {
|
||||
t.Fatalf("StoreWebRTC: %v", err)
|
||||
}
|
||||
if fake.SetCallCount() != 2 {
|
||||
t.Fatalf("expected 2 Set calls (URL + Ufrag), got %d", fake.SetCallCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_StoreWebRTC_FirstSetFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturns(statusCmd(fmt.Errorf("boom")))
|
||||
lb := newRedisLB(fake)
|
||||
err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"})
|
||||
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||
t.Fatalf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisLB_StoreWebRTC_SecondSetFails(t *testing.T) {
|
||||
fake := &redisclientfakes.FakeRedisClient{}
|
||||
fake.SetReturnsOnCall(0, statusCmd(nil))
|
||||
fake.SetReturnsOnCall(1, statusCmd(fmt.Errorf("second boom")))
|
||||
lb := newRedisLB(fake)
|
||||
err := lb.StoreWebRTC(context.Background(), "url1", &stubRTC{ufrag: "u"})
|
||||
if err == nil || !strings.Contains(err.Error(), "second boom") {
|
||||
t.Fatalf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Key helpers.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestRedisLB_KeyHelpers(t *testing.T) {
|
||||
lb := &redisLoadBalancer{}
|
||||
for _, tt := range []struct {
|
||||
got, want string
|
||||
}{
|
||||
{lb.redisKeyUfrag("u"), "srs-proxy-ufrag:u"},
|
||||
{lb.redisKeyRTC("url"), "srs-proxy-rtc:url"},
|
||||
{lb.redisKeySPBHID("s"), "srs-proxy-spbhid:s"},
|
||||
{lb.redisKeyHLS("url"), "srs-proxy-hls:url"},
|
||||
{lb.redisKeyServer("id"), "srs-proxy-server:id"},
|
||||
{lb.redisKeyServers(), "srs-proxy-all-servers"},
|
||||
} {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("got %q, want %q", tt.got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
6
internal/redisclient/gen.go
Normal file
6
internal/redisclient/gen.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package redisclient
|
||||
|
||||
//go:generate go tool counterfeiter -o redisclientfakes/fake_redis_client.go . RedisClient
|
||||
33
internal/redisclient/redisclient.go
Normal file
33
internal/redisclient/redisclient.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package redisclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// RedisClient is the subset of *redis.Client methods used by callers in this
|
||||
// codebase. Declared as an interface so tests can substitute a fake without
|
||||
// standing up a real Redis server. *redis.Client satisfies this interface
|
||||
// directly.
|
||||
type RedisClient interface {
|
||||
Ping(ctx context.Context) *redis.StatusCmd
|
||||
Get(ctx context.Context, key string) *redis.StringCmd
|
||||
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd
|
||||
String() string
|
||||
}
|
||||
|
||||
// New connects to a Redis server at addr (host:port) with the given password
|
||||
// and database index. Returns a RedisClient satisfied by *redis.Client.
|
||||
func New(addr, password string, db int) RedisClient {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: password,
|
||||
DB: db,
|
||||
})
|
||||
}
|
||||
327
internal/redisclient/redisclientfakes/fake_redis_client.go
Normal file
327
internal/redisclient/redisclientfakes/fake_redis_client.go
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
// Code generated by counterfeiter. DO NOT EDIT.
|
||||
package redisclientfakes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"srsx/internal/redisclient"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
type FakeRedisClient struct {
|
||||
GetStub func(context.Context, string) *redis.StringCmd
|
||||
getMutex sync.RWMutex
|
||||
getArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}
|
||||
getReturns struct {
|
||||
result1 *redis.StringCmd
|
||||
}
|
||||
getReturnsOnCall map[int]struct {
|
||||
result1 *redis.StringCmd
|
||||
}
|
||||
PingStub func(context.Context) *redis.StatusCmd
|
||||
pingMutex sync.RWMutex
|
||||
pingArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
}
|
||||
pingReturns struct {
|
||||
result1 *redis.StatusCmd
|
||||
}
|
||||
pingReturnsOnCall map[int]struct {
|
||||
result1 *redis.StatusCmd
|
||||
}
|
||||
SetStub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd
|
||||
setMutex sync.RWMutex
|
||||
setArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 interface{}
|
||||
arg4 time.Duration
|
||||
}
|
||||
setReturns struct {
|
||||
result1 *redis.StatusCmd
|
||||
}
|
||||
setReturnsOnCall map[int]struct {
|
||||
result1 *redis.StatusCmd
|
||||
}
|
||||
StringStub func() string
|
||||
stringMutex sync.RWMutex
|
||||
stringArgsForCall []struct {
|
||||
}
|
||||
stringReturns struct {
|
||||
result1 string
|
||||
}
|
||||
stringReturnsOnCall map[int]struct {
|
||||
result1 string
|
||||
}
|
||||
invocations map[string][][]interface{}
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) Get(arg1 context.Context, arg2 string) *redis.StringCmd {
|
||||
fake.getMutex.Lock()
|
||||
ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)]
|
||||
fake.getArgsForCall = append(fake.getArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
}{arg1, arg2})
|
||||
stub := fake.GetStub
|
||||
fakeReturns := fake.getReturns
|
||||
fake.recordInvocation("Get", []interface{}{arg1, arg2})
|
||||
fake.getMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) GetCallCount() int {
|
||||
fake.getMutex.RLock()
|
||||
defer fake.getMutex.RUnlock()
|
||||
return len(fake.getArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) GetCalls(stub func(context.Context, string) *redis.StringCmd) {
|
||||
fake.getMutex.Lock()
|
||||
defer fake.getMutex.Unlock()
|
||||
fake.GetStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) GetArgsForCall(i int) (context.Context, string) {
|
||||
fake.getMutex.RLock()
|
||||
defer fake.getMutex.RUnlock()
|
||||
argsForCall := fake.getArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) GetReturns(result1 *redis.StringCmd) {
|
||||
fake.getMutex.Lock()
|
||||
defer fake.getMutex.Unlock()
|
||||
fake.GetStub = nil
|
||||
fake.getReturns = struct {
|
||||
result1 *redis.StringCmd
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) GetReturnsOnCall(i int, result1 *redis.StringCmd) {
|
||||
fake.getMutex.Lock()
|
||||
defer fake.getMutex.Unlock()
|
||||
fake.GetStub = nil
|
||||
if fake.getReturnsOnCall == nil {
|
||||
fake.getReturnsOnCall = make(map[int]struct {
|
||||
result1 *redis.StringCmd
|
||||
})
|
||||
}
|
||||
fake.getReturnsOnCall[i] = struct {
|
||||
result1 *redis.StringCmd
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) Ping(arg1 context.Context) *redis.StatusCmd {
|
||||
fake.pingMutex.Lock()
|
||||
ret, specificReturn := fake.pingReturnsOnCall[len(fake.pingArgsForCall)]
|
||||
fake.pingArgsForCall = append(fake.pingArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
}{arg1})
|
||||
stub := fake.PingStub
|
||||
fakeReturns := fake.pingReturns
|
||||
fake.recordInvocation("Ping", []interface{}{arg1})
|
||||
fake.pingMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) PingCallCount() int {
|
||||
fake.pingMutex.RLock()
|
||||
defer fake.pingMutex.RUnlock()
|
||||
return len(fake.pingArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) PingCalls(stub func(context.Context) *redis.StatusCmd) {
|
||||
fake.pingMutex.Lock()
|
||||
defer fake.pingMutex.Unlock()
|
||||
fake.PingStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) PingArgsForCall(i int) context.Context {
|
||||
fake.pingMutex.RLock()
|
||||
defer fake.pingMutex.RUnlock()
|
||||
argsForCall := fake.pingArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) PingReturns(result1 *redis.StatusCmd) {
|
||||
fake.pingMutex.Lock()
|
||||
defer fake.pingMutex.Unlock()
|
||||
fake.PingStub = nil
|
||||
fake.pingReturns = struct {
|
||||
result1 *redis.StatusCmd
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) PingReturnsOnCall(i int, result1 *redis.StatusCmd) {
|
||||
fake.pingMutex.Lock()
|
||||
defer fake.pingMutex.Unlock()
|
||||
fake.PingStub = nil
|
||||
if fake.pingReturnsOnCall == nil {
|
||||
fake.pingReturnsOnCall = make(map[int]struct {
|
||||
result1 *redis.StatusCmd
|
||||
})
|
||||
}
|
||||
fake.pingReturnsOnCall[i] = struct {
|
||||
result1 *redis.StatusCmd
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) Set(arg1 context.Context, arg2 string, arg3 interface{}, arg4 time.Duration) *redis.StatusCmd {
|
||||
fake.setMutex.Lock()
|
||||
ret, specificReturn := fake.setReturnsOnCall[len(fake.setArgsForCall)]
|
||||
fake.setArgsForCall = append(fake.setArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 string
|
||||
arg3 interface{}
|
||||
arg4 time.Duration
|
||||
}{arg1, arg2, arg3, arg4})
|
||||
stub := fake.SetStub
|
||||
fakeReturns := fake.setReturns
|
||||
fake.recordInvocation("Set", []interface{}{arg1, arg2, arg3, arg4})
|
||||
fake.setMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2, arg3, arg4)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) SetCallCount() int {
|
||||
fake.setMutex.RLock()
|
||||
defer fake.setMutex.RUnlock()
|
||||
return len(fake.setArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) SetCalls(stub func(context.Context, string, interface{}, time.Duration) *redis.StatusCmd) {
|
||||
fake.setMutex.Lock()
|
||||
defer fake.setMutex.Unlock()
|
||||
fake.SetStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) SetArgsForCall(i int) (context.Context, string, interface{}, time.Duration) {
|
||||
fake.setMutex.RLock()
|
||||
defer fake.setMutex.RUnlock()
|
||||
argsForCall := fake.setArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) SetReturns(result1 *redis.StatusCmd) {
|
||||
fake.setMutex.Lock()
|
||||
defer fake.setMutex.Unlock()
|
||||
fake.SetStub = nil
|
||||
fake.setReturns = struct {
|
||||
result1 *redis.StatusCmd
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) SetReturnsOnCall(i int, result1 *redis.StatusCmd) {
|
||||
fake.setMutex.Lock()
|
||||
defer fake.setMutex.Unlock()
|
||||
fake.SetStub = nil
|
||||
if fake.setReturnsOnCall == nil {
|
||||
fake.setReturnsOnCall = make(map[int]struct {
|
||||
result1 *redis.StatusCmd
|
||||
})
|
||||
}
|
||||
fake.setReturnsOnCall[i] = struct {
|
||||
result1 *redis.StatusCmd
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) String() string {
|
||||
fake.stringMutex.Lock()
|
||||
ret, specificReturn := fake.stringReturnsOnCall[len(fake.stringArgsForCall)]
|
||||
fake.stringArgsForCall = append(fake.stringArgsForCall, struct {
|
||||
}{})
|
||||
stub := fake.StringStub
|
||||
fakeReturns := fake.stringReturns
|
||||
fake.recordInvocation("String", []interface{}{})
|
||||
fake.stringMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub()
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) StringCallCount() int {
|
||||
fake.stringMutex.RLock()
|
||||
defer fake.stringMutex.RUnlock()
|
||||
return len(fake.stringArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) StringCalls(stub func() string) {
|
||||
fake.stringMutex.Lock()
|
||||
defer fake.stringMutex.Unlock()
|
||||
fake.StringStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) StringReturns(result1 string) {
|
||||
fake.stringMutex.Lock()
|
||||
defer fake.stringMutex.Unlock()
|
||||
fake.StringStub = nil
|
||||
fake.stringReturns = struct {
|
||||
result1 string
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) StringReturnsOnCall(i int, result1 string) {
|
||||
fake.stringMutex.Lock()
|
||||
defer fake.stringMutex.Unlock()
|
||||
fake.StringStub = nil
|
||||
if fake.stringReturnsOnCall == nil {
|
||||
fake.stringReturnsOnCall = make(map[int]struct {
|
||||
result1 string
|
||||
})
|
||||
}
|
||||
fake.stringReturnsOnCall[i] = struct {
|
||||
result1 string
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
copiedInvocations := map[string][][]interface{}{}
|
||||
for key, value := range fake.invocations {
|
||||
copiedInvocations[key] = value
|
||||
}
|
||||
return copiedInvocations
|
||||
}
|
||||
|
||||
func (fake *FakeRedisClient) recordInvocation(key string, args []interface{}) {
|
||||
fake.invocationsMutex.Lock()
|
||||
defer fake.invocationsMutex.Unlock()
|
||||
if fake.invocations == nil {
|
||||
fake.invocations = map[string][][]interface{}{}
|
||||
}
|
||||
if fake.invocations[key] == nil {
|
||||
fake.invocations[key] = [][]interface{}{}
|
||||
}
|
||||
fake.invocations[key] = append(fake.invocations[key], args)
|
||||
}
|
||||
|
||||
var _ redisclient.RedisClient = new(FakeRedisClient)
|
||||
|
|
@ -90,7 +90,9 @@ type amf0Buffer interface {
|
|||
Write(p []byte) (n int, err error)
|
||||
}
|
||||
|
||||
var createBuffer = func() amf0Buffer {
|
||||
// defaultBufFactory is the production amf0Buffer factory. Tests override the
|
||||
// per-instance bufFactory field on amf0ObjectBase instead of swapping a global.
|
||||
func defaultBufFactory() amf0Buffer {
|
||||
return &bytes.Buffer{}
|
||||
}
|
||||
|
||||
|
|
@ -399,6 +401,10 @@ type amf0Property struct {
|
|||
type amf0ObjectBase struct {
|
||||
properties []*amf0Property
|
||||
lock sync.Mutex
|
||||
// bufFactory creates the amf0Buffer used by MarshalBinary. Held as a
|
||||
// per-instance field (not a package global) so concurrent tests can each
|
||||
// install their own buggy buffers without racing on shared state.
|
||||
bufFactory func() amf0Buffer
|
||||
}
|
||||
|
||||
func (v *amf0ObjectBase) Size() int {
|
||||
|
|
@ -562,6 +568,7 @@ func NewAmf0Object() Amf0Object {
|
|||
func newAmf0Object() *amf0Object {
|
||||
v := &amf0Object{}
|
||||
v.properties = []*amf0Property{}
|
||||
v.bufFactory = defaultBufFactory
|
||||
return v
|
||||
}
|
||||
|
||||
|
|
@ -600,7 +607,7 @@ func (v *amf0Object) UnmarshalBinary(data []byte) (err error) {
|
|||
}
|
||||
|
||||
func (v *amf0Object) MarshalBinary() (data []byte, err error) {
|
||||
b := createBuffer()
|
||||
b := v.bufFactory()
|
||||
|
||||
if err = b.WriteByte(byte(amf0MarkerObject)); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal")
|
||||
|
|
@ -640,6 +647,7 @@ func NewAmf0EcmaArray() Amf0EcmaArray {
|
|||
func newAmf0EcmaArray() *amf0EcmaArray {
|
||||
v := &amf0EcmaArray{}
|
||||
v.properties = []*amf0Property{}
|
||||
v.bufFactory = defaultBufFactory
|
||||
return v
|
||||
}
|
||||
|
||||
|
|
@ -678,7 +686,7 @@ func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) {
|
|||
}
|
||||
|
||||
func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
|
||||
b := createBuffer()
|
||||
b := v.bufFactory()
|
||||
|
||||
if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal")
|
||||
|
|
@ -717,6 +725,7 @@ type amf0StrictArray struct {
|
|||
func NewAmf0StrictArray() Amf0StrictArray {
|
||||
v := &amf0StrictArray{}
|
||||
v.properties = []*amf0Property{}
|
||||
v.bufFactory = defaultBufFactory
|
||||
return v
|
||||
}
|
||||
|
||||
|
|
@ -759,7 +768,7 @@ func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) {
|
|||
}
|
||||
|
||||
func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) {
|
||||
b := createBuffer()
|
||||
b := v.bufFactory()
|
||||
|
||||
if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal")
|
||||
|
|
|
|||
|
|
@ -436,10 +436,21 @@ func (v *errorAmf0Any) amf0Marker() amf0Marker {
|
|||
return amf0MarkerNumber
|
||||
}
|
||||
|
||||
func TestAmf0MarshalErrors(t *testing.T) {
|
||||
originalCreateBuffer := createBuffer
|
||||
defer func() { createBuffer = originalCreateBuffer }()
|
||||
// setBufFactory replaces the bufFactory on whichever amf0 object-like type
|
||||
// underlies v. Concurrent tests can use this safely because each value carries
|
||||
// its own factory.
|
||||
func setBufFactory(v Amf0Any, fn func() amf0Buffer) {
|
||||
switch v := v.(type) {
|
||||
case *amf0Object:
|
||||
v.bufFactory = fn
|
||||
case *amf0EcmaArray:
|
||||
v.bufFactory = fn
|
||||
case *amf0StrictArray:
|
||||
v.bufFactory = fn
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmf0MarshalErrors(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
make func() Amf0Any
|
||||
|
|
@ -449,15 +460,16 @@ func TestAmf0MarshalErrors(t *testing.T) {
|
|||
{"strict-array", func() Amf0Any { return NewAmf0StrictArray() }},
|
||||
} {
|
||||
t.Run(tt.name+" write-byte", func(t *testing.T) {
|
||||
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} }
|
||||
if _, err := tt.make().MarshalBinary(); err == nil {
|
||||
value := tt.make()
|
||||
setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} })
|
||||
if _, err := value.MarshalBinary(); err == nil {
|
||||
t.Fatal("MarshalBinary() should fail")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(tt.name+" write-prop", func(t *testing.T) {
|
||||
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} }
|
||||
value := tt.make()
|
||||
setBufFactory(value, func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} })
|
||||
switch v := value.(type) {
|
||||
case Amf0Object:
|
||||
v.Set("name", NewAmf0String("stream"))
|
||||
|
|
@ -473,7 +485,6 @@ func TestAmf0MarshalErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
createBuffer = originalCreateBuffer
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
make func() Amf0Any
|
||||
|
|
|
|||
|
|
@ -15,15 +15,26 @@ import (
|
|||
"srsx/internal/logger"
|
||||
)
|
||||
|
||||
// Indirections so tests can substitute signal delivery and process exit.
|
||||
var (
|
||||
signalNotify = signal.Notify
|
||||
osExit = os.Exit
|
||||
)
|
||||
// Handler installs OS signal handlers and the force-quit timer. The notify
|
||||
// and exit indirections are struct fields (not package globals) so concurrent
|
||||
// tests can each construct a handler with their own fakes without racing on
|
||||
// shared state.
|
||||
type Handler struct {
|
||||
notify func(c chan<- os.Signal, sig ...os.Signal)
|
||||
exit func(code int)
|
||||
}
|
||||
|
||||
func InstallSignals(ctx context.Context, cancel context.CancelFunc) {
|
||||
// NewHandler returns a Handler wired to the real OS implementations.
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{
|
||||
notify: signal.Notify,
|
||||
exit: os.Exit,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) InstallSignals(ctx context.Context, cancel context.CancelFunc) {
|
||||
sc := make(chan os.Signal, 1)
|
||||
signalNotify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
|
||||
h.notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
|
||||
|
||||
go func() {
|
||||
for s := range sc {
|
||||
|
|
@ -33,7 +44,7 @@ func InstallSignals(ctx context.Context, cancel context.CancelFunc) {
|
|||
}()
|
||||
}
|
||||
|
||||
func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error {
|
||||
func (h *Handler) InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) error {
|
||||
var forceTimeout time.Duration
|
||||
timeoutStr := environment.ForceQuitTimeout()
|
||||
if t, err := time.ParseDuration(timeoutStr); err != nil {
|
||||
|
|
@ -46,7 +57,7 @@ func InstallForceQuit(ctx context.Context, environment env.ProxyEnvironment) err
|
|||
<-ctx.Done()
|
||||
time.Sleep(forceTimeout)
|
||||
logger.Warn(ctx, "Force to exit by timeout")
|
||||
osExit(1)
|
||||
h.exit(1)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,59 +16,60 @@ import (
|
|||
"srsx/internal/env/envfakes"
|
||||
)
|
||||
|
||||
// swapNotify replaces signalNotify with a capturing fake and returns a getter
|
||||
// for the channel registered by the code under test plus a restore func.
|
||||
func swapNotify(t *testing.T) (func() chan<- os.Signal, func()) {
|
||||
t.Helper()
|
||||
orig := signalNotify
|
||||
// captureNotify returns a Handler whose notify field records the channel
|
||||
// passed by the code under test, plus a getter that retrieves it.
|
||||
func captureNotify() (*Handler, func() chan<- os.Signal) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
ch chan<- os.Signal
|
||||
)
|
||||
signalNotify = func(c chan<- os.Signal, _ ...os.Signal) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
ch = c
|
||||
}
|
||||
return func() chan<- os.Signal {
|
||||
h := &Handler{
|
||||
notify: func(c chan<- os.Signal, _ ...os.Signal) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return ch
|
||||
}, func() {
|
||||
signalNotify = orig
|
||||
}
|
||||
ch = c
|
||||
},
|
||||
exit: os.Exit,
|
||||
}
|
||||
return h, func() chan<- os.Signal {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
}
|
||||
|
||||
func swapExit(t *testing.T) (*int32, chan int, func()) {
|
||||
t.Helper()
|
||||
orig := osExit
|
||||
// captureExit returns a Handler whose exit field records the code and never
|
||||
// returns, plus a flag and channel that observe the call.
|
||||
func captureExit() (*Handler, *int32, chan int) {
|
||||
var called int32
|
||||
done := make(chan int, 1)
|
||||
osExit = func(code int) {
|
||||
atomic.StoreInt32(&called, 1)
|
||||
select {
|
||||
case done <- code:
|
||||
default:
|
||||
}
|
||||
// Block to mimic os.Exit never returning; the goroutine holding us
|
||||
// here is abandoned when the test ends.
|
||||
select {}
|
||||
h := &Handler{
|
||||
notify: func(chan<- os.Signal, ...os.Signal) {},
|
||||
exit: func(code int) {
|
||||
atomic.StoreInt32(&called, 1)
|
||||
select {
|
||||
case done <- code:
|
||||
default:
|
||||
}
|
||||
// Block to mimic os.Exit never returning; the goroutine holding us
|
||||
// here is abandoned when the test ends.
|
||||
select {}
|
||||
},
|
||||
}
|
||||
return &called, done, func() { osExit = orig }
|
||||
return h, &called, done
|
||||
}
|
||||
|
||||
func TestInstallSignals_CancelsOnSignal(t *testing.T) {
|
||||
getCh, restore := swapNotify(t)
|
||||
defer restore()
|
||||
h, getCh := captureNotify()
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
InstallSignals(ctx, cancel)
|
||||
h.InstallSignals(ctx, cancel)
|
||||
|
||||
ch := getCh()
|
||||
if ch == nil {
|
||||
t.Fatal("signalNotify was not called")
|
||||
t.Fatal("notify was not called")
|
||||
}
|
||||
ch <- syscall.SIGINT
|
||||
|
||||
|
|
@ -80,13 +81,12 @@ func TestInstallSignals_CancelsOnSignal(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstallSignals_HandlesRepeatedSignals(t *testing.T) {
|
||||
getCh, restore := swapNotify(t)
|
||||
defer restore()
|
||||
h, getCh := captureNotify()
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
InstallSignals(ctx, cancel)
|
||||
h.InstallSignals(ctx, cancel)
|
||||
ch := getCh()
|
||||
|
||||
// Multiple signals must not panic; cancel() is idempotent.
|
||||
|
|
@ -105,7 +105,7 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) {
|
|||
fakeEnv := &envfakes.FakeProxyEnvironment{}
|
||||
fakeEnv.ForceQuitTimeoutReturns("not-a-duration")
|
||||
|
||||
err := InstallForceQuit(t.Context(), fakeEnv)
|
||||
err := NewHandler().InstallForceQuit(t.Context(), fakeEnv)
|
||||
if err == nil {
|
||||
t.Fatal("want error for bad duration")
|
||||
}
|
||||
|
|
@ -118,20 +118,19 @@ func TestInstallForceQuit_InvalidDurationReturnsError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) {
|
||||
called, done, restore := swapExit(t)
|
||||
defer restore()
|
||||
h, called, done := captureExit()
|
||||
|
||||
fakeEnv := &envfakes.FakeProxyEnvironment{}
|
||||
fakeEnv.ForceQuitTimeoutReturns("1ms")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
if err := InstallForceQuit(ctx, fakeEnv); err != nil {
|
||||
if err := h.InstallForceQuit(ctx, fakeEnv); err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
// Before cancel, the goroutine is blocked and exit must not fire.
|
||||
if atomic.LoadInt32(called) != 0 {
|
||||
t.Fatal("osExit called before ctx cancel")
|
||||
t.Fatal("exit called before ctx cancel")
|
||||
}
|
||||
cancel()
|
||||
|
||||
|
|
@ -141,30 +140,39 @@ func TestInstallForceQuit_ExitsAfterTimeout(t *testing.T) {
|
|||
t.Fatalf("exit code = %d, want 1", code)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("osExit not called after cancel + timeout")
|
||||
t.Fatal("exit not called after cancel + timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallForceQuit_WaitsForCancelBeforeSleeping(t *testing.T) {
|
||||
called, done, restore := swapExit(t)
|
||||
defer restore()
|
||||
h, called, done := captureExit()
|
||||
|
||||
fakeEnv := &envfakes.FakeProxyEnvironment{}
|
||||
fakeEnv.ForceQuitTimeoutReturns("10ms")
|
||||
|
||||
// Intentionally use a never-canceled context and leak the goroutine:
|
||||
// if we canceled at test end, the goroutine would wake and race with
|
||||
// restore() writing osExit.
|
||||
if err := InstallForceQuit(context.Background(), fakeEnv); err != nil {
|
||||
// Intentionally use a never-canceled context and leak the goroutine: the
|
||||
// handler's exit closure is owned by this test instance, so leaving the
|
||||
// goroutine alive doesn't race other tests.
|
||||
if err := h.InstallForceQuit(context.Background(), fakeEnv); err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("osExit fired without ctx cancel")
|
||||
t.Fatal("exit fired without ctx cancel")
|
||||
case <-time.After(30 * time.Millisecond):
|
||||
}
|
||||
if atomic.LoadInt32(called) != 0 {
|
||||
t.Fatal("osExit called unexpectedly")
|
||||
t.Fatal("exit called unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHandler_UsesRealOSDefaults(t *testing.T) {
|
||||
h := NewHandler()
|
||||
if h.notify == nil {
|
||||
t.Error("notify default not set")
|
||||
}
|
||||
if h.exit == nil {
|
||||
t.Error("exit default not set")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,11 +5,16 @@
|
|||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,16 @@
|
|||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -10,11 +10,16 @@
|
|||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,16 @@
|
|||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,16 @@
|
|||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -10,11 +10,16 @@
|
|||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,16 @@ for arg in "$@"; do
|
|||
done
|
||||
|
||||
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
|
||||
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
|
||||
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
|
||||
# Walk up from SCRIPT_DIR looking for go.mod. This avoids brittle "../../../.."
|
||||
# counting when the skills directory is reached via a symlink (which changes
|
||||
# the symbolic vs. physical depth).
|
||||
WORKSPACE="$SCRIPT_DIR"
|
||||
while [[ "$WORKSPACE" != "/" && ! -f "$WORKSPACE/go.mod" ]]; do
|
||||
WORKSPACE="$(dirname "$WORKSPACE")"
|
||||
done
|
||||
|
||||
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
|
||||
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
|
||||
echo "Error: go.mod not found walking up from: $SCRIPT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user