Refactor internal/proxy/rtmp.go to expose functional-option seams (listen, newConnection, newHandshake, newProtocol, newBackend, dial) and widen the proxy server and connection to net.Listener / net.Conn so fakes can be injected. Tighten the identify switch in serve() to a real switch on CommandName. Add internal/proxy/rtmp_test.go covering rtmpProxyServer (constructor defaults, options, Close, listen error, endpoint normalization, accept-loop, graceful shutdown), rtmpConnection (defaults, serve handshake/protocol error paths, identify-loop branches, newBackend invocation contract), and rtmpClientToBackend (Close, Connect happy and error paths, publish, play). rtmp.go statement coverage rises to 76.9% with every function exercised.
1288 lines
41 KiB
Go
1288 lines
41 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"srsx/internal/env/envfakes"
|
|
"srsx/internal/lb"
|
|
"srsx/internal/lb/lbfakes"
|
|
"srsx/internal/rtmp"
|
|
"srsx/internal/rtmp/rtmpfakes"
|
|
)
|
|
|
|
// fakeConn is an in-memory io.ReadWriteCloser used to replace the TCP
|
|
// connection returned by dial. Read/Write are no-ops because every protocol
|
|
// call on the connection is intercepted by FakeHandshake/FakeProtocol.
|
|
type fakeConn struct {
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func (c *fakeConn) Read(p []byte) (int, error) { return 0, io.EOF }
|
|
func (c *fakeConn) Write(p []byte) (int, error) { return len(p), nil }
|
|
func (c *fakeConn) Close() error {
|
|
c.closed.Store(true)
|
|
return nil
|
|
}
|
|
|
|
// backendFixture bundles the fakes plus an rtmpClientToBackend wired against
|
|
// them. Tests configure the fakes, then exercise the methods.
|
|
type backendFixture struct {
|
|
conn *fakeConn
|
|
lb *lbfakes.FakeOriginLoadBalancer
|
|
handshake *rtmpfakes.FakeHandshake
|
|
protocol *rtmpfakes.FakeProtocol
|
|
client *rtmpClientToBackend
|
|
}
|
|
|
|
func newBackendFixture(typ RTMPClientType) *backendFixture {
|
|
f := &backendFixture{
|
|
conn: &fakeConn{},
|
|
lb: &lbfakes.FakeOriginLoadBalancer{},
|
|
handshake: &rtmpfakes.FakeHandshake{},
|
|
protocol: &rtmpfakes.FakeProtocol{},
|
|
}
|
|
f.client = newRTMPClientToBackend(func(c *rtmpClientToBackend) {
|
|
c.typ = typ
|
|
c.loadBalancer = f.lb
|
|
c.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
|
return f.conn, nil
|
|
}
|
|
c.newHandshake = func() rtmp.Handshake { return f.handshake }
|
|
c.newProtocol = func(rw io.ReadWriter) rtmp.Protocol { return f.protocol }
|
|
})
|
|
return f
|
|
}
|
|
|
|
// queueDecode programs FakeProtocol.DecodeMessage to return the given packets
|
|
// in order, one per call. After the queue is drained, it returns an EOF-ish
|
|
// error to fail the test fast instead of looping forever.
|
|
func queueDecode(p *rtmpfakes.FakeProtocol, packets ...rtmp.Packet) {
|
|
var i atomic.Int32
|
|
p.DecodeMessageStub = func(m rtmp.Message) (rtmp.Packet, error) {
|
|
idx := int(i.Add(1)) - 1
|
|
if idx >= len(packets) {
|
|
return nil, errors.New("decode queue drained")
|
|
}
|
|
return packets[idx], nil
|
|
}
|
|
}
|
|
|
|
// readMessageOK programs ReadMessage to always return a fresh empty Message.
|
|
// The payload is irrelevant because DecodeMessage is stubbed.
|
|
func readMessageOK(p *rtmpfakes.FakeProtocol) {
|
|
p.ReadMessageStub = func(ctx context.Context) (rtmp.Message, error) {
|
|
return rtmp.NewMessage(), nil
|
|
}
|
|
}
|
|
|
|
// onStatusPacket builds a *rtmp.CallPacket whose Args is an Amf0Object
|
|
// carrying the given code. Used to drive both publish() (which inspects
|
|
// Args via Amf0Converter) and play() (which uses ArgsCode()).
|
|
func onStatusPacket(code string) *rtmp.CallPacket {
|
|
pkt := rtmp.NewCallPacket()
|
|
pkt.CommandName = "onStatus"
|
|
pkt.CommandObject = rtmp.NewAmf0Null()
|
|
data := rtmp.NewAmf0Object()
|
|
data.Set("code", rtmp.NewAmf0String(code))
|
|
pkt.Args = data
|
|
return pkt
|
|
}
|
|
|
|
func resultCallPacket() *rtmp.CallPacket {
|
|
pkt := rtmp.NewCallPacket()
|
|
pkt.CommandName = "_result"
|
|
return pkt
|
|
}
|
|
|
|
func createStreamRes(id int) *rtmp.CreateStreamResPacket {
|
|
pkt := rtmp.NewCreateStreamResPacket(0)
|
|
pkt.SetStreamID(id)
|
|
return pkt
|
|
}
|
|
|
|
// pickOK programs the load balancer to return a backend with one RTMP
|
|
// endpoint, mimicking a typical registered SRS origin.
|
|
func pickOK(f *backendFixture) {
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", RTMP: []string{"1935"}}, nil)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Close()
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpClientToBackend_Close_NilConn(t *testing.T) {
|
|
c := newRTMPClientToBackend()
|
|
if err := c.Close(); err != nil {
|
|
t.Fatalf("Close with nil tcpConn: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Close_FakeConn(t *testing.T) {
|
|
conn := &fakeConn{}
|
|
c := newRTMPClientToBackend(func(c *rtmpClientToBackend) {
|
|
c.tcpConn = conn
|
|
})
|
|
if err := c.Close(); err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
if !conn.closed.Load() {
|
|
t.Fatal("fakeConn was not closed")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Connect() error paths
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpClientToBackend_Connect_BuildStreamURLError(t *testing.T) {
|
|
// url.Parse rejects URLs that start with a colon (no scheme/host parseable),
|
|
// so this drives BuildStreamURL's error branch before LB.Pick is reached.
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
err := f.client.Connect(context.Background(), ":bad-url", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error from BuildStreamURL")
|
|
}
|
|
if !strings.Contains(err.Error(), "build stream url") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
if f.lb.PickCallCount() != 0 {
|
|
t.Fatalf("LB.Pick should not be called when URL is bad; got %d calls", f.lb.PickCallCount())
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_PickError(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
f.lb.PickReturns(nil, errors.New("no backend"))
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "pick backend") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_NoRTMPEndpoint(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1"}, nil) // empty RTMP slice
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "no rtmp server") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_BadRTMPPort(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
f.lb.PickReturns(&lb.OriginServer{IP: "10.0.0.1", RTMP: []string{"not-a-port"}}, nil)
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "parse backend") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_DialError(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
f.client.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
|
return nil, errors.New("dial refused")
|
|
}
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "dial backend") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_DialHonorsCtxCancel(t *testing.T) {
|
|
// The default dial uses net.Dialer.DialContext, so a canceled ctx must
|
|
// surface as a dial error rather than hanging on the kernel connect.
|
|
// We assert this contract by having the test dial honor ctx itself.
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
f.client.dial = func(ctx context.Context, ip string, port int) (io.ReadWriteCloser, error) {
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel() // already-canceled ctx
|
|
err := f.client.Connect(ctx, "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "dial backend") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("expected ctx.Canceled in chain, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_HandshakeWriteC0Error(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
f.handshake.WriteC0S0Returns(errors.New("write c0"))
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "write c0") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_HandshakeReadS0Error(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
f.handshake.ReadC0S0Returns(nil, errors.New("read s0"))
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "read s0") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_WriteConnectAppError(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
f.protocol.WritePacketReturns(errors.New("write packet"))
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "write connect app") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_ExpectConnectAppResError(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
// WritePacket succeeds, but ReadMessage inside ExpectPacket fails.
|
|
f.protocol.ReadMessageReturns(nil, errors.New("read message"))
|
|
|
|
err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "expect connect app res") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Connect() happy paths
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpClientToBackend_Connect_PublisherHappyPath(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypePublisher)
|
|
pickOK(f)
|
|
readMessageOK(f.protocol)
|
|
queueDecode(f.protocol,
|
|
rtmp.NewConnectAppResPacket(0), // connect app res
|
|
resultCallPacket(), // releaseStream res
|
|
resultCallPacket(), // FCPublish res
|
|
createStreamRes(1), // createStream res
|
|
onStatusPacket("NetStream.Publish.Start"), // publish onStatus
|
|
)
|
|
|
|
if err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream"); err != nil {
|
|
t.Fatalf("Connect: %+v", err)
|
|
}
|
|
|
|
if f.client.tcpConn != f.conn {
|
|
t.Fatal("tcpConn should be the fake conn from dial")
|
|
}
|
|
if f.client.client != f.protocol {
|
|
t.Fatal("client field should be the fake protocol")
|
|
}
|
|
// One WritePacket each for: connectApp, releaseStream, FCPublish, createStream, publish.
|
|
if got := f.protocol.WritePacketCallCount(); got != 5 {
|
|
t.Fatalf("WritePacket called %d times, want 5", got)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Connect_ViewerHappyPath(t *testing.T) {
|
|
f := newBackendFixture(RTMPClientTypeViewer)
|
|
pickOK(f)
|
|
readMessageOK(f.protocol)
|
|
queueDecode(f.protocol,
|
|
rtmp.NewConnectAppResPacket(0), // connect app res
|
|
createStreamRes(1), // createStream res
|
|
onStatusPacket("NetStream.Play.Start"), // play onStatus
|
|
)
|
|
|
|
if err := f.client.Connect(context.Background(), "rtmp://1.2.3.4/live", "stream"); err != nil {
|
|
t.Fatalf("Connect: %+v", err)
|
|
}
|
|
|
|
// One WritePacket each for: connectApp, createStream, play.
|
|
if got := f.protocol.WritePacketCallCount(); got != 3 {
|
|
t.Fatalf("WritePacket called %d times, want 3", got)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// publish() in isolation
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func newIsolatedBackend(t *testing.T, typ RTMPClientType) (*rtmpClientToBackend, *rtmpfakes.FakeProtocol) {
|
|
t.Helper()
|
|
p := &rtmpfakes.FakeProtocol{}
|
|
readMessageOK(p)
|
|
c := newRTMPClientToBackend(func(c *rtmpClientToBackend) { c.typ = typ })
|
|
return c, p
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_HappyPath(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
queueDecode(p,
|
|
resultCallPacket(), // releaseStream _result
|
|
resultCallPacket(), // FCPublish _result
|
|
createStreamRes(7), // createStream res
|
|
onStatusPacket("NetStream.Publish.Start"), // final publish onStatus
|
|
)
|
|
|
|
if err := c.publish(context.Background(), p, "stream"); err != nil {
|
|
t.Fatalf("publish: %+v", err)
|
|
}
|
|
if got := p.WritePacketCallCount(); got != 4 {
|
|
t.Fatalf("WritePacket called %d times, want 4", got)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_ReleaseStreamWriteError(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
p.WritePacketReturns(errors.New("boom"))
|
|
|
|
err := c.publish(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "releaseStream") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_FCPublishExpectError(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
// First ExpectPacket (releaseStream res) succeeds; the second (FCPublish res)
|
|
// must fail. We fail ReadMessage on its second call.
|
|
var reads atomic.Int32
|
|
p.ReadMessageStub = func(ctx context.Context) (rtmp.Message, error) {
|
|
if reads.Add(1) >= 2 {
|
|
return nil, errors.New("read fail")
|
|
}
|
|
return rtmp.NewMessage(), nil
|
|
}
|
|
queueDecode(p, resultCallPacket()) // only the first decode is consumed
|
|
|
|
err := c.publish(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "FCPublish") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_CreateStreamSkipsZeroID(t *testing.T) {
|
|
// The createStream loop continues until StreamID != 0; verify it ignores
|
|
// the first packet (StreamID 0) and accepts the second (StreamID 9).
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
queueDecode(p,
|
|
resultCallPacket(), // releaseStream res
|
|
resultCallPacket(), // FCPublish res
|
|
createStreamRes(0), // ignored
|
|
createStreamRes(9), // accepted
|
|
onStatusPacket("NetStream.Publish.Start"), // final publish onStatus
|
|
)
|
|
|
|
if err := c.publish(context.Background(), p, "stream"); err != nil {
|
|
t.Fatalf("publish: %+v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_SkipsNonOnStatus(t *testing.T) {
|
|
// publish() loops past onFCPublish (a CallPacket whose CommandName != onStatus)
|
|
// until it sees onStatus(NetStream.Publish.Start).
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
onFC := rtmp.NewCallPacket()
|
|
onFC.CommandName = "onFCPublish"
|
|
queueDecode(p,
|
|
resultCallPacket(),
|
|
resultCallPacket(),
|
|
createStreamRes(1),
|
|
onFC, // skipped: not onStatus
|
|
onStatusPacket("NetStream.Publish.Start"),
|
|
)
|
|
|
|
if err := c.publish(context.Background(), p, "stream"); err != nil {
|
|
t.Fatalf("publish: %+v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_OnStatusArgsNotObject(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
bad := rtmp.NewCallPacket()
|
|
bad.CommandName = "onStatus"
|
|
bad.Args = rtmp.NewAmf0String("not-an-object")
|
|
queueDecode(p,
|
|
resultCallPacket(),
|
|
resultCallPacket(),
|
|
createStreamRes(1),
|
|
bad,
|
|
)
|
|
|
|
err := c.publish(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "args not object") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_OnStatusMissingCode(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
bad := rtmp.NewCallPacket()
|
|
bad.CommandName = "onStatus"
|
|
bad.Args = rtmp.NewAmf0Object() // empty: no "code"
|
|
queueDecode(p,
|
|
resultCallPacket(),
|
|
resultCallPacket(),
|
|
createStreamRes(1),
|
|
bad,
|
|
)
|
|
|
|
err := c.publish(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "code not string") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Publish_OnStatusWrongCode(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypePublisher)
|
|
queueDecode(p,
|
|
resultCallPacket(),
|
|
resultCallPacket(),
|
|
createStreamRes(1),
|
|
onStatusPacket("NetStream.Publish.Failed"),
|
|
)
|
|
|
|
err := c.publish(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "NetStream.Publish.Start") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// play() in isolation
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpClientToBackend_Play_HappyPath(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypeViewer)
|
|
queueDecode(p,
|
|
createStreamRes(3),
|
|
onStatusPacket("NetStream.Play.Start"),
|
|
)
|
|
|
|
if err := c.play(context.Background(), p, "stream"); err != nil {
|
|
t.Fatalf("play: %+v", err)
|
|
}
|
|
// One WritePacket each for: createStream and play.
|
|
if got := p.WritePacketCallCount(); got != 2 {
|
|
t.Fatalf("WritePacket called %d times, want 2", got)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Play_CreateStreamWriteError(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypeViewer)
|
|
p.WritePacketReturns(errors.New("boom"))
|
|
|
|
err := c.play(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "createStream") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Play_CreateStreamExpectError(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypeViewer)
|
|
p.ReadMessageReturns(nil, errors.New("read fail"))
|
|
|
|
err := c.play(context.Background(), p, "stream")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "createStream res") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Play_CreateStreamSkipsZeroID(t *testing.T) {
|
|
c, p := newIsolatedBackend(t, RTMPClientTypeViewer)
|
|
queueDecode(p,
|
|
createStreamRes(0), // skipped
|
|
createStreamRes(5),
|
|
onStatusPacket("NetStream.Play.Start"),
|
|
)
|
|
|
|
if err := c.play(context.Background(), p, "stream"); err != nil {
|
|
t.Fatalf("play: %+v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpClientToBackend_Play_FiltersUntilPlayStart(t *testing.T) {
|
|
// play() ignores onStatus packets whose code is not NetStream.Play.Start
|
|
// (e.g. the proxy sees a NetStream.Play.Reset first).
|
|
c, p := newIsolatedBackend(t, RTMPClientTypeViewer)
|
|
queueDecode(p,
|
|
createStreamRes(1),
|
|
onStatusPacket("NetStream.Play.Reset"), // skipped
|
|
onStatusPacket("NetStream.Play.Start"),
|
|
)
|
|
|
|
if err := c.play(context.Background(), p, "stream"); err != nil {
|
|
t.Fatalf("play: %+v", err)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpConnection: fakes, fixture, and packet builders
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// fakeNetConn is a net.Conn replacement for serve(), which takes net.Conn.
|
|
// Read/Write are no-ops because every protocol call is intercepted by the
|
|
// fake handshake/protocol; RemoteAddr/Close are called directly by serve.
|
|
type fakeNetConn struct {
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func (c *fakeNetConn) Read(p []byte) (int, error) { return 0, io.EOF }
|
|
func (c *fakeNetConn) Write(p []byte) (int, error) { return len(p), nil }
|
|
func (c *fakeNetConn) Close() error { c.closed.Store(true); return nil }
|
|
func (c *fakeNetConn) LocalAddr() net.Addr { return fakeAddr{} }
|
|
func (c *fakeNetConn) RemoteAddr() net.Addr { return fakeAddr{} }
|
|
func (c *fakeNetConn) SetDeadline(time.Time) error { return nil }
|
|
func (c *fakeNetConn) SetReadDeadline(time.Time) error { return nil }
|
|
func (c *fakeNetConn) SetWriteDeadline(time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type fakeAddr struct{}
|
|
|
|
func (fakeAddr) Network() string { return "fake" }
|
|
func (fakeAddr) String() string { return "fake-addr" }
|
|
|
|
// connFixture bundles the fakes plus an rtmpConnection wired against them.
|
|
// Tests configure the fakes, then call rc.serve(ctx, conn).
|
|
//
|
|
// The injected newBackend always returns a "terminating" backend whose
|
|
// inner load balancer fails Pick. This drives serve() far enough to call
|
|
// newBackend (so we can assert clientType), but Connect then fails fast
|
|
// so the test does not need to drive the proxy goroutines.
|
|
type connFixture struct {
|
|
netConn *fakeNetConn
|
|
clientHs *rtmpfakes.FakeHandshake
|
|
clientProto *rtmpfakes.FakeProtocol
|
|
lb *lbfakes.FakeOriginLoadBalancer
|
|
|
|
backendCalls atomic.Int32
|
|
backendClientType atomic.Value // RTMPClientType
|
|
|
|
rc *rtmpConnection
|
|
}
|
|
|
|
func newConnFixture() *connFixture {
|
|
f := &connFixture{
|
|
netConn: &fakeNetConn{},
|
|
clientHs: &rtmpfakes.FakeHandshake{},
|
|
clientProto: &rtmpfakes.FakeProtocol{},
|
|
lb: &lbfakes.FakeOriginLoadBalancer{},
|
|
}
|
|
// Default: protocol.ReadMessage returns a fresh empty Message so
|
|
// ExpectPacket can proceed to DecodeMessage. DecodeMessage must be
|
|
// queued per test via queueDecode.
|
|
readMessageOK(f.clientProto)
|
|
|
|
f.rc = newRTMPConnection(func(c *rtmpConnection) {
|
|
c.loadBalancer = f.lb
|
|
c.newHandshake = func() rtmp.Handshake { return f.clientHs }
|
|
c.newProtocol = func(rw io.ReadWriter) rtmp.Protocol { return f.clientProto }
|
|
c.newBackend = func(clientType RTMPClientType) *rtmpClientToBackend {
|
|
f.backendCalls.Add(1)
|
|
f.backendClientType.Store(clientType)
|
|
// Terminating backend: inner LB.Pick fails, so backend.Connect
|
|
// returns an error wrapped by serve() as "connect backend".
|
|
terminateLb := &lbfakes.FakeOriginLoadBalancer{}
|
|
terminateLb.PickReturns(nil, errors.New("test terminate"))
|
|
return newRTMPClientToBackend(func(b *rtmpClientToBackend) {
|
|
b.typ = clientType
|
|
b.loadBalancer = terminateLb
|
|
})
|
|
}
|
|
})
|
|
return f
|
|
}
|
|
|
|
// connectReqPacket builds a ConnectAppPacket with the given tcUrl.
|
|
func connectReqPacket(tcUrl string) *rtmp.ConnectAppPacket {
|
|
p := rtmp.NewConnectAppPacket()
|
|
p.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
|
|
return p
|
|
}
|
|
|
|
// publishReqPacket builds a PublishPacket with the given stream name.
|
|
func publishReqPacket(streamName string) *rtmp.PublishPacket {
|
|
p := rtmp.NewPublishPacket()
|
|
p.StreamName = rtmp.NewAmf0String(streamName)
|
|
return p
|
|
}
|
|
|
|
// playReqPacket builds a PlayPacket with the given stream name.
|
|
func playReqPacket(streamName string) *rtmp.PlayPacket {
|
|
p := rtmp.NewPlayPacket()
|
|
p.StreamName = rtmp.NewAmf0String(streamName)
|
|
return p
|
|
}
|
|
|
|
// rtmp.CallPacket's CommandName field uses an unexported amf0String, which
|
|
// only accepts untyped string literals. The three identify-loop branches
|
|
// each get their own helper.
|
|
func createStreamCallPacket() *rtmp.CallPacket {
|
|
p := rtmp.NewCallPacket()
|
|
p.CommandName = "createStream"
|
|
return p
|
|
}
|
|
|
|
func releaseStreamCallPacket() *rtmp.CallPacket {
|
|
p := rtmp.NewCallPacket()
|
|
p.CommandName = "releaseStream"
|
|
return p
|
|
}
|
|
|
|
func getStreamLengthCallPacket() *rtmp.CallPacket {
|
|
p := rtmp.NewCallPacket()
|
|
p.CommandName = "getStreamLength"
|
|
return p
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpConnection: constructor & defaults
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpConnection_NewSetsDefaults(t *testing.T) {
|
|
c := newRTMPConnection()
|
|
if c.newHandshake == nil {
|
|
t.Fatal("newHandshake should default to a non-nil factory")
|
|
}
|
|
if c.newProtocol == nil {
|
|
t.Fatal("newProtocol should default to a non-nil factory")
|
|
}
|
|
if c.newBackend == nil {
|
|
t.Fatal("newBackend should default to a non-nil factory")
|
|
}
|
|
// Defaults are real factories — call them to confirm they return
|
|
// non-nil concrete values.
|
|
if hs := c.newHandshake(); hs == nil {
|
|
t.Fatal("default newHandshake returned nil")
|
|
}
|
|
if p := c.newProtocol(&fakeConn{}); p == nil {
|
|
t.Fatal("default newProtocol returned nil")
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_DefaultNewBackendWiresFields(t *testing.T) {
|
|
lbInst := &lbfakes.FakeOriginLoadBalancer{}
|
|
c := newRTMPConnection(func(c *rtmpConnection) {
|
|
c.loadBalancer = lbInst
|
|
})
|
|
|
|
pub := c.newBackend(RTMPClientTypePublisher)
|
|
if pub.typ != RTMPClientTypePublisher {
|
|
t.Fatalf("publisher backend typ=%v, want %v", pub.typ, RTMPClientTypePublisher)
|
|
}
|
|
if pub.loadBalancer != lbInst {
|
|
t.Fatal("publisher backend should reuse the connection's load balancer")
|
|
}
|
|
|
|
view := c.newBackend(RTMPClientTypeViewer)
|
|
if view.typ != RTMPClientTypeViewer {
|
|
t.Fatalf("viewer backend typ=%v, want %v", view.typ, RTMPClientTypeViewer)
|
|
}
|
|
if view.loadBalancer != lbInst {
|
|
t.Fatal("viewer backend should reuse the connection's load balancer")
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_OptionOverridesNewBackend(t *testing.T) {
|
|
var called atomic.Int32
|
|
override := func(clientType RTMPClientType) *rtmpClientToBackend {
|
|
called.Add(1)
|
|
return newRTMPClientToBackend()
|
|
}
|
|
c := newRTMPConnection(func(c *rtmpConnection) { c.newBackend = override })
|
|
|
|
_ = c.newBackend(RTMPClientTypePublisher)
|
|
if got := called.Load(); got != 1 {
|
|
t.Fatalf("override newBackend called %d times, want 1", got)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpConnection.serve: handshake error paths
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpConnection_Serve_HandshakeReadC0Error(t *testing.T) {
|
|
f := newConnFixture()
|
|
f.clientHs.ReadC0S0Returns(nil, errors.New("boom"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "read c0") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
if f.backendCalls.Load() != 0 {
|
|
t.Fatal("newBackend should not be called on handshake failure")
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_HandshakeWriteC0Error(t *testing.T) {
|
|
f := newConnFixture()
|
|
f.clientHs.WriteC0S0Returns(errors.New("boom"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
// The write-c0 branch is wrapped as "write s1" in serve() (typo in
|
|
// production, but the test pins the current behavior).
|
|
if !strings.Contains(err.Error(), "write s1") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_HandshakeReadC2Error(t *testing.T) {
|
|
f := newConnFixture()
|
|
f.clientHs.ReadC2S2Returns(nil, errors.New("boom"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "read c2") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpConnection.serve: protocol error paths
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpConnection_Serve_ExpectConnectReqError(t *testing.T) {
|
|
f := newConnFixture()
|
|
// Fail ReadMessage so ExpectPacket returns immediately.
|
|
f.clientProto.ReadMessageStub = nil
|
|
f.clientProto.ReadMessageReturns(nil, errors.New("read fail"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "expect connect req") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
if f.backendCalls.Load() != 0 {
|
|
t.Fatal("newBackend should not be called when connect req fails")
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_WriteAckSizeError(t *testing.T) {
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto, connectReqPacket("rtmp://1.2.3.4/live"))
|
|
// First WritePacket is the WindowAcknowledgementSize.
|
|
f.clientProto.WritePacketReturnsOnCall(0, errors.New("boom"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "write set ack size") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_WriteConnectResError(t *testing.T) {
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto, connectReqPacket("rtmp://1.2.3.4/live"))
|
|
// Third WritePacket is the ConnectAppResPacket; ack and chunk-size precede it.
|
|
f.clientProto.WritePacketReturnsOnCall(2, errors.New("boom"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "write connect res") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_ExpectIdentifyReqError(t *testing.T) {
|
|
f := newConnFixture()
|
|
// Connect req decodes fine, then the next ReadMessage fails.
|
|
queueDecode(f.clientProto, connectReqPacket("rtmp://1.2.3.4/live"))
|
|
var reads atomic.Int32
|
|
f.clientProto.ReadMessageStub = func(ctx context.Context) (rtmp.Message, error) {
|
|
if reads.Add(1) >= 2 {
|
|
return nil, errors.New("read fail")
|
|
}
|
|
return rtmp.NewMessage(), nil
|
|
}
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "expect identify req") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpConnection.serve: identify-loop branches
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpConnection_Serve_IdentifyCreateStreamThenPublisher(t *testing.T) {
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto,
|
|
connectReqPacket("rtmp://1.2.3.4/live"),
|
|
createStreamCallPacket(),
|
|
publishReqPacket("stream"),
|
|
)
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
// Reaches backend.Connect, which fails via the terminating LB.
|
|
if err == nil || !strings.Contains(err.Error(), "connect backend") {
|
|
t.Fatalf("expected connect backend error, got %v", err)
|
|
}
|
|
// WritePacket calls: ack, chunk, connectRes, createStreamRes, onFCPublish.
|
|
if got := f.clientProto.WritePacketCallCount(); got != 5 {
|
|
t.Fatalf("WritePacket called %d times, want 5", got)
|
|
}
|
|
if v := f.backendClientType.Load(); v != RTMPClientTypePublisher {
|
|
t.Fatalf("backend clientType=%v, want publisher", v)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_IdentifyDefaultCallThenViewer(t *testing.T) {
|
|
// A generic CallPacket (e.g. releaseStream) must be acknowledged with
|
|
// a _result reply before the identify loop sees the Play packet.
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto,
|
|
connectReqPacket("rtmp://1.2.3.4/live"),
|
|
releaseStreamCallPacket(),
|
|
playReqPacket("stream"),
|
|
)
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil || !strings.Contains(err.Error(), "connect backend") {
|
|
t.Fatalf("expected connect backend error, got %v", err)
|
|
}
|
|
// WritePacket calls: ack, chunk, connectRes, _result, onStatus(play).
|
|
if got := f.clientProto.WritePacketCallCount(); got != 5 {
|
|
t.Fatalf("WritePacket called %d times, want 5", got)
|
|
}
|
|
if v := f.backendClientType.Load(); v != RTMPClientTypeViewer {
|
|
t.Fatalf("backend clientType=%v, want viewer", v)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_IdentifyGetStreamLengthSkipsResponse(t *testing.T) {
|
|
// getStreamLength is ignored — no response written, no error, the loop
|
|
// just reads the next packet (the Publish).
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto,
|
|
connectReqPacket("rtmp://1.2.3.4/live"),
|
|
getStreamLengthCallPacket(),
|
|
publishReqPacket("stream"),
|
|
)
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil || !strings.Contains(err.Error(), "connect backend") {
|
|
t.Fatalf("expected connect backend error, got %v", err)
|
|
}
|
|
// WritePacket calls: ack, chunk, connectRes, onFCPublish.
|
|
// getStreamLength contributes nothing because the switch falls through.
|
|
if got := f.clientProto.WritePacketCallCount(); got != 4 {
|
|
t.Fatalf("WritePacket called %d times, want 4", got)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_IdentifyResponseWriteError(t *testing.T) {
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto,
|
|
connectReqPacket("rtmp://1.2.3.4/live"),
|
|
createStreamCallPacket(),
|
|
)
|
|
// First three WritePacket calls (ack, chunk, connectRes) succeed;
|
|
// the fourth (createStream response) fails.
|
|
f.clientProto.WritePacketReturnsOnCall(3, errors.New("boom"))
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "write identify res") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpConnection.serve: newBackend invocation contract
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRtmpConnection_Serve_PublisherInvokesNewBackend(t *testing.T) {
|
|
// A direct Publish (no createStream beforehand) still drives the
|
|
// identify loop to set clientType=Publisher and call newBackend.
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto,
|
|
connectReqPacket("rtmp://1.2.3.4/live"),
|
|
publishReqPacket("stream"),
|
|
)
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil || !strings.Contains(err.Error(), "connect backend") {
|
|
t.Fatalf("expected connect backend error, got %v", err)
|
|
}
|
|
if got := f.backendCalls.Load(); got != 1 {
|
|
t.Fatalf("newBackend called %d times, want 1", got)
|
|
}
|
|
if v := f.backendClientType.Load(); v != RTMPClientTypePublisher {
|
|
t.Fatalf("backend clientType=%v, want publisher", v)
|
|
}
|
|
}
|
|
|
|
func TestRtmpConnection_Serve_ViewerInvokesNewBackend(t *testing.T) {
|
|
f := newConnFixture()
|
|
queueDecode(f.clientProto,
|
|
connectReqPacket("rtmp://1.2.3.4/live"),
|
|
playReqPacket("stream"),
|
|
)
|
|
|
|
err := f.rc.serve(context.Background(), f.netConn)
|
|
if err == nil || !strings.Contains(err.Error(), "connect backend") {
|
|
t.Fatalf("expected connect backend error, got %v", err)
|
|
}
|
|
if got := f.backendCalls.Load(); got != 1 {
|
|
t.Fatalf("newBackend called %d times, want 1", got)
|
|
}
|
|
if v := f.backendClientType.Load(); v != RTMPClientTypeViewer {
|
|
t.Fatalf("backend clientType=%v, want viewer", v)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpProxyServer: fakes and fixture
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// fakeListener is a net.Listener whose Accept returns connections pushed via
|
|
// push() and unblocks Accept with a "use of closed network connection" error
|
|
// on Close. The error message satisfies utils.IsClosedNetworkError so the
|
|
// accept loop exits via the graceful branch.
|
|
type fakeListener struct {
|
|
conns chan net.Conn
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func newFakeListener() *fakeListener {
|
|
return &fakeListener{conns: make(chan net.Conn, 4)}
|
|
}
|
|
|
|
func (l *fakeListener) push(c net.Conn) { l.conns <- c }
|
|
|
|
func (l *fakeListener) Accept() (net.Conn, error) {
|
|
c, ok := <-l.conns
|
|
if !ok {
|
|
return nil, errors.New("use of closed network connection")
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func (l *fakeListener) Close() error {
|
|
if l.closed.CompareAndSwap(false, true) {
|
|
close(l.conns)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *fakeListener) Addr() net.Addr { return fakeAddr{} }
|
|
|
|
// proxyFixture bundles the fakes plus an rtmpProxyServer wired against them.
|
|
// The injected newConnection returns a connection whose handshake fails on
|
|
// ReadC0S0, so serve() returns fast without needing to drive the full RTMP
|
|
// protocol. Tests can assert how many connections were dispatched via
|
|
// newConnCalls and how the listen() option was invoked via listenCalls/
|
|
// listenAddr.
|
|
type proxyFixture struct {
|
|
env *envfakes.FakeProxyEnvironment
|
|
lb *lbfakes.FakeOriginLoadBalancer
|
|
listener *fakeListener
|
|
listenCalls atomic.Int32
|
|
listenAddr atomic.Value // string
|
|
newConnCalls atomic.Int32
|
|
serveDone chan struct{}
|
|
server *rtmpProxyServer
|
|
}
|
|
|
|
func newProxyFixture() *proxyFixture {
|
|
f := &proxyFixture{
|
|
env: &envfakes.FakeProxyEnvironment{},
|
|
lb: &lbfakes.FakeOriginLoadBalancer{},
|
|
listener: newFakeListener(),
|
|
serveDone: make(chan struct{}, 16),
|
|
}
|
|
f.env.RtmpServerReturns("1935")
|
|
|
|
srv := NewRTMPProxyServer(f.env, f.lb, func(v *rtmpProxyServer) {
|
|
v.listen = func(ctx context.Context, addr string) (net.Listener, error) {
|
|
f.listenCalls.Add(1)
|
|
f.listenAddr.Store(addr)
|
|
return f.listener, nil
|
|
}
|
|
v.newConnection = func() *rtmpConnection {
|
|
f.newConnCalls.Add(1)
|
|
hs := &rtmpfakes.FakeHandshake{}
|
|
hs.ReadC0S0Returns(nil, errors.New("test terminate"))
|
|
return newRTMPConnection(func(c *rtmpConnection) {
|
|
// Signal when serve() actually enters the handshake step so
|
|
// tests can sync on "per-conn goroutine has started".
|
|
c.newHandshake = func() rtmp.Handshake {
|
|
f.serveDone <- struct{}{}
|
|
return hs
|
|
}
|
|
})
|
|
}
|
|
})
|
|
f.server = srv.(*rtmpProxyServer)
|
|
return f
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpProxyServer: constructor & defaults
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRTMPProxyServer_NewSetsDefaults(t *testing.T) {
|
|
srv := NewRTMPProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
v := srv.(*rtmpProxyServer)
|
|
if v.listen == nil {
|
|
t.Fatal("listen should default to a non-nil factory")
|
|
}
|
|
if v.newConnection == nil {
|
|
t.Fatal("newConnection should default to a non-nil factory")
|
|
}
|
|
// Default newConnection returns a wired-up rtmpConnection that reuses
|
|
// the server's load balancer.
|
|
rc := v.newConnection()
|
|
if rc == nil {
|
|
t.Fatal("default newConnection returned nil")
|
|
}
|
|
if rc.loadBalancer != v.loadBalancer {
|
|
t.Fatal("default newConnection should propagate the server's loadBalancer")
|
|
}
|
|
}
|
|
|
|
func TestRTMPProxyServer_NewAppliesOptions(t *testing.T) {
|
|
var listenCalls, newConnCalls atomic.Int32
|
|
srv := NewRTMPProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
func(v *rtmpProxyServer) {
|
|
v.listen = func(ctx context.Context, addr string) (net.Listener, error) {
|
|
listenCalls.Add(1)
|
|
return nil, errors.New("unused")
|
|
}
|
|
v.newConnection = func() *rtmpConnection {
|
|
newConnCalls.Add(1)
|
|
return newRTMPConnection()
|
|
}
|
|
},
|
|
)
|
|
v := srv.(*rtmpProxyServer)
|
|
_, _ = v.listen(context.Background(), ":0")
|
|
_ = v.newConnection()
|
|
if got := listenCalls.Load(); got != 1 {
|
|
t.Fatalf("custom listen called %d times, want 1", got)
|
|
}
|
|
if got := newConnCalls.Load(); got != 1 {
|
|
t.Fatalf("custom newConnection called %d times, want 1", got)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpProxyServer.Close
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRTMPProxyServer_Close_NoListener(t *testing.T) {
|
|
// Close before Run must not panic, must not hang, and must not error.
|
|
srv := NewRTMPProxyServer(&envfakes.FakeProxyEnvironment{}, &lbfakes.FakeOriginLoadBalancer{})
|
|
done := make(chan error, 1)
|
|
go func() { done <- srv.Close() }()
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("Close hung with no listener")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpProxyServer.Run: listen and endpoint normalization
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRTMPProxyServer_Run_ListenError(t *testing.T) {
|
|
envFake := &envfakes.FakeProxyEnvironment{}
|
|
envFake.RtmpServerReturns("1935")
|
|
srv := NewRTMPProxyServer(envFake, &lbfakes.FakeOriginLoadBalancer{}, func(v *rtmpProxyServer) {
|
|
v.listen = func(ctx context.Context, addr string) (net.Listener, error) {
|
|
return nil, errors.New("permission denied")
|
|
}
|
|
})
|
|
|
|
err := srv.Run(context.Background())
|
|
if err == nil {
|
|
t.Fatal("expected error from Run when listen fails")
|
|
}
|
|
if !strings.Contains(err.Error(), "listen rtmp addr") {
|
|
t.Fatalf("unexpected error %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRTMPProxyServer_Run_EndpointWithoutColon(t *testing.T) {
|
|
// A bare port like "1935" must be normalized to ":1935" before reaching listen().
|
|
f := newProxyFixture()
|
|
f.env.RtmpServerReturns("1935")
|
|
|
|
if err := f.server.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
defer f.server.Close()
|
|
|
|
if got := f.listenAddr.Load(); got != ":1935" {
|
|
t.Fatalf("listen addr=%v, want :1935", got)
|
|
}
|
|
}
|
|
|
|
func TestRTMPProxyServer_Run_EndpointWithColon(t *testing.T) {
|
|
// An endpoint that already contains ":" must be passed through unchanged.
|
|
f := newProxyFixture()
|
|
f.env.RtmpServerReturns("127.0.0.1:1935")
|
|
|
|
if err := f.server.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
defer f.server.Close()
|
|
|
|
if got := f.listenAddr.Load(); got != "127.0.0.1:1935" {
|
|
t.Fatalf("listen addr=%v, want 127.0.0.1:1935", got)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// rtmpProxyServer.Run: accept loop
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRTMPProxyServer_Run_AcceptInvokesNewConnection(t *testing.T) {
|
|
f := newProxyFixture()
|
|
if err := f.server.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
|
|
conn := &fakeNetConn{}
|
|
f.listener.push(conn)
|
|
|
|
// Wait for the per-conn goroutine to start (newHandshake is observed).
|
|
select {
|
|
case <-f.serveDone:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("newConnection was not invoked for accepted conn")
|
|
}
|
|
|
|
if got := f.newConnCalls.Load(); got != 1 {
|
|
t.Fatalf("newConnection called %d times, want 1", got)
|
|
}
|
|
|
|
// Close shuts the listener and drains the accept goroutine cleanly.
|
|
if err := f.server.Close(); err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
// The accepted conn should have been closed by the per-conn goroutine's
|
|
// defer once serve() returned.
|
|
if !conn.closed.Load() {
|
|
t.Fatal("accepted conn was not closed after serve returned")
|
|
}
|
|
}
|
|
|
|
func TestRTMPProxyServer_Run_CloseShutsDownAcceptLoop(t *testing.T) {
|
|
// Start Run with an idle listener (no queued conns). Accept blocks. Close
|
|
// must unblock it and let Run/Close return cleanly via the closed-network
|
|
// branch in the accept loop.
|
|
f := newProxyFixture()
|
|
if err := f.server.Run(context.Background()); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
go func() { done <- f.server.Close() }()
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("Close hung — accept loop did not exit on listener close")
|
|
}
|
|
|
|
if f.newConnCalls.Load() != 0 {
|
|
t.Fatal("newConnection should not be called when no conn was accepted")
|
|
}
|
|
}
|