From a387fb63696af9dbfad97d347a7f956e72e5ee58 Mon Sep 17 00:00:00 2001 From: winlin Date: Sun, 19 Apr 2026 19:48:54 -0400 Subject: [PATCH] Proxy: Convert internal/sync.Map to an interface and add unit tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/lb/mem.go | 8 +- internal/protocol/rtc.go | 6 +- internal/protocol/srt.go | 1 + internal/sync/map.go | 27 ++++-- internal/sync/map_test.go | 182 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 215 insertions(+), 9 deletions(-) create mode 100644 internal/sync/map_test.go diff --git a/internal/lb/mem.go b/internal/lb/mem.go index 3901ed93b..8fe3602a5 100644 --- a/internal/lb/mem.go +++ b/internal/lb/mem.go @@ -36,7 +36,13 @@ type MemoryLoadBalancer struct { // NewMemoryLoadBalancer creates a new memory-based load balancer. func NewMemoryLoadBalancer(environment env.Environment) SRSLoadBalancer { return &MemoryLoadBalancer{ - environment: environment, + environment: environment, + servers: sync.NewMap[string, *SRSServer](), + picked: sync.NewMap[string, *SRSServer](), + hlsStreamURL: sync.NewMap[string, HLSPlayStream](), + hlsSPBHID: sync.NewMap[string, HLSPlayStream](), + rtcStreamURL: sync.NewMap[string, RTCConnection](), + rtcUfrag: sync.NewMap[string, RTCConnection](), } } diff --git a/internal/protocol/rtc.go b/internal/protocol/rtc.go index b1f43bce2..add8bdf00 100644 --- a/internal/protocol/rtc.go +++ b/internal/protocol/rtc.go @@ -45,7 +45,11 @@ type srsWebRTCServer struct { } func NewSRSWebRTCServer(environment env.Environment, opts ...func(*srsWebRTCServer)) *srsWebRTCServer { - v := &srsWebRTCServer{environment: environment} + v := &srsWebRTCServer{ + environment: environment, + usernames: sync.NewMap[string, *RTCConnection](), + addresses: sync.NewMap[string, *RTCConnection](), + } for _, opt := range opts { opt(v) } diff --git a/internal/protocol/srt.go b/internal/protocol/srt.go index f51724c2a..ced994ef6 100644 --- a/internal/protocol/srt.go +++ b/internal/protocol/srt.go @@ -43,6 +43,7 @@ func NewSRSSRTServer(environment env.Environment, opts ...func(*srsSRTServer)) * v := &srsSRTServer{ environment: environment, start: time.Now(), + sockets: sync.NewMap[uint32, *SRTConnection](), } for _, opt := range opts { diff --git a/internal/sync/map.go b/internal/sync/map.go index 05f628a44..16387ec03 100644 --- a/internal/sync/map.go +++ b/internal/sync/map.go @@ -5,15 +5,28 @@ package sync import "sync" -type Map[K comparable, V any] struct { +type Map[K comparable, V any] interface { + Delete(key K) + Load(key K) (value V, ok bool) + LoadAndDelete(key K) (value V, loaded bool) + LoadOrStore(key K, value V) (actual V, loaded bool) + Range(f func(key K, value V) bool) + Store(key K, value V) +} + +func NewMap[K comparable, V any]() Map[K, V] { + return &mapImpl[K, V]{} +} + +type mapImpl[K comparable, V any] struct { m sync.Map } -func (m *Map[K, V]) Delete(key K) { +func (m *mapImpl[K, V]) Delete(key K) { m.m.Delete(key) } -func (m *Map[K, V]) Load(key K) (value V, ok bool) { +func (m *mapImpl[K, V]) Load(key K) (value V, ok bool) { v, ok := m.m.Load(key) if !ok { return value, ok @@ -21,7 +34,7 @@ func (m *Map[K, V]) Load(key K) (value V, ok bool) { return v.(V), ok } -func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { +func (m *mapImpl[K, V]) LoadAndDelete(key K) (value V, loaded bool) { v, loaded := m.m.LoadAndDelete(key) if !loaded { return value, loaded @@ -29,17 +42,17 @@ func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { return v.(V), loaded } -func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { +func (m *mapImpl[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { a, loaded := m.m.LoadOrStore(key, value) return a.(V), loaded } -func (m *Map[K, V]) Range(f func(key K, value V) bool) { +func (m *mapImpl[K, V]) Range(f func(key K, value V) bool) { m.m.Range(func(key, value any) bool { return f(key.(K), value.(V)) }) } -func (m *Map[K, V]) Store(key K, value V) { +func (m *mapImpl[K, V]) Store(key K, value V) { m.m.Store(key, value) } diff --git a/internal/sync/map_test.go b/internal/sync/map_test.go new file mode 100644 index 000000000..e23d0d698 --- /dev/null +++ b/internal/sync/map_test.go @@ -0,0 +1,182 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package sync + +import ( + "sort" + "testing" +) + +func TestNewMap_ReturnsEmpty(t *testing.T) { + m := NewMap[string, int]() + if m == nil { + t.Fatal("NewMap returned nil") + } + if v, ok := m.Load("missing"); ok || v != 0 { + t.Fatalf("Load(missing) = (%v, %v), want (0, false)", v, ok) + } +} + +func TestStore_AndLoad(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + + v, ok := m.Load("a") + if !ok || v != 1 { + t.Fatalf("Load(a) = (%v, %v), want (1, true)", v, ok) + } +} + +func TestLoad_MissingReturnsZero(t *testing.T) { + m := NewMap[string, int]() + v, ok := m.Load("nope") + if ok { + t.Fatal("Load on missing key returned ok=true") + } + if v != 0 { + t.Fatalf("Load on missing key returned %v, want zero", v) + } +} + +func TestDelete_RemovesKey(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Delete("a") + + if _, ok := m.Load("a"); ok { + t.Fatal("Load(a) returned ok=true after Delete") + } +} + +func TestDelete_MissingIsNoop(t *testing.T) { + m := NewMap[string, int]() + m.Delete("never-stored") +} + +func TestLoadAndDelete_Present(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 42) + + v, loaded := m.LoadAndDelete("a") + if !loaded { + t.Fatal("LoadAndDelete returned loaded=false for present key") + } + if v != 42 { + t.Fatalf("LoadAndDelete returned %v, want 42", v) + } + if _, ok := m.Load("a"); ok { + t.Fatal("key still present after LoadAndDelete") + } +} + +func TestLoadAndDelete_Absent(t *testing.T) { + m := NewMap[string, int]() + v, loaded := m.LoadAndDelete("nope") + if loaded { + t.Fatal("LoadAndDelete returned loaded=true for absent key") + } + if v != 0 { + t.Fatalf("LoadAndDelete on absent key returned %v, want zero", v) + } +} + +func TestLoadOrStore_StoresWhenAbsent(t *testing.T) { + m := NewMap[string, int]() + actual, loaded := m.LoadOrStore("a", 7) + if loaded { + t.Fatal("LoadOrStore returned loaded=true for absent key") + } + if actual != 7 { + t.Fatalf("LoadOrStore returned %v, want 7", actual) + } + + v, ok := m.Load("a") + if !ok || v != 7 { + t.Fatalf("Load after LoadOrStore = (%v, %v), want (7, true)", v, ok) + } +} + +func TestLoadOrStore_LoadsWhenPresent(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + + actual, loaded := m.LoadOrStore("a", 999) + if !loaded { + t.Fatal("LoadOrStore returned loaded=false for present key") + } + if actual != 1 { + t.Fatalf("LoadOrStore returned %v, want existing value 1", actual) + } + + v, _ := m.Load("a") + if v != 1 { + t.Fatalf("LoadOrStore overwrote existing value: got %v, want 1", v) + } +} + +func TestRange_VisitsAllEntries(t *testing.T) { + m := NewMap[string, int]() + want := map[string]int{"a": 1, "b": 2, "c": 3} + for k, v := range want { + m.Store(k, v) + } + + got := map[string]int{} + m.Range(func(key string, value int) bool { + got[key] = value + return true + }) + + if len(got) != len(want) { + t.Fatalf("Range visited %d entries, want %d", len(got), len(want)) + } + for k, v := range want { + if got[k] != v { + t.Fatalf("Range got[%q] = %v, want %v", k, got[k], v) + } + } +} + +func TestRange_EarlyStop(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + visited := 0 + m.Range(func(key string, value int) bool { + visited++ + return false + }) + + if visited != 1 { + t.Fatalf("Range visited %d entries after returning false, want 1", visited) + } +} + +func TestMap_PointerValueType(t *testing.T) { + type entry struct{ n int } + m := NewMap[string, *entry]() + + e := &entry{n: 5} + m.Store("k", e) + + got, ok := m.Load("k") + if !ok { + t.Fatal("Load returned ok=false") + } + if got != e { + t.Fatalf("Load returned different pointer: %p vs %p", got, e) + } + + keys := []string{} + m.Range(func(key string, value *entry) bool { + keys = append(keys, key) + return true + }) + sort.Strings(keys) + if len(keys) != 1 || keys[0] != "k" { + t.Fatalf("Range keys = %v, want [k]", keys) + } +}