Claude: Add WebRTC proxy seams and unit tests.
Introduce listenUDP and backendURL functional-option seams on webRTCProxyServer and a dialBackendUDP seam on rtcConnection, mirroring the pattern already used by rtmpProxyServer. The seams default to the real net.ListenUDP / http URL builder / net.Dialer so production behavior is unchanged, but unit tests can now inject fakes. Cover webRTCProxyServer with focused tests: constructor defaults (including the three default-backendURL branches), Close with no listener, Run's listen error / endpoint normalization / graceful shutdown, HandleApiForWHIP and HandleApiForWHEP CORS preflight, Pick error, full happy-path against an httptest backend asserting SDP port rewrite and StoreWebRTC wiring, proxyApiToBackend error paths (backendURL error, non-2xx, malformed answer), and handleClientUDP's non-STUN, RTP-like, short-STUN, cached-username, LB-load, LB-error, and cached-address paths. internal/proxy package coverage rises from ~23% to 43.4%. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
7b4c4dc999
commit
953b0d63ca
|
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -38,8 +39,9 @@ type webRTCProxyServer struct {
|
|||
environment env.ProxyEnvironment
|
||||
// The load balancer for origin servers.
|
||||
loadBalancer lb.OriginLoadBalancer
|
||||
// The UDP listener for WebRTC server.
|
||||
listener *net.UDPConn
|
||||
// The UDP listener for WebRTC server. Stored as net.PacketConn so tests
|
||||
// can inject a fake listener via listenUDP.
|
||||
listener net.PacketConn
|
||||
|
||||
// Fast cache for the username to identify the connection.
|
||||
// The key is username, the value is the UDP address.
|
||||
|
|
@ -51,6 +53,16 @@ type webRTCProxyServer struct {
|
|||
|
||||
// The wait group for server.
|
||||
wg stdSync.WaitGroup
|
||||
|
||||
// backendURL builds the URL to forward a WHIP/WHEP SDP exchange to a backend
|
||||
// SRS server. Defaults to "http://<ip>:<api-port><path>?<query>"; tests may
|
||||
// override to redirect requests to an httptest.Server.
|
||||
backendURL func(backend *lb.OriginServer, r *http.Request) (string, error)
|
||||
|
||||
// listenUDP opens the UDP listener for the WebRTC server. Defaults to a real
|
||||
// net.ListenUDP on the resolved endpoint; tests may override via a functional
|
||||
// option to supply a fake listener.
|
||||
listenUDP func(ctx context.Context, endpoint string) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*webRTCProxyServer)) WebRTCProxyServer {
|
||||
|
|
@ -60,6 +72,33 @@ func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.Orig
|
|||
usernames: sync.NewMap[string, *rtcConnection](),
|
||||
addresses: sync.NewMap[string, *rtcConnection](),
|
||||
}
|
||||
|
||||
// Default listenUDP: resolve the endpoint and open a real UDP socket.
|
||||
v.listenUDP = func(ctx context.Context, endpoint string) (net.PacketConn, error) {
|
||||
saddr, err := net.ResolveUDPAddr("udp", endpoint)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "resolve udp addr %v", endpoint)
|
||||
}
|
||||
return net.ListenUDP("udp", saddr)
|
||||
}
|
||||
|
||||
// Default backendURL: validate API endpoint, parse port, format URL preserving
|
||||
// the inbound request's path and raw query.
|
||||
v.backendURL = func(backend *lb.OriginServer, r *http.Request) (string, error) {
|
||||
if len(backend.API) == 0 {
|
||||
return "", errors.Errorf("no http api server")
|
||||
}
|
||||
apiPort, err := strconv.ParseInt(backend.API[0], 10, 64)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "parse http port %v", backend.API[0])
|
||||
}
|
||||
u := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
|
||||
if r.URL.RawQuery != "" {
|
||||
u += "?" + r.URL.RawQuery
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(v)
|
||||
}
|
||||
|
|
@ -85,7 +124,7 @@ func (v *webRTCProxyServer) HandleApiForWHIP(ctx context.Context, w http.Respons
|
|||
}
|
||||
|
||||
// Read remote SDP offer from body.
|
||||
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
|
||||
remoteSDPOffer, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "read remote sdp offer")
|
||||
}
|
||||
|
|
@ -122,7 +161,7 @@ func (v *webRTCProxyServer) HandleApiForWHEP(ctx context.Context, w http.Respons
|
|||
}
|
||||
|
||||
// Read remote SDP offer from body.
|
||||
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
|
||||
remoteSDPOffer, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "read remote sdp offer")
|
||||
}
|
||||
|
|
@ -153,22 +192,11 @@ func (v *webRTCProxyServer) proxyApiToBackend(
|
|||
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.OriginServer,
|
||||
remoteSDPOffer string, streamURL string,
|
||||
) error {
|
||||
// Parse HTTP port from backend.
|
||||
if len(backend.API) == 0 {
|
||||
return errors.Errorf("no http api server")
|
||||
}
|
||||
|
||||
var apiPort int
|
||||
if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil {
|
||||
return errors.Wrapf(err, "parse http port %v", backend.API[0])
|
||||
} else {
|
||||
apiPort = int(iv)
|
||||
}
|
||||
|
||||
// Connect to backend SRS server via HTTP client.
|
||||
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
|
||||
if r.URL.RawQuery != "" {
|
||||
backendURL += "?" + r.URL.RawQuery
|
||||
// Resolve the backend URL via the configurable seam (so tests can redirect to
|
||||
// an httptest.Server).
|
||||
backendURL, err := v.backendURL(backend, r)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "build backend url")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer))
|
||||
|
|
@ -257,17 +285,12 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
|
|||
endpoint = fmt.Sprintf(":%v", endpoint)
|
||||
}
|
||||
|
||||
saddr, err := net.ResolveUDPAddr("udp", endpoint)
|
||||
listener, err := v.listenUDP(ctx, endpoint)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
|
||||
}
|
||||
|
||||
listener, err := net.ListenUDP("udp", saddr)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "listen udp %v", saddr)
|
||||
return errors.Wrapf(err, "listen udp %v", endpoint)
|
||||
}
|
||||
v.listener = listener
|
||||
logger.Debug(ctx, "WebRTC server listen at %v", saddr)
|
||||
logger.Debug(ctx, "WebRTC server listen at %v", listener.LocalAddr())
|
||||
|
||||
// Consume all messages from UDP media transport.
|
||||
v.wg.Add(1)
|
||||
|
|
@ -276,7 +299,7 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
|
|||
|
||||
for ctx.Err() == nil {
|
||||
buf := make([]byte, 4096)
|
||||
n, caddr, err := listener.ReadFromUDP(buf)
|
||||
n, addr, err := listener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
// If context is canceled or connection is closed, exit gracefully without logging error.
|
||||
if ctx.Err() != nil || utils.IsClosedNetworkError(err) {
|
||||
|
|
@ -289,8 +312,8 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
|
|||
continue
|
||||
}
|
||||
|
||||
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
|
||||
logger.Warn(ctx, "WebRTC handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
|
||||
if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil {
|
||||
logger.Warn(ctx, "WebRTC handle udp %vB failed, addr=%v, err=%+v", n, addr, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
|
@ -298,7 +321,7 @@ func (v *webRTCProxyServer) Run(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
|
||||
func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr net.Addr, data []byte) error {
|
||||
var connection *rtcConnection
|
||||
|
||||
// If STUN binding request, parse the ufrag and identify the connection.
|
||||
|
|
@ -379,23 +402,37 @@ type rtcConnection struct {
|
|||
// The ufrag for this WebRTC connection.
|
||||
Ufrag string `json:"ufrag"`
|
||||
|
||||
// The UDP connection proxy to backend.
|
||||
backendUDP *net.UDPConn
|
||||
// The UDP connection proxy to backend. Stored as io.ReadWriteCloser so tests
|
||||
// can inject a fake connection by overriding dialBackendUDP.
|
||||
backendUDP io.ReadWriteCloser
|
||||
// The client UDP address. Note that it may change.
|
||||
clientUDP *net.UDPAddr
|
||||
// The listener UDP connection, used to send messages to client.
|
||||
listenerUDP *net.UDPConn
|
||||
clientUDP net.Addr
|
||||
// The listener UDP connection, used to send messages to client. Stored as
|
||||
// net.PacketConn so tests can inject a fake listener.
|
||||
listenerUDP net.PacketConn
|
||||
|
||||
// dialBackendUDP opens a UDP connection to a backend SRS server. Defaults to a real
|
||||
// UDP dial; tests may override via a functional option to supply a fake connection.
|
||||
dialBackendUDP func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
func newRTCConnection(opts ...func(*rtcConnection)) *rtcConnection {
|
||||
v := &rtcConnection{}
|
||||
|
||||
// Default dial: a real UDP connection to the backend. Uses Dialer.DialContext
|
||||
// so ctx cancellation/deadline aborts DNS resolution (UDP itself has no handshake).
|
||||
v.dialBackendUDP = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "udp", net.JoinHostPort(ip, strconv.Itoa(port)))
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(v)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *rtcConnection) Initialize(ctx context.Context, listener *net.UDPConn) *rtcConnection {
|
||||
func (v *rtcConnection) Initialize(ctx context.Context, listener net.PacketConn) *rtcConnection {
|
||||
if v.ctx == nil {
|
||||
v.ctx = logger.WithContext(ctx)
|
||||
}
|
||||
|
|
@ -409,7 +446,7 @@ func (v *rtcConnection) GetUfrag() string {
|
|||
return v.Ufrag
|
||||
}
|
||||
|
||||
func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
|
||||
func (v *rtcConnection) HandlePacket(addr net.Addr, data []byte) error {
|
||||
ctx := v.ctx
|
||||
|
||||
// Update the current UDP address.
|
||||
|
|
@ -429,14 +466,14 @@ func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
|
|||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
buf := make([]byte, 4096)
|
||||
n, _, err := v.backendUDP.ReadFromUDP(buf)
|
||||
n, err := v.backendUDP.Read(buf)
|
||||
if err != nil {
|
||||
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
|
||||
logger.Warn(ctx, "read from backend failed, err=%v", err)
|
||||
break
|
||||
}
|
||||
|
||||
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
|
||||
if _, err = v.listenerUDP.WriteTo(buf[:n], v.clientUDP); err != nil {
|
||||
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
|
||||
logger.Warn(ctx, "write to client failed, err=%v", err)
|
||||
break
|
||||
|
|
@ -474,12 +511,11 @@ func (v *rtcConnection) connectBackend(ctx context.Context) error {
|
|||
|
||||
// Connect to backend SRS server via UDP client.
|
||||
// TODO: FIXME: Support close the connection when timeout or DTLS alert.
|
||||
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
|
||||
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
|
||||
return errors.Wrapf(err, "dial udp to %v", backendAddr)
|
||||
} else {
|
||||
v.backendUDP = backendUDP
|
||||
backendUDP, err := v.dialBackendUDP(ctx, backend.IP, int(udpPort))
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "dial udp to %v:%v", backend.IP, udpPort)
|
||||
}
|
||||
v.backendUDP = backendUDP
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
1111
internal/proxy/rtc_test.go
Normal file
1111
internal/proxy/rtc_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user