Claude: Add RTMP proxy seams and unit tests.
Refactor internal/proxy/rtmp.go to expose functional-option seams (listen, newConnection, newHandshake, newProtocol, newBackend, dial) and widen the proxy server and connection to net.Listener / net.Conn so fakes can be injected. Tighten the identify switch in serve() to a real switch on CommandName. Add internal/proxy/rtmp_test.go covering rtmpProxyServer (constructor defaults, options, Close, listen error, endpoint normalization, accept-loop, graceful shutdown), rtmpConnection (defaults, serve handshake/protocol error paths, identify-loop branches, newBackend invocation contract), and rtmpClientToBackend (Close, Connect happy and error paths, publish, play). rtmp.go statement coverage rises to 76.9% with every function exercised.
This commit is contained in:
parent
3060bf8e7c
commit
7b4c4dc999
|
|
@ -6,6 +6,7 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -33,14 +34,38 @@ type rtmpProxyServer struct {
|
|||
environment env.ProxyEnvironment
|
||||
// The load balancer for origin servers.
|
||||
loadBalancer lb.OriginLoadBalancer
|
||||
// The TCP listener for RTMP server.
|
||||
listener *net.TCPListener
|
||||
// The listener for RTMP server. Stored as net.Listener so tests can inject
|
||||
// a fake listener by overriding listen.
|
||||
listener net.Listener
|
||||
// The wait group for all goroutines.
|
||||
wg sync.WaitGroup
|
||||
// listen opens a listener on the given address. Defaults to a real TCP listener;
|
||||
// tests may override via a functional option to supply a fake listener.
|
||||
listen func(ctx context.Context, addr string) (net.Listener, error)
|
||||
// newConnection creates a fresh rtmpConnection wired up with this server's
|
||||
// load balancer. Defaults to a real rtmpConnection; tests may override via
|
||||
// a functional option to supply a fake.
|
||||
newConnection func() *rtmpConnection
|
||||
}
|
||||
|
||||
func NewRTMPProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*rtmpProxyServer)) RTMPProxyServer {
|
||||
v := &rtmpProxyServer{environment: environment, loadBalancer: loadBalancer}
|
||||
|
||||
// Default listen: a real TCP listener. Uses ListenConfig.Listen so ctx is
|
||||
// consulted during setup (mainly address resolution); the listener itself
|
||||
// is still torn down via Close(), not ctx cancellation.
|
||||
v.listen = func(ctx context.Context, addr string) (net.Listener, error) {
|
||||
var lc net.ListenConfig
|
||||
return lc.Listen(ctx, "tcp", addr)
|
||||
}
|
||||
// Default connection factory: a real rtmpConnection wired up with the
|
||||
// server's load balancer.
|
||||
v.newConnection = func() *rtmpConnection {
|
||||
return newRTMPConnection(func(c *rtmpConnection) {
|
||||
c.loadBalancer = v.loadBalancer
|
||||
})
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(v)
|
||||
}
|
||||
|
|
@ -62,24 +87,19 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
|
|||
endpoint = ":" + endpoint
|
||||
}
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", endpoint)
|
||||
listener, err := v.listen(ctx, endpoint)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
|
||||
}
|
||||
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "listen rtmp addr %v", addr)
|
||||
return errors.Wrapf(err, "listen rtmp addr %v", endpoint)
|
||||
}
|
||||
v.listener = listener
|
||||
logger.Debug(ctx, "RTMP server listen at %v", addr)
|
||||
logger.Debug(ctx, "RTMP server listen at %v", listener.Addr())
|
||||
|
||||
v.wg.Add(1)
|
||||
go func() {
|
||||
defer v.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := v.listener.AcceptTCP()
|
||||
conn, err := v.listener.Accept()
|
||||
if err != nil {
|
||||
// If context is canceled or connection is closed, exit gracefully without logging error.
|
||||
if ctx.Err() != nil || utils.IsClosedNetworkError(err) {
|
||||
|
|
@ -92,7 +112,7 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
v.wg.Add(1)
|
||||
go func(ctx context.Context, conn *net.TCPConn) {
|
||||
go func(ctx context.Context, conn net.Conn) {
|
||||
defer v.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
|
|
@ -104,9 +124,7 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
rc := newRTMPConnection(func(c *rtmpConnection) {
|
||||
c.loadBalancer = v.loadBalancer
|
||||
})
|
||||
rc := v.newConnection()
|
||||
if err := rc.serve(ctx, conn); err != nil {
|
||||
handleErr(err)
|
||||
} else {
|
||||
|
|
@ -128,17 +146,41 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
|
|||
type rtmpConnection struct {
|
||||
// The load balancer for origin servers.
|
||||
loadBalancer lb.OriginLoadBalancer
|
||||
// newHandshake creates a fresh RTMP handshake instance. Defaults to a real handshake;
|
||||
// tests may override via a functional option to supply a fake.
|
||||
newHandshake func() rtmp.Handshake
|
||||
// newProtocol creates a fresh RTMP protocol instance over the given stream. Defaults to
|
||||
// a real protocol; tests may override via a functional option to supply a fake.
|
||||
newProtocol func(rw io.ReadWriter) rtmp.Protocol
|
||||
// newBackend creates a fresh backend client wired up with the given clientType and the
|
||||
// connection's load balancer. Defaults to a real rtmpClientToBackend; tests may override
|
||||
// via a functional option to supply a fake.
|
||||
newBackend func(clientType RTMPClientType) *rtmpClientToBackend
|
||||
}
|
||||
|
||||
func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection {
|
||||
v := &rtmpConnection{}
|
||||
|
||||
// Default handshake factory: a real RTMP handshake.
|
||||
v.newHandshake = rtmp.NewHandshake
|
||||
// Default protocol factory: a real RTMP protocol.
|
||||
v.newProtocol = rtmp.NewProtocol
|
||||
// Default backend factory: a real rtmpClientToBackend wired up with the connection's
|
||||
// load balancer and the given clientType.
|
||||
v.newBackend = func(clientType RTMPClientType) *rtmpClientToBackend {
|
||||
return newRTMPClientToBackend(func(client *rtmpClientToBackend) {
|
||||
client.typ = clientType
|
||||
client.loadBalancer = v.loadBalancer
|
||||
})
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(v)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
||||
func (v *rtmpConnection) serve(ctx context.Context, conn net.Conn) error {
|
||||
logger.Debug(ctx, "Got RTMP client from %v", conn.RemoteAddr())
|
||||
|
||||
// If any goroutine quit, cancel another one.
|
||||
|
|
@ -158,7 +200,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|||
}
|
||||
|
||||
// Simple handshake with client.
|
||||
hs := rtmp.NewHandshake()
|
||||
hs := v.newHandshake()
|
||||
if _, err := hs.ReadC0S0(conn); err != nil {
|
||||
return errors.Wrapf(err, "read c0")
|
||||
}
|
||||
|
|
@ -178,7 +220,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|||
return errors.Wrapf(err, "read c2")
|
||||
}
|
||||
|
||||
client := rtmp.NewProtocol(conn)
|
||||
client := v.newProtocol(conn)
|
||||
logger.Debug(ctx, "RTMP simple handshake done")
|
||||
|
||||
// Expect RTMP connect command with tcUrl.
|
||||
|
|
@ -235,15 +277,16 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|||
var response rtmp.Packet
|
||||
switch pkt := identifyReq.(type) {
|
||||
case *rtmp.CallPacket:
|
||||
if pkt.CommandName == "createStream" {
|
||||
switch pkt.CommandName {
|
||||
case "createStream":
|
||||
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
|
||||
response = identifyRes
|
||||
|
||||
nextStreamID = 1
|
||||
identifyRes.SetStreamID(nextStreamID)
|
||||
} else if pkt.CommandName == "getStreamLength" {
|
||||
case "getStreamLength":
|
||||
// Ignore and do not reply these packets.
|
||||
} else {
|
||||
default:
|
||||
// For releaseStream, FCPublish, etc.
|
||||
identifyRes := rtmp.NewCallPacket()
|
||||
response = identifyRes
|
||||
|
|
@ -300,10 +343,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|||
tcUrl, streamName, currentStreamID, clientType)
|
||||
|
||||
// Find a backend SRS server to proxy the RTMP stream.
|
||||
backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) {
|
||||
client.typ = clientType
|
||||
client.loadBalancer = v.loadBalancer
|
||||
})
|
||||
backend = v.newBackend(clientType)
|
||||
defer backend.Close()
|
||||
|
||||
if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
|
||||
|
|
@ -311,7 +351,8 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|||
}
|
||||
|
||||
// Start the streaming.
|
||||
if clientType == RTMPClientTypePublisher {
|
||||
switch clientType {
|
||||
case RTMPClientTypePublisher:
|
||||
identifyRes := rtmp.NewCallPacket()
|
||||
|
||||
identifyRes.CommandName = "onStatus"
|
||||
|
|
@ -327,7 +368,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|||
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
|
||||
return errors.Wrapf(err, "start publish")
|
||||
}
|
||||
} else if clientType == RTMPClientTypeViewer {
|
||||
case RTMPClientTypeViewer:
|
||||
identifyRes := rtmp.NewCallPacket()
|
||||
|
||||
identifyRes.CommandName = "onStatus"
|
||||
|
|
@ -430,18 +471,40 @@ const (
|
|||
|
||||
// rtmpClientToBackend is an RTMP client to proxy the RTMP stream to backend.
|
||||
type rtmpClientToBackend struct {
|
||||
// The underlayer tcp client.
|
||||
tcpConn *net.TCPConn
|
||||
// The underlayer connection to backend. Stored as io.ReadWriteCloser so tests
|
||||
// can inject a fake connection by overriding dial.
|
||||
tcpConn io.ReadWriteCloser
|
||||
// The RTMP protocol client.
|
||||
client rtmp.Protocol
|
||||
// The stream type.
|
||||
typ RTMPClientType
|
||||
// The load balancer for origin servers.
|
||||
loadBalancer lb.OriginLoadBalancer
|
||||
// dial opens a connection to a backend SRS server. Defaults to a real TCP dial;
|
||||
// tests may override via a functional option to supply a fake connection.
|
||||
dial func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error)
|
||||
// newHandshake creates a fresh RTMP handshake instance. Defaults to a real handshake;
|
||||
// tests may override via a functional option to supply a fake.
|
||||
newHandshake func() rtmp.Handshake
|
||||
// newProtocol creates a fresh RTMP protocol instance over the given stream. Defaults to
|
||||
// a real protocol; tests may override via a functional option to supply a fake.
|
||||
newProtocol func(rw io.ReadWriter) rtmp.Protocol
|
||||
}
|
||||
|
||||
func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend {
|
||||
v := &rtmpClientToBackend{}
|
||||
|
||||
// Default dial: a real TCP connection to the backend. Uses Dialer.DialContext
|
||||
// so ctx cancellation/deadline aborts the connect (net.DialTCP ignores ctx).
|
||||
v.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", net.JoinHostPort(ip, strconv.Itoa(port)))
|
||||
}
|
||||
// Default handshake factory: a real RTMP handshake.
|
||||
v.newHandshake = rtmp.NewHandshake
|
||||
// Default protocol factory: a real RTMP protocol.
|
||||
v.newProtocol = rtmp.NewProtocol
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(v)
|
||||
}
|
||||
|
|
@ -480,16 +543,15 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
|
|||
rtmpPort = int(iv)
|
||||
}
|
||||
|
||||
// Connect to backend SRS server via TCP client.
|
||||
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
|
||||
c, err := net.DialTCP("tcp", nil, addr)
|
||||
// Connect to backend SRS server.
|
||||
c, err := v.dial(ctx, backend.IP, rtmpPort)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
|
||||
return errors.Wrapf(err, "dial backend ip=%v, port=%v, srs=%v", backend.IP, rtmpPort, backend)
|
||||
}
|
||||
v.tcpConn = c
|
||||
|
||||
hs := rtmp.NewHandshake()
|
||||
client := rtmp.NewProtocol(c)
|
||||
hs := v.newHandshake()
|
||||
client := v.newProtocol(c)
|
||||
v.client = client
|
||||
|
||||
// Simple RTMP handshake with server.
|
||||
|
|
@ -509,7 +571,7 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
|
|||
if _, err = hs.ReadC2S2(c); err != nil {
|
||||
return errors.Wrapf(err, "read c2")
|
||||
}
|
||||
logger.Debug(ctx, "backend simple handshake done, server=%v", addr)
|
||||
logger.Debug(ctx, "backend simple handshake done, server=%v:%v", backend.IP, rtmpPort)
|
||||
|
||||
if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
|
||||
return errors.Wrapf(err, "write c2")
|
||||
|
|
|
|||
1287
internal/proxy/rtmp_test.go
Normal file
1287
internal/proxy/rtmp_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user