Claude: Add RTMP proxy seams and unit tests.

Refactor internal/proxy/rtmp.go to expose functional-option seams
(listen, newConnection, newHandshake, newProtocol, newBackend, dial)
and widen the proxy server and connection to net.Listener / net.Conn
so fakes can be injected. Tighten the identify switch in serve() to
a real switch on CommandName.

Add internal/proxy/rtmp_test.go covering rtmpProxyServer (constructor
defaults, options, Close, listen error, endpoint normalization,
accept-loop, graceful shutdown), rtmpConnection (defaults, serve
handshake/protocol error paths, identify-loop branches, newBackend
invocation contract), and rtmpClientToBackend (Close, Connect happy
and error paths, publish, play). rtmp.go statement coverage rises to
76.9% with every function exercised.
This commit is contained in:
winlin 2026-05-16 16:54:11 -04:00
parent 3060bf8e7c
commit 7b4c4dc999
2 changed files with 1385 additions and 36 deletions

View File

@ -6,6 +6,7 @@ package proxy
import (
"context"
"fmt"
"io"
"net"
"strconv"
"strings"
@ -33,14 +34,38 @@ type rtmpProxyServer struct {
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The TCP listener for RTMP server.
listener *net.TCPListener
// The listener for RTMP server. Stored as net.Listener so tests can inject
// a fake listener by overriding listen.
listener net.Listener
// The wait group for all goroutines.
wg sync.WaitGroup
// listen opens a listener on the given address. Defaults to a real TCP listener;
// tests may override via a functional option to supply a fake listener.
listen func(ctx context.Context, addr string) (net.Listener, error)
// newConnection creates a fresh rtmpConnection wired up with this server's
// load balancer. Defaults to a real rtmpConnection; tests may override via
// a functional option to supply a fake.
newConnection func() *rtmpConnection
}
func NewRTMPProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*rtmpProxyServer)) RTMPProxyServer {
v := &rtmpProxyServer{environment: environment, loadBalancer: loadBalancer}
// Default listen: a real TCP listener. Uses ListenConfig.Listen so ctx is
// consulted during setup (mainly address resolution); the listener itself
// is still torn down via Close(), not ctx cancellation.
v.listen = func(ctx context.Context, addr string) (net.Listener, error) {
var lc net.ListenConfig
return lc.Listen(ctx, "tcp", addr)
}
// Default connection factory: a real rtmpConnection wired up with the
// server's load balancer.
v.newConnection = func() *rtmpConnection {
return newRTMPConnection(func(c *rtmpConnection) {
c.loadBalancer = v.loadBalancer
})
}
for _, opt := range opts {
opt(v)
}
@ -62,24 +87,19 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
endpoint = ":" + endpoint
}
addr, err := net.ResolveTCPAddr("tcp", endpoint)
listener, err := v.listen(ctx, endpoint)
if err != nil {
return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return errors.Wrapf(err, "listen rtmp addr %v", addr)
return errors.Wrapf(err, "listen rtmp addr %v", endpoint)
}
v.listener = listener
logger.Debug(ctx, "RTMP server listen at %v", addr)
logger.Debug(ctx, "RTMP server listen at %v", listener.Addr())
v.wg.Add(1)
go func() {
defer v.wg.Done()
for {
conn, err := v.listener.AcceptTCP()
conn, err := v.listener.Accept()
if err != nil {
// If context is canceled or connection is closed, exit gracefully without logging error.
if ctx.Err() != nil || utils.IsClosedNetworkError(err) {
@ -92,7 +112,7 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
}
v.wg.Add(1)
go func(ctx context.Context, conn *net.TCPConn) {
go func(ctx context.Context, conn net.Conn) {
defer v.wg.Done()
defer conn.Close()
@ -104,9 +124,7 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
}
}
rc := newRTMPConnection(func(c *rtmpConnection) {
c.loadBalancer = v.loadBalancer
})
rc := v.newConnection()
if err := rc.serve(ctx, conn); err != nil {
handleErr(err)
} else {
@ -128,17 +146,41 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
type rtmpConnection struct {
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// newHandshake creates a fresh RTMP handshake instance. Defaults to a real handshake;
// tests may override via a functional option to supply a fake.
newHandshake func() rtmp.Handshake
// newProtocol creates a fresh RTMP protocol instance over the given stream. Defaults to
// a real protocol; tests may override via a functional option to supply a fake.
newProtocol func(rw io.ReadWriter) rtmp.Protocol
// newBackend creates a fresh backend client wired up with the given clientType and the
// connection's load balancer. Defaults to a real rtmpClientToBackend; tests may override
// via a functional option to supply a fake.
newBackend func(clientType RTMPClientType) *rtmpClientToBackend
}
func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection {
v := &rtmpConnection{}
// Default handshake factory: a real RTMP handshake.
v.newHandshake = rtmp.NewHandshake
// Default protocol factory: a real RTMP protocol.
v.newProtocol = rtmp.NewProtocol
// Default backend factory: a real rtmpClientToBackend wired up with the connection's
// load balancer and the given clientType.
v.newBackend = func(clientType RTMPClientType) *rtmpClientToBackend {
return newRTMPClientToBackend(func(client *rtmpClientToBackend) {
client.typ = clientType
client.loadBalancer = v.loadBalancer
})
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
func (v *rtmpConnection) serve(ctx context.Context, conn net.Conn) error {
logger.Debug(ctx, "Got RTMP client from %v", conn.RemoteAddr())
// If any goroutine quit, cancel another one.
@ -158,7 +200,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
}
// Simple handshake with client.
hs := rtmp.NewHandshake()
hs := v.newHandshake()
if _, err := hs.ReadC0S0(conn); err != nil {
return errors.Wrapf(err, "read c0")
}
@ -178,7 +220,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
return errors.Wrapf(err, "read c2")
}
client := rtmp.NewProtocol(conn)
client := v.newProtocol(conn)
logger.Debug(ctx, "RTMP simple handshake done")
// Expect RTMP connect command with tcUrl.
@ -235,15 +277,16 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
var response rtmp.Packet
switch pkt := identifyReq.(type) {
case *rtmp.CallPacket:
if pkt.CommandName == "createStream" {
switch pkt.CommandName {
case "createStream":
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
response = identifyRes
nextStreamID = 1
identifyRes.SetStreamID(nextStreamID)
} else if pkt.CommandName == "getStreamLength" {
case "getStreamLength":
// Ignore and do not reply these packets.
} else {
default:
// For releaseStream, FCPublish, etc.
identifyRes := rtmp.NewCallPacket()
response = identifyRes
@ -300,10 +343,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
tcUrl, streamName, currentStreamID, clientType)
// Find a backend SRS server to proxy the RTMP stream.
backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) {
client.typ = clientType
client.loadBalancer = v.loadBalancer
})
backend = v.newBackend(clientType)
defer backend.Close()
if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
@ -311,7 +351,8 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
}
// Start the streaming.
if clientType == RTMPClientTypePublisher {
switch clientType {
case RTMPClientTypePublisher:
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
@ -327,7 +368,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start publish")
}
} else if clientType == RTMPClientTypeViewer {
case RTMPClientTypeViewer:
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
@ -430,18 +471,40 @@ const (
// rtmpClientToBackend is an RTMP client to proxy the RTMP stream to backend.
type rtmpClientToBackend struct {
// The underlayer tcp client.
tcpConn *net.TCPConn
// The underlayer connection to backend. Stored as io.ReadWriteCloser so tests
// can inject a fake connection by overriding dial.
tcpConn io.ReadWriteCloser
// The RTMP protocol client.
client rtmp.Protocol
// The stream type.
typ RTMPClientType
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// dial opens a connection to a backend SRS server. Defaults to a real TCP dial;
// tests may override via a functional option to supply a fake connection.
dial func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error)
// newHandshake creates a fresh RTMP handshake instance. Defaults to a real handshake;
// tests may override via a functional option to supply a fake.
newHandshake func() rtmp.Handshake
// newProtocol creates a fresh RTMP protocol instance over the given stream. Defaults to
// a real protocol; tests may override via a functional option to supply a fake.
newProtocol func(rw io.ReadWriter) rtmp.Protocol
}
func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend {
v := &rtmpClientToBackend{}
// Default dial: a real TCP connection to the backend. Uses Dialer.DialContext
// so ctx cancellation/deadline aborts the connect (net.DialTCP ignores ctx).
v.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
var d net.Dialer
return d.DialContext(ctx, "tcp", net.JoinHostPort(ip, strconv.Itoa(port)))
}
// Default handshake factory: a real RTMP handshake.
v.newHandshake = rtmp.NewHandshake
// Default protocol factory: a real RTMP protocol.
v.newProtocol = rtmp.NewProtocol
for _, opt := range opts {
opt(v)
}
@ -480,16 +543,15 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
rtmpPort = int(iv)
}
// Connect to backend SRS server via TCP client.
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
c, err := net.DialTCP("tcp", nil, addr)
// Connect to backend SRS server.
c, err := v.dial(ctx, backend.IP, rtmpPort)
if err != nil {
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
return errors.Wrapf(err, "dial backend ip=%v, port=%v, srs=%v", backend.IP, rtmpPort, backend)
}
v.tcpConn = c
hs := rtmp.NewHandshake()
client := rtmp.NewProtocol(c)
hs := v.newHandshake()
client := v.newProtocol(c)
v.client = client
// Simple RTMP handshake with server.
@ -509,7 +571,7 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
if _, err = hs.ReadC2S2(c); err != nil {
return errors.Wrapf(err, "read c2")
}
logger.Debug(ctx, "backend simple handshake done, server=%v", addr)
logger.Debug(ctx, "backend simple handshake done, server=%v:%v", backend.IP, rtmpPort)
if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write c2")

1287
internal/proxy/rtmp_test.go Normal file

File diff suppressed because it is too large Load Diff