Codex: Add RTMP protocol tests
Add coverage for RTMP handshakes, chunk headers, message IO, packet decoding, transactions, and error branches. Rename examples so Go documents the exported constructors they exercise. Co-authored-by: chatgpt-codex-connector[bot] <199175422+chatgpt-codex-connector[bot]@users.noreply.github.com>
This commit is contained in:
parent
c2f2dc768a
commit
88a26804dd
|
|
@ -65,7 +65,7 @@ func ExampleAmf0Object() {
|
|||
// is number: false
|
||||
}
|
||||
|
||||
func ExampleRTMPHandshake() {
|
||||
func ExampleNewHandshake() {
|
||||
client := rtmp.NewHandshake()
|
||||
server := rtmp.NewHandshake()
|
||||
|
||||
|
|
@ -136,7 +136,7 @@ func ExampleRTMPHandshake() {
|
|||
// server cached c1: true
|
||||
}
|
||||
|
||||
func ExampleRTMPProtocol() {
|
||||
func ExampleNewProtocol() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
|
|
|||
728
internal/rtmp/rtmp_test.go
Normal file
728
internal/rtmp/rtmp_test.go
Normal file
|
|
@ -0,0 +1,728 @@
|
|||
// Copyright (c) 2026 Winlin
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
package rtmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type errWriter struct{}
|
||||
|
||||
func (errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
|
||||
func TestHandshakeSimpleAndErrors(t *testing.T) {
|
||||
h := NewHandshake()
|
||||
var b bytes.Buffer
|
||||
if err := h.WriteC0S0(&b); err != nil {
|
||||
t.Fatalf("WriteC0S0 err=%v", err)
|
||||
}
|
||||
c0, err := h.ReadC0S0(&b)
|
||||
if err != nil || !bytes.Equal(c0, []byte{3}) {
|
||||
t.Fatalf("ReadC0S0=%v, err=%v", c0, err)
|
||||
}
|
||||
if err := h.WriteC0S0(errWriter{}); err == nil {
|
||||
t.Fatal("WriteC0S0 should fail")
|
||||
}
|
||||
if _, err := h.ReadC0S0(bytes.NewReader(nil)); err == nil {
|
||||
t.Fatal("ReadC0S0 should fail")
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
if err := h.WriteC1S1(&b); err != nil {
|
||||
t.Fatalf("WriteC1S1 err=%v", err)
|
||||
}
|
||||
if b.Len() != 1536 {
|
||||
t.Fatalf("C1S1 len=%v", b.Len())
|
||||
}
|
||||
c1, err := h.ReadC1S1(&b)
|
||||
if err != nil || len(c1) != 1536 || !bytes.Equal(h.C1S1(), c1) {
|
||||
t.Fatalf("ReadC1S1 len=%v, cached=%v, err=%v", len(c1), bytes.Equal(h.C1S1(), c1), err)
|
||||
}
|
||||
if err := h.WriteC1S1(errWriter{}); err == nil {
|
||||
t.Fatal("WriteC1S1 should fail")
|
||||
}
|
||||
if _, err := h.ReadC1S1(bytes.NewReader(make([]byte, 1535))); err == nil {
|
||||
t.Fatal("ReadC1S1 should fail")
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
if err := h.WriteC2S2(&b, c1); err != nil {
|
||||
t.Fatalf("WriteC2S2 err=%v", err)
|
||||
}
|
||||
c2, err := h.ReadC2S2(&b)
|
||||
if err != nil || !bytes.Equal(c2, c1) {
|
||||
t.Fatalf("ReadC2S2 match=%v, err=%v", bytes.Equal(c2, c1), err)
|
||||
}
|
||||
if err := h.WriteC2S2(errWriter{}, c1); err == nil {
|
||||
t.Fatal("WriteC2S2 should fail")
|
||||
}
|
||||
if _, err := h.ReadC2S2(bytes.NewReader(make([]byte, 1535))); err == nil {
|
||||
t.Fatal("ReadC2S2 should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsChunkStreamAndMessageConstructors(t *testing.T) {
|
||||
if s := newSettings(); s.chunkSize != defaultChunkSize {
|
||||
t.Fatalf("chunk size=%v", s.chunkSize)
|
||||
}
|
||||
if c := newChunkStream(); c == nil || c.count != 0 {
|
||||
t.Fatalf("chunk stream=%#v", c)
|
||||
}
|
||||
m := NewMessage().asMessage()
|
||||
m.messageHeader.MessageType = MessageTypeAudio
|
||||
m.messageHeader.Timestamp = 99
|
||||
m.payload = []byte{1, 2, 3}
|
||||
if m.MessageType() != MessageTypeAudio || m.Timestamp() != 99 || !bytes.Equal(m.Payload(), []byte{1, 2, 3}) || m.asMessage() != m {
|
||||
t.Fatalf("bad message accessors")
|
||||
}
|
||||
sm := NewStreamMessage(7).asMessage()
|
||||
if sm.streamID != 7 || sm.betterCid != chunkIDOverStream {
|
||||
t.Fatalf("stream message=%#v", sm.messageHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicHeaderVariantsAndErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
fmt formatType
|
||||
cid chunkID
|
||||
}{
|
||||
{"one-byte", []byte{0x85}, formatType2, 5},
|
||||
{"two-byte", []byte{0x40, 0x0a}, formatType1, 74},
|
||||
{"three-byte-code-path", []byte{0xc1, 0x01, 0x02}, formatType3, 65},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewProtocol(bytes.NewBuffer(tt.data)).(*protocol)
|
||||
fmt, cid, err := p.readBasicHeader(ctx)
|
||||
if err != nil || fmt != tt.fmt || cid != tt.cid {
|
||||
t.Fatalf("fmt=%v cid=%v err=%v", fmt, cid, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
for _, data := range [][]byte{{}, {0x00}} {
|
||||
p := NewProtocol(bytes.NewBuffer(data)).(*protocol)
|
||||
if _, _, err := p.readBasicHeader(ctx); err == nil {
|
||||
t.Fatalf("readBasicHeader(%x) should fail", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMessageHeadersPayloadsAndChunks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var in bytes.Buffer
|
||||
// fmt0 cid=5, timestamp=10, len=3, type audio, stream=1, payload 010203.
|
||||
in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3})
|
||||
// fmt1 same cid, delta=5, len=2, type video, payload 0405.
|
||||
in.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5})
|
||||
// fmt2 same cid, delta=7, reuses len/type/stream, payload 0607.
|
||||
in.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7})
|
||||
// fmt3 same cid, reuses delta and advances timestamp, payload 0809.
|
||||
in.Write([]byte{0xc5, 8, 9})
|
||||
|
||||
p := NewProtocol(&in).(*protocol)
|
||||
for i, want := range []struct {
|
||||
typ MessageType
|
||||
ts uint64
|
||||
pl []byte
|
||||
}{
|
||||
{MessageTypeAudio, 10, []byte{1, 2, 3}},
|
||||
{MessageTypeVideo, 15, []byte{4, 5}},
|
||||
{MessageTypeVideo, 22, []byte{6, 7}},
|
||||
{MessageTypeVideo, 29, []byte{8, 9}},
|
||||
} {
|
||||
m, err := p.ReadMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage #%v err=%v", i, err)
|
||||
}
|
||||
if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
|
||||
t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMessageExtendedTimestampAndChunking(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var in bytes.Buffer
|
||||
payload := []byte{1, 2, 3, 4, 5}
|
||||
// fmt0 cid=5, normal timestamp=0xffffff, extended timestamp has high bit set and should be masked.
|
||||
in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, byte(len(payload)), byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
|
||||
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
|
||||
in.Write(payload[:2])
|
||||
// continuation chunk has fmt3 and extended timestamp too.
|
||||
in.Write([]byte{0xc5})
|
||||
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
|
||||
in.Write(payload[2:4])
|
||||
in.Write([]byte{0xc5})
|
||||
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
|
||||
in.Write(payload[4:])
|
||||
|
||||
p := NewProtocol(&in).(*protocol)
|
||||
p.input.opt.chunkSize = 2
|
||||
m, err := p.ReadMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage err=%v", err)
|
||||
}
|
||||
if m.Timestamp() != 42 || !bytes.Equal(m.Payload(), payload) {
|
||||
t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMessageHeaderErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// Fresh non-zero chunk with fmt1 is rejected.
|
||||
p := NewProtocol(bytes.NewBuffer([]byte{0x45})).(*protocol)
|
||||
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "fresh chunk") {
|
||||
t.Fatalf("fresh fmt1 err=%v", err)
|
||||
}
|
||||
// Existing partial message cannot restart with fmt0.
|
||||
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x05})).(*protocol)
|
||||
p.input.opt.chunkSize = 1
|
||||
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "exists chunk") {
|
||||
t.Fatalf("restart err=%v", err)
|
||||
}
|
||||
// Size change in a continuation header is rejected.
|
||||
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x45, 0, 0, 1, 0, 0, 4, byte(MessageTypeAudio)})).(*protocol)
|
||||
p.input.opt.chunkSize = 1
|
||||
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "message size") {
|
||||
t.Fatalf("size change err=%v", err)
|
||||
}
|
||||
// Short payload and short extended timestamp.
|
||||
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 2, byte(MessageTypeAudio), 1, 0, 0, 0, 1})).(*protocol)
|
||||
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "read chunk") {
|
||||
t.Fatalf("payload err=%v", err)
|
||||
}
|
||||
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0xff, 0xff, 0xff, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 2})).(*protocol)
|
||||
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "ext-ts") {
|
||||
t.Fatalf("ext-ts err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMessageHeadersChunkingAndErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var out bytes.Buffer
|
||||
p := NewProtocol(&out).(*protocol)
|
||||
p.output.opt.chunkSize = 2
|
||||
m := NewStreamMessage(7).asMessage()
|
||||
m.messageHeader.MessageType = MessageTypeVideo
|
||||
m.messageHeader.Timestamp = extendedTimestamp + 9
|
||||
m.payload = []byte{1, 2, 3, 4, 5}
|
||||
if err := p.WriteMessage(ctx, m); err != nil {
|
||||
t.Fatalf("WriteMessage err=%v", err)
|
||||
}
|
||||
want := []byte{0x05, 0xff, 0xff, 0xff, 0, 0, 5, byte(MessageTypeVideo), 7, 0, 0, 0, 0x01, 0x00, 0x00, 0x08, 1, 2, 0xc5, 0x01, 0x00, 0x00, 0x08, 3, 4, 0xc5, 0x01, 0x00, 0x00, 0x08, 5}
|
||||
if !bytes.Equal(out.Bytes(), want) {
|
||||
t.Fatalf("written=%x want=%x", out.Bytes(), want)
|
||||
}
|
||||
if err := p.WriteMessage(ctx, (&message{})); err != nil {
|
||||
t.Fatalf("empty WriteMessage err=%v", err)
|
||||
}
|
||||
canceled, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
if err := p.WriteMessage(canceled, m); err != context.Canceled {
|
||||
t.Fatalf("canceled WriteMessage err=%v", err)
|
||||
}
|
||||
p = NewProtocol(struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{bytes.NewReader(nil), errWriter{}}).(*protocol)
|
||||
if err := p.WriteMessage(ctx, m); err == nil {
|
||||
t.Fatal("WriteMessage to bad writer should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocolDecodeMessageAndControls(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p := NewProtocol(&bytes.Buffer{}).(*protocol)
|
||||
if _, err := p.DecodeMessage((&message{})); err == nil || !strings.Contains(err.Error(), "Empty packet") {
|
||||
t.Fatalf("empty decode err=%v", err)
|
||||
}
|
||||
unknown := &message{}
|
||||
unknown.messageHeader.MessageType = MessageTypeAudio
|
||||
unknown.payload = []byte{1}
|
||||
if _, err := p.DecodeMessage(unknown); err == nil || !strings.Contains(err.Error(), "Unknown message") {
|
||||
t.Fatalf("unknown err=%v", err)
|
||||
}
|
||||
bad := &message{}
|
||||
bad.messageHeader.MessageType = MessageTypeSetChunkSize
|
||||
bad.payload = []byte{1, 2}
|
||||
if _, err := p.DecodeMessage(bad); err == nil || !strings.Contains(err.Error(), "Unmarshal") {
|
||||
t.Fatalf("bad control err=%v", err)
|
||||
}
|
||||
|
||||
for _, pkt := range []Packet{
|
||||
&SetChunkSize{ChunkSize: 4096},
|
||||
&WindowAcknowledgementSize{AckSize: 2500000},
|
||||
&SetPeerBandwidth{Bandwidth: 1000, LimitType: LimitTypeSoft},
|
||||
&UserControl{EventType: EventTypePingRequest, EventData: 123},
|
||||
} {
|
||||
data, err := pkt.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal %T err=%v", pkt, err)
|
||||
}
|
||||
m := &message{payload: data}
|
||||
m.messageHeader.MessageType = pkt.Type()
|
||||
got, err := p.DecodeMessage(m)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeMessage %T err=%v", pkt, err)
|
||||
}
|
||||
if reflect.TypeOf(got) != reflect.TypeOf(pkt) {
|
||||
t.Fatalf("got %T want %T", got, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
chunk := &SetChunkSize{ChunkSize: 3}
|
||||
m := &message{}
|
||||
m.messageHeader.MessageType = chunk.Type()
|
||||
m.payload, _ = chunk.MarshalBinary()
|
||||
if err := p.onMessageArrivated(m); err != nil || p.input.opt.chunkSize != 3 {
|
||||
t.Fatalf("onMessageArrivated err=%v chunk=%v", err, p.input.opt.chunkSize)
|
||||
}
|
||||
if err := p.onMessageArrivated(nil); err != nil {
|
||||
t.Fatalf("nil onMessageArrivated err=%v", err)
|
||||
}
|
||||
bad.Payload()[0] = 1
|
||||
if err := p.onMessageArrivated(bad); err == nil {
|
||||
t.Fatal("bad onMessageArrivated should fail")
|
||||
}
|
||||
if _, err := p.ExpectMessage(ctx); err == nil {
|
||||
t.Fatal("ExpectMessage on empty reader should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocolPacketsAndTransactions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var wire bytes.Buffer
|
||||
writer := NewProtocol(&wire).(*protocol)
|
||||
connect := NewConnectAppPacket()
|
||||
connect.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live"))
|
||||
if connect.Size() == 0 || connect.BetterCid() != chunkIDOverConnection || connect.Type() != MessageTypeAMF0Command || connect.TcUrl() == "" {
|
||||
t.Fatalf("connect metadata invalid")
|
||||
}
|
||||
if err := writer.WritePacket(ctx, connect, 0); err != nil {
|
||||
t.Fatalf("WritePacket connect err=%v", err)
|
||||
}
|
||||
if _, ok := writer.input.transactions[connect.TransactionID]; !ok {
|
||||
t.Fatal("connect transaction not tracked")
|
||||
}
|
||||
create := NewCreateStreamPacket()
|
||||
if err := writer.WritePacket(ctx, create, 0); err != nil {
|
||||
t.Fatalf("WritePacket create err=%v", err)
|
||||
}
|
||||
call := NewCallPacket()
|
||||
call.CommandName = commandReleaseStream
|
||||
call.TransactionID = 3
|
||||
call.CommandObject = NewAmf0Null()
|
||||
if err := writer.WritePacket(ctx, call, 0); err != nil {
|
||||
t.Fatalf("WritePacket call err=%v", err)
|
||||
}
|
||||
reader := NewProtocol(&wire)
|
||||
var gotConnect *ConnectAppPacket
|
||||
if _, err := ExpectPacket(ctx, reader, &gotConnect); err != nil || gotConnect.TcUrl() != "rtmp://host/live" {
|
||||
t.Fatalf("gotConnect=%v err=%v", gotConnect, err)
|
||||
}
|
||||
var gotCreate *CallPacket
|
||||
if _, err := ExpectPacket(ctx, reader, &gotCreate); err != nil || gotCreate.CommandName != commandCreateStream {
|
||||
t.Fatalf("gotCreate=%v err=%v", gotCreate, err)
|
||||
}
|
||||
|
||||
decoder := NewProtocol(&bytes.Buffer{}).(*protocol)
|
||||
decoder.input.transactions[1] = commandConnect
|
||||
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, NewConnectAppResPacket(1))); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(NewConnectAppResPacket(1)) {
|
||||
t.Fatalf("connect res pkt=%T err=%v", pkt, err)
|
||||
}
|
||||
decoder.input.transactions[2] = commandCreateStream
|
||||
csr := NewCreateStreamResPacket(2)
|
||||
csr.SetStreamID(99)
|
||||
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, csr)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(csr) {
|
||||
t.Fatalf("create res pkt=%T err=%v", pkt, err)
|
||||
}
|
||||
decoder.input.transactions[3] = commandReleaseStream
|
||||
res := NewCallPacket()
|
||||
res.CommandName = commandResult
|
||||
res.TransactionID = 3
|
||||
res.CommandObject = NewAmf0Null()
|
||||
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, res)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(res) {
|
||||
t.Fatalf("call res pkt=%T err=%v", pkt, err)
|
||||
}
|
||||
for _, name := range []amf0String{commandPublish, commandPlay, commandOnStatus} {
|
||||
pkt := NewCallPacket()
|
||||
pkt.CommandName = name
|
||||
pkt.TransactionID = 0
|
||||
pkt.CommandObject = NewAmf0Null()
|
||||
if name == commandPublish {
|
||||
pub := NewPublishPacket()
|
||||
pub.TransactionID = 0
|
||||
pub.StreamName = NewAmf0String("s")
|
||||
pub.StreamType = NewAmf0String("live")
|
||||
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pub)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pub) {
|
||||
t.Fatalf("publish decoded=%T err=%v", decoded, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if name == commandPlay {
|
||||
play := NewPlayPacket()
|
||||
play.TransactionID = 0
|
||||
play.StreamName = NewAmf0String("s")
|
||||
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, play)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(play) {
|
||||
t.Fatalf("play decoded=%T err=%v", decoded, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pkt)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pkt) {
|
||||
t.Fatalf("call decoded=%T err=%v", decoded, err)
|
||||
}
|
||||
}
|
||||
decoder.input.transactions[9] = commandPause
|
||||
errPkt := NewCallPacket()
|
||||
errPkt.CommandName = commandError
|
||||
errPkt.TransactionID = 9
|
||||
errPkt.CommandObject = NewAmf0Null()
|
||||
if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No request") {
|
||||
t.Fatalf("unknown request err=%v", err)
|
||||
}
|
||||
if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No matched request") {
|
||||
t.Fatalf("missing transaction err=%v", err)
|
||||
}
|
||||
if _, err := decoder.parseAMFObject([]byte{byte(amf0MarkerString), 0, 8, 'c'}); err == nil {
|
||||
t.Fatal("bad AMF parse should fail")
|
||||
}
|
||||
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
if err := writer.WritePacket(cctx, connect, 0); err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("WritePacket canceled err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeprecatedExpectPacketPanics(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("Expected panic")
|
||||
}
|
||||
}()
|
||||
NewProtocol(&bytes.Buffer{}).ExpectPacket(context.Background(), nil)
|
||||
}
|
||||
|
||||
func TestPacketRoundTripsAndErrors(t *testing.T) {
|
||||
packets := []Packet{
|
||||
NewConnectAppPacket(),
|
||||
NewConnectAppResPacket(7),
|
||||
NewCallPacket(),
|
||||
NewCreateStreamPacket(),
|
||||
func() Packet { p := NewCreateStreamResPacket(2); p.SetStreamID(1); return p }(),
|
||||
func() Packet {
|
||||
p := NewPublishPacket()
|
||||
p.TransactionID = 0
|
||||
p.StreamName = NewAmf0String("s")
|
||||
return p
|
||||
}(),
|
||||
func() Packet { p := NewPlayPacket(); p.TransactionID = 0; p.StreamName = NewAmf0String("s"); return p }(),
|
||||
&SetChunkSize{ChunkSize: 1},
|
||||
&WindowAcknowledgementSize{AckSize: 2},
|
||||
&SetPeerBandwidth{Bandwidth: 3, LimitType: LimitTypeDynamic},
|
||||
&UserControl{EventType: EventTypeFmsEvent0, EventData: 1},
|
||||
&UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 2},
|
||||
}
|
||||
// Initialize the generic call packet so it is marshalable.
|
||||
packets[2].(*CallPacket).CommandName = commandOnStatus
|
||||
packets[2].(*CallPacket).TransactionID = 0
|
||||
packets[2].(*CallPacket).CommandObject = NewAmf0Null()
|
||||
packets[2].(*CallPacket).Args = NewAmf0Object().Set("code", NewAmf0String("ok"))
|
||||
packets[1].(*ConnectAppResPacket).Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid")))
|
||||
|
||||
for _, pkt := range packets {
|
||||
t.Run(reflect.TypeOf(pkt).String(), func(t *testing.T) {
|
||||
data, err := pkt.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalBinary err=%v", err)
|
||||
}
|
||||
if len(data) != pkt.Size() {
|
||||
t.Fatalf("len=%v Size=%v", len(data), pkt.Size())
|
||||
}
|
||||
fresh := reflect.New(reflect.TypeOf(pkt).Elem()).Interface().(Packet)
|
||||
switch v := fresh.(type) {
|
||||
case *ConnectAppPacket:
|
||||
*v = *NewConnectAppPacket()
|
||||
case *ConnectAppResPacket:
|
||||
*v = *NewConnectAppResPacket(0)
|
||||
case *CreateStreamPacket:
|
||||
*v = *NewCreateStreamPacket()
|
||||
case *CreateStreamResPacket:
|
||||
*v = *NewCreateStreamResPacket(0)
|
||||
case *PublishPacket:
|
||||
*v = *NewPublishPacket()
|
||||
case *PlayPacket:
|
||||
*v = *NewPlayPacket()
|
||||
}
|
||||
if err := fresh.UnmarshalBinary(data); err != nil {
|
||||
t.Fatalf("UnmarshalBinary err=%v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
if packets[1].(*ConnectAppResPacket).SrsID() != "sid" || packets[2].(*CallPacket).ArgsCode() != "ok" {
|
||||
t.Fatalf("packet helpers failed")
|
||||
}
|
||||
if NewConnectAppResPacket(1).SrsID() != "" || NewCallPacket().ArgsCode() != "" || NewConnectAppPacket().TcUrl() != "" {
|
||||
t.Fatalf("empty helpers failed")
|
||||
}
|
||||
|
||||
badConnect := NewConnectAppPacket()
|
||||
badConnect.CommandName = commandPlay
|
||||
if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil {
|
||||
t.Fatal("bad connect name should fail")
|
||||
}
|
||||
badConnect = NewConnectAppPacket()
|
||||
badConnect.TransactionID = 2
|
||||
if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil {
|
||||
t.Fatal("bad connect tid should fail")
|
||||
}
|
||||
badRes := NewConnectAppResPacket(1)
|
||||
badRes.CommandName = commandPlay
|
||||
if err := badRes.UnmarshalBinary(mustPacketBytes(t, badRes)); err == nil {
|
||||
t.Fatal("bad connect response name should fail")
|
||||
}
|
||||
for _, pkt := range []Packet{NewConnectAppPacket(), NewCallPacket(), NewCreateStreamResPacket(1), NewPublishPacket(), NewPlayPacket()} {
|
||||
if err := pkt.UnmarshalBinary([]byte{byte(amf0MarkerString)}); err == nil {
|
||||
t.Fatalf("%T short unmarshal should fail", pkt)
|
||||
}
|
||||
}
|
||||
for _, pkt := range []Packet{&SetChunkSize{}, &WindowAcknowledgementSize{}, &SetPeerBandwidth{}, &UserControl{}} {
|
||||
if err := pkt.UnmarshalBinary([]byte{0, 1}); err == nil {
|
||||
t.Fatalf("%T short unmarshal should fail", pkt)
|
||||
}
|
||||
}
|
||||
uc := &UserControl{}
|
||||
if err := uc.UnmarshalBinary([]byte{0, byte(EventTypeSetBufferLength), 1, 2, 3, 4, 5}); err == nil {
|
||||
t.Fatal("short set-buffer-length should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func mustPacketBytes(t *testing.T, pkt Packet) []byte {
|
||||
t.Helper()
|
||||
data, err := pkt.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalBinary %T err=%v", pkt, err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
type failPacket struct{}
|
||||
|
||||
func (failPacket) Size() int { return 0 }
|
||||
func (failPacket) UnmarshalBinary([]byte) error { return io.ErrUnexpectedEOF }
|
||||
func (failPacket) MarshalBinary() ([]byte, error) { return nil, io.ErrClosedPipe }
|
||||
func (failPacket) BetterCid() chunkID { return chunkIDOverConnection }
|
||||
func (failPacket) Type() MessageType { return MessageTypeAMF0Command }
|
||||
|
||||
type stepWriter struct {
|
||||
writes int
|
||||
failAt int
|
||||
}
|
||||
|
||||
func (w *stepWriter) Write(p []byte) (int, error) {
|
||||
w.writes++
|
||||
if w.writes == w.failAt {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestProtocolAdditionalBranches(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p := NewProtocol(&bytes.Buffer{}).(*protocol)
|
||||
if NewSetPeerBandwidth().BetterCid() != chunkIDProtocolControl || NewUserControl().BetterCid() != chunkIDProtocolControl {
|
||||
t.Fatal("control better cid failed")
|
||||
}
|
||||
if err := p.WritePacket(ctx, failPacket{}, 0); err == nil || !strings.Contains(err.Error(), "marshal payload") {
|
||||
t.Fatalf("WritePacket marshal err=%v", err)
|
||||
}
|
||||
|
||||
payloadWriter := &stepWriter{failAt: 1}
|
||||
p = NewProtocol(struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{bytes.NewReader(nil), payloadWriter}).(*protocol)
|
||||
m := NewStreamMessage(1).asMessage()
|
||||
m.messageHeader.MessageType = MessageTypeVideo
|
||||
m.payload = bytes.Repeat([]byte{1}, 5000)
|
||||
if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "write chunk payload") {
|
||||
t.Fatalf("WriteMessage payload err=%v", err)
|
||||
}
|
||||
flushWriter := &stepWriter{failAt: 1}
|
||||
p = NewProtocol(struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{bytes.NewReader(nil), flushWriter}).(*protocol)
|
||||
m.payload = []byte{1}
|
||||
if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "flush writer") {
|
||||
t.Fatalf("WriteMessage flush err=%v writes=%v", err, flushWriter.writes)
|
||||
}
|
||||
|
||||
// Zero-length payload returns a complete message without reading chunk bytes.
|
||||
in := bytes.NewBuffer([]byte{0x05, 0, 0, 1, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0})
|
||||
p = NewProtocol(in).(*protocol)
|
||||
if msg, err := p.ReadMessage(ctx); err != nil || msg.MessageType() != MessageTypeAudio || len(msg.Payload()) != 0 {
|
||||
t.Fatalf("zero payload msg=%v err=%v", msg, err)
|
||||
}
|
||||
|
||||
// ExpectMessage skips unwanted message types before returning the desired one.
|
||||
var wire bytes.Buffer
|
||||
writer := NewProtocol(&wire).(*protocol)
|
||||
am := NewStreamMessage(1).asMessage()
|
||||
am.messageHeader.MessageType = MessageTypeAudio
|
||||
am.payload = []byte{1}
|
||||
vm := NewStreamMessage(1).asMessage()
|
||||
vm.messageHeader.MessageType = MessageTypeVideo
|
||||
vm.payload = []byte{2}
|
||||
if err := writer.WriteMessage(ctx, am); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := writer.WriteMessage(ctx, vm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
reader := NewProtocol(&wire)
|
||||
if got, err := reader.ExpectMessage(ctx, MessageTypeVideo); err != nil || got.MessageType() != MessageTypeVideo {
|
||||
t.Fatalf("ExpectMessage got=%v err=%v", got, err)
|
||||
}
|
||||
|
||||
// Generic ExpectPacket skips non-matching packets, then returns matching; it also reports decode/read errors.
|
||||
wire.Reset()
|
||||
writer = NewProtocol(&wire).(*protocol)
|
||||
if err := writer.WritePacket(ctx, &WindowAcknowledgementSize{AckSize: 1}, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := writer.WritePacket(ctx, &SetChunkSize{ChunkSize: 2}, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
reader = NewProtocol(&wire)
|
||||
var chunk *SetChunkSize
|
||||
if _, err := ExpectPacket(ctx, reader, &chunk); err != nil || chunk.ChunkSize != 2 {
|
||||
t.Fatalf("ExpectPacket chunk=%v err=%v", chunk, err)
|
||||
}
|
||||
reader = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 1, byte(MessageTypeSetChunkSize), 1, 0, 0, 0, 0}))
|
||||
if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "decode message") {
|
||||
t.Fatalf("ExpectPacket decode err=%v", err)
|
||||
}
|
||||
reader = NewProtocol(bytes.NewBuffer(nil))
|
||||
if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "read message") {
|
||||
t.Fatalf("ExpectPacket read err=%v", err)
|
||||
}
|
||||
|
||||
// AMF3 strips the leading byte before AMF0 decoding.
|
||||
pub := NewPublishPacket()
|
||||
pub.TransactionID = 0
|
||||
pub.StreamName = NewAmf0String("stream")
|
||||
data := append([]byte{0}, mustPacketBytes(t, pub)...)
|
||||
msg := &message{payload: data}
|
||||
msg.messageHeader.MessageType = MessageTypeAMF3Command
|
||||
if pkt, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(msg); err != nil || pkt.(*PublishPacket).StreamName.String() != "stream" {
|
||||
t.Fatalf("AMF3 decode pkt=%T err=%v", pkt, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocolErrorBranchesForCoverage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// ExpectMessage without requested types returns the first message.
|
||||
var wire bytes.Buffer
|
||||
w := NewProtocol(&wire).(*protocol)
|
||||
msg := NewStreamMessage(1).asMessage()
|
||||
msg.messageHeader.MessageType = MessageTypeAudio
|
||||
msg.payload = []byte{1}
|
||||
if err := w.WriteMessage(ctx, msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, err := NewProtocol(&wire).ExpectMessage(ctx); err != nil || got.MessageType() != MessageTypeAudio {
|
||||
t.Fatalf("ExpectMessage any got=%v err=%v", got, err)
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
if _, err := NewProtocol(bytes.NewBuffer(nil)).ReadMessage(cctx); err != context.Canceled {
|
||||
t.Fatalf("ReadMessage canceled err=%v", err)
|
||||
}
|
||||
if err := w.WriteMessage(cctx, (&message{})); err != context.Canceled {
|
||||
t.Fatalf("WriteMessage empty canceled err=%v", err)
|
||||
}
|
||||
badAMF := &message{payload: []byte{0xff}}
|
||||
badAMF.messageHeader.MessageType = MessageTypeAMF0Command
|
||||
if _, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(badAMF); err == nil || !strings.Contains(err.Error(), "Parse AMF") {
|
||||
t.Fatalf("bad AMF decode err=%v", err)
|
||||
}
|
||||
rn := commandResult
|
||||
resultName, _ := (&rn).MarshalBinary()
|
||||
if _, err := NewProtocol(&bytes.Buffer{}).(*protocol).parseAMFObject(append(resultName, 0)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
|
||||
t.Fatalf("bad result tid err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketUnmarshalErrorBranchesForCoverage(t *testing.T) {
|
||||
cn := commandConnect
|
||||
name, _ := (&cn).MarshalBinary()
|
||||
tn := amf0Number(1)
|
||||
tid, _ := (&tn).MarshalBinary()
|
||||
obj, _ := NewAmf0Object().MarshalBinary()
|
||||
base := append(append([]byte{}, name...), tid...)
|
||||
oc := &objectCallPacket{CommandObject: NewAmf0Object()}
|
||||
if err := oc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
|
||||
t.Fatalf("object tid err=%v", err)
|
||||
}
|
||||
if err := oc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal command") {
|
||||
t.Fatalf("object command err=%v", err)
|
||||
}
|
||||
withObj := append(append([]byte{}, base...), obj...)
|
||||
if err := oc.UnmarshalBinary(append(withObj, 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal args") {
|
||||
t.Fatalf("object args err=%v", err)
|
||||
}
|
||||
|
||||
vc := &variantCallPacket{}
|
||||
if err := vc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
|
||||
t.Fatalf("variant tid err=%v", err)
|
||||
}
|
||||
if err := vc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "discovery command object") {
|
||||
t.Fatalf("variant discovery err=%v", err)
|
||||
}
|
||||
if err := vc.UnmarshalBinary(append(append([]byte{}, base...), byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal command object") {
|
||||
t.Fatalf("variant command object err=%v", err)
|
||||
}
|
||||
|
||||
call := NewCallPacket()
|
||||
call.CommandName = commandOnStatus
|
||||
call.TransactionID = 0
|
||||
call.CommandObject = NewAmf0Null()
|
||||
callBase := mustPacketBytes(t, call)
|
||||
if err := NewCallPacket().UnmarshalBinary(append(callBase, 0xff)); err == nil || !strings.Contains(err.Error(), "discovery args") {
|
||||
t.Fatalf("call discovery args err=%v", err)
|
||||
}
|
||||
if err := NewCallPacket().UnmarshalBinary(append(callBase, byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal args") {
|
||||
t.Fatalf("call unmarshal args err=%v", err)
|
||||
}
|
||||
csr := NewCreateStreamResPacket(2)
|
||||
if err := NewCreateStreamResPacket(0).UnmarshalBinary(mustPacketBytes(t, &csr.variantCallPacket)); err == nil || !strings.Contains(err.Error(), "unmarshal sid") {
|
||||
t.Fatalf("create stream sid err=%v", err)
|
||||
}
|
||||
pub := NewPublishPacket()
|
||||
pub.TransactionID = 0
|
||||
pubPrefix, _ := pub.variantCallPacket.MarshalBinary()
|
||||
if err := NewPublishPacket().UnmarshalBinary(append(pubPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") {
|
||||
t.Fatalf("publish stream name err=%v", err)
|
||||
}
|
||||
streamName, _ := NewAmf0String("s").MarshalBinary()
|
||||
if err := NewPublishPacket().UnmarshalBinary(append(append(pubPrefix, streamName...), 0xff)); err == nil || !strings.Contains(err.Error(), "stream type") {
|
||||
t.Fatalf("publish stream type err=%v", err)
|
||||
}
|
||||
play := NewPlayPacket()
|
||||
play.TransactionID = 0
|
||||
playPrefix, _ := play.variantCallPacket.MarshalBinary()
|
||||
if err := NewPlayPacket().UnmarshalBinary(append(playPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") {
|
||||
t.Fatalf("play stream name err=%v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user