- Refactor the Go proxy for dependency injection: every proxy server, the bootstrap, the signal handler, the load balancers, and AMF0 now accept functional-option seams (factories/closures) so tests can inject fakes without binding real sockets, talking to real Redis, or racing on package globals. - Drop the package-global `lb.SrsLoadBalancer`. The bootstrap creates the LB locally and threads it through every proxy server constructor. Two old global indirections in `internal/signal` and `internal/rtmp/amf0` are likewise replaced by per-instance fields. - Rename `internal/server` → `internal/proxy` and rename the `lb` public surface for clarity: `SRSLoadBalancer` is split into `OriginService` / `HLSService` / `RTCService` and recomposed as `OriginLoadBalancer`; `SRSServer` → `OriginServer`; all proxy server types gain a `Proxy` qualifier (e.g. `RTMPServer` → `RTMPProxyServer`). - Extract the Redis client behind a new `internal/redisclient` package with a minimal `RedisClient` interface and a counterfeiter fake. - Add counterfeiter fakes (`proxyfakes`, `lbfakes`, `redisclientfakes`) and ~7.5k lines of unit tests covering bootstrap, memory + Redis LBs, all five proxy servers, the signal handler, and AMF0. - Add two new E2E flows — `proxy-e2e-srt-test.sh` (SRT publish through proxy, verify SRT/RTMP/HTTP-FLV/HLS playback) and `proxy-e2e-whip-test.sh` (WHIP publish, verify RTMP/HTTP-FLV/HLS via origin `rtc_to_rtmp`) — plus `setup-ffmpeg-with-whip.sh`, a macOS builder for an ffmpeg with openssl-DTLS WHIP and SRT support that the two scripts auto-invoke when needed. - Workspace reorg: move `memory/` and `skills/` to the repo root so all agent tools (Claude / Codex / Kiro / OpenClaw) share one source of truth via symlinks. Sync `docs/proxy/proxy-load-balancer.md` and `memory/srs-codebase-map.md` with the new names. No protocol, log, HTTP API, or wire-format changes. Refactor only — all externally observable proxy behavior is unchanged. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: chatgpt-codex-connector[bot] <199175422+chatgpt-codex-connector[bot]@users.noreply.github.com>
1112 lines
37 KiB
Go
1112 lines
37 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"srsx/internal/env/envfakes"
|
|
"srsx/internal/lb"
|
|
"srsx/internal/lb/lbfakes"
|
|
)
|
|
|
|
func TestRtcICEPair_Ufrag(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
pair rtcICEPair
|
|
want string
|
|
}{
|
|
{
|
|
name: "typical",
|
|
pair: rtcICEPair{
|
|
RemoteICEUfrag: "remote-ufrag",
|
|
RemoteICEPwd: "remote-pwd",
|
|
LocalICEUfrag: "local-ufrag",
|
|
LocalICEPwd: "local-pwd",
|
|
},
|
|
want: "local-ufrag:remote-ufrag",
|
|
},
|
|
{
|
|
name: "both empty",
|
|
pair: rtcICEPair{},
|
|
want: ":",
|
|
},
|
|
{
|
|
name: "only local",
|
|
pair: rtcICEPair{LocalICEUfrag: "L"},
|
|
want: "L:",
|
|
},
|
|
{
|
|
name: "only remote",
|
|
pair: rtcICEPair{RemoteICEUfrag: "R"},
|
|
want: ":R",
|
|
},
|
|
{
|
|
name: "pwd fields do not affect ufrag",
|
|
pair: rtcICEPair{
|
|
RemoteICEUfrag: "r",
|
|
RemoteICEPwd: "should-be-ignored",
|
|
LocalICEUfrag: "l",
|
|
LocalICEPwd: "should-be-ignored",
|
|
},
|
|
want: "l:r",
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
if got := c.pair.Ufrag(); got != c.want {
|
|
t.Fatalf("Ufrag()=%q, want %q", got, c.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// fakeBackendUDP is an in-memory io.ReadWriteCloser standing in for the dialed
|
|
// UDP socket. Writes are captured on a channel; reads block until reads is fed
|
|
// or closed (in which case Read returns io.EOF).
|
|
type fakeBackendUDP struct {
|
|
writes chan []byte
|
|
reads chan []byte
|
|
closed atomic.Bool
|
|
writeErr error
|
|
readErr error
|
|
readOnce atomic.Bool // when set, second Read returns io.EOF to terminate the goroutine
|
|
bytesRead atomic.Int64
|
|
}
|
|
|
|
func newFakeBackendUDP() *fakeBackendUDP {
|
|
return &fakeBackendUDP{
|
|
writes: make(chan []byte, 16),
|
|
reads: make(chan []byte, 16),
|
|
}
|
|
}
|
|
|
|
func (f *fakeBackendUDP) Read(p []byte) (int, error) {
|
|
if f.readErr != nil {
|
|
return 0, f.readErr
|
|
}
|
|
data, ok := <-f.reads
|
|
if !ok {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, data)
|
|
f.bytesRead.Add(int64(n))
|
|
return n, nil
|
|
}
|
|
|
|
func (f *fakeBackendUDP) Write(p []byte) (int, error) {
|
|
if f.writeErr != nil {
|
|
return 0, f.writeErr
|
|
}
|
|
cp := make([]byte, len(p))
|
|
copy(cp, p)
|
|
f.writes <- cp
|
|
return len(p), nil
|
|
}
|
|
|
|
func (f *fakeBackendUDP) Close() error {
|
|
if f.closed.CompareAndSwap(false, true) {
|
|
close(f.reads)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// fakePacketConn is an in-memory net.PacketConn standing in for the proxy's
|
|
// UDP listener. Only WriteTo is exercised; the other methods are no-ops.
|
|
type fakePacketConn struct {
|
|
writes chan packetWrite
|
|
writeErr error
|
|
}
|
|
|
|
type packetWrite struct {
|
|
data []byte
|
|
addr net.Addr
|
|
}
|
|
|
|
func newFakePacketConn() *fakePacketConn {
|
|
return &fakePacketConn{writes: make(chan packetWrite, 16)}
|
|
}
|
|
|
|
func (f *fakePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
|
if f.writeErr != nil {
|
|
return 0, f.writeErr
|
|
}
|
|
cp := make([]byte, len(p))
|
|
copy(cp, p)
|
|
f.writes <- packetWrite{data: cp, addr: addr}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (f *fakePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { return 0, nil, io.EOF }
|
|
func (f *fakePacketConn) Close() error { return nil }
|
|
func (f *fakePacketConn) LocalAddr() net.Addr { return nil }
|
|
func (f *fakePacketConn) SetDeadline(time.Time) error { return nil }
|
|
func (f *fakePacketConn) SetReadDeadline(time.Time) error { return nil }
|
|
func (f *fakePacketConn) SetWriteDeadline(time.Time) error { return nil }
|
|
|
|
func TestNewRTCConnection(t *testing.T) {
|
|
t.Run("defaults dialBackendUDP", func(t *testing.T) {
|
|
c := newRTCConnection()
|
|
if c.dialBackendUDP == nil {
|
|
t.Fatal("expected dialBackendUDP to be defaulted")
|
|
}
|
|
})
|
|
|
|
t.Run("applies functional options", func(t *testing.T) {
|
|
c := newRTCConnection(func(c *rtcConnection) {
|
|
c.StreamURL = "vhost/app/stream"
|
|
c.Ufrag = "L:R"
|
|
})
|
|
if c.StreamURL != "vhost/app/stream" {
|
|
t.Fatalf("StreamURL=%q", c.StreamURL)
|
|
}
|
|
if c.Ufrag != "L:R" {
|
|
t.Fatalf("Ufrag=%q", c.Ufrag)
|
|
}
|
|
})
|
|
|
|
t.Run("options override default dialBackendUDP", func(t *testing.T) {
|
|
called := false
|
|
dial := func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
|
called = true
|
|
return nil, nil
|
|
}
|
|
c := newRTCConnection(func(c *rtcConnection) { c.dialBackendUDP = dial })
|
|
_, _ = c.dialBackendUDP(context.Background(), "", 0)
|
|
if !called {
|
|
t.Fatal("expected overridden dialBackendUDP to be invoked")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestRtcConnection_Initialize(t *testing.T) {
|
|
t.Run("sets ctx when nil", func(t *testing.T) {
|
|
c := newRTCConnection()
|
|
listener := newFakePacketConn()
|
|
ret := c.Initialize(context.Background(), listener)
|
|
if c.ctx == nil {
|
|
t.Fatal("expected ctx to be set")
|
|
}
|
|
if c.listenerUDP != listener {
|
|
t.Fatal("expected listenerUDP to be set")
|
|
}
|
|
if ret != c {
|
|
t.Fatal("expected Initialize to return receiver")
|
|
}
|
|
})
|
|
|
|
t.Run("does not overwrite existing ctx", func(t *testing.T) {
|
|
type ctxKey struct{}
|
|
original := context.WithValue(context.Background(), ctxKey{}, "marker")
|
|
c := newRTCConnection(func(c *rtcConnection) { c.ctx = original })
|
|
c.Initialize(context.Background(), nil)
|
|
if got := c.ctx.Value(ctxKey{}); got != "marker" {
|
|
t.Fatalf("ctx was overwritten; got value=%v", got)
|
|
}
|
|
})
|
|
|
|
t.Run("nil listener does not overwrite existing", func(t *testing.T) {
|
|
existing := newFakePacketConn()
|
|
c := newRTCConnection(func(c *rtcConnection) { c.listenerUDP = existing })
|
|
c.Initialize(context.Background(), nil)
|
|
if c.listenerUDP != existing {
|
|
t.Fatal("nil listener overwrote existing listenerUDP")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestRtcConnection_GetUfrag(t *testing.T) {
|
|
c := newRTCConnection(func(c *rtcConnection) { c.Ufrag = "abc:def" })
|
|
if got := c.GetUfrag(); got != "abc:def" {
|
|
t.Fatalf("GetUfrag()=%q", got)
|
|
}
|
|
}
|
|
|
|
// rtcConnFixture wires an rtcConnection with fakes for the load balancer,
|
|
// listener, and backend dial seam.
|
|
type rtcConnFixture struct {
|
|
conn *rtcConnection
|
|
lb *lbfakes.FakeOriginLoadBalancer
|
|
listener *fakePacketConn
|
|
backend *fakeBackendUDP
|
|
dialErr error
|
|
dialIP string
|
|
dialPort int
|
|
}
|
|
|
|
func newRtcConnFixture() *rtcConnFixture {
|
|
f := &rtcConnFixture{
|
|
lb: &lbfakes.FakeOriginLoadBalancer{},
|
|
listener: newFakePacketConn(),
|
|
backend: newFakeBackendUDP(),
|
|
}
|
|
f.conn = newRTCConnection(func(c *rtcConnection) {
|
|
c.loadBalancer = f.lb
|
|
c.StreamURL = "vhost/app/stream"
|
|
c.Ufrag = "L:R"
|
|
c.listenerUDP = f.listener
|
|
c.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
|
f.dialIP, f.dialPort = ip, port
|
|
if f.dialErr != nil {
|
|
return nil, f.dialErr
|
|
}
|
|
return f.backend, nil
|
|
}
|
|
})
|
|
return f
|
|
}
|
|
|
|
func TestRtcConnection_ConnectBackend(t *testing.T) {
|
|
t.Run("noop when already connected", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.conn.backendUDP = f.backend
|
|
if err := f.conn.connectBackend(context.Background()); err != nil {
|
|
t.Fatalf("unexpected err=%v", err)
|
|
}
|
|
if f.lb.PickCallCount() != 0 {
|
|
t.Fatal("expected Pick not to be called when already connected")
|
|
}
|
|
})
|
|
|
|
t.Run("propagates Pick error", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(nil, errors.New("boom"))
|
|
err := f.conn.connectBackend(context.Background())
|
|
if err == nil || !strings.Contains(err.Error(), "boom") {
|
|
t.Fatalf("expected pick err, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("errors when backend has no RTC endpoints", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil)
|
|
err := f.conn.connectBackend(context.Background())
|
|
if err == nil || !strings.Contains(err.Error(), "no udp server") {
|
|
t.Fatalf("expected no-udp-server err, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("propagates dial error", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil)
|
|
f.dialErr = errors.New("dial-failed")
|
|
err := f.conn.connectBackend(context.Background())
|
|
if err == nil || !strings.Contains(err.Error(), "dial-failed") {
|
|
t.Fatalf("expected dial err, got %v", err)
|
|
}
|
|
if f.conn.backendUDP != nil {
|
|
t.Fatal("backendUDP should remain nil on dial failure")
|
|
}
|
|
})
|
|
|
|
t.Run("success sets backendUDP and forwards ip/port", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.5", RTC: []string{"18000"}}, nil)
|
|
if err := f.conn.connectBackend(context.Background()); err != nil {
|
|
t.Fatalf("unexpected err=%v", err)
|
|
}
|
|
if f.conn.backendUDP != f.backend {
|
|
t.Fatal("backendUDP not set")
|
|
}
|
|
if f.dialIP != "10.0.0.5" || f.dialPort != 18000 {
|
|
t.Fatalf("dial got ip=%q port=%d", f.dialIP, f.dialPort)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestRtcConnection_HandlePacket(t *testing.T) {
|
|
t.Run("writes data to backend and stores client addr", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
f.conn.Initialize(ctx, f.listener)
|
|
|
|
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.5"), Port: 5000}
|
|
payload := []byte("hello-backend")
|
|
if err := f.conn.HandlePacket(clientAddr, payload); err != nil {
|
|
t.Fatalf("HandlePacket err=%v", err)
|
|
}
|
|
|
|
if f.conn.clientUDP != clientAddr {
|
|
t.Fatal("clientUDP not updated")
|
|
}
|
|
|
|
select {
|
|
case got := <-f.backend.writes:
|
|
if string(got) != string(payload) {
|
|
t.Fatalf("backend got %q, want %q", got, payload)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for backend write")
|
|
}
|
|
})
|
|
|
|
t.Run("propagates connectBackend error", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(nil, errors.New("pick-fail"))
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
f.conn.Initialize(ctx, f.listener)
|
|
|
|
err := f.conn.HandlePacket(&net.UDPAddr{}, []byte("x"))
|
|
if err == nil || !strings.Contains(err.Error(), "pick-fail") {
|
|
t.Fatalf("expected propagated pick err, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("propagates backend write error", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil)
|
|
f.backend.writeErr = errors.New("write-fail")
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
f.conn.Initialize(ctx, f.listener)
|
|
|
|
err := f.conn.HandlePacket(&net.UDPAddr{}, []byte("x"))
|
|
if err == nil || !strings.Contains(err.Error(), "write-fail") {
|
|
t.Fatalf("expected propagated write err, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("backend reads are forwarded to listener", func(t *testing.T) {
|
|
f := newRtcConnFixture()
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "127.0.0.1", RTC: []string{"18000"}}, nil)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
f.conn.Initialize(ctx, f.listener)
|
|
|
|
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.5"), Port: 5000}
|
|
if err := f.conn.HandlePacket(clientAddr, []byte("trigger")); err != nil {
|
|
t.Fatalf("HandlePacket err=%v", err)
|
|
}
|
|
// drain the trigger packet sent to backend
|
|
<-f.backend.writes
|
|
|
|
// Feed a packet from the backend; expect it forwarded to the listener.
|
|
f.backend.reads <- []byte("from-backend")
|
|
|
|
select {
|
|
case got := <-f.listener.writes:
|
|
if string(got.data) != "from-backend" {
|
|
t.Fatalf("listener got %q, want %q", got.data, "from-backend")
|
|
}
|
|
if got.addr != clientAddr {
|
|
t.Fatalf("listener addr=%v, want %v", got.addr, clientAddr)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for listener write")
|
|
}
|
|
|
|
// Cleanly terminate the read loop.
|
|
_ = f.backend.Close()
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// webRTCProxyServer: fakes, helpers, and fixtures
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// blockingUDPListener stands in for the WebRTC UDP listener used by Run().
|
|
// ReadFrom blocks until packets are pushed via push(); Close unblocks the
|
|
// reader with a "use of closed network connection" error so the accept loop
|
|
// hits utils.IsClosedNetworkError and exits gracefully.
|
|
type blockingUDPListener struct {
|
|
packets chan udpPacket
|
|
writes chan packetWrite
|
|
closed atomic.Bool
|
|
}
|
|
|
|
type udpPacket struct {
|
|
data []byte
|
|
addr net.Addr
|
|
}
|
|
|
|
func newBlockingUDPListener() *blockingUDPListener {
|
|
return &blockingUDPListener{
|
|
packets: make(chan udpPacket, 8),
|
|
writes: make(chan packetWrite, 16),
|
|
}
|
|
}
|
|
|
|
func (l *blockingUDPListener) push(p udpPacket) { l.packets <- p }
|
|
|
|
func (l *blockingUDPListener) ReadFrom(buf []byte) (int, net.Addr, error) {
|
|
p, ok := <-l.packets
|
|
if !ok {
|
|
return 0, nil, errors.New("use of closed network connection")
|
|
}
|
|
n := copy(buf, p.data)
|
|
return n, p.addr, nil
|
|
}
|
|
|
|
func (l *blockingUDPListener) WriteTo(p []byte, addr net.Addr) (int, error) {
|
|
cp := make([]byte, len(p))
|
|
copy(cp, p)
|
|
l.writes <- packetWrite{data: cp, addr: addr}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (l *blockingUDPListener) Close() error {
|
|
if l.closed.CompareAndSwap(false, true) {
|
|
close(l.packets)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *blockingUDPListener) LocalAddr() net.Addr { return fakeAddr{} }
|
|
func (l *blockingUDPListener) SetDeadline(time.Time) error { return nil }
|
|
func (l *blockingUDPListener) SetReadDeadline(time.Time) error { return nil }
|
|
func (l *blockingUDPListener) SetWriteDeadline(time.Time) error { return nil }
|
|
|
|
// newStunBindingRequest builds a minimal STUN binding request packet whose
|
|
// USERNAME attribute (type 0x0006) carries the given ufrag. The first byte is
|
|
// 0x00 so utils.RtcIsSTUN returns true; the header's message-length field
|
|
// matches the attribute body so rtcStunPacket.UnmarshalBinary succeeds.
|
|
func newStunBindingRequest(ufrag string) []byte {
|
|
body := make([]byte, 0, 4+len(ufrag)+3)
|
|
body = append(body, 0x00, 0x06)
|
|
body = append(body, byte(len(ufrag)>>8), byte(len(ufrag)))
|
|
body = append(body, []byte(ufrag)...)
|
|
for len(body)%4 != 0 {
|
|
body = append(body, 0)
|
|
}
|
|
|
|
hdr := make([]byte, 20)
|
|
binary.BigEndian.PutUint16(hdr[0:2], 0x0001)
|
|
binary.BigEndian.PutUint16(hdr[2:4], uint16(len(body)))
|
|
return append(hdr, body...)
|
|
}
|
|
|
|
// fakeNonStunPacket builds a UDP payload whose first byte is neither 0/1 (so
|
|
// utils.RtcIsSTUN returns false) nor a valid RTP marker, so handleClientUDP
|
|
// treats it as "unknown" and skips parsing.
|
|
func fakeNonStunPacket() []byte { return []byte{0x42, 0x00, 0x00, 0x00} }
|
|
|
|
// fakeRTPPacket builds a minimal payload that satisfies utils.RtcIsRTPOrRTCP
|
|
// (len >= 12, first byte 0x80) so handleClientUDP's STUN parser is skipped.
|
|
func fakeRTPPacket() []byte {
|
|
p := make([]byte, 12)
|
|
p[0] = 0x80
|
|
return p
|
|
}
|
|
|
|
// webRTCFixture bundles fakes plus a webRTCProxyServer wired against them.
|
|
// The default listenUDP returns the fixture's blocking listener; tests can
|
|
// either drive Run() through it or call handler methods directly without
|
|
// starting Run() at all.
|
|
type webRTCFixture struct {
|
|
env *envfakes.FakeProxyEnvironment
|
|
lb *lbfakes.FakeOriginLoadBalancer
|
|
listener *blockingUDPListener
|
|
server *webRTCProxyServer
|
|
}
|
|
|
|
func newWebRTCFixture() *webRTCFixture {
|
|
f := &webRTCFixture{
|
|
env: &envfakes.FakeProxyEnvironment{},
|
|
lb: &lbfakes.FakeOriginLoadBalancer{},
|
|
listener: newBlockingUDPListener(),
|
|
}
|
|
f.env.WebRTCServerReturns("18000")
|
|
|
|
srv := NewWebRTCProxyServer(f.env, f.lb, func(v *webRTCProxyServer) {
|
|
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
|
|
return f.listener, nil
|
|
}
|
|
})
|
|
f.server = srv.(*webRTCProxyServer)
|
|
return f
|
|
}
|
|
|
|
// sampleSDPOffer is a minimal valid SDP offer with the ICE attributes
|
|
// ParseIceUfragPwd looks for. Used as the WHIP/WHEP request body.
|
|
const sampleSDPOffer = "v=0\r\n" +
|
|
"a=ice-ufrag:remote-ufrag\r\n" +
|
|
"a=ice-pwd:remote-pwd-very-long-value-32xx\r\n"
|
|
|
|
// sampleSDPAnswer returns an SDP answer where the backend's RTC port appears
|
|
// in a candidate line so the proxy's port-rewrite path can be exercised.
|
|
func sampleSDPAnswer(port string) string {
|
|
return "v=0\r\n" +
|
|
"a=ice-ufrag:local-ufrag\r\n" +
|
|
"a=ice-pwd:local-pwd-very-long-value-32xxxx\r\n" +
|
|
"a=candidate:1 1 udp 1 1.2.3.4 " + port + " typ host\r\n"
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// NewWebRTCProxyServer: constructor & defaults
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestNewWebRTCProxyServer_SetsDefaults(t *testing.T) {
|
|
srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
v := srv.(*webRTCProxyServer)
|
|
if v.listenUDP == nil {
|
|
t.Fatal("listenUDP should default to a non-nil factory")
|
|
}
|
|
if v.backendURL == nil {
|
|
t.Fatal("backendURL should default to a non-nil factory")
|
|
}
|
|
}
|
|
|
|
func TestNewWebRTCProxyServer_DefaultBackendURL_NoAPI(t *testing.T) {
|
|
srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
v := srv.(*webRTCProxyServer)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/", strings.NewReader(""))
|
|
_, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1"}, req)
|
|
if err == nil || !strings.Contains(err.Error(), "no http api server") {
|
|
t.Fatalf("expected no-api error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewWebRTCProxyServer_DefaultBackendURL_BadPort(t *testing.T) {
|
|
srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
v := srv.(*webRTCProxyServer)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/x", strings.NewReader(""))
|
|
_, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1", API: []string{"not-a-port"}}, req)
|
|
if err == nil || !strings.Contains(err.Error(), "parse http port") {
|
|
t.Fatalf("expected parse-port error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewWebRTCProxyServer_DefaultBackendURL_Success(t *testing.T) {
|
|
srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
v := srv.(*webRTCProxyServer)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(""))
|
|
got, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}}, req)
|
|
if err != nil {
|
|
t.Fatalf("backendURL: %v", err)
|
|
}
|
|
want := "http://10.0.0.1:1985/rtc/v1/whip/?app=live&stream=demo"
|
|
if got != want {
|
|
t.Fatalf("backendURL=%q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestNewWebRTCProxyServer_DefaultBackendURL_NoQuery(t *testing.T) {
|
|
// When the inbound request has no raw query, the URL must not get a
|
|
// dangling "?" appended.
|
|
srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
v := srv.(*webRTCProxyServer)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whep/", strings.NewReader(""))
|
|
got, err := v.backendURL(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}}, req)
|
|
if err != nil {
|
|
t.Fatalf("backendURL: %v", err)
|
|
}
|
|
want := "http://10.0.0.1:1985/rtc/v1/whep/"
|
|
if got != want {
|
|
t.Fatalf("backendURL=%q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestNewWebRTCProxyServer_AppliesOptions(t *testing.T) {
|
|
var listenCalls, backendCalls atomic.Int32
|
|
srv := NewWebRTCProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
func(v *webRTCProxyServer) {
|
|
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
|
|
listenCalls.Add(1)
|
|
return nil, errors.New("unused")
|
|
}
|
|
v.backendURL = func(backend *lb.OriginServer, r *http.Request) (string, error) {
|
|
backendCalls.Add(1)
|
|
return "http://example.test", nil
|
|
}
|
|
},
|
|
)
|
|
v := srv.(*webRTCProxyServer)
|
|
_, _ = v.listenUDP(context.Background(), ":0")
|
|
_, _ = v.backendURL(&lb.OriginServer{}, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if got := listenCalls.Load(); got != 1 {
|
|
t.Fatalf("custom listenUDP called %d times, want 1", got)
|
|
}
|
|
if got := backendCalls.Load(); got != 1 {
|
|
t.Fatalf("custom backendURL called %d times, want 1", got)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// webRTCProxyServer.Close
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestWebRTCProxyServer_Close_NilListener(t *testing.T) {
|
|
// Close before Run must not panic, must not hang, and must not error.
|
|
srv := NewWebRTCProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
done := make(chan error, 1)
|
|
go func() { done <- srv.Close() }()
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("Close hung with no listener")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// webRTCProxyServer.Run
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestWebRTCProxyServer_Run_ListenError(t *testing.T) {
|
|
envFake := &envfakes.FakeProxyEnvironment{}
|
|
envFake.WebRTCServerReturns("18000")
|
|
srv := NewWebRTCProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *webRTCProxyServer) {
|
|
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
|
|
return nil, errors.New("permission denied")
|
|
}
|
|
})
|
|
|
|
err := srv.Run(context.Background())
|
|
if err == nil {
|
|
t.Fatal("expected error from Run when listenUDP fails")
|
|
}
|
|
if !strings.Contains(err.Error(), "listen udp") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_Run_EndpointWithoutColon(t *testing.T) {
|
|
// A bare port like "18000" must be normalized to ":18000".
|
|
envFake := &envfakes.FakeProxyEnvironment{}
|
|
envFake.WebRTCServerReturns("18000")
|
|
listener := newBlockingUDPListener()
|
|
var captured atomic.Value
|
|
srv := NewWebRTCProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *webRTCProxyServer) {
|
|
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
|
|
captured.Store(endpoint)
|
|
return listener, nil
|
|
}
|
|
})
|
|
|
|
if err := srv.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
defer srv.Close()
|
|
|
|
if got := captured.Load(); got != ":18000" {
|
|
t.Fatalf("listenUDP endpoint=%v, want :18000", got)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_Run_EndpointWithColon(t *testing.T) {
|
|
// An endpoint that already contains ":" must be passed through unchanged.
|
|
envFake := &envfakes.FakeProxyEnvironment{}
|
|
envFake.WebRTCServerReturns("127.0.0.1:18000")
|
|
listener := newBlockingUDPListener()
|
|
var captured atomic.Value
|
|
srv := NewWebRTCProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *webRTCProxyServer) {
|
|
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
|
|
captured.Store(endpoint)
|
|
return listener, nil
|
|
}
|
|
})
|
|
|
|
if err := srv.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
defer srv.Close()
|
|
|
|
if got := captured.Load(); got != "127.0.0.1:18000" {
|
|
t.Fatalf("listenUDP endpoint=%v, want 127.0.0.1:18000", got)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_Run_CloseStopsReadLoop(t *testing.T) {
|
|
// Start Run with an idle listener (no packets queued). The read goroutine
|
|
// blocks in ReadFrom. Close must unblock it via the "closed network
|
|
// connection" error and allow the wait group to drain.
|
|
f := newWebRTCFixture()
|
|
if err := f.server.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
go func() { done <- f.server.Close() }()
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("Close hung — read loop did not exit on listener close")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// webRTCProxyServer.HandleApiForWHIP / HandleApiForWHEP
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestWebRTCProxyServer_HandleApiForWHIP_CORSPreflight(t *testing.T) {
|
|
// OPTIONS short-circuits before reading the body, so the LB is untouched.
|
|
f := newWebRTCFixture()
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/rtc/v1/whip/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
if err := f.server.HandleApiForWHIP(context.Background(), rec, req); err != nil {
|
|
t.Fatalf("WHIP: %v", err)
|
|
}
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status=%d, want 200", rec.Code)
|
|
}
|
|
if f.lb.PickCallCount() != 0 {
|
|
t.Fatal("LB.Pick should not be called for CORS preflight")
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleApiForWHEP_CORSPreflight(t *testing.T) {
|
|
f := newWebRTCFixture()
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/rtc/v1/whep/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
if err := f.server.HandleApiForWHEP(context.Background(), rec, req); err != nil {
|
|
t.Fatalf("WHEP: %v", err)
|
|
}
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status=%d, want 200", rec.Code)
|
|
}
|
|
if f.lb.PickCallCount() != 0 {
|
|
t.Fatal("LB.Pick should not be called for CORS preflight")
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleApiForWHIP_PickError(t *testing.T) {
|
|
f := newWebRTCFixture()
|
|
f.lb.PickReturns(nil, errors.New("no backend"))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
err := f.server.HandleApiForWHIP(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "pick backend") {
|
|
t.Fatalf("expected pick-backend error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleApiForWHEP_PickError(t *testing.T) {
|
|
f := newWebRTCFixture()
|
|
f.lb.PickReturns(nil, errors.New("no backend"))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whep/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
err := f.server.HandleApiForWHEP(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "pick backend") {
|
|
t.Fatalf("expected pick-backend error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleApiForWHIP_HappyPath(t *testing.T) {
|
|
// Drive a full WHIP exchange: the proxy forwards the offer to an httptest
|
|
// backend, rewrites the UDP port in the answer, and calls StoreWebRTC.
|
|
f := newWebRTCFixture()
|
|
f.env.WebRTCServerReturns("19000")
|
|
|
|
const backendRTCPort = "18000"
|
|
var backendSawOffer atomic.Bool
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
if string(body) == sampleSDPOffer {
|
|
backendSawOffer.Store(true)
|
|
}
|
|
w.WriteHeader(http.StatusCreated)
|
|
_, _ = w.Write([]byte(sampleSDPAnswer(backendRTCPort)))
|
|
}))
|
|
defer backend.Close()
|
|
|
|
// Override backendURL so the proxy talks to the httptest server.
|
|
f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) {
|
|
return backend.URL + r.URL.Path, nil
|
|
}
|
|
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{backendRTCPort}}, nil)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
if err := f.server.HandleApiForWHIP(context.Background(), rec, req); err != nil {
|
|
t.Fatalf("WHIP: %v", err)
|
|
}
|
|
if !backendSawOffer.Load() {
|
|
t.Fatal("backend did not receive the SDP offer body")
|
|
}
|
|
if rec.Code != http.StatusCreated {
|
|
t.Fatalf("client status=%d, want 201", rec.Code)
|
|
}
|
|
body := rec.Body.String()
|
|
if !strings.Contains(body, " 19000 typ host") {
|
|
t.Fatalf("answer did not rewrite backend port; got %q", body)
|
|
}
|
|
if strings.Contains(body, " "+backendRTCPort+" typ host") {
|
|
t.Fatalf("answer still contains original backend port; got %q", body)
|
|
}
|
|
if f.lb.StoreWebRTCCallCount() != 1 {
|
|
t.Fatalf("StoreWebRTC called %d times, want 1", f.lb.StoreWebRTCCallCount())
|
|
}
|
|
_, streamURL, stored := f.lb.StoreWebRTCArgsForCall(0)
|
|
if !strings.HasSuffix(streamURL, "/live/demo") {
|
|
t.Fatalf("StoreWebRTC streamURL=%q, want suffix /live/demo", streamURL)
|
|
}
|
|
if got := stored.GetUfrag(); got != "local-ufrag:remote-ufrag" {
|
|
t.Fatalf("stored ufrag=%q, want local-ufrag:remote-ufrag", got)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleApiForWHEP_HappyPath(t *testing.T) {
|
|
f := newWebRTCFixture()
|
|
f.env.WebRTCServerReturns("19000")
|
|
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.ReadAll(r.Body)
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(sampleSDPAnswer("18000")))
|
|
}))
|
|
defer backend.Close()
|
|
|
|
f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) {
|
|
return backend.URL + r.URL.Path, nil
|
|
}
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whep/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
if err := f.server.HandleApiForWHEP(context.Background(), rec, req); err != nil {
|
|
t.Fatalf("WHEP: %v", err)
|
|
}
|
|
if f.lb.StoreWebRTCCallCount() != 1 {
|
|
t.Fatalf("StoreWebRTC called %d times, want 1", f.lb.StoreWebRTCCallCount())
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// webRTCProxyServer.proxyApiToBackend: error paths
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestWebRTCProxyServer_ProxyApiToBackend_BackendURLError(t *testing.T) {
|
|
f := newWebRTCFixture()
|
|
f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) {
|
|
return "", errors.New("build err")
|
|
}
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
err := f.server.HandleApiForWHIP(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "build err") {
|
|
t.Fatalf("expected build err, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_ProxyApiToBackend_BackendNon200(t *testing.T) {
|
|
f := newWebRTCFixture()
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}))
|
|
defer backend.Close()
|
|
|
|
f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) {
|
|
return backend.URL + r.URL.Path, nil
|
|
}
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
err := f.server.HandleApiForWHIP(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "proxy api to") {
|
|
t.Fatalf("expected proxy-api error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_ProxyApiToBackend_BadAnswerNoIceUfrag(t *testing.T) {
|
|
// Backend returns an answer missing the ice-ufrag/pwd attributes; the
|
|
// proxy must surface the ParseIceUfragPwd error rather than calling
|
|
// StoreWebRTC.
|
|
f := newWebRTCFixture()
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("v=0\r\n"))
|
|
}))
|
|
defer backend.Close()
|
|
|
|
f.server.backendURL = func(b *lb.OriginServer, r *http.Request) (string, error) {
|
|
return backend.URL + r.URL.Path, nil
|
|
}
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", API: []string{"1985"}, RTC: []string{"18000"}}, nil)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/rtc/v1/whip/?app=live&stream=demo", strings.NewReader(sampleSDPOffer))
|
|
rec := httptest.NewRecorder()
|
|
|
|
err := f.server.HandleApiForWHIP(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "parse local sdp answer") {
|
|
t.Fatalf("expected parse-answer error, got %v", err)
|
|
}
|
|
if f.lb.StoreWebRTCCallCount() != 0 {
|
|
t.Fatal("StoreWebRTC should not be called when answer is malformed")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// webRTCProxyServer.handleClientUDP
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_NonStunIgnored(t *testing.T) {
|
|
// A non-STUN, non-RTP/RTCP packet with no cached connection must return
|
|
// without touching the LB.
|
|
f := newWebRTCFixture()
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7000}
|
|
|
|
if err := f.server.handleClientUDP(context.Background(), addr, fakeNonStunPacket()); err != nil {
|
|
t.Fatalf("handleClientUDP: %v", err)
|
|
}
|
|
if f.lb.LoadWebRTCByUfragCallCount() != 0 {
|
|
t.Fatal("LB.LoadWebRTCByUfrag should not be called for non-STUN packet")
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_RTPLikeIgnored(t *testing.T) {
|
|
// An RTP-like packet (first byte 0x80) skips STUN parsing entirely; the
|
|
// LB must not be consulted because no connection lookup happens.
|
|
f := newWebRTCFixture()
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7001}
|
|
|
|
if err := f.server.handleClientUDP(context.Background(), addr, fakeRTPPacket()); err != nil {
|
|
t.Fatalf("handleClientUDP: %v", err)
|
|
}
|
|
if f.lb.LoadWebRTCByUfragCallCount() != 0 {
|
|
t.Fatal("LB.LoadWebRTCByUfrag should not be called for RTP-like packet")
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_StunBadPacket(t *testing.T) {
|
|
// A short payload that satisfies utils.RtcIsSTUN (first byte 0x00) but
|
|
// is shorter than the 20-byte STUN header should surface the
|
|
// unmarshaler's "too short" error.
|
|
f := newWebRTCFixture()
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7002}
|
|
|
|
err := f.server.handleClientUDP(context.Background(), addr, []byte{0x00, 0x00, 0x00})
|
|
if err == nil || !strings.Contains(err.Error(), "stun packet too short") {
|
|
t.Fatalf("expected too-short err, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_StunCachedUsername(t *testing.T) {
|
|
// A STUN packet whose USERNAME matches a connection already in the
|
|
// username cache must route directly to that connection. We pre-wire
|
|
// the connection so its load balancer fails Pick, so HandlePacket exits
|
|
// quickly with a recognizable error and we can assert routing.
|
|
f := newWebRTCFixture()
|
|
|
|
cachedLB := &lbfakes.FakeOriginLoadBalancer{}
|
|
cachedLB.PickReturns(nil, errors.New("test terminate"))
|
|
cached := newRTCConnection(func(c *rtcConnection) {
|
|
c.loadBalancer = cachedLB
|
|
c.StreamURL = "vhost/app/stream"
|
|
c.Ufrag = "L:R"
|
|
})
|
|
cached.Initialize(context.Background(), f.listener)
|
|
f.server.usernames.Store("L:R", cached)
|
|
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7003}
|
|
err := f.server.handleClientUDP(context.Background(), addr, newStunBindingRequest("L:R"))
|
|
if err == nil || !strings.Contains(err.Error(), "test terminate") {
|
|
t.Fatalf("expected terminate err, got %v", err)
|
|
}
|
|
// The address cache must have learned this addr.
|
|
if _, ok := f.server.addresses.Load(addr.String()); !ok {
|
|
t.Fatal("expected addr to be cached after routing via username")
|
|
}
|
|
if f.lb.LoadWebRTCByUfragCallCount() != 0 {
|
|
t.Fatal("LB.LoadWebRTCByUfrag should not be called when cached")
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_StunLoadsFromLB(t *testing.T) {
|
|
// STUN packet whose USERNAME is not in the cache: the proxy must consult
|
|
// the load balancer, cache the returned connection by username, and then
|
|
// dispatch to it. handleClientUDP rewires the loaded connection's
|
|
// loadBalancer to the server's LB, so we make f.lb.Pick fail to keep the
|
|
// HandlePacket call deterministic.
|
|
f := newWebRTCFixture()
|
|
f.lb.PickReturns(nil, errors.New("test terminate"))
|
|
|
|
loaded := newRTCConnection(func(c *rtcConnection) {
|
|
c.StreamURL = "vhost/app/stream"
|
|
c.Ufrag = "L:R"
|
|
})
|
|
f.lb.LoadWebRTCByUfragReturns(loaded, nil)
|
|
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7004}
|
|
err := f.server.handleClientUDP(context.Background(), addr, newStunBindingRequest("L:R"))
|
|
if err == nil || !strings.Contains(err.Error(), "test terminate") {
|
|
t.Fatalf("expected terminate err, got %v", err)
|
|
}
|
|
if got := f.lb.LoadWebRTCByUfragCallCount(); got != 1 {
|
|
t.Fatalf("LoadWebRTCByUfrag called %d times, want 1", got)
|
|
}
|
|
if _, ok := f.server.usernames.Load("L:R"); !ok {
|
|
t.Fatal("expected username to be cached after LB load")
|
|
}
|
|
// The loaded connection should have been rewired to use the server's LB.
|
|
if loaded.loadBalancer != f.lb {
|
|
t.Fatal("loaded connection should adopt the server's load balancer")
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_StunLBError(t *testing.T) {
|
|
// LB.LoadWebRTCByUfrag failure must surface as a wrapped error.
|
|
f := newWebRTCFixture()
|
|
f.lb.LoadWebRTCByUfragReturns(nil, errors.New("lookup failed"))
|
|
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7005}
|
|
err := f.server.handleClientUDP(context.Background(), addr, newStunBindingRequest("missing"))
|
|
if err == nil || !strings.Contains(err.Error(), "load webrtc by ufrag") {
|
|
t.Fatalf("expected load-webrtc err, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWebRTCProxyServer_HandleClientUDP_UsesCachedAddress(t *testing.T) {
|
|
// A non-STUN packet from an address already in the address cache must be
|
|
// dispatched to the cached connection without consulting the LB.
|
|
f := newWebRTCFixture()
|
|
cachedLB := &lbfakes.FakeOriginLoadBalancer{}
|
|
cachedLB.PickReturns(nil, errors.New("test terminate"))
|
|
cached := newRTCConnection(func(c *rtcConnection) {
|
|
c.loadBalancer = cachedLB
|
|
c.StreamURL = "vhost/app/stream"
|
|
c.Ufrag = "L:R"
|
|
})
|
|
cached.Initialize(context.Background(), f.listener)
|
|
|
|
addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7006}
|
|
f.server.addresses.Store(addr.String(), cached)
|
|
|
|
err := f.server.handleClientUDP(context.Background(), addr, fakeRTPPacket())
|
|
if err == nil || !strings.Contains(err.Error(), "test terminate") {
|
|
t.Fatalf("expected terminate err, got %v", err)
|
|
}
|
|
if f.lb.LoadWebRTCByUfragCallCount() != 0 {
|
|
t.Fatal("LB.LoadWebRTCByUfrag should not be called when address is cached")
|
|
}
|
|
}
|