Claude: Add WebRTC proxy seams and unit tests.

Introduce listenUDP and backendURL functional-option seams on
webRTCProxyServer and a dialBackendUDP seam on rtcConnection, mirroring
the pattern already used by rtmpProxyServer. The seams default to the
real net.ListenUDP / http URL builder / net.Dialer so production
behavior is unchanged, but unit tests can now inject fakes.

Cover webRTCProxyServer with focused tests: constructor defaults
(including the three default-backendURL branches), Close with no
listener, Run's listen error / endpoint normalization / graceful
shutdown, HandleApiForWHIP and HandleApiForWHEP CORS preflight, Pick
error, full happy-path against an httptest backend asserting SDP port
rewrite and StoreWebRTC wiring, proxyApiToBackend error paths
(backendURL error, non-2xx, malformed answer), and handleClientUDP's
non-STUN, RTP-like, short-STUN, cached-username, LB-load, LB-error,
and cached-address paths. internal/proxy package coverage rises from
~23% to 43.4%.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
winlin 2026-05-16 19:21:34 -04:00
parent 7b4c4dc999
commit 953b0d63ca
2 changed files with 1193 additions and 46 deletions

View File

@ -7,6 +7,7 @@ import (
"context"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
@ -38,8 +39,9 @@ type webRTCProxyServer struct {
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The UDP listener for WebRTC server.
listener *net.UDPConn
// The UDP listener for WebRTC server. Stored as net.PacketConn so tests
// can inject a fake listener via listenUDP.
listener net.PacketConn
// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
@ -51,6 +53,16 @@ type webRTCProxyServer struct {
// The wait group for server.
wg stdSync.WaitGroup
// backendURL builds the URL to forward a WHIP/WHEP SDP exchange to a backend
// SRS server. Defaults to "http://<ip>:<api-port><path>?<query>"; tests may
// override to redirect requests to an httptest.Server.
backendURL func(backend *lb.OriginServer, r *http.Request) (string, error)
// listenUDP opens the UDP listener for the WebRTC server. Defaults to a real
// net.ListenUDP on the resolved endpoint; tests may override via a functional
// option to supply a fake listener.
listenUDP func(ctx context.Context, endpoint string) (net.PacketConn, error)
}
func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*webRTCProxyServer)) WebRTCProxyServer {
@ -60,6 +72,33 @@ func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.Orig
usernames: sync.NewMap[string, *rtcConnection](),
addresses: sync.NewMap[string, *rtcConnection](),
}
// Default listenUDP: resolve the endpoint and open a real UDP socket.
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return nil, errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
return net.ListenUDP("udp", saddr)
}
// Default backendURL: validate API endpoint, parse port, format URL preserving
// the inbound request's path and raw query.
v.backendURL = func(backend *lb.OriginServer, r *http.Request) (string, error) {
if len(backend.API) == 0 {
return "", errors.Errorf("no http api server")
}
apiPort, err := strconv.ParseInt(backend.API[0], 10, 64)
if err != nil {
return "", errors.Wrapf(err, "parse http port %v", backend.API[0])
}
u := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
if r.URL.RawQuery != "" {
u += "?" + r.URL.RawQuery
}
return u, nil
}
for _, opt := range opts {
opt(v)
}
@ -85,7 +124,7 @@ func (v *webRTCProxyServer) HandleApiForWHIP(ctx context.Context, w http.Respons
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
remoteSDPOffer, err := io.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
@ -122,7 +161,7 @@ func (v *webRTCProxyServer) HandleApiForWHEP(ctx context.Context, w http.Respons
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
remoteSDPOffer, err := io.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
@ -153,22 +192,11 @@ func (v *webRTCProxyServer) proxyApiToBackend(
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.OriginServer,
remoteSDPOffer string, streamURL string,
) error {
// Parse HTTP port from backend.
if len(backend.API) == 0 {
return errors.Errorf("no http api server")
}
var apiPort int
if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.API[0])
} else {
apiPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
// Resolve the backend URL via the configurable seam (so tests can redirect to
// an httptest.Server).
backendURL, err := v.backendURL(backend, r)
if err != nil {
return errors.Wrapf(err, "build backend url")
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer))
@ -257,17 +285,12 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
endpoint = fmt.Sprintf(":%v", endpoint)
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
listener, err := v.listenUDP(ctx, endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
return errors.Wrapf(err, "listen udp %v", endpoint)
}
v.listener = listener
logger.Debug(ctx, "WebRTC server listen at %v", saddr)
logger.Debug(ctx, "WebRTC server listen at %v", listener.LocalAddr())
// Consume all messages from UDP media transport.
v.wg.Add(1)
@ -276,7 +299,7 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := listener.ReadFromUDP(buf)
n, addr, err := listener.ReadFrom(buf)
if err != nil {
// If context is canceled or connection is closed, exit gracefully without logging error.
if ctx.Err() != nil || utils.IsClosedNetworkError(err) {
@ -289,8 +312,8 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Warn(ctx, "WebRTC handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil {
logger.Warn(ctx, "WebRTC handle udp %vB failed, addr=%v, err=%+v", n, addr, err)
}
}
}()
@ -298,7 +321,7 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
return nil
}
func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr net.Addr, data []byte) error {
var connection *rtcConnection
// If STUN binding request, parse the ufrag and identify the connection.
@ -379,23 +402,37 @@ type rtcConnection struct {
// The ufrag for this WebRTC connection.
Ufrag string `json:"ufrag"`
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The UDP connection proxy to backend. Stored as io.ReadWriteCloser so tests
// can inject a fake connection by overriding dialBackendUDP.
backendUDP io.ReadWriteCloser
// The client UDP address. Note that it may change.
clientUDP *net.UDPAddr
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
clientUDP net.Addr
// The listener UDP connection, used to send messages to client. Stored as
// net.PacketConn so tests can inject a fake listener.
listenerUDP net.PacketConn
// dialBackendUDP opens a UDP connection to a backend SRS server. Defaults to a real
// UDP dial; tests may override via a functional option to supply a fake connection.
dialBackendUDP func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error)
}
func newRTCConnection(opts ...func(*rtcConnection)) *rtcConnection {
v := &rtcConnection{}
// Default dial: a real UDP connection to the backend. Uses Dialer.DialContext
// so ctx cancellation/deadline aborts DNS resolution (UDP itself has no handshake).
v.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
var d net.Dialer
return d.DialContext(ctx, "udp", net.JoinHostPort(ip, strconv.Itoa(port)))
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *rtcConnection) Initialize(ctx context.Context, listener *net.UDPConn) *rtcConnection {
func (v *rtcConnection) Initialize(ctx context.Context, listener net.PacketConn) *rtcConnection {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
@ -409,7 +446,7 @@ func (v *rtcConnection) GetUfrag() string {
return v.Ufrag
}
func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
func (v *rtcConnection) HandlePacket(addr net.Addr, data []byte) error {
ctx := v.ctx
// Update the current UDP address.
@ -429,14 +466,14 @@ func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
go func() {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, _, err := v.backendUDP.ReadFromUDP(buf)
n, err := v.backendUDP.Read(buf)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Warn(ctx, "read from backend failed, err=%v", err)
break
}
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
if _, err = v.listenerUDP.WriteTo(buf[:n], v.clientUDP); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Warn(ctx, "write to client failed, err=%v", err)
break
@ -474,12 +511,11 @@ func (v *rtcConnection) connectBackend(ctx context.Context) error {
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or DTLS alert.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v", backendAddr)
} else {
v.backendUDP = backendUDP
backendUDP, err := v.dialBackendUDP(ctx, backend.IP, int(udpPort))
if err != nil {
return errors.Wrapf(err, "dial udp to %v:%v", backend.IP, udpPort)
}
v.backendUDP = backendUDP
return nil
}

1111
internal/proxy/rtc_test.go Normal file

File diff suppressed because it is too large Load Diff