diff --git a/internal/rtmp/rtmp.go b/internal/rtmp/rtmp.go index 988804b3e..319ede779 100644 --- a/internal/rtmp/rtmp.go +++ b/internal/rtmp/rtmp.go @@ -136,12 +136,21 @@ func newSettings() *settings { // The chunk stream which transport a message once. type chunkStream struct { - format formatType - cid chunkID - header messageHeader - message *message - count uint64 - extendedTimestamp bool + format formatType + cid chunkID + header messageHeader + message *message + count uint64 + + // Whether the chunk carries an extended timestamp, set when the (delta) timestamp in + // the message header equals 0xffffff. Type-3 continuation chunks inherit this from the + // preceding Type-0/1/2 chunk. + hasExtendedTimestamp bool + // The raw value last read from the extended timestamp field. Kept separately from + // header.Timestamp (the accumulated message timestamp) so we can both detect Type-3 + // chunks that omit the extended timestamp and use it as a delta for fmt=1/2 chunks. + // See readMessageHeader. + extendedTimestamp uint32 } func newChunkStream() *chunkStream { @@ -540,29 +549,7 @@ func (v *protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // 0x00ffffff), this value MUST be 16777215, and the 'extended // timestamp header' MUST be present. Otherwise, this value SHOULD be // the entire delta. - chunk.extendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp - if !chunk.extendedTimestamp { - // Extended timestamp: 0 or 4 bytes - // This field MUST be sent when the normal timsestamp is set to - // 0xffffff, it MUST NOT be sent if the normal timestamp is set to - // anything else. So for values less than 0xffffff the normal - // timestamp field SHOULD be used in which case the extended timestamp - // MUST NOT be present. For values greater than or equal to 0xffffff - // the normal timestamp field MUST NOT be used and MUST be set to - // 0xffffff and the extended timestamp MUST be sent. - if format == formatType0 { - // 6.1.2.1. Type 0 - // For a type-0 chunk, the absolute timestamp of the message is sent - // here. - chunk.header.Timestamp = uint64(chunk.header.timestampDelta) - } else { - // 6.1.2.2. Type 1 - // 6.1.2.3. Type 2 - // For a type-1 or type-2 chunk, the difference between the previous - // chunk's timestamp and the current chunk's timestamp is sent here. - chunk.header.Timestamp += uint64(chunk.header.timestampDelta) - } - } + chunk.hasExtendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp if format <= formatType1 { payloadLength := uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) @@ -585,27 +572,58 @@ func (v *protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo p = p[4:] } } - } else { - // Update the timestamp even fmt=3 for first chunk packet - if isFirstChunkOfMsg && !chunk.extendedTimestamp { - chunk.header.Timestamp += uint64(chunk.header.timestampDelta) - } } - // Read extended-timestamp - if chunk.extendedTimestamp { - var timestamp uint32 - if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { + // Read extended-timestamp, present when the (delta) timestamp in the message header is + // 0xffffff. Type-3 chunks inherit hasExtendedTimestamp from the preceding chunk. + if chunk.hasExtendedTimestamp { + // Peek instead of read, so the 4 bytes can be left in place when a sender omits the + // extended timestamp on a Type-3 chunk (see the detection below). + var b []byte + if b, err = v.r.Peek(4); err != nil { return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) } // We always use 31bits timestamp, for some server may use 32bits extended timestamp. // @see https://github.com/ossrs/srs/issues/111 - timestamp &= 0x7fffffff + timestamp := binary.BigEndian.Uint32(b) & 0x7fffffff - // TODO: FIXME: Support detect the extended timestamp. + // For the RTMP v1 2009 version (6.1.3. Extended Timestamp), Type 3 chunks MUST NOT + // have this field. For the RTMP v1 2012 version (5.3.1.3. Extended Timestamp), it is + // present in Type 3 chunks when the most recent Type 0/1/2 chunk indicated one. + // + // FMLE/FMS/Flash Player follow the 2012 version and always send the extended + // timestamp in Type 3 chunks; librtmp/ffmpeg may not. So detect it: if this is not + // the first chunk of the message and the peeked value differs from the previously + // stored extended timestamp, the sender omitted it and these 4 bytes are payload, so + // leave them in the reader. Otherwise consume and store them. // @see http://blog.csdn.net/win_lin/article/details/13363699 + // @see https://github.com/veovera/enhanced-rtmp/issues/42 + if !isFirstChunkOfMsg && chunk.extendedTimestamp > 0 && chunk.extendedTimestamp != timestamp { + // No extended timestamp on this Type-3 chunk; the 4 bytes belong to the payload. + } else { + if _, err = v.r.Discard(4); err != nil { + return errors.Wrapf(err, "discard ext-ts, pkt-ts=%v", chunk.header.Timestamp) + } + chunk.extendedTimestamp = timestamp + } + } + + // Compute the message timestamp. The source is the extended timestamp when present, + // otherwise the 3-byte (delta) timestamp from the message header. + // + // fmt=0: the value is the absolute timestamp of the message. + // fmt=1/2 (and a fmt=3 first chunk continuing them): the value is a delta and is + // accumulated onto the previous timestamp. This is required when the delta is >= 0xffffff + // and is therefore carried in the extended timestamp. + timestamp := chunk.header.timestampDelta + if chunk.hasExtendedTimestamp { + timestamp = chunk.extendedTimestamp + } + if format == formatType0 { chunk.header.Timestamp = uint64(timestamp) + } else if isFirstChunkOfMsg { + chunk.header.Timestamp += uint64(timestamp) } // The extended-timestamp must be unsigned-int, @@ -696,6 +714,11 @@ func (v *protocol) readBasicHeader(ctx context.Context) (format formatType, cid return } + // Here cid is 0 or 1: a marker selecting the 2B or 3B form, not the real cid. Keep it, + // because cid is overwritten below and the marker decides whether a third byte (the + // high-order part of the cid) follows. Do not test the overwritten cid for this. + marker := cid + // 64-319, 2B chunk header if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) @@ -703,7 +726,7 @@ func (v *protocol) readBasicHeader(ctx context.Context) (format formatType, cid cid = chunkID(64 + uint32(t)) // 64-65599, 3B chunk header - if cid == 1 { + if marker == 1 { if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) } @@ -1283,6 +1306,12 @@ func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { } p = p[v.TransactionID.Size():] + // Reset the optional command object before deciding whether it is present. + // A New*Packet constructor may have pre-set it to a default (e.g. Null), but + // when the wire data is exhausted here the object is absent. Leaving the stale + // default would make Size() count bytes that were never parsed, overflowing the + // caller's p = p[Size():] advance on truncated, untrusted input. + v.CommandObject = nil if len(p) > 0 { if v.CommandObject, err = Amf0Discovery(p); err != nil { return errors.WithMessage(err, "discovery command object") @@ -1353,14 +1382,32 @@ func (v *CallPacket) Size() int { return size } +// advanceBytes returns p[n:] after verifying n lies within p. Packet +// UnmarshalBinary advances its cursor by each embedded field's decoded Size(); +// on untrusted wire input a malformed length can make Size() exceed the bytes +// actually present, so this guard turns a slice-out-of-range panic into a clean +// error. See the RTMP test plan, P8 (adversarial resource-safety). +func advanceBytes(p []byte, n int) ([]byte, error) { + if n < 0 || n > len(p) { + return nil, errors.Errorf("advance %v exceeds remaining %v bytes", n, len(p)) + } + return p[n:], nil +} + func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal call") } - p = p[v.variantCallPacket.Size():] + if p, err = advanceBytes(p, v.variantCallPacket.Size()); err != nil { + return errors.WithMessage(err, "advance call") + } + // Reset the optional args before deciding whether they are present, for the + // same reason as variantCallPacket.CommandObject: a stale constructor default + // would be counted by Size() and overflow a later advance. + v.Args = nil if len(p) > 0 { if v.Args, err = Amf0Discovery(p); err != nil { return errors.WithMessage(err, "discovery args") @@ -1436,7 +1483,9 @@ func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal call") } - p = p[v.variantCallPacket.Size():] + if p, err = advanceBytes(p, v.variantCallPacket.Size()); err != nil { + return errors.WithMessage(err, "advance call") + } if err = v.StreamID.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal sid") @@ -1486,7 +1535,9 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal call") } - p = p[v.variantCallPacket.Size():] + if p, err = advanceBytes(p, v.variantCallPacket.Size()); err != nil { + return errors.WithMessage(err, "advance call") + } v.StreamName = newAmf0String("") if err = v.StreamName.UnmarshalBinary(p); err != nil { @@ -1546,7 +1597,9 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal call") } - p = p[v.variantCallPacket.Size():] + if p, err = advanceBytes(p, v.variantCallPacket.Size()); err != nil { + return errors.WithMessage(err, "advance call") + } v.StreamName = newAmf0String("") if err = v.StreamName.UnmarshalBinary(p); err != nil { diff --git a/internal/rtmp/rtmp_test.go b/internal/rtmp/rtmp_test.go index 9dc0013ca..779104bf2 100644 --- a/internal/rtmp/rtmp_test.go +++ b/internal/rtmp/rtmp_test.go @@ -7,9 +7,11 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "io" "reflect" "strings" + "sync" "testing" ) @@ -98,7 +100,7 @@ func TestBasicHeaderVariantsAndErrors(t *testing.T) { }{ {"one-byte", []byte{0x85}, formatType2, 5}, {"two-byte", []byte{0x40, 0x0a}, formatType1, 74}, - {"three-byte-code-path", []byte{0xc1, 0x01, 0x02}, formatType3, 65}, + {"three-byte", []byte{0xc1, 0x01, 0x02}, formatType3, 577}, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { @@ -177,6 +179,60 @@ func TestReadMessageExtendedTimestampAndChunking(t *testing.T) { } } +func TestReadMessageExtendedTimestampAsDeltaForFmt1(t *testing.T) { + ctx := context.Background() + var in bytes.Buffer + // fmt0 cid=5, timestamp=10, len=1, type video, stream=1, payload AA. + in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA}) + // fmt1 cid=5, delta=0xffffff so the real delta is carried in the extended timestamp (=100), + // len=1, type video, payload BB. For fmt=1/2 the extended timestamp is a delta, so the + // message timestamp must accumulate: 10 + 100 = 110 (not be replaced by 100). + in.Write([]byte{0x45, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, byte(MessageTypeVideo)}) + binary.Write(&in, binary.BigEndian, uint32(100)) + in.Write([]byte{0xBB}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + ts uint64 + pl []byte + }{ + {10, []byte{0xAA}}, + {110, []byte{0xBB}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload()) + } + } +} + +func TestReadMessageType3OmitsExtendedTimestamp(t *testing.T) { + ctx := context.Background() + var in bytes.Buffer + // fmt0 cid=5, timestamp=0xffffff so an extended timestamp (=100) is present, len=8, + // type video, stream=1, with the first 4 payload bytes. + in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x08, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + binary.Write(&in, binary.BigEndian, uint32(100)) + in.Write([]byte{0x01, 0x02, 0x03, 0x04}) + // fmt3 continuation from a librtmp/ffmpeg-style sender that omits the extended timestamp. + // The next 4 bytes are payload, not an extended timestamp; the parser must detect the + // mismatch against the stored value (100) and treat them as payload, keeping ts=100. + in.Write([]byte{0xc5, 0x05, 0x06, 0x07, 0x08}) + + p := NewProtocol(&in).(*protocol) + p.input.opt.chunkSize = 4 + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if m.Timestamp() != 100 || !bytes.Equal(m.Payload(), []byte{1, 2, 3, 4, 5, 6, 7, 8}) { + t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload()) + } +} + func TestReadMessageHeaderErrors(t *testing.T) { ctx := context.Background() // Fresh non-zero chunk with fmt1 is rejected. @@ -726,3 +782,1258 @@ func TestPacketUnmarshalErrorBranchesForCoverage(t *testing.T) { t.Fatalf("play stream name err=%v", err) } } + +// TestReadMessageInterleavedMultiStream covers P1: two or more chunk streams interleaved on the +// wire, each reassembling independently via protocol.input.chunks. All other read tests use a +// single cid (5), so the per-cid map and per-cid header state are never exercised under +// interleaving. Mirrors the C++ srs_utest_manual_protocol.cpp ProtocolRecvVAVMessage / +// ProtocolRecvVAVFmt1/2/3 family. +func TestReadMessageInterleavedMultiStream(t *testing.T) { + ctx := context.Background() + + // Chunk-by-chunk interleave: three multi-chunk messages on cid 6 (video), 7 (audio) and 8 + // (data) are split at chunkSize=3 and woven together on the wire. Each message must reassemble + // from its own cid's chunks regardless of the chunks belonging to other cids in between, and + // surface in the order its final chunk arrives (V, then A, then D). + t.Run("chunk-by-chunk", func(t *testing.T) { + var in bytes.Buffer + // V fmt0 cid=6, ts=100, len=6, video, stream=1, first 3 payload bytes. + in.Write([]byte{0x06, 0x00, 0x00, 0x64, 0x00, 0x00, 0x06, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xa1, 0xa2, 0xa3}) + // A fmt0 cid=7, ts=200, len=4, audio, stream=1, first 3 payload bytes. + in.Write([]byte{0x07, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x04, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0xb1, 0xb2, 0xb3}) + // D fmt0 cid=8, ts=300, len=5, data, stream=1, first 3 payload bytes. + in.Write([]byte{0x08, 0x00, 0x01, 0x2c, 0x00, 0x00, 0x05, byte(MessageTypeAMF0Data), 0x01, 0x00, 0x00, 0x00, 0xc1, 0xc2, 0xc3}) + // V fmt3 cid=6 continuation, last 3 payload bytes -> V completes. + in.Write([]byte{0xc6, 0xa4, 0xa5, 0xa6}) + // A fmt3 cid=7 continuation, last payload byte -> A completes. + in.Write([]byte{0xc7, 0xb4}) + // D fmt3 cid=8 continuation, last 2 payload bytes -> D completes. + in.Write([]byte{0xc8, 0xc4, 0xc5}) + + p := NewProtocol(&in).(*protocol) + p.input.opt.chunkSize = 3 + for i, want := range []struct { + typ MessageType + ts uint64 + pl []byte + }{ + {MessageTypeVideo, 100, []byte{0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6}}, + {MessageTypeAudio, 200, []byte{0xb1, 0xb2, 0xb3, 0xb4}}, + {MessageTypeAMF0Data, 300, []byte{0xc1, 0xc2, 0xc3, 0xc4, 0xc5}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload()) + } + } + }) + + // Per-cid header-state isolation: single-chunk messages alternate between the video cid (6) + // and audio cid (7). The second and later messages on each cid use fmt1/2/3 headers, which + // inherit timestamp delta / payload length / type from the *previous message on the same cid*. + // An interleaved message on the other cid must not perturb that state, so the video deltas + // accumulate only over video (1000 -> 1010 -> 1015) and audio only over audio + // (5000 -> 5020 -> 5040). + t.Run("per-cid-header-state", func(t *testing.T) { + var in bytes.Buffer + // V1 fmt0 cid=6, ts=1000, len=2, video, stream=1. + in.Write([]byte{0x06, 0x00, 0x03, 0xe8, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0x11, 0x12}) + // A1 fmt0 cid=7, ts=5000, len=2, audio, stream=1. + in.Write([]byte{0x07, 0x00, 0x13, 0x88, 0x00, 0x00, 0x02, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0x21, 0x22}) + // V2 fmt1 cid=6, delta=10, len=2, video -> ts 1000+10=1010 (inherits from V1, not A1). + in.Write([]byte{0x46, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x13, 0x14}) + // A2 fmt1 cid=7, delta=20, len=2, audio -> ts 5000+20=5020 (inherits from A1). + in.Write([]byte{0x47, 0x00, 0x00, 0x14, 0x00, 0x00, 0x02, byte(MessageTypeAudio), 0x23, 0x24}) + // V3 fmt2 cid=6, delta=5 (len/type reused from V2) -> ts 1010+5=1015. + in.Write([]byte{0x86, 0x00, 0x00, 0x05, 0x15, 0x16}) + // A3 fmt3 cid=7 (delta=20, len/type reused from A2) -> ts 5020+20=5040. + in.Write([]byte{0xc7, 0x25, 0x26}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + typ MessageType + ts uint64 + pl []byte + }{ + {MessageTypeVideo, 1000, []byte{0x11, 0x12}}, + {MessageTypeAudio, 5000, []byte{0x21, 0x22}}, + {MessageTypeVideo, 1010, []byte{0x13, 0x14}}, + {MessageTypeAudio, 5020, []byte{0x23, 0x24}}, + {MessageTypeVideo, 1015, []byte{0x15, 0x16}}, + {MessageTypeAudio, 5040, []byte{0x25, 0x26}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload()) + } + } + }) + + // Per-cid payload-length change: successive messages on the video cid (6) carry *different* + // payload lengths via fmt1 headers (2 -> 3 -> 1), while audio (cid 7) is interleaved in + // between. fmt1 begins a new message (chunk.message is nil), so the length check is skipped + // and the chunkStream must adopt each new length rather than reusing the previous one. Mirrors + // the C++ srs_utest_manual_protocol2.cpp ProtocolRecvVAVVFmt11Length / Fmt12Length cases. + t.Run("per-cid-length-change", func(t *testing.T) { + var in bytes.Buffer + // V1 fmt0 cid=6, ts=0x10, len=2, video, stream=1. + in.Write([]byte{0x06, 0x00, 0x00, 0x10, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0x11, 0x12}) + // A1 fmt0 cid=7, ts=0x15, len=2, audio, stream=1. + in.Write([]byte{0x07, 0x00, 0x00, 0x15, 0x00, 0x00, 0x02, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0x21, 0x22}) + // V2 fmt1 cid=6, delta=0x10, len=3 (changed 2->3), video -> ts 0x20. + in.Write([]byte{0x46, 0x00, 0x00, 0x10, 0x00, 0x00, 0x03, byte(MessageTypeVideo), 0x13, 0x14, 0x15}) + // V3 fmt1 cid=6, delta=0x20, len=1 (changed 3->1), video -> ts 0x40. + in.Write([]byte{0x46, 0x00, 0x00, 0x20, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x16}) + // A2 fmt1 cid=7, delta=0x05, len=4 (changed 2->4), audio -> ts 0x1a. + in.Write([]byte{0x47, 0x00, 0x00, 0x05, 0x00, 0x00, 0x04, byte(MessageTypeAudio), 0x23, 0x24, 0x25, 0x26}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + typ MessageType + ts uint64 + pl []byte + }{ + {MessageTypeVideo, 0x10, []byte{0x11, 0x12}}, + {MessageTypeAudio, 0x15, []byte{0x21, 0x22}}, + {MessageTypeVideo, 0x20, []byte{0x13, 0x14, 0x15}}, + {MessageTypeVideo, 0x40, []byte{0x16}}, + {MessageTypeAudio, 0x1a, []byte{0x23, 0x24, 0x25, 0x26}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload()) + } + } + }) +} + +// TestReadMessageLargeChunkStreamID covers P2: reading a complete message (basic header + message +// header + payload) whose chunks carry the chunk-stream ID in the 2-byte (cid 64-319) and 3-byte +// (cid 64-65599) basic-header forms. Every other read test uses a 1-byte cid (5), so +// readBasicHeader's multi-byte cid decode is exercised for header decode only +// (TestBasicHeaderVariantsAndErrors) and never end-to-end through ReadMessage carrying a real +// payload. Asserting the decoded cid keys input.chunks also proves the encode/decode are inverses +// (a swapped 2nd/3rd byte would land on a different cid). Mirrors the C++ +// srs_utest_manual_protocol.cpp / protocol2.cpp ProtocolRecvVCid2B* / Cid3B* family. +func TestReadMessageLargeChunkStreamID(t *testing.T) { + ctx := context.Background() + + // basicHeader encodes fmt+cid in the smallest basic-header form that fits cid: 1 byte for + // 2-63, 2 bytes for 64-319, 3 bytes for 320-65599. The ID math is the inverse of + // readBasicHeader: 2-byte cid = 64 + b1; 3-byte cid = 64 + b1 + b2*256. + basicHeader := func(format formatType, cid chunkID) []byte { + f := byte(format) << 6 + switch { + case cid <= 63: + return []byte{f | byte(cid)} + case cid <= 319: + return []byte{f, byte(cid - 64)} // 2-byte marker is 0 in the low 6 bits + default: + v := uint32(cid) - 64 + return []byte{f | 0x01, byte(v % 256), byte(v / 256)} + } + } + + // Single-chunk fmt0 message on each boundary cid: 2-byte min (64) and max (319), 3-byte first + // value (320) and max (65599). The cid must decode to the right value (which keys input.chunks) + // and the payload must reassemble. + t.Run("single-chunk-boundaries", func(t *testing.T) { + for _, cid := range []chunkID{64, 319, 320, 65599} { + var in bytes.Buffer + in.Write(basicHeader(formatType0, cid)) + // fmt0 message header: ts=0x40, len=3, video, stream=1, then payload a1a2a3. + in.Write([]byte{0x00, 0x00, 0x40, 0x00, 0x00, 0x03, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + in.Write([]byte{0xa1, 0xa2, 0xa3}) + + p := NewProtocol(&in).(*protocol) + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("cid=%v ReadMessage err=%v", cid, err) + } + if m.MessageType() != MessageTypeVideo || m.Timestamp() != 0x40 || !bytes.Equal(m.Payload(), []byte{0xa1, 0xa2, 0xa3}) { + t.Fatalf("cid=%v type=%v ts=%v payload=%x", cid, m.MessageType(), m.Timestamp(), m.Payload()) + } + if _, ok := p.input.chunks[cid]; !ok { + t.Fatalf("cid=%v not keyed in chunks map: %v", cid, p.input.chunks) + } + } + }) + + // Multi-chunk message on a 3-byte cid: the fmt3 continuation must re-encode the same large cid + // in its basic header, so readBasicHeader is invoked again mid-message and must resolve to the + // same chunkStream for the payload to reassemble. + t.Run("multi-chunk-large-cid", func(t *testing.T) { + const cid chunkID = 65599 + var in bytes.Buffer + // fmt0 on the 3-byte cid: ts=0x50, len=5, video, stream=1, first 3 payload bytes. + in.Write(basicHeader(formatType0, cid)) + in.Write([]byte{0x00, 0x00, 0x50, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + in.Write([]byte{0xb1, 0xb2, 0xb3}) + // fmt3 continuation, same 3-byte cid re-encoded, last 2 payload bytes -> message completes. + in.Write(basicHeader(formatType3, cid)) + in.Write([]byte{0xb4, 0xb5}) + + p := NewProtocol(&in).(*protocol) + p.input.opt.chunkSize = 3 + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if m.Timestamp() != 0x50 || !bytes.Equal(m.Payload(), []byte{0xb1, 0xb2, 0xb3, 0xb4, 0xb5}) { + t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload()) + } + }) + + // The spec allows cid 64-319 in either the 2-byte or 3-byte form, and both must decode to the + // same cid. Read cid 64 first via its 2-byte form, then a second message via the (non-minimal) + // 3-byte form, and confirm both land on a single chunk stream so fmt1's delta accumulates over + // the first message (0x10 -> 0x15) rather than starting a second stream. + t.Run("cid-64-both-forms", func(t *testing.T) { + const cid chunkID = 64 + var in bytes.Buffer + // 2-byte form fmt0: ts=0x10, len=1, video, stream=1, payload c1. + in.Write([]byte{0x00, 0x00}) // fmt0, 2-byte cid 64 + in.Write([]byte{0x00, 0x00, 0x10, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xc1}) + // 3-byte form fmt1 on the same cid: delta=5, len=1, video, payload c2 -> ts 0x10+5=0x15. + in.Write([]byte{0x41, 0x00, 0x00}) // fmt1, 3-byte cid 64 + in.Write([]byte{0x00, 0x00, 0x05, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0xc2}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + ts uint64 + pl []byte + }{ + {0x10, []byte{0xc1}}, + {0x15, []byte{0xc2}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("message #%v ReadMessage err=%v", i, err) + } + if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload()) + } + } + if _, ok := p.input.chunks[cid]; !ok || len(p.input.chunks) != 1 { + t.Fatalf("both forms must share one chunk stream at cid %v: %v", cid, p.input.chunks) + } + }) +} + +// TestProtocolWritePacketReadMessageRoundTrip covers P3: every Packet type round-tripped through +// the full chunk-stream layer at a representative span of chunk sizes — 1 (every payload byte in +// its own chunk, the extreme case), 2 (small), 128 (default), and a value larger than any payload +// here (never chunks). For each (packet, chunk size) pair the packet is written via WritePacket, +// read back via ReadMessage, and decoded via DecodeMessage; the decoded packet must have the +// expected concrete type and re-marshal to the same wire bytes as the original. This proves the +// Go encoder and decoder agree end-to-end. The existing TestPacketRoundTripsAndErrors only checks +// MarshalBinary↔UnmarshalBinary in isolation (no chunk layer), and TestProtocolPacketsAndTransactions +// covers only a handful of packets at the default chunk size — so neither exercises the cross +// product of every typed packet against multi-chunk reassembly. Mirrors the C++ +// srs_utest_manual_protocol.cpp ProtocolSendSrs*Packet family and +// srs_utest_manual_rtmp.cpp ProtocolRTMPTest.DecodeMessages / OnDecodeMessages family. +// +// Wire-byte equivalence is the equality check rather than reflect.DeepEqual: AMF0 packets carry an +// amf0ObjectBase.bufFactory func field that is non-nil in both sides, and reflect.DeepEqual treats +// any pair of non-nil funcs as not-equal. Comparing MarshalBinary() output is also the strongest +// practical assertion for "encoder and decoder agree on the wire," which is what P3 is testing. +// +// Some packets decode to a different concrete type than what we wrote. parseAMFObject's default +// branch returns NewCallPacket() for any AMF0 command name not in its special-case switch, so +// e.g. createStream comes back as *CallPacket. The variantCallPacket marshal layout is identical +// between CreateStreamPacket and a CallPacket carrying the same fields (Args nil is skipped on +// marshal), so the wire-bytes check still holds. wantType records the expected decoded type per +// case so the asymmetry is explicit. +func TestProtocolWritePacketReadMessageRoundTrip(t *testing.T) { + ctx := context.Background() + + // Builder functions for each packet, populated with non-default fields so the round-trip + // exercises real values instead of zero values. Use closures so the test can rebuild a fresh + // orig per (chunk size) iteration without state leaking between subtests. + makeConnect := func() Packet { + p := NewConnectAppPacket() + p.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live")) + p.CommandObject.Set("app", NewAmf0String("live")) + return p + } + makeConnectRes := func() Packet { + p := NewConnectAppResPacket(1) // tid matches makeConnect's default TransactionID=1 + p.Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid"))) + return p + } + makeCreateStream := func() Packet { return NewCreateStreamPacket() } + makeCreateStreamRes := func() Packet { + p := NewCreateStreamResPacket(2) // tid=2 matches makeCreateStream's TransactionID + p.SetStreamID(99) + return p + } + makeOnStatus := func() Packet { + p := NewCallPacket() + p.CommandName = commandOnStatus + p.TransactionID = 0 + p.CommandObject = NewAmf0Null() + p.Args = NewAmf0Object().Set("level", NewAmf0String("status")).Set("code", NewAmf0String("NetStream.Play.Start")) + return p + } + makeReleaseStreamRes := func() Packet { + // _result for a releaseStream call: parseAMFObject returns NewCallPacket() with the tid + // pre-filled. We send the response shape that the C++ side sends back to FMLE. + p := NewCallPacket() + p.CommandName = commandResult + p.TransactionID = 4 + p.CommandObject = NewAmf0Null() + return p + } + makePublish := func() Packet { + p := NewPublishPacket() + p.TransactionID = 0 + p.StreamName = NewAmf0String("livestream") + p.StreamType = NewAmf0String("live") + return p + } + makePlay := func() Packet { + p := NewPlayPacket() + p.TransactionID = 0 + p.StreamName = NewAmf0String("livestream") + return p + } + + type seedFn func(*protocol) + seedConnect := func(p *protocol) { p.input.transactions[1] = commandConnect } + seedCreateStream := func(p *protocol) { p.input.transactions[2] = commandCreateStream } + seedRelease := func(p *protocol) { p.input.transactions[4] = commandReleaseStream } + + cases := []struct { + name string + build func() Packet + wantType reflect.Type + // seed registers a tid → request-name mapping on the reader's transactions map so that + // parseAMFObject can resolve the concrete response type for *_result/_error packets. + seed seedFn + }{ + {"connect-app", makeConnect, reflect.TypeOf((*ConnectAppPacket)(nil)), nil}, + {"connect-app-res", makeConnectRes, reflect.TypeOf((*ConnectAppResPacket)(nil)), seedConnect}, + {"create-stream", makeCreateStream, reflect.TypeOf((*CallPacket)(nil)), nil}, + {"create-stream-res", makeCreateStreamRes, reflect.TypeOf((*CreateStreamResPacket)(nil)), seedCreateStream}, + {"call-onstatus", makeOnStatus, reflect.TypeOf((*CallPacket)(nil)), nil}, + {"call-result-releaseStream", makeReleaseStreamRes, reflect.TypeOf((*CallPacket)(nil)), seedRelease}, + {"publish", makePublish, reflect.TypeOf((*PublishPacket)(nil)), nil}, + {"play", makePlay, reflect.TypeOf((*PlayPacket)(nil)), nil}, + {"set-chunk-size", func() Packet { return &SetChunkSize{ChunkSize: 4096} }, reflect.TypeOf((*SetChunkSize)(nil)), nil}, + {"window-ack-size", func() Packet { return &WindowAcknowledgementSize{AckSize: 2500000} }, reflect.TypeOf((*WindowAcknowledgementSize)(nil)), nil}, + {"set-peer-bandwidth", func() Packet { + return &SetPeerBandwidth{Bandwidth: 2500000, LimitType: LimitTypeDynamic} + }, reflect.TypeOf((*SetPeerBandwidth)(nil)), nil}, + {"user-control-ping", func() Packet { + return &UserControl{EventType: EventTypePingRequest, EventData: 12345} + }, reflect.TypeOf((*UserControl)(nil)), nil}, + {"user-control-buffer-len", func() Packet { + return &UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 1500} + }, reflect.TypeOf((*UserControl)(nil)), nil}, + {"user-control-fms-event0", func() Packet { + return &UserControl{EventType: EventTypeFmsEvent0, EventData: 1} + }, reflect.TypeOf((*UserControl)(nil)), nil}, + } + + // Chunk sizes: 1 forces every payload byte into its own chunk (maximum c3 continuations); + // 2 is a small non-trivial chunking; 128 is the protocol default; 4096 is larger than every + // payload in this table (the connect packet is ~60 bytes), so the message is sent as a single + // chunk with no c3 continuations. + chunkSizes := []uint32{1, 2, 128, 4096} + + for _, c := range cases { + for _, chunkSize := range chunkSizes { + t.Run(fmt.Sprintf("%s/chunk=%d", c.name, chunkSize), func(t *testing.T) { + orig := c.build() + origBytes, err := orig.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary orig err=%v", err) + } + + var wire bytes.Buffer + writer := NewProtocol(&wire).(*protocol) + writer.output.opt.chunkSize = chunkSize + if err := writer.WritePacket(ctx, orig, 1); err != nil { + t.Fatalf("WritePacket err=%v", err) + } + + // Verify the writer actually chunked at this size: a payload longer than + // chunkSize must produce at least one c3 continuation chunk on the wire. The + // continuation header is a single byte 0xc0|cid; the basic-header byte for the + // initial c0 chunk is 0x0X|cid (fmt=0). Counting bytes that match the c3 form is + // a coarse but adequate check that the chunk path was actually exercised. + if uint32(len(origBytes)) > chunkSize { + wantContinuations := (uint32(len(origBytes))-1)/chunkSize + 1 - 1 + var got uint32 + for _, b := range wire.Bytes() { + if b == 0xc0|byte(orig.BetterCid()) { + got++ + } + } + if got < wantContinuations { + t.Fatalf("expected >=%v c3 continuations on the wire, saw %v (wire=%x)", wantContinuations, got, wire.Bytes()) + } + } + + reader := NewProtocol(&wire).(*protocol) + reader.input.opt.chunkSize = chunkSize + if c.seed != nil { + c.seed(reader) + } + m, err := reader.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if m.MessageType() != orig.Type() { + t.Fatalf("message type=%v want=%v", m.MessageType(), orig.Type()) + } + got, err := reader.DecodeMessage(m) + if err != nil { + t.Fatalf("DecodeMessage err=%v", err) + } + if reflect.TypeOf(got) != c.wantType { + t.Fatalf("decoded type=%T want=%v", got, c.wantType) + } + + gotBytes, err := got.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary got err=%v", err) + } + if !bytes.Equal(origBytes, gotBytes) { + t.Fatalf("round-trip mismatch:\n orig=%x\n got =%x", origBytes, gotBytes) + } + }) + } + } +} + +// TestReadMessageTimestampDiscontinuity covers P5: timestamp continuity edges on the per-cid +// chunkStream that the existing monotonically-increasing tests don't exercise. The C++ +// ProtocolStackTest suite has no equivalent coverage either, so this is new coverage beyond the +// C++ reference. +// +// 1. backward-jump-fmt0: a new fmt0 message whose absolute timestamp is *smaller* than the +// previous message's timestamp must replace the stored Timestamp, not add to it. fmt0 is +// absolute by the spec; fmt1/2/3 deltas are unsigned, so a real backward jump can only ride +// on fmt0. +// 2. wraparound-31bit-mask: delta accumulation crossing the 31-bit boundary must wrap to the +// low 31 bits via `chunk.header.Timestamp &= 0x7fffffff`. ts 0x7ffffff0 + delta 0x20 = +// 0x80000010 -> masked 0x10. Easy to regress if anyone splits the accumulate-then-mask +// sequence. +// 3. forward-jump-fmt0: a new fmt0 message whose absolute timestamp is much *larger* than the +// previous one (carried in the extended timestamp because the 3-byte field saturates at +// 0xffffff) must also replace, not add. Mirror image of case 1; together they prove fmt0 is +// absolute regardless of direction. +func TestReadMessageTimestampDiscontinuity(t *testing.T) { + ctx := context.Background() + + t.Run("backward-jump-fmt0", func(t *testing.T) { + var in bytes.Buffer + // M1 fmt0 cid=5, ts=2000=0x7d0, len=2, video, stream=1, payload AA BB. + in.Write([]byte{0x05, 0x00, 0x07, 0xd0, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA, 0xBB}) + // M2 fmt0 cid=5, ts=1000=0x3e8 (< 2000), len=2, video, stream=1, payload CC DD. + // fmt0 sets the message timestamp absolutely; if it were ever accidentally accumulated, + // M2.Timestamp would be 2000+1000=3000 instead of 1000. + in.Write([]byte{0x05, 0x00, 0x03, 0xe8, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xCC, 0xDD}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + ts uint64 + pl []byte + }{ + {2000, []byte{0xAA, 0xBB}}, + {1000, []byte{0xCC, 0xDD}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload()) + } + } + }) + + t.Run("wraparound-31bit-mask", func(t *testing.T) { + var in bytes.Buffer + // M1 fmt0 cid=5, ts(3B)=0xffffff so an extended timestamp is present, len=1, video, + // stream=1, ext-ts=0x7ffffff0 (just below the 31-bit edge), payload AA. The 31-bit mask + // at the end of readMessageHeader leaves 0x7ffffff0 unchanged because bit 31 is clear, + // so chunk.header.Timestamp settles at 0x7ffffff0 between messages. + in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + binary.Write(&in, binary.BigEndian, uint32(0x7ffffff0)) + in.Write([]byte{0xAA}) + // M2 fmt1 cid=5, delta=0x20 (3-byte, no ext-ts because 0x20 < 0xffffff), len=1, video, + // payload BB. Accumulation: 0x7ffffff0 + 0x20 = 0x80000010 (bit 31 set), then the + // `chunk.header.Timestamp &= 0x7fffffff` mask drops bit 31 -> 0x10. Drop the mask and + // M2.Timestamp would surface as 0x80000010 instead. + in.Write([]byte{0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0xBB}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + ts uint64 + pl []byte + }{ + {0x7ffffff0, []byte{0xAA}}, + {0x10, []byte{0xBB}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload()) + } + } + }) + + t.Run("forward-jump-fmt0", func(t *testing.T) { + var in bytes.Buffer + // M1 fmt0 cid=5, ts=10, len=1, video, stream=1, payload AA. + in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA}) + // M2 fmt0 cid=5, ts(3B)=0xffffff (sentinel), len=1, video, stream=1, + // ext-ts=0x12345678 (large forward absolute, bit 31 clear so no mask interaction), + // payload BB. fmt0 replaces absolutely, so M2.Timestamp must be 0x12345678, not + // 10 + 0x12345678 = 0x12345682. + in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + binary.Write(&in, binary.BigEndian, uint32(0x12345678)) + in.Write([]byte{0xBB}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + ts uint64 + pl []byte + }{ + {10, []byte{0xAA}}, + {0x12345678, []byte{0xBB}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v ts=%v payload=%x", i, m.Timestamp(), m.Payload()) + } + } + }) +} + +// TestReadWriteLargePayloadChunkBoundaries covers P4: large-payload and chunk-boundary stress on +// both the write (chunking) and read (reassembly) paths. The existing tests only write a single +// 5000-byte message and never read a large multi-chunk message back, nor exercise the exact +// payload==chunkSize / payload==N*chunkSize boundaries where an off-by-one in the chunking loop +// would surface as a spurious empty trailing chunk. The 3-byte max length (0xffffff) parse is also +// pinned here; its DoS/truncation behavior is separately covered by +// TestPacketUnmarshalAdversarialInputs (P8, oversized-length-truncated). +// +// C++ reference: +// - srs_utest_manual_rtmp.cpp :: TEST(ProtocolRTMPTest, HugeMessages) — a 256B audio payload +// at chunkSize=128 serializes to exactly 269 wire bytes (12B c0 header + 128 + 1B c3 + 128), +// i.e. no empty trailing chunk. +// - srs_utest_manual_rtmp.cpp :: TEST(ProtocolRTMPTest, SendHugePacket) — a 1024B send. +// - srs_utest_manual_protocol.cpp :: TEST(ProtocolStackTest, ProtocolRecvVMessage2Trunk) — read a +// 272B video message split across 3 chunks (2 c3 continuations) at chunkSize=128. +func TestReadWriteLargePayloadChunkBoundaries(t *testing.T) { + ctx := context.Background() + + // payload-equals-chunksize: a payload of exactly chunkSize fits in a single chunk. The write + // loop (`for len(p) > 0`) must emit zero c3 continuations — emitting an empty trailing chunk + // for the 0 bytes "remaining" after the first full chunk would be an off-by-one bug. The exact + // wire length (1B basic + 11B msg header + chunkSize payload) proves the single-chunk shape, + // and the read path reassembles back to the original payload. + t.Run("payload-equals-chunksize", func(t *testing.T) { + const chunkSize = 128 + payload := make([]byte, chunkSize) + for i := range payload { + payload[i] = byte(i) + } + + var wire bytes.Buffer + w := NewProtocol(&wire).(*protocol) + w.output.opt.chunkSize = chunkSize + m := NewStreamMessage(1).asMessage() + m.messageHeader.MessageType = MessageTypeVideo + m.messageHeader.Timestamp = 40 + m.payload = payload + if err := w.WriteMessage(ctx, m); err != nil { + t.Fatalf("WriteMessage err=%v", err) + } + if want := 1 + 11 + chunkSize; wire.Len() != want { + t.Fatalf("wire len=%v want=%v (a spurious empty trailing chunk would add a c3 byte)", wire.Len(), want) + } + + r := NewProtocol(&wire).(*protocol) + r.input.opt.chunkSize = chunkSize + got, err := r.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if got.Timestamp() != 40 || !bytes.Equal(got.Payload(), payload) { + t.Fatalf("ts=%v payloadLen=%v", got.Timestamp(), len(got.Payload())) + } + }) + + // payload-exact-multiple: a payload that is an exact multiple of chunkSize (256 = 2*128) must + // serialize to exactly two chunks — c0+128 then c3+128 — and NOT a third empty trailing chunk. + // This mirrors the C++ HugeMessages golden: 269 wire bytes for a 256B payload with ts<0xffffff + // (no extended timestamp). The continuation header at offset 140 must be the 1-byte c3 form, + // and the message reassembles back to the original payload. + t.Run("payload-exact-multiple", func(t *testing.T) { + const chunkSize = 128 + payload := make([]byte, 2*chunkSize) + for i := range payload { + payload[i] = byte(i) + } + + var wire bytes.Buffer + w := NewProtocol(&wire).(*protocol) + w.output.opt.chunkSize = chunkSize + m := NewStreamMessage(1).asMessage() + m.messageHeader.MessageType = MessageTypeAudio + m.messageHeader.Timestamp = 1000 + m.payload = payload + if err := w.WriteMessage(ctx, m); err != nil { + t.Fatalf("WriteMessage err=%v", err) + } + // 12B c0 header + 128 + 1B c3 header + 128 = 269 (the C++ HugeMessages value). 270 would + // mean an empty trailing chunk was emitted. + if want := 1 + 11 + chunkSize + 1 + chunkSize; wire.Len() != want { + t.Fatalf("wire len=%v want=%v", wire.Len(), want) + } + if got := wire.Bytes()[1+11+chunkSize]; got != 0xc0|byte(chunkIDOverStream) { + t.Fatalf("continuation header=%#x want=%#x (1-byte c3 form)", got, 0xc0|byte(chunkIDOverStream)) + } + + r := NewProtocol(&wire).(*protocol) + r.input.opt.chunkSize = chunkSize + got, err := r.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if got.Timestamp() != 1000 || !bytes.Equal(got.Payload(), payload) { + t.Fatalf("ts=%v payloadLen=%v", got.Timestamp(), len(got.Payload())) + } + }) + + // read-multichunk-handbuilt: read a 300-byte video message split across 3 chunks (c0 + two c3 + // continuations) at chunkSize=128, against a hand-built wire layout rather than this package's + // own writer output — directly mirroring the C++ ProtocolRecvVMessage2Trunk reassembly test. + // Proves readMessagePayload accumulates across multiple c3 continuations into one message. + t.Run("read-multichunk-handbuilt", func(t *testing.T) { + const chunkSize = 128 + payload := make([]byte, 300) + for i := range payload { + payload[i] = byte(i) + } + + var in bytes.Buffer + // fmt0 cid=3, ts=0, len=300 (0x00012c), video, stream=0, then payload[0:128]. + in.Write([]byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x2c, byte(MessageTypeVideo), 0x00, 0x00, 0x00, 0x00}) + in.Write(payload[0:chunkSize]) + in.Write([]byte{0xc3}) // fmt3 continuation, cid=3. + in.Write(payload[chunkSize : 2*chunkSize]) + in.Write([]byte{0xc3}) // fmt3 continuation, cid=3. + in.Write(payload[2*chunkSize:]) + + r := NewProtocol(&in).(*protocol) + r.input.opt.chunkSize = chunkSize + got, err := r.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if got.MessageType() != MessageTypeVideo || !bytes.Equal(got.Payload(), payload) { + t.Fatalf("type=%v payloadLen=%v", got.MessageType(), len(got.Payload())) + } + }) + + // read-large-roundtrip: a 5000-byte payload at chunkSize=128 spans 40 chunks (39 c3 + // continuations). The exact wire length (12B c0 + 5000 payload + 39 c3 bytes = 5051) proves the + // writer chunked into exactly 40 chunks with no empty trailing chunk, and the reader reassembles + // the full 5000 bytes back. This is the "many c3 continuations + large multi-chunk read" case + // the existing 5000-byte write test never read back. + t.Run("read-large-roundtrip", func(t *testing.T) { + const chunkSize = 128 + payload := make([]byte, 5000) + for i := range payload { + payload[i] = byte(i % 251) + } + // 5000 bytes: first chunk carries 128, the remaining 4872 take ceil(4872/128)=39 c3 chunks. + const wantContinuations = 39 + + var wire bytes.Buffer + w := NewProtocol(&wire).(*protocol) + w.output.opt.chunkSize = chunkSize + m := NewStreamMessage(1).asMessage() + m.messageHeader.MessageType = MessageTypeVideo + m.messageHeader.Timestamp = 12345 + m.payload = payload + if err := w.WriteMessage(ctx, m); err != nil { + t.Fatalf("WriteMessage err=%v", err) + } + if want := 12 + len(payload) + wantContinuations; wire.Len() != want { + t.Fatalf("wire len=%v want=%v", wire.Len(), want) + } + + r := NewProtocol(&wire).(*protocol) + r.input.opt.chunkSize = chunkSize + got, err := r.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if got.Timestamp() != 12345 || !bytes.Equal(got.Payload(), payload) { + t.Fatalf("ts=%v payloadLen=%v", got.Timestamp(), len(got.Payload())) + } + }) + + // max-3byte-length-parse: a fmt0 header declaring payloadLength = 0xffffff (the 3-byte field + // maxed out) must decode to exactly 0xffffff. Drive the header parse directly — reassembling + // 16MiB of payload is pointless, and the DoS/truncation behavior at this length is covered by + // TestPacketUnmarshalAdversarialInputs (P8). A shift/mask regression in the length decode would + // corrupt the value, which is what this pins. + t.Run("max-3byte-length-parse", func(t *testing.T) { + // fmt0 cid=5, ts=0 (so no extended timestamp), len=0xffffff, video, stream=1. + in := bytes.NewBuffer([]byte{0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + p := NewProtocol(in).(*protocol) + format, cid, err := p.readBasicHeader(ctx) + if err != nil { + t.Fatalf("readBasicHeader err=%v", err) + } + // Mirror ReadMessage's chunkStream setup. + chunk := newChunkStream() + p.input.chunks[cid] = chunk + chunk.header.betterCid = cid + if err := p.readMessageHeader(ctx, chunk, format); err != nil { + t.Fatalf("readMessageHeader err=%v", err) + } + if chunk.message.payloadLength != 0xffffff { + t.Fatalf("payloadLength=%#x want=0xffffff", chunk.message.payloadLength) + } + }) +} + +// TestGoldenWireBytes covers P6: golden wire-byte regression for the chunk headers and control +// packets. The existing TestWriteMessageHeadersChunkingAndErrors pins golden bytes for one +// extended-timestamp video message only; this locks the remaining wire shapes a refactor could +// silently change — both forms of the C0 and C3 headers, and the full on-wire framing of every +// control packet. +// +// C++ reference (send/golden): +// +// srs_utest_manual_protocol.cpp :: TEST(ProtocolStackTest, +// ProtocolSendSrsSetChunkSizePacket / ProtocolSendSrsSetWindowAckSizePacket / +// ProtocolSendSrsSetPeerBandwidthPacket / ProtocolSendSrsUserControlPacket) +// +// The control-packet payload bytes below match those C++ goldens exactly; the only difference is +// the chunk basic-header byte — Go frames protocol-control packets on cid=2 (-> 0x02) where the +// C++ goldens show 0x03 — so the payload portion is the cross-implementation wire-format invariant. +func TestGoldenWireBytes(t *testing.T) { + ctx := context.Background() + + // C0/C3 chunk headers, both timestamp forms. A message on cid=5, video, stream=7, payload + // length 5. With ts < 0xffffff the C0 header is 12 bytes carrying the timestamp inline; with + // ts >= 0xffffff the 3-byte field saturates to 0xffffff and a 4-byte extended timestamp is + // appended (16 bytes total). The C3 header is the 1-byte continuation form normally, but + // inherits the same extended-timestamp quirk (Adobe always re-sends it), so it grows to 5 + // bytes when ts >= 0xffffff. + t.Run("chunk-headers", func(t *testing.T) { + // MessageType and Timestamp are method names on *message, shadowing the promoted + // messageHeader fields, so set those two through the embedded struct explicitly. + shortTs := &message{} + shortTs.betterCid = chunkIDOverStream // 5 + shortTs.messageHeader.MessageType = MessageTypeVideo + shortTs.streamID = 7 + shortTs.payloadLength = 5 + shortTs.messageHeader.Timestamp = 0x0a + + extTs := &message{} + extTs.betterCid = chunkIDOverStream + extTs.messageHeader.MessageType = MessageTypeVideo + extTs.streamID = 7 + extTs.payloadLength = 5 + extTs.messageHeader.Timestamp = extendedTimestamp + 9 // 0x01000008, >= 0xffffff + + cases := []struct { + name string + gen func() ([]byte, error) + want []byte + }{ + // basic(0x05) | ts(00 00 0a) | len(00 00 05) | type(09) | streamID LE(07 00 00 00) + {"c0-short-ts", shortTs.generateC0Header, []byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x07, 0x00, 0x00, 0x00}}, + // ts field saturates to ff ff ff; ext-ts(01 00 00 08) appended after streamID. + {"c0-ext-ts", extTs.generateC0Header, []byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x08}}, + // 1-byte continuation: 0xc0 | cid. + {"c3-short-ts", shortTs.generateC3Header, []byte{0xc5}}, + // continuation + re-sent 4-byte ext-ts. + {"c3-ext-ts", extTs.generateC3Header, []byte{0xc5, 0x01, 0x00, 0x00, 0x08}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := c.gen() + if err != nil { + t.Fatalf("gen err=%v", err) + } + if !bytes.Equal(got, c.want) { + t.Fatalf("got=%x want=%x", got, c.want) + } + }) + } + }) + + // Control packets, full on-wire framing via WritePacket. WritePacket frames each control packet + // on cid=2 (chunkIDProtocolControl) with ts=0 and streamID=0, so the wire is the 12-byte + // short-ts C0 header followed by the marshaled payload (all shorter than the default chunk size, + // hence a single chunk). The payload values mirror the C++ send goldens. + t.Run("control-packets", func(t *testing.T) { + cases := []struct { + name string + pkt Packet + want []byte + }{ + // SetChunkSize 1024=0x00000400, type 0x01, len 4. + { + "set-chunk-size", + &SetChunkSize{ChunkSize: 1024}, + []byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00}, + }, + // WindowAcknowledgementSize 102400=0x00019000, type 0x05, len 4. + { + "window-ack-size", + &WindowAcknowledgementSize{AckSize: 102400}, + []byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x90, 0x00}, + }, + // SetPeerBandwidth 1024=0x00000400 + limit soft(0x01), type 0x06, len 5. + { + "set-peer-bandwidth", + &SetPeerBandwidth{Bandwidth: 1024, LimitType: LimitTypeSoft}, + []byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x01}, + }, + // UserControl SetBufferLength: event-type 0x0003, event-data 0x00000001, + // extra-data 0x00000010; type 0x04, len 10=0x0a. + { + "user-control-set-buffer-length", + &UserControl{EventType: EventTypeSetBufferLength, EventData: 0x01, ExtraData: 0x10}, + []byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10}, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var wire bytes.Buffer + p := NewProtocol(&wire).(*protocol) + if err := p.WritePacket(ctx, c.pkt, 0); err != nil { + t.Fatalf("WritePacket err=%v", err) + } + if !bytes.Equal(wire.Bytes(), c.want) { + t.Fatalf("got=%x want=%x", wire.Bytes(), c.want) + } + }) + } + }) + + // UserControl payload, all three Size() branches. The marshaler has three event-data shapes: a + // normal 4-byte event-data (e.g. PingRequest), an 8-byte event-data with the extra + // buffer-length word (SetBufferLength), and the special 1-byte event-data for FmsEvent0 + // (0x001a). Pin the raw payload bytes for each. + t.Run("user-control-event-forms", func(t *testing.T) { + cases := []struct { + name string + pkt *UserControl + want []byte + }{ + // event-type 0x0006, event-data 0x12345678 (4 bytes). + {"ping-request-4byte", &UserControl{EventType: EventTypePingRequest, EventData: 0x12345678}, []byte{0x00, 0x06, 0x12, 0x34, 0x56, 0x78}}, + // event-type 0x0003, event-data 0x00000001, extra 0x000005dc (8 bytes total). + {"set-buffer-length-8byte", &UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 1500}, []byte{0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x05, 0xdc}}, + // event-type 0x001a, event-data 0x01 (1 byte). + {"fms-event0-1byte", &UserControl{EventType: EventTypeFmsEvent0, EventData: 0x01}, []byte{0x00, 0x1a, 0x01}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := c.pkt.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary err=%v", err) + } + if !bytes.Equal(got, c.want) { + t.Fatalf("got=%x want=%x", got, c.want) + } + }) + } + }) +} + +// === P7: Fuzz targets === +// +// Three Go native fuzz targets covering the untrusted-input parsers in this package. +// FuzzReadMessage — (*protocol).ReadMessage on arbitrary wire bytes. +// FuzzDecodeMessage — (*protocol).DecodeMessage on arbitrary (MessageType, payload). +// FuzzPacketUnmarshal — *Packet.UnmarshalBinary on arbitrary bytes across every Packet type. +// +// Each target's contract is "no panic". Termination is guaranteed: every fuzz body reads +// from a finite bytes.Buffer and caps input size to keep iterations cheap. Real OOM / +// resource-exhaustion surfaces (e.g. an attacker-controlled SetChunkSize followed by a +// large payload length forcing a multi-MB make) are intentionally NOT pre-guarded here — +// fuzzing is how P8's adversarial cases get discovered. +// +// Seeds come from the existing happy-path tests so the fuzzer starts at valid wire bytes +// and explores the nearby malformed space. + +// fuzzInputCap bounds the input fuzz can feed each iteration. The cap keeps single +// iterations under a millisecond on a laptop and stops the corpus from growing +// arbitrarily — it is not a security boundary. +const fuzzInputCap = 8 * 1024 + +// FuzzReadMessage drives the full chunk-stream reader against arbitrary bytes. The +// target asserts no panic across readBasicHeader (1/2/3-byte cid), readMessageHeader +// (every fmt + ext-ts), readMessagePayload (chunked reassembly), and onMessageArrivated +// (SetChunkSize side effect on subsequent reads). +func FuzzReadMessage(f *testing.F) { + // Seed 1: a single fmt0 audio message on cid=5, ts=10, len=3. + f.Add([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3}) + + // Seed 2: fmt0 -> fmt1 -> fmt2 -> fmt3 sequence on cid=5 + // (lifted from TestReadMessageHeadersPayloadsAndChunks). + { + var s bytes.Buffer + s.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3}) + s.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5}) + s.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7}) + s.Write([]byte{0xc5, 8, 9}) + f.Add(s.Bytes()) + } + + // Seed 3: extended-timestamp fmt0 with payload split across chunks at the default + // chunk size. Exercises the ext-ts read + accumulate path. + { + var s bytes.Buffer + s.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, 0x05, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + binary.Write(&s, binary.BigEndian, uint32(42)) + s.Write([]byte{1, 2, 3, 4, 5}) + f.Add(s.Bytes()) + } + + // Seed 4: 2-byte cid header (cid=74) wrapping a complete fmt0 message. Exercises + // the 2-byte readBasicHeader branch end-to-end. + f.Add([]byte{0x00, 0x0a, 0x00, 0x00, 0x05, 0x00, 0x00, 0x01, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00, 0xAA}) + + // Seed 5: SetChunkSize=128 followed by an audio message on cid=5. Exercises the + // onMessageArrivated -> input.chunkSize update path. + { + var s bytes.Buffer + s.Write([]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, byte(MessageTypeSetChunkSize), 0x00, 0x00, 0x00, 0x00}) + binary.Write(&s, binary.BigEndian, uint32(128)) + s.Write([]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 0xAA}) + f.Add(s.Bytes()) + } + + f.Fuzz(func(t *testing.T, data []byte) { + if len(data) > fuzzInputCap { + return + } + ctx := context.Background() + p := NewProtocol(bytes.NewBuffer(append([]byte(nil), data...))) + // bytes.Buffer EOFs deterministically once drained, so a small cap on the + // number of messages we accept here is enough to bound the iteration. + for range 16 { + if _, err := p.ReadMessage(ctx); err != nil { + return + } + } + }) +} + +// FuzzDecodeMessage drives DecodeMessage with an arbitrary (MessageType, payload) pair. +// It covers the SetChunkSize / WindowAcknowledgementSize / SetPeerBandwidth / UserControl +// branches and, for AMF0/AMF3 command/data types, the parseAMFObject dispatch into every +// concrete *Packet's UnmarshalBinary. +func FuzzDecodeMessage(f *testing.F) { + // Seed control-message payloads at their exact required sizes. + f.Add(uint8(MessageTypeSetChunkSize), []byte{0, 0, 0, 128}) + f.Add(uint8(MessageTypeWindowAcknowledgementSize), []byte{0, 0, 0x10, 0}) + f.Add(uint8(MessageTypeSetPeerBandwidth), []byte{0, 0, 0x10, 0, byte(LimitTypeDynamic)}) + // UserControl PingRequest: 2B event-type + 4B data. + f.Add(uint8(MessageTypeUserControl), []byte{0x00, byte(EventTypePingRequest), 0x00, 0x00, 0x00, 0x01}) + + // Seed an AMF0 command with a well-formed publish packet. Also seed the AMF3 form, + // which differs only by a leading byte that DecodeMessage strips. + pub := NewPublishPacket() + pub.TransactionID = 0 + pub.StreamName = NewAmf0String("s") + pubBytes, err := pub.MarshalBinary() + if err != nil { + f.Fatalf("seed marshal publish: %v", err) + } + f.Add(uint8(MessageTypeAMF0Command), pubBytes) + f.Add(uint8(MessageTypeAMF3Command), append([]byte{0}, pubBytes...)) + + f.Fuzz(func(t *testing.T, mtype uint8, payload []byte) { + if len(payload) > fuzzInputCap { + return + } + p := NewProtocol(&bytes.Buffer{}) + m := &message{payload: payload} + m.messageHeader.MessageType = MessageType(mtype) + _, _ = p.DecodeMessage(m) + }) +} + +// FuzzPacketUnmarshal drives every Packet's UnmarshalBinary against arbitrary bytes. +// One target, dispatched by a kind discriminator, so the fuzzer can share a corpus and +// mutate the kind alongside the bytes. +func FuzzPacketUnmarshal(f *testing.F) { + // Build a seed per packet type from its own round-trippable bytes. + type seed struct { + kind uint8 + pkt Packet + } + connRes := NewConnectAppResPacket(7) + connRes.Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid"))) + call := NewCallPacket() + call.CommandName = commandOnStatus + call.TransactionID = 0 + call.CommandObject = NewAmf0Null() + pub := NewPublishPacket() + pub.TransactionID = 0 + pub.StreamName = NewAmf0String("s") + play := NewPlayPacket() + play.TransactionID = 0 + play.StreamName = NewAmf0String("s") + seeds := []seed{ + {0, NewConnectAppPacket()}, + {1, connRes}, + {2, call}, + {3, NewCreateStreamPacket()}, + {4, NewCreateStreamResPacket(2)}, + {5, pub}, + {6, play}, + {7, &SetChunkSize{ChunkSize: 128}}, + {8, &WindowAcknowledgementSize{AckSize: 2500000}}, + {9, &SetPeerBandwidth{Bandwidth: 2500000, LimitType: LimitTypeDynamic}}, + {10, &UserControl{EventType: EventTypePingRequest, EventData: 1}}, + } + for _, s := range seeds { + b, err := s.pkt.MarshalBinary() + if err != nil { + f.Fatalf("seed marshal %T: %v", s.pkt, err) + } + f.Add(s.kind, b) + } + + f.Fuzz(func(t *testing.T, kind uint8, data []byte) { + if len(data) > fuzzInputCap { + return + } + var pkt Packet + switch kind % 11 { + case 0: + pkt = NewConnectAppPacket() + case 1: + pkt = NewConnectAppResPacket(0) + case 2: + pkt = NewCallPacket() + case 3: + pkt = NewCreateStreamPacket() + case 4: + pkt = NewCreateStreamResPacket(0) + case 5: + pkt = NewPublishPacket() + case 6: + pkt = NewPlayPacket() + case 7: + pkt = NewSetChunkSize() + case 8: + pkt = NewWindowAcknowledgementSize() + case 9: + pkt = NewSetPeerBandwidth() + case 10: + pkt = NewUserControl() + } + _ = pkt.UnmarshalBinary(data) + }) +} + +// TestPacketUnmarshalAdversarialInputs covers P8: malformed and truncated wire input +// must never panic the parser, only error. The P7 fuzzers found a panic class where a +// New*Packet constructor pre-set an optional AMF0 field (variantCallPacket.CommandObject +// or CallPacket.Args), and Size() then counted that phantom default even when the wire +// was exhausted before it — so the caller's p = p[Size():] advance sliced out of range +// (rtmp.go:1512, "slice bounds out of range"). These cases lock in the fix; the two +// minimized fuzz inputs are also committed under testdata/fuzz as regression corpus. +func TestPacketUnmarshalAdversarialInputs(t *testing.T) { + // safeUnmarshal runs UnmarshalBinary and turns any panic into an immediate failure. + safeUnmarshal := func(t *testing.T, pkt Packet, data []byte) (err error) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic on %T with %x: %v", pkt, data, r) + } + }() + return pkt.UnmarshalBinary(data) + } + + // The two inputs the P7 fuzzers minimized to. Both are a "publish" command name + a + // number transaction id with nothing after, so the optional command object is absent + // and must not be counted by Size(). They previously panicked; expect a clean error. + t.Run("fuzz-crashers", func(t *testing.T) { + // FuzzPacketUnmarshal/2b0534f8182fac96: direct PublishPacket.UnmarshalBinary. + if err := safeUnmarshal(t, NewPublishPacket(), []byte("\x02\x00\x00\x0000000000")); err == nil { + t.Fatalf("truncated publish: want error, got nil") + } + // FuzzDecodeMessage/20ed1884f5b4f009: the AMF3 form reaches the same packet via + // DecodeMessage, which strips the leading AMF3 byte ('0'=0x30) before dispatching. + func() { + defer func() { + if r := recover(); r != nil { + t.Fatalf("DecodeMessage panic: %v", r) + } + }() + p := NewProtocol(&bytes.Buffer{}) + m := &message{payload: []byte("0\x02\x00\apublish\x0000000000")} + m.messageHeader.MessageType = MessageTypeAMF3Command + if _, err := p.DecodeMessage(m); err == nil { + t.Fatalf("truncated AMF3 publish: want error, got nil") + } + }() + }) + + // For every variantCallPacket-derived packet, marshal a valid instance then feed every + // truncation of its bytes back to a fresh packet. No prefix may panic, and the + // full-length bytes must still round-trip. + t.Run("truncations", func(t *testing.T) { + call := NewCallPacket() + call.CommandName = commandOnStatus + call.TransactionID = 0 + call.CommandObject = NewAmf0Null() + pub := NewPublishPacket() + pub.TransactionID = 0 + pub.StreamName = NewAmf0String("s") + play := NewPlayPacket() + play.TransactionID = 0 + play.StreamName = NewAmf0String("s") + cases := []struct { + name string + full Packet + fresh func() Packet + }{ + {"call", call, func() Packet { return NewCallPacket() }}, + {"publish", pub, func() Packet { return NewPublishPacket() }}, + {"play", play, func() Packet { return NewPlayPacket() }}, + {"createStreamRes", NewCreateStreamResPacket(2), func() Packet { return NewCreateStreamResPacket(0) }}, + } + for _, c := range cases { + b, err := c.full.MarshalBinary() + if err != nil { + t.Fatalf("%v marshal: %v", c.name, err) + } + for n := 0; n <= len(b); n++ { + err := safeUnmarshal(t, c.fresh(), b[:n]) + if n == len(b) && err != nil { + t.Fatalf("%v full unmarshal: %v", c.name, err) + } + } + } + }) + + // An oversized declared message length with a truncated stream must error via the + // incremental chunk read (readMessagePayload caps each read at chunkSize and lets + // io.ReadFull fail), not allocate ~16MB up front or hang. The header declares + // payloadLength = 0xffffff but only four payload bytes follow. + t.Run("oversized-length-truncated", func(t *testing.T) { + var in bytes.Buffer + in.Write([]byte{0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + in.Write([]byte{1, 2, 3, 4}) + p := NewProtocol(&in) + if _, err := p.ReadMessage(context.Background()); err == nil { + t.Fatalf("oversized truncated message: want error, got nil") + } + }) +} + +// P9 — concurrency / race on the transaction map. +// +// WritePacket registers a tid -> request-name entry under input.ltransactions +// (via onPacketWriten); parseAMFObject, reached from DecodeMessage when a +// _result/_error arrives, reads and deletes that entry under the same lock. +// This test hammers both paths from two goroutines over an overlapping tid +// space so the writes and reads/deletes genuinely interleave on +// input.transactions. Run with -race to validate the locking; even without it, +// the Go runtime panics on unsynchronized concurrent map access, so a dropped +// lock in either path fails the test. +// +// The shared state is only the transactions map: WritePacket touches just the +// writer (io.Discard here) and DecodeMessage operates on the Message it is +// handed, never the reader. Read-side "No matched request" errors are expected +// and tolerated — a tid may not be registered yet or may already be consumed; +// the assertion is no race and no panic, not that every lookup hits. +// +// No C++ reference: the C++ ProtocolStackTest suite has no concurrency test. +// New coverage. +func TestProtocolTransactionMapConcurrency(t *testing.T) { + const ( + keys = 16 // tid space both goroutines cycle through (overlap forces hits) + iterations = 4000 // per goroutine + ) + + // WritePacket only writes to v.w; DecodeMessage only reads the Message it is + // given. io.Discard keeps the single-writer goroutine from growing a buffer. + rw := struct { + io.Reader + io.Writer + }{strings.NewReader(""), io.Discard} + p := NewProtocol(rw).(*protocol) + + // Pre-build the _result bytes per tid so the read loop exercises only the + // map access in parseAMFObject, not AMF marshaling. releaseStream is one of + // the request names parseAMFObject resolves a _result against, so a hit + // returns a CallPacket and deletes the entry. + resultBytes := make([][]byte, keys+1) + for tid := 1; tid <= keys; tid++ { + res := NewCallPacket() + res.CommandName = commandResult + res.TransactionID = amf0Number(tid) + res.CommandObject = NewAmf0Null() + resultBytes[tid] = mustPacketBytes(t, res) + } + + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(2) + + // Writer: registers tid -> releaseStream under the lock, cycling the tid space. + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + call := NewCallPacket() + call.CommandName = commandReleaseStream + call.TransactionID = amf0Number(i%keys + 1) + call.CommandObject = NewAmf0Null() + if err := p.WritePacket(ctx, call, 0); err != nil { + t.Errorf("WritePacket err=%v", err) + return + } + } + }() + + // Reader: decodes _result messages whose tids overlap the writer's range. + // Hits delete the entry; misses return "No matched request". Both take the lock. + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + msg := &message{payload: resultBytes[i%keys+1]} + msg.messageHeader.MessageType = MessageTypeAMF0Command + if _, err := p.DecodeMessage(msg); err != nil && + !strings.Contains(err.Error(), "No matched request") { + t.Errorf("DecodeMessage err=%v", err) + return + } + } + }() + + wg.Wait() +} diff --git a/internal/version/version.go b/internal/version/version.go index 1f271694c..2c570a9b1 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -15,7 +15,7 @@ func VersionMinor() int { } func VersionRevision() int { - return 2 + return 3 } func Version() string { diff --git a/skills/srs-develop/SKILL.md b/skills/srs-develop/SKILL.md index a252e9345..968f490c5 100644 --- a/skills/srs-develop/SKILL.md +++ b/skills/srs-develop/SKILL.md @@ -143,7 +143,7 @@ Only after the user confirms the routing do you proceed to Step 2. ``` bash scripts/proxy-utest.sh --coverage ``` -4. Run the proxy E2E tests: +4. Run **all** of the proxy E2E tests below — every one, not just the first. Run them one at a time (they bind fixed ports, so they cannot run in parallel), and do not stop early: a later test can fail even when the earlier ones pass. - Single-origin RTMP proxy test (starts proxy + one SRS origin, publishes RTMP, verifies playback): ``` bash scripts/proxy-e2e-test.sh diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index b10cb60e4..d95eca246 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 8.0 Changelog +* v8.0, 2026-05-28, Merge [#4680](https://github.com/ossrs/srs/pull/4680): RTMP: Fix chunk timestamp/basic-header decoding and harden packet unmarshal. v8.0.3 (#4680) * v8.0, 2026-05-19, Merge [#4678](https://github.com/ossrs/srs/pull/4678): Edge: Fix HTTP-FLV 404 and RTMP late-join missing sequence headers. v8.0.2 (#4678) * v8.0, 2026-05-17, Merge [#4676](https://github.com/ossrs/srs/pull/4676): Proxy: Fix RTC/SRT reader goroutine leak; unwrap legacy WHEP JSON envelope; add WHEP pprof guide. v8.0.1 (#4676) * v8.0, 2026-05-17, Init SRS 8.0, code Free. v8.0.0 diff --git a/trunk/src/core/srs_core_version8.hpp b/trunk/src/core/srs_core_version8.hpp index ac376c261..867dcbd9c 100644 --- a/trunk/src/core/srs_core_version8.hpp +++ b/trunk/src/core/srs_core_version8.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 8 #define VERSION_MINOR 0 -#define VERSION_REVISION 2 +#define VERSION_REVISION 3 #endif