Fixes three RTMP chunk-stream decoding bugs in the proxy and hardens AMF0 command-packet unmarshalling against malformed input, backed by a new protocol unit-test suite. All changes are confined to the `internal/rtmp` package. No public API, log format, or emitted wire format changes — these are decode-correctness and robustness fixes only. **3-byte chunk basic header decode (`readBasicHeader`) ** The 3-byte basic-header form (cid 64–65599) was selected by testing `cid == 1` *after* `cid` had already been overwritten with `64 + t`, so it was never detected. Capture the original marker before overwriting and test that instead. **Extended-timestamp handling (`chunkStream`, `readMessageHeader`)** - Use the extended timestamp as a delta for fmt=1/2 chunks (and a fmt=3 first chunk continuing them), required when the delta is ≥ `0xffffff`. Timestamp computation is unified into a single post-step: extended timestamp when present, otherwise the 3-byte header delta; fmt=0 absolute, fmt=1/2 accumulated. - Detect Type-3 chunks that omit the extended timestamp. FMLE/FMS/Flash follow the RTMP 2012 spec and always send it on Type-3 chunks; librtmp/ffmpeg may not. Switched from an unconditional 4-byte read to `Peek` + conditional `Discard`: if the peeked value differs from the stored one on a non-first chunk, those 4 bytes are payload and are left in the reader. - Split the single `extendedTimestamp` bool into `hasExtendedTimestamp` (bool) and `extendedTimestamp` (the last raw value, used for the detection above). **Packet unmarshal hardening** - Add an `advanceBytes(p, n)` helper that bounds-checks each `p = p[field.Size():]` advance, turning a slice-out-of-range panic into a clean error on truncated/untrusted input. Applied in `CallPacket`, `CreateStreamResPacket`, `PublishPacket`, and `PlayPacket`. - Reset the optional `CommandObject` / `Args` to nil before probing for their presence, so a stale constructor default (e.g. Null) isn't counted by `Size()` and can't overflow a later advance. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2040 lines
83 KiB
Go
2040 lines
83 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package rtmp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
type errWriter struct{}
|
|
|
|
func (errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe }
|
|
|
|
func TestHandshakeSimpleAndErrors(t *testing.T) {
|
|
h := NewHandshake()
|
|
var b bytes.Buffer
|
|
if err := h.WriteC0S0(&b); err != nil {
|
|
t.Fatalf("WriteC0S0 err=%v", err)
|
|
}
|
|
c0, err := h.ReadC0S0(&b)
|
|
if err != nil || !bytes.Equal(c0, []byte{3}) {
|
|
t.Fatalf("ReadC0S0=%v, err=%v", c0, err)
|
|
}
|
|
if err := h.WriteC0S0(errWriter{}); err == nil {
|
|
t.Fatal("WriteC0S0 should fail")
|
|
}
|
|
if _, err := h.ReadC0S0(bytes.NewReader(nil)); err == nil {
|
|
t.Fatal("ReadC0S0 should fail")
|
|
}
|
|
|
|
b.Reset()
|
|
if err := h.WriteC1S1(&b); err != nil {
|
|
t.Fatalf("WriteC1S1 err=%v", err)
|
|
}
|
|
if b.Len() != 1536 {
|
|
t.Fatalf("C1S1 len=%v", b.Len())
|
|
}
|
|
c1, err := h.ReadC1S1(&b)
|
|
if err != nil || len(c1) != 1536 || !bytes.Equal(h.C1S1(), c1) {
|
|
t.Fatalf("ReadC1S1 len=%v, cached=%v, err=%v", len(c1), bytes.Equal(h.C1S1(), c1), err)
|
|
}
|
|
if err := h.WriteC1S1(errWriter{}); err == nil {
|
|
t.Fatal("WriteC1S1 should fail")
|
|
}
|
|
if _, err := h.ReadC1S1(bytes.NewReader(make([]byte, 1535))); err == nil {
|
|
t.Fatal("ReadC1S1 should fail")
|
|
}
|
|
|
|
b.Reset()
|
|
if err := h.WriteC2S2(&b, c1); err != nil {
|
|
t.Fatalf("WriteC2S2 err=%v", err)
|
|
}
|
|
c2, err := h.ReadC2S2(&b)
|
|
if err != nil || !bytes.Equal(c2, c1) {
|
|
t.Fatalf("ReadC2S2 match=%v, err=%v", bytes.Equal(c2, c1), err)
|
|
}
|
|
if err := h.WriteC2S2(errWriter{}, c1); err == nil {
|
|
t.Fatal("WriteC2S2 should fail")
|
|
}
|
|
if _, err := h.ReadC2S2(bytes.NewReader(make([]byte, 1535))); err == nil {
|
|
t.Fatal("ReadC2S2 should fail")
|
|
}
|
|
}
|
|
|
|
func TestSettingsChunkStreamAndMessageConstructors(t *testing.T) {
|
|
if s := newSettings(); s.chunkSize != defaultChunkSize {
|
|
t.Fatalf("chunk size=%v", s.chunkSize)
|
|
}
|
|
if c := newChunkStream(); c == nil || c.count != 0 {
|
|
t.Fatalf("chunk stream=%#v", c)
|
|
}
|
|
m := NewMessage().asMessage()
|
|
m.messageHeader.MessageType = MessageTypeAudio
|
|
m.messageHeader.Timestamp = 99
|
|
m.payload = []byte{1, 2, 3}
|
|
if m.MessageType() != MessageTypeAudio || m.Timestamp() != 99 || !bytes.Equal(m.Payload(), []byte{1, 2, 3}) || m.asMessage() != m {
|
|
t.Fatalf("bad message accessors")
|
|
}
|
|
sm := NewStreamMessage(7).asMessage()
|
|
if sm.streamID != 7 || sm.betterCid != chunkIDOverStream {
|
|
t.Fatalf("stream message=%#v", sm.messageHeader)
|
|
}
|
|
}
|
|
|
|
func TestBasicHeaderVariantsAndErrors(t *testing.T) {
|
|
ctx := context.Background()
|
|
cases := []struct {
|
|
name string
|
|
data []byte
|
|
fmt formatType
|
|
cid chunkID
|
|
}{
|
|
{"one-byte", []byte{0x85}, formatType2, 5},
|
|
{"two-byte", []byte{0x40, 0x0a}, formatType1, 74},
|
|
{"three-byte", []byte{0xc1, 0x01, 0x02}, formatType3, 577},
|
|
}
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := NewProtocol(bytes.NewBuffer(tt.data)).(*protocol)
|
|
fmt, cid, err := p.readBasicHeader(ctx)
|
|
if err != nil || fmt != tt.fmt || cid != tt.cid {
|
|
t.Fatalf("fmt=%v cid=%v err=%v", fmt, cid, err)
|
|
}
|
|
})
|
|
}
|
|
for _, data := range [][]byte{{}, {0x00}} {
|
|
p := NewProtocol(bytes.NewBuffer(data)).(*protocol)
|
|
if _, _, err := p.readBasicHeader(ctx); err == nil {
|
|
t.Fatalf("readBasicHeader(%x) should fail", data)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestReadMessageHeadersPayloadsAndChunks(t *testing.T) {
|
|
ctx := context.Background()
|
|
var in bytes.Buffer
|
|
// fmt0 cid=5, timestamp=10, len=3, type audio, stream=1, payload 010203.
|
|
in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3})
|
|
// fmt1 same cid, delta=5, len=2, type video, payload 0405.
|
|
in.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5})
|
|
// fmt2 same cid, delta=7, reuses len/type/stream, payload 0607.
|
|
in.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7})
|
|
// fmt3 same cid, reuses delta and advances timestamp, payload 0809.
|
|
in.Write([]byte{0xc5, 8, 9})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
typ MessageType
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{MessageTypeAudio, 10, []byte{1, 2, 3}},
|
|
{MessageTypeVideo, 15, []byte{4, 5}},
|
|
{MessageTypeVideo, 22, []byte{6, 7}},
|
|
{MessageTypeVideo, 29, []byte{8, 9}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestReadMessageExtendedTimestampAndChunking(t *testing.T) {
|
|
ctx := context.Background()
|
|
var in bytes.Buffer
|
|
payload := []byte{1, 2, 3, 4, 5}
|
|
// fmt0 cid=5, normal timestamp=0xffffff, extended timestamp has high bit set and should be masked.
|
|
in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, byte(len(payload)), byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
|
|
in.Write(payload[:2])
|
|
// continuation chunk has fmt3 and extended timestamp too.
|
|
in.Write([]byte{0xc5})
|
|
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
|
|
in.Write(payload[2:4])
|
|
in.Write([]byte{0xc5})
|
|
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
|
|
in.Write(payload[4:])
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
p.input.opt.chunkSize = 2
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if m.Timestamp() != 42 || !bytes.Equal(m.Payload(), payload) {
|
|
t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
|
|
func TestReadMessageExtendedTimestampAsDeltaForFmt1(t *testing.T) {
|
|
ctx := context.Background()
|
|
var in bytes.Buffer
|
|
// fmt0 cid=5, timestamp=10, len=1, type video, stream=1, payload AA.
|
|
in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA})
|
|
// fmt1 cid=5, delta=0xffffff so the real delta is carried in the extended timestamp (=100),
|
|
// len=1, type video, payload BB. For fmt=1/2 the extended timestamp is a delta, so the
|
|
// message timestamp must accumulate: 10 + 100 = 110 (not be replaced by 100).
|
|
in.Write([]byte{0x45, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, byte(MessageTypeVideo)})
|
|
binary.Write(&in, binary.BigEndian, uint32(100))
|
|
in.Write([]byte{0xBB})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{10, []byte{0xAA}},
|
|
{110, []byte{0xBB}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestReadMessageType3OmitsExtendedTimestamp(t *testing.T) {
|
|
ctx := context.Background()
|
|
var in bytes.Buffer
|
|
// fmt0 cid=5, timestamp=0xffffff so an extended timestamp (=100) is present, len=8,
|
|
// type video, stream=1, with the first 4 payload bytes.
|
|
in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x08, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
binary.Write(&in, binary.BigEndian, uint32(100))
|
|
in.Write([]byte{0x01, 0x02, 0x03, 0x04})
|
|
// fmt3 continuation from a librtmp/ffmpeg-style sender that omits the extended timestamp.
|
|
// The next 4 bytes are payload, not an extended timestamp; the parser must detect the
|
|
// mismatch against the stored value (100) and treat them as payload, keeping ts=100.
|
|
in.Write([]byte{0xc5, 0x05, 0x06, 0x07, 0x08})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
p.input.opt.chunkSize = 4
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if m.Timestamp() != 100 || !bytes.Equal(m.Payload(), []byte{1, 2, 3, 4, 5, 6, 7, 8}) {
|
|
t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
|
|
func TestReadMessageHeaderErrors(t *testing.T) {
|
|
ctx := context.Background()
|
|
// Fresh non-zero chunk with fmt1 is rejected.
|
|
p := NewProtocol(bytes.NewBuffer([]byte{0x45})).(*protocol)
|
|
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "fresh chunk") {
|
|
t.Fatalf("fresh fmt1 err=%v", err)
|
|
}
|
|
// Existing partial message cannot restart with fmt0.
|
|
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x05})).(*protocol)
|
|
p.input.opt.chunkSize = 1
|
|
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "exists chunk") {
|
|
t.Fatalf("restart err=%v", err)
|
|
}
|
|
// Size change in a continuation header is rejected.
|
|
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x45, 0, 0, 1, 0, 0, 4, byte(MessageTypeAudio)})).(*protocol)
|
|
p.input.opt.chunkSize = 1
|
|
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "message size") {
|
|
t.Fatalf("size change err=%v", err)
|
|
}
|
|
// Short payload and short extended timestamp.
|
|
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 2, byte(MessageTypeAudio), 1, 0, 0, 0, 1})).(*protocol)
|
|
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "read chunk") {
|
|
t.Fatalf("payload err=%v", err)
|
|
}
|
|
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0xff, 0xff, 0xff, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 2})).(*protocol)
|
|
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "ext-ts") {
|
|
t.Fatalf("ext-ts err=%v", err)
|
|
}
|
|
}
|
|
|
|
func TestWriteMessageHeadersChunkingAndErrors(t *testing.T) {
|
|
ctx := context.Background()
|
|
var out bytes.Buffer
|
|
p := NewProtocol(&out).(*protocol)
|
|
p.output.opt.chunkSize = 2
|
|
m := NewStreamMessage(7).asMessage()
|
|
m.messageHeader.MessageType = MessageTypeVideo
|
|
m.messageHeader.Timestamp = extendedTimestamp + 9
|
|
m.payload = []byte{1, 2, 3, 4, 5}
|
|
if err := p.WriteMessage(ctx, m); err != nil {
|
|
t.Fatalf("WriteMessage err=%v", err)
|
|
}
|
|
want := []byte{0x05, 0xff, 0xff, 0xff, 0, 0, 5, byte(MessageTypeVideo), 7, 0, 0, 0, 0x01, 0x00, 0x00, 0x08, 1, 2, 0xc5, 0x01, 0x00, 0x00, 0x08, 3, 4, 0xc5, 0x01, 0x00, 0x00, 0x08, 5}
|
|
if !bytes.Equal(out.Bytes(), want) {
|
|
t.Fatalf("written=%x want=%x", out.Bytes(), want)
|
|
}
|
|
if err := p.WriteMessage(ctx, (&message{})); err != nil {
|
|
t.Fatalf("empty WriteMessage err=%v", err)
|
|
}
|
|
canceled, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
if err := p.WriteMessage(canceled, m); err != context.Canceled {
|
|
t.Fatalf("canceled WriteMessage err=%v", err)
|
|
}
|
|
p = NewProtocol(struct {
|
|
io.Reader
|
|
io.Writer
|
|
}{bytes.NewReader(nil), errWriter{}}).(*protocol)
|
|
if err := p.WriteMessage(ctx, m); err == nil {
|
|
t.Fatal("WriteMessage to bad writer should fail")
|
|
}
|
|
}
|
|
|
|
func TestProtocolDecodeMessageAndControls(t *testing.T) {
|
|
ctx := context.Background()
|
|
p := NewProtocol(&bytes.Buffer{}).(*protocol)
|
|
if _, err := p.DecodeMessage((&message{})); err == nil || !strings.Contains(err.Error(), "Empty packet") {
|
|
t.Fatalf("empty decode err=%v", err)
|
|
}
|
|
unknown := &message{}
|
|
unknown.messageHeader.MessageType = MessageTypeAudio
|
|
unknown.payload = []byte{1}
|
|
if _, err := p.DecodeMessage(unknown); err == nil || !strings.Contains(err.Error(), "Unknown message") {
|
|
t.Fatalf("unknown err=%v", err)
|
|
}
|
|
bad := &message{}
|
|
bad.messageHeader.MessageType = MessageTypeSetChunkSize
|
|
bad.payload = []byte{1, 2}
|
|
if _, err := p.DecodeMessage(bad); err == nil || !strings.Contains(err.Error(), "Unmarshal") {
|
|
t.Fatalf("bad control err=%v", err)
|
|
}
|
|
|
|
for _, pkt := range []Packet{
|
|
&SetChunkSize{ChunkSize: 4096},
|
|
&WindowAcknowledgementSize{AckSize: 2500000},
|
|
&SetPeerBandwidth{Bandwidth: 1000, LimitType: LimitTypeSoft},
|
|
&UserControl{EventType: EventTypePingRequest, EventData: 123},
|
|
} {
|
|
data, err := pkt.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("marshal %T err=%v", pkt, err)
|
|
}
|
|
m := &message{payload: data}
|
|
m.messageHeader.MessageType = pkt.Type()
|
|
got, err := p.DecodeMessage(m)
|
|
if err != nil {
|
|
t.Fatalf("DecodeMessage %T err=%v", pkt, err)
|
|
}
|
|
if reflect.TypeOf(got) != reflect.TypeOf(pkt) {
|
|
t.Fatalf("got %T want %T", got, pkt)
|
|
}
|
|
}
|
|
|
|
chunk := &SetChunkSize{ChunkSize: 3}
|
|
m := &message{}
|
|
m.messageHeader.MessageType = chunk.Type()
|
|
m.payload, _ = chunk.MarshalBinary()
|
|
if err := p.onMessageArrivated(m); err != nil || p.input.opt.chunkSize != 3 {
|
|
t.Fatalf("onMessageArrivated err=%v chunk=%v", err, p.input.opt.chunkSize)
|
|
}
|
|
if err := p.onMessageArrivated(nil); err != nil {
|
|
t.Fatalf("nil onMessageArrivated err=%v", err)
|
|
}
|
|
bad.Payload()[0] = 1
|
|
if err := p.onMessageArrivated(bad); err == nil {
|
|
t.Fatal("bad onMessageArrivated should fail")
|
|
}
|
|
if _, err := p.ExpectMessage(ctx); err == nil {
|
|
t.Fatal("ExpectMessage on empty reader should fail")
|
|
}
|
|
}
|
|
|
|
func TestProtocolPacketsAndTransactions(t *testing.T) {
|
|
ctx := context.Background()
|
|
var wire bytes.Buffer
|
|
writer := NewProtocol(&wire).(*protocol)
|
|
connect := NewConnectAppPacket()
|
|
connect.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live"))
|
|
if connect.Size() == 0 || connect.BetterCid() != chunkIDOverConnection || connect.Type() != MessageTypeAMF0Command || connect.TcUrl() == "" {
|
|
t.Fatalf("connect metadata invalid")
|
|
}
|
|
if err := writer.WritePacket(ctx, connect, 0); err != nil {
|
|
t.Fatalf("WritePacket connect err=%v", err)
|
|
}
|
|
if _, ok := writer.input.transactions[connect.TransactionID]; !ok {
|
|
t.Fatal("connect transaction not tracked")
|
|
}
|
|
create := NewCreateStreamPacket()
|
|
if err := writer.WritePacket(ctx, create, 0); err != nil {
|
|
t.Fatalf("WritePacket create err=%v", err)
|
|
}
|
|
call := NewCallPacket()
|
|
call.CommandName = commandReleaseStream
|
|
call.TransactionID = 3
|
|
call.CommandObject = NewAmf0Null()
|
|
if err := writer.WritePacket(ctx, call, 0); err != nil {
|
|
t.Fatalf("WritePacket call err=%v", err)
|
|
}
|
|
reader := NewProtocol(&wire)
|
|
var gotConnect *ConnectAppPacket
|
|
if _, err := ExpectPacket(ctx, reader, &gotConnect); err != nil || gotConnect.TcUrl() != "rtmp://host/live" {
|
|
t.Fatalf("gotConnect=%v err=%v", gotConnect, err)
|
|
}
|
|
var gotCreate *CallPacket
|
|
if _, err := ExpectPacket(ctx, reader, &gotCreate); err != nil || gotCreate.CommandName != commandCreateStream {
|
|
t.Fatalf("gotCreate=%v err=%v", gotCreate, err)
|
|
}
|
|
|
|
decoder := NewProtocol(&bytes.Buffer{}).(*protocol)
|
|
decoder.input.transactions[1] = commandConnect
|
|
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, NewConnectAppResPacket(1))); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(NewConnectAppResPacket(1)) {
|
|
t.Fatalf("connect res pkt=%T err=%v", pkt, err)
|
|
}
|
|
decoder.input.transactions[2] = commandCreateStream
|
|
csr := NewCreateStreamResPacket(2)
|
|
csr.SetStreamID(99)
|
|
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, csr)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(csr) {
|
|
t.Fatalf("create res pkt=%T err=%v", pkt, err)
|
|
}
|
|
decoder.input.transactions[3] = commandReleaseStream
|
|
res := NewCallPacket()
|
|
res.CommandName = commandResult
|
|
res.TransactionID = 3
|
|
res.CommandObject = NewAmf0Null()
|
|
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, res)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(res) {
|
|
t.Fatalf("call res pkt=%T err=%v", pkt, err)
|
|
}
|
|
for _, name := range []amf0String{commandPublish, commandPlay, commandOnStatus} {
|
|
pkt := NewCallPacket()
|
|
pkt.CommandName = name
|
|
pkt.TransactionID = 0
|
|
pkt.CommandObject = NewAmf0Null()
|
|
if name == commandPublish {
|
|
pub := NewPublishPacket()
|
|
pub.TransactionID = 0
|
|
pub.StreamName = NewAmf0String("s")
|
|
pub.StreamType = NewAmf0String("live")
|
|
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pub)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pub) {
|
|
t.Fatalf("publish decoded=%T err=%v", decoded, err)
|
|
}
|
|
continue
|
|
}
|
|
if name == commandPlay {
|
|
play := NewPlayPacket()
|
|
play.TransactionID = 0
|
|
play.StreamName = NewAmf0String("s")
|
|
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, play)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(play) {
|
|
t.Fatalf("play decoded=%T err=%v", decoded, err)
|
|
}
|
|
continue
|
|
}
|
|
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pkt)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pkt) {
|
|
t.Fatalf("call decoded=%T err=%v", decoded, err)
|
|
}
|
|
}
|
|
decoder.input.transactions[9] = commandPause
|
|
errPkt := NewCallPacket()
|
|
errPkt.CommandName = commandError
|
|
errPkt.TransactionID = 9
|
|
errPkt.CommandObject = NewAmf0Null()
|
|
if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No request") {
|
|
t.Fatalf("unknown request err=%v", err)
|
|
}
|
|
if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No matched request") {
|
|
t.Fatalf("missing transaction err=%v", err)
|
|
}
|
|
if _, err := decoder.parseAMFObject([]byte{byte(amf0MarkerString), 0, 8, 'c'}); err == nil {
|
|
t.Fatal("bad AMF parse should fail")
|
|
}
|
|
|
|
cctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
if err := writer.WritePacket(cctx, connect, 0); err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
|
t.Fatalf("WritePacket canceled err=%v", err)
|
|
}
|
|
}
|
|
|
|
func TestDeprecatedExpectPacketPanics(t *testing.T) {
|
|
defer func() {
|
|
if recover() == nil {
|
|
t.Fatal("Expected panic")
|
|
}
|
|
}()
|
|
NewProtocol(&bytes.Buffer{}).ExpectPacket(context.Background(), nil)
|
|
}
|
|
|
|
func TestPacketRoundTripsAndErrors(t *testing.T) {
|
|
packets := []Packet{
|
|
NewConnectAppPacket(),
|
|
NewConnectAppResPacket(7),
|
|
NewCallPacket(),
|
|
NewCreateStreamPacket(),
|
|
func() Packet { p := NewCreateStreamResPacket(2); p.SetStreamID(1); return p }(),
|
|
func() Packet {
|
|
p := NewPublishPacket()
|
|
p.TransactionID = 0
|
|
p.StreamName = NewAmf0String("s")
|
|
return p
|
|
}(),
|
|
func() Packet { p := NewPlayPacket(); p.TransactionID = 0; p.StreamName = NewAmf0String("s"); return p }(),
|
|
&SetChunkSize{ChunkSize: 1},
|
|
&WindowAcknowledgementSize{AckSize: 2},
|
|
&SetPeerBandwidth{Bandwidth: 3, LimitType: LimitTypeDynamic},
|
|
&UserControl{EventType: EventTypeFmsEvent0, EventData: 1},
|
|
&UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 2},
|
|
}
|
|
// Initialize the generic call packet so it is marshalable.
|
|
packets[2].(*CallPacket).CommandName = commandOnStatus
|
|
packets[2].(*CallPacket).TransactionID = 0
|
|
packets[2].(*CallPacket).CommandObject = NewAmf0Null()
|
|
packets[2].(*CallPacket).Args = NewAmf0Object().Set("code", NewAmf0String("ok"))
|
|
packets[1].(*ConnectAppResPacket).Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid")))
|
|
|
|
for _, pkt := range packets {
|
|
t.Run(reflect.TypeOf(pkt).String(), func(t *testing.T) {
|
|
data, err := pkt.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("MarshalBinary err=%v", err)
|
|
}
|
|
if len(data) != pkt.Size() {
|
|
t.Fatalf("len=%v Size=%v", len(data), pkt.Size())
|
|
}
|
|
fresh := reflect.New(reflect.TypeOf(pkt).Elem()).Interface().(Packet)
|
|
switch v := fresh.(type) {
|
|
case *ConnectAppPacket:
|
|
*v = *NewConnectAppPacket()
|
|
case *ConnectAppResPacket:
|
|
*v = *NewConnectAppResPacket(0)
|
|
case *CreateStreamPacket:
|
|
*v = *NewCreateStreamPacket()
|
|
case *CreateStreamResPacket:
|
|
*v = *NewCreateStreamResPacket(0)
|
|
case *PublishPacket:
|
|
*v = *NewPublishPacket()
|
|
case *PlayPacket:
|
|
*v = *NewPlayPacket()
|
|
}
|
|
if err := fresh.UnmarshalBinary(data); err != nil {
|
|
t.Fatalf("UnmarshalBinary err=%v", err)
|
|
}
|
|
})
|
|
}
|
|
if packets[1].(*ConnectAppResPacket).SrsID() != "sid" || packets[2].(*CallPacket).ArgsCode() != "ok" {
|
|
t.Fatalf("packet helpers failed")
|
|
}
|
|
if NewConnectAppResPacket(1).SrsID() != "" || NewCallPacket().ArgsCode() != "" || NewConnectAppPacket().TcUrl() != "" {
|
|
t.Fatalf("empty helpers failed")
|
|
}
|
|
|
|
badConnect := NewConnectAppPacket()
|
|
badConnect.CommandName = commandPlay
|
|
if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil {
|
|
t.Fatal("bad connect name should fail")
|
|
}
|
|
badConnect = NewConnectAppPacket()
|
|
badConnect.TransactionID = 2
|
|
if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil {
|
|
t.Fatal("bad connect tid should fail")
|
|
}
|
|
badRes := NewConnectAppResPacket(1)
|
|
badRes.CommandName = commandPlay
|
|
if err := badRes.UnmarshalBinary(mustPacketBytes(t, badRes)); err == nil {
|
|
t.Fatal("bad connect response name should fail")
|
|
}
|
|
for _, pkt := range []Packet{NewConnectAppPacket(), NewCallPacket(), NewCreateStreamResPacket(1), NewPublishPacket(), NewPlayPacket()} {
|
|
if err := pkt.UnmarshalBinary([]byte{byte(amf0MarkerString)}); err == nil {
|
|
t.Fatalf("%T short unmarshal should fail", pkt)
|
|
}
|
|
}
|
|
for _, pkt := range []Packet{&SetChunkSize{}, &WindowAcknowledgementSize{}, &SetPeerBandwidth{}, &UserControl{}} {
|
|
if err := pkt.UnmarshalBinary([]byte{0, 1}); err == nil {
|
|
t.Fatalf("%T short unmarshal should fail", pkt)
|
|
}
|
|
}
|
|
uc := &UserControl{}
|
|
if err := uc.UnmarshalBinary([]byte{0, byte(EventTypeSetBufferLength), 1, 2, 3, 4, 5}); err == nil {
|
|
t.Fatal("short set-buffer-length should fail")
|
|
}
|
|
}
|
|
|
|
func mustPacketBytes(t *testing.T, pkt Packet) []byte {
|
|
t.Helper()
|
|
data, err := pkt.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("MarshalBinary %T err=%v", pkt, err)
|
|
}
|
|
return data
|
|
}
|
|
|
|
type failPacket struct{}
|
|
|
|
func (failPacket) Size() int { return 0 }
|
|
func (failPacket) UnmarshalBinary([]byte) error { return io.ErrUnexpectedEOF }
|
|
func (failPacket) MarshalBinary() ([]byte, error) { return nil, io.ErrClosedPipe }
|
|
func (failPacket) BetterCid() chunkID { return chunkIDOverConnection }
|
|
func (failPacket) Type() MessageType { return MessageTypeAMF0Command }
|
|
|
|
type stepWriter struct {
|
|
writes int
|
|
failAt int
|
|
}
|
|
|
|
func (w *stepWriter) Write(p []byte) (int, error) {
|
|
w.writes++
|
|
if w.writes == w.failAt {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func TestProtocolAdditionalBranches(t *testing.T) {
|
|
ctx := context.Background()
|
|
p := NewProtocol(&bytes.Buffer{}).(*protocol)
|
|
if NewSetPeerBandwidth().BetterCid() != chunkIDProtocolControl || NewUserControl().BetterCid() != chunkIDProtocolControl {
|
|
t.Fatal("control better cid failed")
|
|
}
|
|
if err := p.WritePacket(ctx, failPacket{}, 0); err == nil || !strings.Contains(err.Error(), "marshal payload") {
|
|
t.Fatalf("WritePacket marshal err=%v", err)
|
|
}
|
|
|
|
payloadWriter := &stepWriter{failAt: 1}
|
|
p = NewProtocol(struct {
|
|
io.Reader
|
|
io.Writer
|
|
}{bytes.NewReader(nil), payloadWriter}).(*protocol)
|
|
m := NewStreamMessage(1).asMessage()
|
|
m.messageHeader.MessageType = MessageTypeVideo
|
|
m.payload = bytes.Repeat([]byte{1}, 5000)
|
|
if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "write chunk payload") {
|
|
t.Fatalf("WriteMessage payload err=%v", err)
|
|
}
|
|
flushWriter := &stepWriter{failAt: 1}
|
|
p = NewProtocol(struct {
|
|
io.Reader
|
|
io.Writer
|
|
}{bytes.NewReader(nil), flushWriter}).(*protocol)
|
|
m.payload = []byte{1}
|
|
if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "flush writer") {
|
|
t.Fatalf("WriteMessage flush err=%v writes=%v", err, flushWriter.writes)
|
|
}
|
|
|
|
// Zero-length payload returns a complete message without reading chunk bytes.
|
|
in := bytes.NewBuffer([]byte{0x05, 0, 0, 1, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0})
|
|
p = NewProtocol(in).(*protocol)
|
|
if msg, err := p.ReadMessage(ctx); err != nil || msg.MessageType() != MessageTypeAudio || len(msg.Payload()) != 0 {
|
|
t.Fatalf("zero payload msg=%v err=%v", msg, err)
|
|
}
|
|
|
|
// ExpectMessage skips unwanted message types before returning the desired one.
|
|
var wire bytes.Buffer
|
|
writer := NewProtocol(&wire).(*protocol)
|
|
am := NewStreamMessage(1).asMessage()
|
|
am.messageHeader.MessageType = MessageTypeAudio
|
|
am.payload = []byte{1}
|
|
vm := NewStreamMessage(1).asMessage()
|
|
vm.messageHeader.MessageType = MessageTypeVideo
|
|
vm.payload = []byte{2}
|
|
if err := writer.WriteMessage(ctx, am); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := writer.WriteMessage(ctx, vm); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
reader := NewProtocol(&wire)
|
|
if got, err := reader.ExpectMessage(ctx, MessageTypeVideo); err != nil || got.MessageType() != MessageTypeVideo {
|
|
t.Fatalf("ExpectMessage got=%v err=%v", got, err)
|
|
}
|
|
|
|
// Generic ExpectPacket skips non-matching packets, then returns matching; it also reports decode/read errors.
|
|
wire.Reset()
|
|
writer = NewProtocol(&wire).(*protocol)
|
|
if err := writer.WritePacket(ctx, &WindowAcknowledgementSize{AckSize: 1}, 0); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := writer.WritePacket(ctx, &SetChunkSize{ChunkSize: 2}, 0); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
reader = NewProtocol(&wire)
|
|
var chunk *SetChunkSize
|
|
if _, err := ExpectPacket(ctx, reader, &chunk); err != nil || chunk.ChunkSize != 2 {
|
|
t.Fatalf("ExpectPacket chunk=%v err=%v", chunk, err)
|
|
}
|
|
reader = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 1, byte(MessageTypeSetChunkSize), 1, 0, 0, 0, 0}))
|
|
if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "decode message") {
|
|
t.Fatalf("ExpectPacket decode err=%v", err)
|
|
}
|
|
reader = NewProtocol(bytes.NewBuffer(nil))
|
|
if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "read message") {
|
|
t.Fatalf("ExpectPacket read err=%v", err)
|
|
}
|
|
|
|
// AMF3 strips the leading byte before AMF0 decoding.
|
|
pub := NewPublishPacket()
|
|
pub.TransactionID = 0
|
|
pub.StreamName = NewAmf0String("stream")
|
|
data := append([]byte{0}, mustPacketBytes(t, pub)...)
|
|
msg := &message{payload: data}
|
|
msg.messageHeader.MessageType = MessageTypeAMF3Command
|
|
if pkt, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(msg); err != nil || pkt.(*PublishPacket).StreamName.String() != "stream" {
|
|
t.Fatalf("AMF3 decode pkt=%T err=%v", pkt, err)
|
|
}
|
|
}
|
|
|
|
func TestProtocolErrorBranchesForCoverage(t *testing.T) {
|
|
ctx := context.Background()
|
|
// ExpectMessage without requested types returns the first message.
|
|
var wire bytes.Buffer
|
|
w := NewProtocol(&wire).(*protocol)
|
|
msg := NewStreamMessage(1).asMessage()
|
|
msg.messageHeader.MessageType = MessageTypeAudio
|
|
msg.payload = []byte{1}
|
|
if err := w.WriteMessage(ctx, msg); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got, err := NewProtocol(&wire).ExpectMessage(ctx); err != nil || got.MessageType() != MessageTypeAudio {
|
|
t.Fatalf("ExpectMessage any got=%v err=%v", got, err)
|
|
}
|
|
cctx, cancel := context.WithCancel(ctx)
|
|
cancel()
|
|
if _, err := NewProtocol(bytes.NewBuffer(nil)).ReadMessage(cctx); err != context.Canceled {
|
|
t.Fatalf("ReadMessage canceled err=%v", err)
|
|
}
|
|
if err := w.WriteMessage(cctx, (&message{})); err != context.Canceled {
|
|
t.Fatalf("WriteMessage empty canceled err=%v", err)
|
|
}
|
|
badAMF := &message{payload: []byte{0xff}}
|
|
badAMF.messageHeader.MessageType = MessageTypeAMF0Command
|
|
if _, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(badAMF); err == nil || !strings.Contains(err.Error(), "Parse AMF") {
|
|
t.Fatalf("bad AMF decode err=%v", err)
|
|
}
|
|
rn := commandResult
|
|
resultName, _ := (&rn).MarshalBinary()
|
|
if _, err := NewProtocol(&bytes.Buffer{}).(*protocol).parseAMFObject(append(resultName, 0)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
|
|
t.Fatalf("bad result tid err=%v", err)
|
|
}
|
|
}
|
|
|
|
func TestPacketUnmarshalErrorBranchesForCoverage(t *testing.T) {
|
|
cn := commandConnect
|
|
name, _ := (&cn).MarshalBinary()
|
|
tn := amf0Number(1)
|
|
tid, _ := (&tn).MarshalBinary()
|
|
obj, _ := NewAmf0Object().MarshalBinary()
|
|
base := append(append([]byte{}, name...), tid...)
|
|
oc := &objectCallPacket{CommandObject: NewAmf0Object()}
|
|
if err := oc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
|
|
t.Fatalf("object tid err=%v", err)
|
|
}
|
|
if err := oc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal command") {
|
|
t.Fatalf("object command err=%v", err)
|
|
}
|
|
withObj := append(append([]byte{}, base...), obj...)
|
|
if err := oc.UnmarshalBinary(append(withObj, 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal args") {
|
|
t.Fatalf("object args err=%v", err)
|
|
}
|
|
|
|
vc := &variantCallPacket{}
|
|
if err := vc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
|
|
t.Fatalf("variant tid err=%v", err)
|
|
}
|
|
if err := vc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "discovery command object") {
|
|
t.Fatalf("variant discovery err=%v", err)
|
|
}
|
|
if err := vc.UnmarshalBinary(append(append([]byte{}, base...), byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal command object") {
|
|
t.Fatalf("variant command object err=%v", err)
|
|
}
|
|
|
|
call := NewCallPacket()
|
|
call.CommandName = commandOnStatus
|
|
call.TransactionID = 0
|
|
call.CommandObject = NewAmf0Null()
|
|
callBase := mustPacketBytes(t, call)
|
|
if err := NewCallPacket().UnmarshalBinary(append(callBase, 0xff)); err == nil || !strings.Contains(err.Error(), "discovery args") {
|
|
t.Fatalf("call discovery args err=%v", err)
|
|
}
|
|
if err := NewCallPacket().UnmarshalBinary(append(callBase, byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal args") {
|
|
t.Fatalf("call unmarshal args err=%v", err)
|
|
}
|
|
csr := NewCreateStreamResPacket(2)
|
|
if err := NewCreateStreamResPacket(0).UnmarshalBinary(mustPacketBytes(t, &csr.variantCallPacket)); err == nil || !strings.Contains(err.Error(), "unmarshal sid") {
|
|
t.Fatalf("create stream sid err=%v", err)
|
|
}
|
|
pub := NewPublishPacket()
|
|
pub.TransactionID = 0
|
|
pubPrefix, _ := pub.variantCallPacket.MarshalBinary()
|
|
if err := NewPublishPacket().UnmarshalBinary(append(pubPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") {
|
|
t.Fatalf("publish stream name err=%v", err)
|
|
}
|
|
streamName, _ := NewAmf0String("s").MarshalBinary()
|
|
if err := NewPublishPacket().UnmarshalBinary(append(append(pubPrefix, streamName...), 0xff)); err == nil || !strings.Contains(err.Error(), "stream type") {
|
|
t.Fatalf("publish stream type err=%v", err)
|
|
}
|
|
play := NewPlayPacket()
|
|
play.TransactionID = 0
|
|
playPrefix, _ := play.variantCallPacket.MarshalBinary()
|
|
if err := NewPlayPacket().UnmarshalBinary(append(playPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") {
|
|
t.Fatalf("play stream name err=%v", err)
|
|
}
|
|
}
|
|
|
|
// TestReadMessageInterleavedMultiStream covers P1: two or more chunk streams interleaved on the
|
|
// wire, each reassembling independently via protocol.input.chunks. All other read tests use a
|
|
// single cid (5), so the per-cid map and per-cid header state are never exercised under
|
|
// interleaving. Mirrors the C++ srs_utest_manual_protocol.cpp ProtocolRecvVAVMessage /
|
|
// ProtocolRecvVAVFmt1/2/3 family.
|
|
func TestReadMessageInterleavedMultiStream(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
// Chunk-by-chunk interleave: three multi-chunk messages on cid 6 (video), 7 (audio) and 8
|
|
// (data) are split at chunkSize=3 and woven together on the wire. Each message must reassemble
|
|
// from its own cid's chunks regardless of the chunks belonging to other cids in between, and
|
|
// surface in the order its final chunk arrives (V, then A, then D).
|
|
t.Run("chunk-by-chunk", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
// V fmt0 cid=6, ts=100, len=6, video, stream=1, first 3 payload bytes.
|
|
in.Write([]byte{0x06, 0x00, 0x00, 0x64, 0x00, 0x00, 0x06, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xa1, 0xa2, 0xa3})
|
|
// A fmt0 cid=7, ts=200, len=4, audio, stream=1, first 3 payload bytes.
|
|
in.Write([]byte{0x07, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x04, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0xb1, 0xb2, 0xb3})
|
|
// D fmt0 cid=8, ts=300, len=5, data, stream=1, first 3 payload bytes.
|
|
in.Write([]byte{0x08, 0x00, 0x01, 0x2c, 0x00, 0x00, 0x05, byte(MessageTypeAMF0Data), 0x01, 0x00, 0x00, 0x00, 0xc1, 0xc2, 0xc3})
|
|
// V fmt3 cid=6 continuation, last 3 payload bytes -> V completes.
|
|
in.Write([]byte{0xc6, 0xa4, 0xa5, 0xa6})
|
|
// A fmt3 cid=7 continuation, last payload byte -> A completes.
|
|
in.Write([]byte{0xc7, 0xb4})
|
|
// D fmt3 cid=8 continuation, last 2 payload bytes -> D completes.
|
|
in.Write([]byte{0xc8, 0xc4, 0xc5})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
p.input.opt.chunkSize = 3
|
|
for i, want := range []struct {
|
|
typ MessageType
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{MessageTypeVideo, 100, []byte{0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6}},
|
|
{MessageTypeAudio, 200, []byte{0xb1, 0xb2, 0xb3, 0xb4}},
|
|
{MessageTypeAMF0Data, 300, []byte{0xc1, 0xc2, 0xc3, 0xc4, 0xc5}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
})
|
|
|
|
// Per-cid header-state isolation: single-chunk messages alternate between the video cid (6)
|
|
// and audio cid (7). The second and later messages on each cid use fmt1/2/3 headers, which
|
|
// inherit timestamp delta / payload length / type from the *previous message on the same cid*.
|
|
// An interleaved message on the other cid must not perturb that state, so the video deltas
|
|
// accumulate only over video (1000 -> 1010 -> 1015) and audio only over audio
|
|
// (5000 -> 5020 -> 5040).
|
|
t.Run("per-cid-header-state", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
// V1 fmt0 cid=6, ts=1000, len=2, video, stream=1.
|
|
in.Write([]byte{0x06, 0x00, 0x03, 0xe8, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0x11, 0x12})
|
|
// A1 fmt0 cid=7, ts=5000, len=2, audio, stream=1.
|
|
in.Write([]byte{0x07, 0x00, 0x13, 0x88, 0x00, 0x00, 0x02, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0x21, 0x22})
|
|
// V2 fmt1 cid=6, delta=10, len=2, video -> ts 1000+10=1010 (inherits from V1, not A1).
|
|
in.Write([]byte{0x46, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x13, 0x14})
|
|
// A2 fmt1 cid=7, delta=20, len=2, audio -> ts 5000+20=5020 (inherits from A1).
|
|
in.Write([]byte{0x47, 0x00, 0x00, 0x14, 0x00, 0x00, 0x02, byte(MessageTypeAudio), 0x23, 0x24})
|
|
// V3 fmt2 cid=6, delta=5 (len/type reused from V2) -> ts 1010+5=1015.
|
|
in.Write([]byte{0x86, 0x00, 0x00, 0x05, 0x15, 0x16})
|
|
// A3 fmt3 cid=7 (delta=20, len/type reused from A2) -> ts 5020+20=5040.
|
|
in.Write([]byte{0xc7, 0x25, 0x26})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
typ MessageType
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{MessageTypeVideo, 1000, []byte{0x11, 0x12}},
|
|
{MessageTypeAudio, 5000, []byte{0x21, 0x22}},
|
|
{MessageTypeVideo, 1010, []byte{0x13, 0x14}},
|
|
{MessageTypeAudio, 5020, []byte{0x23, 0x24}},
|
|
{MessageTypeVideo, 1015, []byte{0x15, 0x16}},
|
|
{MessageTypeAudio, 5040, []byte{0x25, 0x26}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
})
|
|
|
|
// Per-cid payload-length change: successive messages on the video cid (6) carry *different*
|
|
// payload lengths via fmt1 headers (2 -> 3 -> 1), while audio (cid 7) is interleaved in
|
|
// between. fmt1 begins a new message (chunk.message is nil), so the length check is skipped
|
|
// and the chunkStream must adopt each new length rather than reusing the previous one. Mirrors
|
|
// the C++ srs_utest_manual_protocol2.cpp ProtocolRecvVAVVFmt11Length / Fmt12Length cases.
|
|
t.Run("per-cid-length-change", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
// V1 fmt0 cid=6, ts=0x10, len=2, video, stream=1.
|
|
in.Write([]byte{0x06, 0x00, 0x00, 0x10, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0x11, 0x12})
|
|
// A1 fmt0 cid=7, ts=0x15, len=2, audio, stream=1.
|
|
in.Write([]byte{0x07, 0x00, 0x00, 0x15, 0x00, 0x00, 0x02, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0x21, 0x22})
|
|
// V2 fmt1 cid=6, delta=0x10, len=3 (changed 2->3), video -> ts 0x20.
|
|
in.Write([]byte{0x46, 0x00, 0x00, 0x10, 0x00, 0x00, 0x03, byte(MessageTypeVideo), 0x13, 0x14, 0x15})
|
|
// V3 fmt1 cid=6, delta=0x20, len=1 (changed 3->1), video -> ts 0x40.
|
|
in.Write([]byte{0x46, 0x00, 0x00, 0x20, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x16})
|
|
// A2 fmt1 cid=7, delta=0x05, len=4 (changed 2->4), audio -> ts 0x1a.
|
|
in.Write([]byte{0x47, 0x00, 0x00, 0x05, 0x00, 0x00, 0x04, byte(MessageTypeAudio), 0x23, 0x24, 0x25, 0x26})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
typ MessageType
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{MessageTypeVideo, 0x10, []byte{0x11, 0x12}},
|
|
{MessageTypeAudio, 0x15, []byte{0x21, 0x22}},
|
|
{MessageTypeVideo, 0x20, []byte{0x13, 0x14, 0x15}},
|
|
{MessageTypeVideo, 0x40, []byte{0x16}},
|
|
{MessageTypeAudio, 0x1a, []byte{0x23, 0x24, 0x25, 0x26}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestReadMessageLargeChunkStreamID covers P2: reading a complete message (basic header + message
|
|
// header + payload) whose chunks carry the chunk-stream ID in the 2-byte (cid 64-319) and 3-byte
|
|
// (cid 64-65599) basic-header forms. Every other read test uses a 1-byte cid (5), so
|
|
// readBasicHeader's multi-byte cid decode is exercised for header decode only
|
|
// (TestBasicHeaderVariantsAndErrors) and never end-to-end through ReadMessage carrying a real
|
|
// payload. Asserting the decoded cid keys input.chunks also proves the encode/decode are inverses
|
|
// (a swapped 2nd/3rd byte would land on a different cid). Mirrors the C++
|
|
// srs_utest_manual_protocol.cpp / protocol2.cpp ProtocolRecvVCid2B* / Cid3B* family.
|
|
func TestReadMessageLargeChunkStreamID(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
// basicHeader encodes fmt+cid in the smallest basic-header form that fits cid: 1 byte for
|
|
// 2-63, 2 bytes for 64-319, 3 bytes for 320-65599. The ID math is the inverse of
|
|
// readBasicHeader: 2-byte cid = 64 + b1; 3-byte cid = 64 + b1 + b2*256.
|
|
basicHeader := func(format formatType, cid chunkID) []byte {
|
|
f := byte(format) << 6
|
|
switch {
|
|
case cid <= 63:
|
|
return []byte{f | byte(cid)}
|
|
case cid <= 319:
|
|
return []byte{f, byte(cid - 64)} // 2-byte marker is 0 in the low 6 bits
|
|
default:
|
|
v := uint32(cid) - 64
|
|
return []byte{f | 0x01, byte(v % 256), byte(v / 256)}
|
|
}
|
|
}
|
|
|
|
// Single-chunk fmt0 message on each boundary cid: 2-byte min (64) and max (319), 3-byte first
|
|
// value (320) and max (65599). The cid must decode to the right value (which keys input.chunks)
|
|
// and the payload must reassemble.
|
|
t.Run("single-chunk-boundaries", func(t *testing.T) {
|
|
for _, cid := range []chunkID{64, 319, 320, 65599} {
|
|
var in bytes.Buffer
|
|
in.Write(basicHeader(formatType0, cid))
|
|
// fmt0 message header: ts=0x40, len=3, video, stream=1, then payload a1a2a3.
|
|
in.Write([]byte{0x00, 0x00, 0x40, 0x00, 0x00, 0x03, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
in.Write([]byte{0xa1, 0xa2, 0xa3})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("cid=%v ReadMessage err=%v", cid, err)
|
|
}
|
|
if m.MessageType() != MessageTypeVideo || m.Timestamp() != 0x40 || !bytes.Equal(m.Payload(), []byte{0xa1, 0xa2, 0xa3}) {
|
|
t.Fatalf("cid=%v type=%v ts=%v payload=%x", cid, m.MessageType(), m.Timestamp(), m.Payload())
|
|
}
|
|
if _, ok := p.input.chunks[cid]; !ok {
|
|
t.Fatalf("cid=%v not keyed in chunks map: %v", cid, p.input.chunks)
|
|
}
|
|
}
|
|
})
|
|
|
|
// Multi-chunk message on a 3-byte cid: the fmt3 continuation must re-encode the same large cid
|
|
// in its basic header, so readBasicHeader is invoked again mid-message and must resolve to the
|
|
// same chunkStream for the payload to reassemble.
|
|
t.Run("multi-chunk-large-cid", func(t *testing.T) {
|
|
const cid chunkID = 65599
|
|
var in bytes.Buffer
|
|
// fmt0 on the 3-byte cid: ts=0x50, len=5, video, stream=1, first 3 payload bytes.
|
|
in.Write(basicHeader(formatType0, cid))
|
|
in.Write([]byte{0x00, 0x00, 0x50, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
in.Write([]byte{0xb1, 0xb2, 0xb3})
|
|
// fmt3 continuation, same 3-byte cid re-encoded, last 2 payload bytes -> message completes.
|
|
in.Write(basicHeader(formatType3, cid))
|
|
in.Write([]byte{0xb4, 0xb5})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
p.input.opt.chunkSize = 3
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if m.Timestamp() != 0x50 || !bytes.Equal(m.Payload(), []byte{0xb1, 0xb2, 0xb3, 0xb4, 0xb5}) {
|
|
t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload())
|
|
}
|
|
})
|
|
|
|
// The spec allows cid 64-319 in either the 2-byte or 3-byte form, and both must decode to the
|
|
// same cid. Read cid 64 first via its 2-byte form, then a second message via the (non-minimal)
|
|
// 3-byte form, and confirm both land on a single chunk stream so fmt1's delta accumulates over
|
|
// the first message (0x10 -> 0x15) rather than starting a second stream.
|
|
t.Run("cid-64-both-forms", func(t *testing.T) {
|
|
const cid chunkID = 64
|
|
var in bytes.Buffer
|
|
// 2-byte form fmt0: ts=0x10, len=1, video, stream=1, payload c1.
|
|
in.Write([]byte{0x00, 0x00}) // fmt0, 2-byte cid 64
|
|
in.Write([]byte{0x00, 0x00, 0x10, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xc1})
|
|
// 3-byte form fmt1 on the same cid: delta=5, len=1, video, payload c2 -> ts 0x10+5=0x15.
|
|
in.Write([]byte{0x41, 0x00, 0x00}) // fmt1, 3-byte cid 64
|
|
in.Write([]byte{0x00, 0x00, 0x05, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0xc2})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{0x10, []byte{0xc1}},
|
|
{0x15, []byte{0xc2}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("message #%v ReadMessage err=%v", i, err)
|
|
}
|
|
if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
if _, ok := p.input.chunks[cid]; !ok || len(p.input.chunks) != 1 {
|
|
t.Fatalf("both forms must share one chunk stream at cid %v: %v", cid, p.input.chunks)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestProtocolWritePacketReadMessageRoundTrip covers P3: every Packet type round-tripped through
|
|
// the full chunk-stream layer at a representative span of chunk sizes — 1 (every payload byte in
|
|
// its own chunk, the extreme case), 2 (small), 128 (default), and a value larger than any payload
|
|
// here (never chunks). For each (packet, chunk size) pair the packet is written via WritePacket,
|
|
// read back via ReadMessage, and decoded via DecodeMessage; the decoded packet must have the
|
|
// expected concrete type and re-marshal to the same wire bytes as the original. This proves the
|
|
// Go encoder and decoder agree end-to-end. The existing TestPacketRoundTripsAndErrors only checks
|
|
// MarshalBinary↔UnmarshalBinary in isolation (no chunk layer), and TestProtocolPacketsAndTransactions
|
|
// covers only a handful of packets at the default chunk size — so neither exercises the cross
|
|
// product of every typed packet against multi-chunk reassembly. Mirrors the C++
|
|
// srs_utest_manual_protocol.cpp ProtocolSendSrs*Packet family and
|
|
// srs_utest_manual_rtmp.cpp ProtocolRTMPTest.DecodeMessages / OnDecodeMessages family.
|
|
//
|
|
// Wire-byte equivalence is the equality check rather than reflect.DeepEqual: AMF0 packets carry an
|
|
// amf0ObjectBase.bufFactory func field that is non-nil in both sides, and reflect.DeepEqual treats
|
|
// any pair of non-nil funcs as not-equal. Comparing MarshalBinary() output is also the strongest
|
|
// practical assertion for "encoder and decoder agree on the wire," which is what P3 is testing.
|
|
//
|
|
// Some packets decode to a different concrete type than what we wrote. parseAMFObject's default
|
|
// branch returns NewCallPacket() for any AMF0 command name not in its special-case switch, so
|
|
// e.g. createStream comes back as *CallPacket. The variantCallPacket marshal layout is identical
|
|
// between CreateStreamPacket and a CallPacket carrying the same fields (Args nil is skipped on
|
|
// marshal), so the wire-bytes check still holds. wantType records the expected decoded type per
|
|
// case so the asymmetry is explicit.
|
|
func TestProtocolWritePacketReadMessageRoundTrip(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
// Builder functions for each packet, populated with non-default fields so the round-trip
|
|
// exercises real values instead of zero values. Use closures so the test can rebuild a fresh
|
|
// orig per (chunk size) iteration without state leaking between subtests.
|
|
makeConnect := func() Packet {
|
|
p := NewConnectAppPacket()
|
|
p.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live"))
|
|
p.CommandObject.Set("app", NewAmf0String("live"))
|
|
return p
|
|
}
|
|
makeConnectRes := func() Packet {
|
|
p := NewConnectAppResPacket(1) // tid matches makeConnect's default TransactionID=1
|
|
p.Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid")))
|
|
return p
|
|
}
|
|
makeCreateStream := func() Packet { return NewCreateStreamPacket() }
|
|
makeCreateStreamRes := func() Packet {
|
|
p := NewCreateStreamResPacket(2) // tid=2 matches makeCreateStream's TransactionID
|
|
p.SetStreamID(99)
|
|
return p
|
|
}
|
|
makeOnStatus := func() Packet {
|
|
p := NewCallPacket()
|
|
p.CommandName = commandOnStatus
|
|
p.TransactionID = 0
|
|
p.CommandObject = NewAmf0Null()
|
|
p.Args = NewAmf0Object().Set("level", NewAmf0String("status")).Set("code", NewAmf0String("NetStream.Play.Start"))
|
|
return p
|
|
}
|
|
makeReleaseStreamRes := func() Packet {
|
|
// _result for a releaseStream call: parseAMFObject returns NewCallPacket() with the tid
|
|
// pre-filled. We send the response shape that the C++ side sends back to FMLE.
|
|
p := NewCallPacket()
|
|
p.CommandName = commandResult
|
|
p.TransactionID = 4
|
|
p.CommandObject = NewAmf0Null()
|
|
return p
|
|
}
|
|
makePublish := func() Packet {
|
|
p := NewPublishPacket()
|
|
p.TransactionID = 0
|
|
p.StreamName = NewAmf0String("livestream")
|
|
p.StreamType = NewAmf0String("live")
|
|
return p
|
|
}
|
|
makePlay := func() Packet {
|
|
p := NewPlayPacket()
|
|
p.TransactionID = 0
|
|
p.StreamName = NewAmf0String("livestream")
|
|
return p
|
|
}
|
|
|
|
type seedFn func(*protocol)
|
|
seedConnect := func(p *protocol) { p.input.transactions[1] = commandConnect }
|
|
seedCreateStream := func(p *protocol) { p.input.transactions[2] = commandCreateStream }
|
|
seedRelease := func(p *protocol) { p.input.transactions[4] = commandReleaseStream }
|
|
|
|
cases := []struct {
|
|
name string
|
|
build func() Packet
|
|
wantType reflect.Type
|
|
// seed registers a tid → request-name mapping on the reader's transactions map so that
|
|
// parseAMFObject can resolve the concrete response type for *_result/_error packets.
|
|
seed seedFn
|
|
}{
|
|
{"connect-app", makeConnect, reflect.TypeOf((*ConnectAppPacket)(nil)), nil},
|
|
{"connect-app-res", makeConnectRes, reflect.TypeOf((*ConnectAppResPacket)(nil)), seedConnect},
|
|
{"create-stream", makeCreateStream, reflect.TypeOf((*CallPacket)(nil)), nil},
|
|
{"create-stream-res", makeCreateStreamRes, reflect.TypeOf((*CreateStreamResPacket)(nil)), seedCreateStream},
|
|
{"call-onstatus", makeOnStatus, reflect.TypeOf((*CallPacket)(nil)), nil},
|
|
{"call-result-releaseStream", makeReleaseStreamRes, reflect.TypeOf((*CallPacket)(nil)), seedRelease},
|
|
{"publish", makePublish, reflect.TypeOf((*PublishPacket)(nil)), nil},
|
|
{"play", makePlay, reflect.TypeOf((*PlayPacket)(nil)), nil},
|
|
{"set-chunk-size", func() Packet { return &SetChunkSize{ChunkSize: 4096} }, reflect.TypeOf((*SetChunkSize)(nil)), nil},
|
|
{"window-ack-size", func() Packet { return &WindowAcknowledgementSize{AckSize: 2500000} }, reflect.TypeOf((*WindowAcknowledgementSize)(nil)), nil},
|
|
{"set-peer-bandwidth", func() Packet {
|
|
return &SetPeerBandwidth{Bandwidth: 2500000, LimitType: LimitTypeDynamic}
|
|
}, reflect.TypeOf((*SetPeerBandwidth)(nil)), nil},
|
|
{"user-control-ping", func() Packet {
|
|
return &UserControl{EventType: EventTypePingRequest, EventData: 12345}
|
|
}, reflect.TypeOf((*UserControl)(nil)), nil},
|
|
{"user-control-buffer-len", func() Packet {
|
|
return &UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 1500}
|
|
}, reflect.TypeOf((*UserControl)(nil)), nil},
|
|
{"user-control-fms-event0", func() Packet {
|
|
return &UserControl{EventType: EventTypeFmsEvent0, EventData: 1}
|
|
}, reflect.TypeOf((*UserControl)(nil)), nil},
|
|
}
|
|
|
|
// Chunk sizes: 1 forces every payload byte into its own chunk (maximum c3 continuations);
|
|
// 2 is a small non-trivial chunking; 128 is the protocol default; 4096 is larger than every
|
|
// payload in this table (the connect packet is ~60 bytes), so the message is sent as a single
|
|
// chunk with no c3 continuations.
|
|
chunkSizes := []uint32{1, 2, 128, 4096}
|
|
|
|
for _, c := range cases {
|
|
for _, chunkSize := range chunkSizes {
|
|
t.Run(fmt.Sprintf("%s/chunk=%d", c.name, chunkSize), func(t *testing.T) {
|
|
orig := c.build()
|
|
origBytes, err := orig.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("MarshalBinary orig err=%v", err)
|
|
}
|
|
|
|
var wire bytes.Buffer
|
|
writer := NewProtocol(&wire).(*protocol)
|
|
writer.output.opt.chunkSize = chunkSize
|
|
if err := writer.WritePacket(ctx, orig, 1); err != nil {
|
|
t.Fatalf("WritePacket err=%v", err)
|
|
}
|
|
|
|
// Verify the writer actually chunked at this size: a payload longer than
|
|
// chunkSize must produce at least one c3 continuation chunk on the wire. The
|
|
// continuation header is a single byte 0xc0|cid; the basic-header byte for the
|
|
// initial c0 chunk is 0x0X|cid (fmt=0). Counting bytes that match the c3 form is
|
|
// a coarse but adequate check that the chunk path was actually exercised.
|
|
if uint32(len(origBytes)) > chunkSize {
|
|
wantContinuations := (uint32(len(origBytes))-1)/chunkSize + 1 - 1
|
|
var got uint32
|
|
for _, b := range wire.Bytes() {
|
|
if b == 0xc0|byte(orig.BetterCid()) {
|
|
got++
|
|
}
|
|
}
|
|
if got < wantContinuations {
|
|
t.Fatalf("expected >=%v c3 continuations on the wire, saw %v (wire=%x)", wantContinuations, got, wire.Bytes())
|
|
}
|
|
}
|
|
|
|
reader := NewProtocol(&wire).(*protocol)
|
|
reader.input.opt.chunkSize = chunkSize
|
|
if c.seed != nil {
|
|
c.seed(reader)
|
|
}
|
|
m, err := reader.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if m.MessageType() != orig.Type() {
|
|
t.Fatalf("message type=%v want=%v", m.MessageType(), orig.Type())
|
|
}
|
|
got, err := reader.DecodeMessage(m)
|
|
if err != nil {
|
|
t.Fatalf("DecodeMessage err=%v", err)
|
|
}
|
|
if reflect.TypeOf(got) != c.wantType {
|
|
t.Fatalf("decoded type=%T want=%v", got, c.wantType)
|
|
}
|
|
|
|
gotBytes, err := got.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("MarshalBinary got err=%v", err)
|
|
}
|
|
if !bytes.Equal(origBytes, gotBytes) {
|
|
t.Fatalf("round-trip mismatch:\n orig=%x\n got =%x", origBytes, gotBytes)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestReadMessageTimestampDiscontinuity covers P5: timestamp continuity edges on the per-cid
|
|
// chunkStream that the existing monotonically-increasing tests don't exercise. The C++
|
|
// ProtocolStackTest suite has no equivalent coverage either, so this is new coverage beyond the
|
|
// C++ reference.
|
|
//
|
|
// 1. backward-jump-fmt0: a new fmt0 message whose absolute timestamp is *smaller* than the
|
|
// previous message's timestamp must replace the stored Timestamp, not add to it. fmt0 is
|
|
// absolute by the spec; fmt1/2/3 deltas are unsigned, so a real backward jump can only ride
|
|
// on fmt0.
|
|
// 2. wraparound-31bit-mask: delta accumulation crossing the 31-bit boundary must wrap to the
|
|
// low 31 bits via `chunk.header.Timestamp &= 0x7fffffff`. ts 0x7ffffff0 + delta 0x20 =
|
|
// 0x80000010 -> masked 0x10. Easy to regress if anyone splits the accumulate-then-mask
|
|
// sequence.
|
|
// 3. forward-jump-fmt0: a new fmt0 message whose absolute timestamp is much *larger* than the
|
|
// previous one (carried in the extended timestamp because the 3-byte field saturates at
|
|
// 0xffffff) must also replace, not add. Mirror image of case 1; together they prove fmt0 is
|
|
// absolute regardless of direction.
|
|
func TestReadMessageTimestampDiscontinuity(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
t.Run("backward-jump-fmt0", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
// M1 fmt0 cid=5, ts=2000=0x7d0, len=2, video, stream=1, payload AA BB.
|
|
in.Write([]byte{0x05, 0x00, 0x07, 0xd0, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA, 0xBB})
|
|
// M2 fmt0 cid=5, ts=1000=0x3e8 (< 2000), len=2, video, stream=1, payload CC DD.
|
|
// fmt0 sets the message timestamp absolutely; if it were ever accidentally accumulated,
|
|
// M2.Timestamp would be 2000+1000=3000 instead of 1000.
|
|
in.Write([]byte{0x05, 0x00, 0x03, 0xe8, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xCC, 0xDD})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{2000, []byte{0xAA, 0xBB}},
|
|
{1000, []byte{0xCC, 0xDD}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("wraparound-31bit-mask", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
// M1 fmt0 cid=5, ts(3B)=0xffffff so an extended timestamp is present, len=1, video,
|
|
// stream=1, ext-ts=0x7ffffff0 (just below the 31-bit edge), payload AA. The 31-bit mask
|
|
// at the end of readMessageHeader leaves 0x7ffffff0 unchanged because bit 31 is clear,
|
|
// so chunk.header.Timestamp settles at 0x7ffffff0 between messages.
|
|
in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
binary.Write(&in, binary.BigEndian, uint32(0x7ffffff0))
|
|
in.Write([]byte{0xAA})
|
|
// M2 fmt1 cid=5, delta=0x20 (3-byte, no ext-ts because 0x20 < 0xffffff), len=1, video,
|
|
// payload BB. Accumulation: 0x7ffffff0 + 0x20 = 0x80000010 (bit 31 set), then the
|
|
// `chunk.header.Timestamp &= 0x7fffffff` mask drops bit 31 -> 0x10. Drop the mask and
|
|
// M2.Timestamp would surface as 0x80000010 instead.
|
|
in.Write([]byte{0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0xBB})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{0x7ffffff0, []byte{0xAA}},
|
|
{0x10, []byte{0xBB}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("forward-jump-fmt0", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
// M1 fmt0 cid=5, ts=10, len=1, video, stream=1, payload AA.
|
|
in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA})
|
|
// M2 fmt0 cid=5, ts(3B)=0xffffff (sentinel), len=1, video, stream=1,
|
|
// ext-ts=0x12345678 (large forward absolute, bit 31 clear so no mask interaction),
|
|
// payload BB. fmt0 replaces absolutely, so M2.Timestamp must be 0x12345678, not
|
|
// 10 + 0x12345678 = 0x12345682.
|
|
in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
binary.Write(&in, binary.BigEndian, uint32(0x12345678))
|
|
in.Write([]byte{0xBB})
|
|
|
|
p := NewProtocol(&in).(*protocol)
|
|
for i, want := range []struct {
|
|
ts uint64
|
|
pl []byte
|
|
}{
|
|
{10, []byte{0xAA}},
|
|
{0x12345678, []byte{0xBB}},
|
|
} {
|
|
m, err := p.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
|
}
|
|
if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
|
t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestReadWriteLargePayloadChunkBoundaries covers P4: large-payload and chunk-boundary stress on
|
|
// both the write (chunking) and read (reassembly) paths. The existing tests only write a single
|
|
// 5000-byte message and never read a large multi-chunk message back, nor exercise the exact
|
|
// payload==chunkSize / payload==N*chunkSize boundaries where an off-by-one in the chunking loop
|
|
// would surface as a spurious empty trailing chunk. The 3-byte max length (0xffffff) parse is also
|
|
// pinned here; its DoS/truncation behavior is separately covered by
|
|
// TestPacketUnmarshalAdversarialInputs (P8, oversized-length-truncated).
|
|
//
|
|
// C++ reference:
|
|
// - srs_utest_manual_rtmp.cpp :: TEST(ProtocolRTMPTest, HugeMessages) — a 256B audio payload
|
|
// at chunkSize=128 serializes to exactly 269 wire bytes (12B c0 header + 128 + 1B c3 + 128),
|
|
// i.e. no empty trailing chunk.
|
|
// - srs_utest_manual_rtmp.cpp :: TEST(ProtocolRTMPTest, SendHugePacket) — a 1024B send.
|
|
// - srs_utest_manual_protocol.cpp :: TEST(ProtocolStackTest, ProtocolRecvVMessage2Trunk) — read a
|
|
// 272B video message split across 3 chunks (2 c3 continuations) at chunkSize=128.
|
|
func TestReadWriteLargePayloadChunkBoundaries(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
// payload-equals-chunksize: a payload of exactly chunkSize fits in a single chunk. The write
|
|
// loop (`for len(p) > 0`) must emit zero c3 continuations — emitting an empty trailing chunk
|
|
// for the 0 bytes "remaining" after the first full chunk would be an off-by-one bug. The exact
|
|
// wire length (1B basic + 11B msg header + chunkSize payload) proves the single-chunk shape,
|
|
// and the read path reassembles back to the original payload.
|
|
t.Run("payload-equals-chunksize", func(t *testing.T) {
|
|
const chunkSize = 128
|
|
payload := make([]byte, chunkSize)
|
|
for i := range payload {
|
|
payload[i] = byte(i)
|
|
}
|
|
|
|
var wire bytes.Buffer
|
|
w := NewProtocol(&wire).(*protocol)
|
|
w.output.opt.chunkSize = chunkSize
|
|
m := NewStreamMessage(1).asMessage()
|
|
m.messageHeader.MessageType = MessageTypeVideo
|
|
m.messageHeader.Timestamp = 40
|
|
m.payload = payload
|
|
if err := w.WriteMessage(ctx, m); err != nil {
|
|
t.Fatalf("WriteMessage err=%v", err)
|
|
}
|
|
if want := 1 + 11 + chunkSize; wire.Len() != want {
|
|
t.Fatalf("wire len=%v want=%v (a spurious empty trailing chunk would add a c3 byte)", wire.Len(), want)
|
|
}
|
|
|
|
r := NewProtocol(&wire).(*protocol)
|
|
r.input.opt.chunkSize = chunkSize
|
|
got, err := r.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if got.Timestamp() != 40 || !bytes.Equal(got.Payload(), payload) {
|
|
t.Fatalf("ts=%v payloadLen=%v", got.Timestamp(), len(got.Payload()))
|
|
}
|
|
})
|
|
|
|
// payload-exact-multiple: a payload that is an exact multiple of chunkSize (256 = 2*128) must
|
|
// serialize to exactly two chunks — c0+128 then c3+128 — and NOT a third empty trailing chunk.
|
|
// This mirrors the C++ HugeMessages golden: 269 wire bytes for a 256B payload with ts<0xffffff
|
|
// (no extended timestamp). The continuation header at offset 140 must be the 1-byte c3 form,
|
|
// and the message reassembles back to the original payload.
|
|
t.Run("payload-exact-multiple", func(t *testing.T) {
|
|
const chunkSize = 128
|
|
payload := make([]byte, 2*chunkSize)
|
|
for i := range payload {
|
|
payload[i] = byte(i)
|
|
}
|
|
|
|
var wire bytes.Buffer
|
|
w := NewProtocol(&wire).(*protocol)
|
|
w.output.opt.chunkSize = chunkSize
|
|
m := NewStreamMessage(1).asMessage()
|
|
m.messageHeader.MessageType = MessageTypeAudio
|
|
m.messageHeader.Timestamp = 1000
|
|
m.payload = payload
|
|
if err := w.WriteMessage(ctx, m); err != nil {
|
|
t.Fatalf("WriteMessage err=%v", err)
|
|
}
|
|
// 12B c0 header + 128 + 1B c3 header + 128 = 269 (the C++ HugeMessages value). 270 would
|
|
// mean an empty trailing chunk was emitted.
|
|
if want := 1 + 11 + chunkSize + 1 + chunkSize; wire.Len() != want {
|
|
t.Fatalf("wire len=%v want=%v", wire.Len(), want)
|
|
}
|
|
if got := wire.Bytes()[1+11+chunkSize]; got != 0xc0|byte(chunkIDOverStream) {
|
|
t.Fatalf("continuation header=%#x want=%#x (1-byte c3 form)", got, 0xc0|byte(chunkIDOverStream))
|
|
}
|
|
|
|
r := NewProtocol(&wire).(*protocol)
|
|
r.input.opt.chunkSize = chunkSize
|
|
got, err := r.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if got.Timestamp() != 1000 || !bytes.Equal(got.Payload(), payload) {
|
|
t.Fatalf("ts=%v payloadLen=%v", got.Timestamp(), len(got.Payload()))
|
|
}
|
|
})
|
|
|
|
// read-multichunk-handbuilt: read a 300-byte video message split across 3 chunks (c0 + two c3
|
|
// continuations) at chunkSize=128, against a hand-built wire layout rather than this package's
|
|
// own writer output — directly mirroring the C++ ProtocolRecvVMessage2Trunk reassembly test.
|
|
// Proves readMessagePayload accumulates across multiple c3 continuations into one message.
|
|
t.Run("read-multichunk-handbuilt", func(t *testing.T) {
|
|
const chunkSize = 128
|
|
payload := make([]byte, 300)
|
|
for i := range payload {
|
|
payload[i] = byte(i)
|
|
}
|
|
|
|
var in bytes.Buffer
|
|
// fmt0 cid=3, ts=0, len=300 (0x00012c), video, stream=0, then payload[0:128].
|
|
in.Write([]byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x2c, byte(MessageTypeVideo), 0x00, 0x00, 0x00, 0x00})
|
|
in.Write(payload[0:chunkSize])
|
|
in.Write([]byte{0xc3}) // fmt3 continuation, cid=3.
|
|
in.Write(payload[chunkSize : 2*chunkSize])
|
|
in.Write([]byte{0xc3}) // fmt3 continuation, cid=3.
|
|
in.Write(payload[2*chunkSize:])
|
|
|
|
r := NewProtocol(&in).(*protocol)
|
|
r.input.opt.chunkSize = chunkSize
|
|
got, err := r.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if got.MessageType() != MessageTypeVideo || !bytes.Equal(got.Payload(), payload) {
|
|
t.Fatalf("type=%v payloadLen=%v", got.MessageType(), len(got.Payload()))
|
|
}
|
|
})
|
|
|
|
// read-large-roundtrip: a 5000-byte payload at chunkSize=128 spans 40 chunks (39 c3
|
|
// continuations). The exact wire length (12B c0 + 5000 payload + 39 c3 bytes = 5051) proves the
|
|
// writer chunked into exactly 40 chunks with no empty trailing chunk, and the reader reassembles
|
|
// the full 5000 bytes back. This is the "many c3 continuations + large multi-chunk read" case
|
|
// the existing 5000-byte write test never read back.
|
|
t.Run("read-large-roundtrip", func(t *testing.T) {
|
|
const chunkSize = 128
|
|
payload := make([]byte, 5000)
|
|
for i := range payload {
|
|
payload[i] = byte(i % 251)
|
|
}
|
|
// 5000 bytes: first chunk carries 128, the remaining 4872 take ceil(4872/128)=39 c3 chunks.
|
|
const wantContinuations = 39
|
|
|
|
var wire bytes.Buffer
|
|
w := NewProtocol(&wire).(*protocol)
|
|
w.output.opt.chunkSize = chunkSize
|
|
m := NewStreamMessage(1).asMessage()
|
|
m.messageHeader.MessageType = MessageTypeVideo
|
|
m.messageHeader.Timestamp = 12345
|
|
m.payload = payload
|
|
if err := w.WriteMessage(ctx, m); err != nil {
|
|
t.Fatalf("WriteMessage err=%v", err)
|
|
}
|
|
if want := 12 + len(payload) + wantContinuations; wire.Len() != want {
|
|
t.Fatalf("wire len=%v want=%v", wire.Len(), want)
|
|
}
|
|
|
|
r := NewProtocol(&wire).(*protocol)
|
|
r.input.opt.chunkSize = chunkSize
|
|
got, err := r.ReadMessage(ctx)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage err=%v", err)
|
|
}
|
|
if got.Timestamp() != 12345 || !bytes.Equal(got.Payload(), payload) {
|
|
t.Fatalf("ts=%v payloadLen=%v", got.Timestamp(), len(got.Payload()))
|
|
}
|
|
})
|
|
|
|
// max-3byte-length-parse: a fmt0 header declaring payloadLength = 0xffffff (the 3-byte field
|
|
// maxed out) must decode to exactly 0xffffff. Drive the header parse directly — reassembling
|
|
// 16MiB of payload is pointless, and the DoS/truncation behavior at this length is covered by
|
|
// TestPacketUnmarshalAdversarialInputs (P8). A shift/mask regression in the length decode would
|
|
// corrupt the value, which is what this pins.
|
|
t.Run("max-3byte-length-parse", func(t *testing.T) {
|
|
// fmt0 cid=5, ts=0 (so no extended timestamp), len=0xffffff, video, stream=1.
|
|
in := bytes.NewBuffer([]byte{0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
p := NewProtocol(in).(*protocol)
|
|
format, cid, err := p.readBasicHeader(ctx)
|
|
if err != nil {
|
|
t.Fatalf("readBasicHeader err=%v", err)
|
|
}
|
|
// Mirror ReadMessage's chunkStream setup.
|
|
chunk := newChunkStream()
|
|
p.input.chunks[cid] = chunk
|
|
chunk.header.betterCid = cid
|
|
if err := p.readMessageHeader(ctx, chunk, format); err != nil {
|
|
t.Fatalf("readMessageHeader err=%v", err)
|
|
}
|
|
if chunk.message.payloadLength != 0xffffff {
|
|
t.Fatalf("payloadLength=%#x want=0xffffff", chunk.message.payloadLength)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestGoldenWireBytes covers P6: golden wire-byte regression for the chunk headers and control
|
|
// packets. The existing TestWriteMessageHeadersChunkingAndErrors pins golden bytes for one
|
|
// extended-timestamp video message only; this locks the remaining wire shapes a refactor could
|
|
// silently change — both forms of the C0 and C3 headers, and the full on-wire framing of every
|
|
// control packet.
|
|
//
|
|
// C++ reference (send/golden):
|
|
//
|
|
// srs_utest_manual_protocol.cpp :: TEST(ProtocolStackTest,
|
|
// ProtocolSendSrsSetChunkSizePacket / ProtocolSendSrsSetWindowAckSizePacket /
|
|
// ProtocolSendSrsSetPeerBandwidthPacket / ProtocolSendSrsUserControlPacket)
|
|
//
|
|
// The control-packet payload bytes below match those C++ goldens exactly; the only difference is
|
|
// the chunk basic-header byte — Go frames protocol-control packets on cid=2 (-> 0x02) where the
|
|
// C++ goldens show 0x03 — so the payload portion is the cross-implementation wire-format invariant.
|
|
func TestGoldenWireBytes(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
// C0/C3 chunk headers, both timestamp forms. A message on cid=5, video, stream=7, payload
|
|
// length 5. With ts < 0xffffff the C0 header is 12 bytes carrying the timestamp inline; with
|
|
// ts >= 0xffffff the 3-byte field saturates to 0xffffff and a 4-byte extended timestamp is
|
|
// appended (16 bytes total). The C3 header is the 1-byte continuation form normally, but
|
|
// inherits the same extended-timestamp quirk (Adobe always re-sends it), so it grows to 5
|
|
// bytes when ts >= 0xffffff.
|
|
t.Run("chunk-headers", func(t *testing.T) {
|
|
// MessageType and Timestamp are method names on *message, shadowing the promoted
|
|
// messageHeader fields, so set those two through the embedded struct explicitly.
|
|
shortTs := &message{}
|
|
shortTs.betterCid = chunkIDOverStream // 5
|
|
shortTs.messageHeader.MessageType = MessageTypeVideo
|
|
shortTs.streamID = 7
|
|
shortTs.payloadLength = 5
|
|
shortTs.messageHeader.Timestamp = 0x0a
|
|
|
|
extTs := &message{}
|
|
extTs.betterCid = chunkIDOverStream
|
|
extTs.messageHeader.MessageType = MessageTypeVideo
|
|
extTs.streamID = 7
|
|
extTs.payloadLength = 5
|
|
extTs.messageHeader.Timestamp = extendedTimestamp + 9 // 0x01000008, >= 0xffffff
|
|
|
|
cases := []struct {
|
|
name string
|
|
gen func() ([]byte, error)
|
|
want []byte
|
|
}{
|
|
// basic(0x05) | ts(00 00 0a) | len(00 00 05) | type(09) | streamID LE(07 00 00 00)
|
|
{"c0-short-ts", shortTs.generateC0Header, []byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x07, 0x00, 0x00, 0x00}},
|
|
// ts field saturates to ff ff ff; ext-ts(01 00 00 08) appended after streamID.
|
|
{"c0-ext-ts", extTs.generateC0Header, []byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x08}},
|
|
// 1-byte continuation: 0xc0 | cid.
|
|
{"c3-short-ts", shortTs.generateC3Header, []byte{0xc5}},
|
|
// continuation + re-sent 4-byte ext-ts.
|
|
{"c3-ext-ts", extTs.generateC3Header, []byte{0xc5, 0x01, 0x00, 0x00, 0x08}},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
got, err := c.gen()
|
|
if err != nil {
|
|
t.Fatalf("gen err=%v", err)
|
|
}
|
|
if !bytes.Equal(got, c.want) {
|
|
t.Fatalf("got=%x want=%x", got, c.want)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
// Control packets, full on-wire framing via WritePacket. WritePacket frames each control packet
|
|
// on cid=2 (chunkIDProtocolControl) with ts=0 and streamID=0, so the wire is the 12-byte
|
|
// short-ts C0 header followed by the marshaled payload (all shorter than the default chunk size,
|
|
// hence a single chunk). The payload values mirror the C++ send goldens.
|
|
t.Run("control-packets", func(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
pkt Packet
|
|
want []byte
|
|
}{
|
|
// SetChunkSize 1024=0x00000400, type 0x01, len 4.
|
|
{
|
|
"set-chunk-size",
|
|
&SetChunkSize{ChunkSize: 1024},
|
|
[]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00},
|
|
},
|
|
// WindowAcknowledgementSize 102400=0x00019000, type 0x05, len 4.
|
|
{
|
|
"window-ack-size",
|
|
&WindowAcknowledgementSize{AckSize: 102400},
|
|
[]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x90, 0x00},
|
|
},
|
|
// SetPeerBandwidth 1024=0x00000400 + limit soft(0x01), type 0x06, len 5.
|
|
{
|
|
"set-peer-bandwidth",
|
|
&SetPeerBandwidth{Bandwidth: 1024, LimitType: LimitTypeSoft},
|
|
[]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x01},
|
|
},
|
|
// UserControl SetBufferLength: event-type 0x0003, event-data 0x00000001,
|
|
// extra-data 0x00000010; type 0x04, len 10=0x0a.
|
|
{
|
|
"user-control-set-buffer-length",
|
|
&UserControl{EventType: EventTypeSetBufferLength, EventData: 0x01, ExtraData: 0x10},
|
|
[]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10},
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
var wire bytes.Buffer
|
|
p := NewProtocol(&wire).(*protocol)
|
|
if err := p.WritePacket(ctx, c.pkt, 0); err != nil {
|
|
t.Fatalf("WritePacket err=%v", err)
|
|
}
|
|
if !bytes.Equal(wire.Bytes(), c.want) {
|
|
t.Fatalf("got=%x want=%x", wire.Bytes(), c.want)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
// UserControl payload, all three Size() branches. The marshaler has three event-data shapes: a
|
|
// normal 4-byte event-data (e.g. PingRequest), an 8-byte event-data with the extra
|
|
// buffer-length word (SetBufferLength), and the special 1-byte event-data for FmsEvent0
|
|
// (0x001a). Pin the raw payload bytes for each.
|
|
t.Run("user-control-event-forms", func(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
pkt *UserControl
|
|
want []byte
|
|
}{
|
|
// event-type 0x0006, event-data 0x12345678 (4 bytes).
|
|
{"ping-request-4byte", &UserControl{EventType: EventTypePingRequest, EventData: 0x12345678}, []byte{0x00, 0x06, 0x12, 0x34, 0x56, 0x78}},
|
|
// event-type 0x0003, event-data 0x00000001, extra 0x000005dc (8 bytes total).
|
|
{"set-buffer-length-8byte", &UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 1500}, []byte{0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x05, 0xdc}},
|
|
// event-type 0x001a, event-data 0x01 (1 byte).
|
|
{"fms-event0-1byte", &UserControl{EventType: EventTypeFmsEvent0, EventData: 0x01}, []byte{0x00, 0x1a, 0x01}},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
got, err := c.pkt.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("MarshalBinary err=%v", err)
|
|
}
|
|
if !bytes.Equal(got, c.want) {
|
|
t.Fatalf("got=%x want=%x", got, c.want)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// === P7: Fuzz targets ===
|
|
//
|
|
// Three Go native fuzz targets covering the untrusted-input parsers in this package.
|
|
// FuzzReadMessage — (*protocol).ReadMessage on arbitrary wire bytes.
|
|
// FuzzDecodeMessage — (*protocol).DecodeMessage on arbitrary (MessageType, payload).
|
|
// FuzzPacketUnmarshal — *Packet.UnmarshalBinary on arbitrary bytes across every Packet type.
|
|
//
|
|
// Each target's contract is "no panic". Termination is guaranteed: every fuzz body reads
|
|
// from a finite bytes.Buffer and caps input size to keep iterations cheap. Real OOM /
|
|
// resource-exhaustion surfaces (e.g. an attacker-controlled SetChunkSize followed by a
|
|
// large payload length forcing a multi-MB make) are intentionally NOT pre-guarded here —
|
|
// fuzzing is how P8's adversarial cases get discovered.
|
|
//
|
|
// Seeds come from the existing happy-path tests so the fuzzer starts at valid wire bytes
|
|
// and explores the nearby malformed space.
|
|
|
|
// fuzzInputCap bounds the input fuzz can feed each iteration. The cap keeps single
|
|
// iterations under a millisecond on a laptop and stops the corpus from growing
|
|
// arbitrarily — it is not a security boundary.
|
|
const fuzzInputCap = 8 * 1024
|
|
|
|
// FuzzReadMessage drives the full chunk-stream reader against arbitrary bytes. The
|
|
// target asserts no panic across readBasicHeader (1/2/3-byte cid), readMessageHeader
|
|
// (every fmt + ext-ts), readMessagePayload (chunked reassembly), and onMessageArrivated
|
|
// (SetChunkSize side effect on subsequent reads).
|
|
func FuzzReadMessage(f *testing.F) {
|
|
// Seed 1: a single fmt0 audio message on cid=5, ts=10, len=3.
|
|
f.Add([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3})
|
|
|
|
// Seed 2: fmt0 -> fmt1 -> fmt2 -> fmt3 sequence on cid=5
|
|
// (lifted from TestReadMessageHeadersPayloadsAndChunks).
|
|
{
|
|
var s bytes.Buffer
|
|
s.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3})
|
|
s.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5})
|
|
s.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7})
|
|
s.Write([]byte{0xc5, 8, 9})
|
|
f.Add(s.Bytes())
|
|
}
|
|
|
|
// Seed 3: extended-timestamp fmt0 with payload split across chunks at the default
|
|
// chunk size. Exercises the ext-ts read + accumulate path.
|
|
{
|
|
var s bytes.Buffer
|
|
s.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
binary.Write(&s, binary.BigEndian, uint32(42))
|
|
s.Write([]byte{1, 2, 3, 4, 5})
|
|
f.Add(s.Bytes())
|
|
}
|
|
|
|
// Seed 4: 2-byte cid header (cid=74) wrapping a complete fmt0 message. Exercises
|
|
// the 2-byte readBasicHeader branch end-to-end.
|
|
f.Add([]byte{0x00, 0x0a, 0x00, 0x00, 0x05, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA})
|
|
|
|
// Seed 5: SetChunkSize=128 followed by an audio message on cid=5. Exercises the
|
|
// onMessageArrivated -> input.chunkSize update path.
|
|
{
|
|
var s bytes.Buffer
|
|
s.Write([]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, byte(MessageTypeSetChunkSize), 0x00, 0x00, 0x00, 0x00})
|
|
binary.Write(&s, binary.BigEndian, uint32(128))
|
|
s.Write([]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0xAA})
|
|
f.Add(s.Bytes())
|
|
}
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte) {
|
|
if len(data) > fuzzInputCap {
|
|
return
|
|
}
|
|
ctx := context.Background()
|
|
p := NewProtocol(bytes.NewBuffer(append([]byte(nil), data...)))
|
|
// bytes.Buffer EOFs deterministically once drained, so a small cap on the
|
|
// number of messages we accept here is enough to bound the iteration.
|
|
for range 16 {
|
|
if _, err := p.ReadMessage(ctx); err != nil {
|
|
return
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// FuzzDecodeMessage drives DecodeMessage with an arbitrary (MessageType, payload) pair.
|
|
// It covers the SetChunkSize / WindowAcknowledgementSize / SetPeerBandwidth / UserControl
|
|
// branches and, for AMF0/AMF3 command/data types, the parseAMFObject dispatch into every
|
|
// concrete *Packet's UnmarshalBinary.
|
|
func FuzzDecodeMessage(f *testing.F) {
|
|
// Seed control-message payloads at their exact required sizes.
|
|
f.Add(uint8(MessageTypeSetChunkSize), []byte{0, 0, 0, 128})
|
|
f.Add(uint8(MessageTypeWindowAcknowledgementSize), []byte{0, 0, 0x10, 0})
|
|
f.Add(uint8(MessageTypeSetPeerBandwidth), []byte{0, 0, 0x10, 0, byte(LimitTypeDynamic)})
|
|
// UserControl PingRequest: 2B event-type + 4B data.
|
|
f.Add(uint8(MessageTypeUserControl), []byte{0x00, byte(EventTypePingRequest), 0x00, 0x00, 0x00, 0x01})
|
|
|
|
// Seed an AMF0 command with a well-formed publish packet. Also seed the AMF3 form,
|
|
// which differs only by a leading byte that DecodeMessage strips.
|
|
pub := NewPublishPacket()
|
|
pub.TransactionID = 0
|
|
pub.StreamName = NewAmf0String("s")
|
|
pubBytes, err := pub.MarshalBinary()
|
|
if err != nil {
|
|
f.Fatalf("seed marshal publish: %v", err)
|
|
}
|
|
f.Add(uint8(MessageTypeAMF0Command), pubBytes)
|
|
f.Add(uint8(MessageTypeAMF3Command), append([]byte{0}, pubBytes...))
|
|
|
|
f.Fuzz(func(t *testing.T, mtype uint8, payload []byte) {
|
|
if len(payload) > fuzzInputCap {
|
|
return
|
|
}
|
|
p := NewProtocol(&bytes.Buffer{})
|
|
m := &message{payload: payload}
|
|
m.messageHeader.MessageType = MessageType(mtype)
|
|
_, _ = p.DecodeMessage(m)
|
|
})
|
|
}
|
|
|
|
// FuzzPacketUnmarshal drives every Packet's UnmarshalBinary against arbitrary bytes.
|
|
// One target, dispatched by a kind discriminator, so the fuzzer can share a corpus and
|
|
// mutate the kind alongside the bytes.
|
|
func FuzzPacketUnmarshal(f *testing.F) {
|
|
// Build a seed per packet type from its own round-trippable bytes.
|
|
type seed struct {
|
|
kind uint8
|
|
pkt Packet
|
|
}
|
|
connRes := NewConnectAppResPacket(7)
|
|
connRes.Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid")))
|
|
call := NewCallPacket()
|
|
call.CommandName = commandOnStatus
|
|
call.TransactionID = 0
|
|
call.CommandObject = NewAmf0Null()
|
|
pub := NewPublishPacket()
|
|
pub.TransactionID = 0
|
|
pub.StreamName = NewAmf0String("s")
|
|
play := NewPlayPacket()
|
|
play.TransactionID = 0
|
|
play.StreamName = NewAmf0String("s")
|
|
seeds := []seed{
|
|
{0, NewConnectAppPacket()},
|
|
{1, connRes},
|
|
{2, call},
|
|
{3, NewCreateStreamPacket()},
|
|
{4, NewCreateStreamResPacket(2)},
|
|
{5, pub},
|
|
{6, play},
|
|
{7, &SetChunkSize{ChunkSize: 128}},
|
|
{8, &WindowAcknowledgementSize{AckSize: 2500000}},
|
|
{9, &SetPeerBandwidth{Bandwidth: 2500000, LimitType: LimitTypeDynamic}},
|
|
{10, &UserControl{EventType: EventTypePingRequest, EventData: 1}},
|
|
}
|
|
for _, s := range seeds {
|
|
b, err := s.pkt.MarshalBinary()
|
|
if err != nil {
|
|
f.Fatalf("seed marshal %T: %v", s.pkt, err)
|
|
}
|
|
f.Add(s.kind, b)
|
|
}
|
|
|
|
f.Fuzz(func(t *testing.T, kind uint8, data []byte) {
|
|
if len(data) > fuzzInputCap {
|
|
return
|
|
}
|
|
var pkt Packet
|
|
switch kind % 11 {
|
|
case 0:
|
|
pkt = NewConnectAppPacket()
|
|
case 1:
|
|
pkt = NewConnectAppResPacket(0)
|
|
case 2:
|
|
pkt = NewCallPacket()
|
|
case 3:
|
|
pkt = NewCreateStreamPacket()
|
|
case 4:
|
|
pkt = NewCreateStreamResPacket(0)
|
|
case 5:
|
|
pkt = NewPublishPacket()
|
|
case 6:
|
|
pkt = NewPlayPacket()
|
|
case 7:
|
|
pkt = NewSetChunkSize()
|
|
case 8:
|
|
pkt = NewWindowAcknowledgementSize()
|
|
case 9:
|
|
pkt = NewSetPeerBandwidth()
|
|
case 10:
|
|
pkt = NewUserControl()
|
|
}
|
|
_ = pkt.UnmarshalBinary(data)
|
|
})
|
|
}
|
|
|
|
// TestPacketUnmarshalAdversarialInputs covers P8: malformed and truncated wire input
|
|
// must never panic the parser, only error. The P7 fuzzers found a panic class where a
|
|
// New*Packet constructor pre-set an optional AMF0 field (variantCallPacket.CommandObject
|
|
// or CallPacket.Args), and Size() then counted that phantom default even when the wire
|
|
// was exhausted before it — so the caller's p = p[Size():] advance sliced out of range
|
|
// (rtmp.go:1512, "slice bounds out of range"). These cases lock in the fix; the two
|
|
// minimized fuzz inputs are also committed under testdata/fuzz as regression corpus.
|
|
func TestPacketUnmarshalAdversarialInputs(t *testing.T) {
|
|
// safeUnmarshal runs UnmarshalBinary and turns any panic into an immediate failure.
|
|
safeUnmarshal := func(t *testing.T, pkt Packet, data []byte) (err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Fatalf("panic on %T with %x: %v", pkt, data, r)
|
|
}
|
|
}()
|
|
return pkt.UnmarshalBinary(data)
|
|
}
|
|
|
|
// The two inputs the P7 fuzzers minimized to. Both are a "publish" command name + a
|
|
// number transaction id with nothing after, so the optional command object is absent
|
|
// and must not be counted by Size(). They previously panicked; expect a clean error.
|
|
t.Run("fuzz-crashers", func(t *testing.T) {
|
|
// FuzzPacketUnmarshal/2b0534f8182fac96: direct PublishPacket.UnmarshalBinary.
|
|
if err := safeUnmarshal(t, NewPublishPacket(), []byte("\x02\x00\x00\x0000000000")); err == nil {
|
|
t.Fatalf("truncated publish: want error, got nil")
|
|
}
|
|
// FuzzDecodeMessage/20ed1884f5b4f009: the AMF3 form reaches the same packet via
|
|
// DecodeMessage, which strips the leading AMF3 byte ('0'=0x30) before dispatching.
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Fatalf("DecodeMessage panic: %v", r)
|
|
}
|
|
}()
|
|
p := NewProtocol(&bytes.Buffer{})
|
|
m := &message{payload: []byte("0\x02\x00\apublish\x0000000000")}
|
|
m.messageHeader.MessageType = MessageTypeAMF3Command
|
|
if _, err := p.DecodeMessage(m); err == nil {
|
|
t.Fatalf("truncated AMF3 publish: want error, got nil")
|
|
}
|
|
}()
|
|
})
|
|
|
|
// For every variantCallPacket-derived packet, marshal a valid instance then feed every
|
|
// truncation of its bytes back to a fresh packet. No prefix may panic, and the
|
|
// full-length bytes must still round-trip.
|
|
t.Run("truncations", func(t *testing.T) {
|
|
call := NewCallPacket()
|
|
call.CommandName = commandOnStatus
|
|
call.TransactionID = 0
|
|
call.CommandObject = NewAmf0Null()
|
|
pub := NewPublishPacket()
|
|
pub.TransactionID = 0
|
|
pub.StreamName = NewAmf0String("s")
|
|
play := NewPlayPacket()
|
|
play.TransactionID = 0
|
|
play.StreamName = NewAmf0String("s")
|
|
cases := []struct {
|
|
name string
|
|
full Packet
|
|
fresh func() Packet
|
|
}{
|
|
{"call", call, func() Packet { return NewCallPacket() }},
|
|
{"publish", pub, func() Packet { return NewPublishPacket() }},
|
|
{"play", play, func() Packet { return NewPlayPacket() }},
|
|
{"createStreamRes", NewCreateStreamResPacket(2), func() Packet { return NewCreateStreamResPacket(0) }},
|
|
}
|
|
for _, c := range cases {
|
|
b, err := c.full.MarshalBinary()
|
|
if err != nil {
|
|
t.Fatalf("%v marshal: %v", c.name, err)
|
|
}
|
|
for n := 0; n <= len(b); n++ {
|
|
err := safeUnmarshal(t, c.fresh(), b[:n])
|
|
if n == len(b) && err != nil {
|
|
t.Fatalf("%v full unmarshal: %v", c.name, err)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
// An oversized declared message length with a truncated stream must error via the
|
|
// incremental chunk read (readMessagePayload caps each read at chunkSize and lets
|
|
// io.ReadFull fail), not allocate ~16MB up front or hang. The header declares
|
|
// payloadLength = 0xffffff but only four payload bytes follow.
|
|
t.Run("oversized-length-truncated", func(t *testing.T) {
|
|
var in bytes.Buffer
|
|
in.Write([]byte{0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
|
in.Write([]byte{1, 2, 3, 4})
|
|
p := NewProtocol(&in)
|
|
if _, err := p.ReadMessage(context.Background()); err == nil {
|
|
t.Fatalf("oversized truncated message: want error, got nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
// P9 — concurrency / race on the transaction map.
|
|
//
|
|
// WritePacket registers a tid -> request-name entry under input.ltransactions
|
|
// (via onPacketWriten); parseAMFObject, reached from DecodeMessage when a
|
|
// _result/_error arrives, reads and deletes that entry under the same lock.
|
|
// This test hammers both paths from two goroutines over an overlapping tid
|
|
// space so the writes and reads/deletes genuinely interleave on
|
|
// input.transactions. Run with -race to validate the locking; even without it,
|
|
// the Go runtime panics on unsynchronized concurrent map access, so a dropped
|
|
// lock in either path fails the test.
|
|
//
|
|
// The shared state is only the transactions map: WritePacket touches just the
|
|
// writer (io.Discard here) and DecodeMessage operates on the Message it is
|
|
// handed, never the reader. Read-side "No matched request" errors are expected
|
|
// and tolerated — a tid may not be registered yet or may already be consumed;
|
|
// the assertion is no race and no panic, not that every lookup hits.
|
|
//
|
|
// No C++ reference: the C++ ProtocolStackTest suite has no concurrency test.
|
|
// New coverage.
|
|
func TestProtocolTransactionMapConcurrency(t *testing.T) {
|
|
const (
|
|
keys = 16 // tid space both goroutines cycle through (overlap forces hits)
|
|
iterations = 4000 // per goroutine
|
|
)
|
|
|
|
// WritePacket only writes to v.w; DecodeMessage only reads the Message it is
|
|
// given. io.Discard keeps the single-writer goroutine from growing a buffer.
|
|
rw := struct {
|
|
io.Reader
|
|
io.Writer
|
|
}{strings.NewReader(""), io.Discard}
|
|
p := NewProtocol(rw).(*protocol)
|
|
|
|
// Pre-build the _result bytes per tid so the read loop exercises only the
|
|
// map access in parseAMFObject, not AMF marshaling. releaseStream is one of
|
|
// the request names parseAMFObject resolves a _result against, so a hit
|
|
// returns a CallPacket and deletes the entry.
|
|
resultBytes := make([][]byte, keys+1)
|
|
for tid := 1; tid <= keys; tid++ {
|
|
res := NewCallPacket()
|
|
res.CommandName = commandResult
|
|
res.TransactionID = amf0Number(tid)
|
|
res.CommandObject = NewAmf0Null()
|
|
resultBytes[tid] = mustPacketBytes(t, res)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
// Writer: registers tid -> releaseStream under the lock, cycling the tid space.
|
|
go func() {
|
|
defer wg.Done()
|
|
for i := 0; i < iterations; i++ {
|
|
call := NewCallPacket()
|
|
call.CommandName = commandReleaseStream
|
|
call.TransactionID = amf0Number(i%keys + 1)
|
|
call.CommandObject = NewAmf0Null()
|
|
if err := p.WritePacket(ctx, call, 0); err != nil {
|
|
t.Errorf("WritePacket err=%v", err)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Reader: decodes _result messages whose tids overlap the writer's range.
|
|
// Hits delete the entry; misses return "No matched request". Both take the lock.
|
|
go func() {
|
|
defer wg.Done()
|
|
for i := 0; i < iterations; i++ {
|
|
msg := &message{payload: resultBytes[i%keys+1]}
|
|
msg.messageHeader.MessageType = MessageTypeAMF0Command
|
|
if _, err := p.DecodeMessage(msg); err != nil &&
|
|
!strings.Contains(err.Error(), "No matched request") {
|
|
t.Errorf("DecodeMessage err=%v", err)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
}
|