// Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT package rtmp import ( "bytes" "context" "encoding/binary" "io" "reflect" "strings" "testing" ) type errWriter struct{} func (errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe } func TestHandshakeSimpleAndErrors(t *testing.T) { h := NewHandshake() var b bytes.Buffer if err := h.WriteC0S0(&b); err != nil { t.Fatalf("WriteC0S0 err=%v", err) } c0, err := h.ReadC0S0(&b) if err != nil || !bytes.Equal(c0, []byte{3}) { t.Fatalf("ReadC0S0=%v, err=%v", c0, err) } if err := h.WriteC0S0(errWriter{}); err == nil { t.Fatal("WriteC0S0 should fail") } if _, err := h.ReadC0S0(bytes.NewReader(nil)); err == nil { t.Fatal("ReadC0S0 should fail") } b.Reset() if err := h.WriteC1S1(&b); err != nil { t.Fatalf("WriteC1S1 err=%v", err) } if b.Len() != 1536 { t.Fatalf("C1S1 len=%v", b.Len()) } c1, err := h.ReadC1S1(&b) if err != nil || len(c1) != 1536 || !bytes.Equal(h.C1S1(), c1) { t.Fatalf("ReadC1S1 len=%v, cached=%v, err=%v", len(c1), bytes.Equal(h.C1S1(), c1), err) } if err := h.WriteC1S1(errWriter{}); err == nil { t.Fatal("WriteC1S1 should fail") } if _, err := h.ReadC1S1(bytes.NewReader(make([]byte, 1535))); err == nil { t.Fatal("ReadC1S1 should fail") } b.Reset() if err := h.WriteC2S2(&b, c1); err != nil { t.Fatalf("WriteC2S2 err=%v", err) } c2, err := h.ReadC2S2(&b) if err != nil || !bytes.Equal(c2, c1) { t.Fatalf("ReadC2S2 match=%v, err=%v", bytes.Equal(c2, c1), err) } if err := h.WriteC2S2(errWriter{}, c1); err == nil { t.Fatal("WriteC2S2 should fail") } if _, err := h.ReadC2S2(bytes.NewReader(make([]byte, 1535))); err == nil { t.Fatal("ReadC2S2 should fail") } } func TestSettingsChunkStreamAndMessageConstructors(t *testing.T) { if s := newSettings(); s.chunkSize != defaultChunkSize { t.Fatalf("chunk size=%v", s.chunkSize) } if c := newChunkStream(); c == nil || c.count != 0 { t.Fatalf("chunk stream=%#v", c) } m := NewMessage().asMessage() m.messageHeader.MessageType = MessageTypeAudio m.messageHeader.Timestamp = 99 m.payload = []byte{1, 2, 3} if m.MessageType() != MessageTypeAudio || m.Timestamp() != 99 || !bytes.Equal(m.Payload(), []byte{1, 2, 3}) || m.asMessage() != m { t.Fatalf("bad message accessors") } sm := NewStreamMessage(7).asMessage() if sm.streamID != 7 || sm.betterCid != chunkIDOverStream { t.Fatalf("stream message=%#v", sm.messageHeader) } } func TestBasicHeaderVariantsAndErrors(t *testing.T) { ctx := context.Background() cases := []struct { name string data []byte fmt formatType cid chunkID }{ {"one-byte", []byte{0x85}, formatType2, 5}, {"two-byte", []byte{0x40, 0x0a}, formatType1, 74}, {"three-byte-code-path", []byte{0xc1, 0x01, 0x02}, formatType3, 65}, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { p := NewProtocol(bytes.NewBuffer(tt.data)).(*protocol) fmt, cid, err := p.readBasicHeader(ctx) if err != nil || fmt != tt.fmt || cid != tt.cid { t.Fatalf("fmt=%v cid=%v err=%v", fmt, cid, err) } }) } for _, data := range [][]byte{{}, {0x00}} { p := NewProtocol(bytes.NewBuffer(data)).(*protocol) if _, _, err := p.readBasicHeader(ctx); err == nil { t.Fatalf("readBasicHeader(%x) should fail", data) } } } func TestReadMessageHeadersPayloadsAndChunks(t *testing.T) { ctx := context.Background() var in bytes.Buffer // fmt0 cid=5, timestamp=10, len=3, type audio, stream=1, payload 010203. in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3}) // fmt1 same cid, delta=5, len=2, type video, payload 0405. in.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5}) // fmt2 same cid, delta=7, reuses len/type/stream, payload 0607. in.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7}) // fmt3 same cid, reuses delta and advances timestamp, payload 0809. in.Write([]byte{0xc5, 8, 9}) p := NewProtocol(&in).(*protocol) for i, want := range []struct { typ MessageType ts uint64 pl []byte }{ {MessageTypeAudio, 10, []byte{1, 2, 3}}, {MessageTypeVideo, 15, []byte{4, 5}}, {MessageTypeVideo, 22, []byte{6, 7}}, {MessageTypeVideo, 29, []byte{8, 9}}, } { m, err := p.ReadMessage(ctx) if err != nil { t.Fatalf("ReadMessage #%v err=%v", i, err) } if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload()) } } } func TestReadMessageExtendedTimestampAndChunking(t *testing.T) { ctx := context.Background() var in bytes.Buffer payload := []byte{1, 2, 3, 4, 5} // fmt0 cid=5, normal timestamp=0xffffff, extended timestamp has high bit set and should be masked. in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, byte(len(payload)), byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) binary.Write(&in, binary.BigEndian, uint32(0x8000002a)) in.Write(payload[:2]) // continuation chunk has fmt3 and extended timestamp too. in.Write([]byte{0xc5}) binary.Write(&in, binary.BigEndian, uint32(0x8000002a)) in.Write(payload[2:4]) in.Write([]byte{0xc5}) binary.Write(&in, binary.BigEndian, uint32(0x8000002a)) in.Write(payload[4:]) p := NewProtocol(&in).(*protocol) p.input.opt.chunkSize = 2 m, err := p.ReadMessage(ctx) if err != nil { t.Fatalf("ReadMessage err=%v", err) } if m.Timestamp() != 42 || !bytes.Equal(m.Payload(), payload) { t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload()) } } func TestReadMessageHeaderErrors(t *testing.T) { ctx := context.Background() // Fresh non-zero chunk with fmt1 is rejected. p := NewProtocol(bytes.NewBuffer([]byte{0x45})).(*protocol) if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "fresh chunk") { t.Fatalf("fresh fmt1 err=%v", err) } // Existing partial message cannot restart with fmt0. p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x05})).(*protocol) p.input.opt.chunkSize = 1 if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "exists chunk") { t.Fatalf("restart err=%v", err) } // Size change in a continuation header is rejected. p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x45, 0, 0, 1, 0, 0, 4, byte(MessageTypeAudio)})).(*protocol) p.input.opt.chunkSize = 1 if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "message size") { t.Fatalf("size change err=%v", err) } // Short payload and short extended timestamp. p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 2, byte(MessageTypeAudio), 1, 0, 0, 0, 1})).(*protocol) if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "read chunk") { t.Fatalf("payload err=%v", err) } p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0xff, 0xff, 0xff, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 2})).(*protocol) if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "ext-ts") { t.Fatalf("ext-ts err=%v", err) } } func TestWriteMessageHeadersChunkingAndErrors(t *testing.T) { ctx := context.Background() var out bytes.Buffer p := NewProtocol(&out).(*protocol) p.output.opt.chunkSize = 2 m := NewStreamMessage(7).asMessage() m.messageHeader.MessageType = MessageTypeVideo m.messageHeader.Timestamp = extendedTimestamp + 9 m.payload = []byte{1, 2, 3, 4, 5} if err := p.WriteMessage(ctx, m); err != nil { t.Fatalf("WriteMessage err=%v", err) } want := []byte{0x05, 0xff, 0xff, 0xff, 0, 0, 5, byte(MessageTypeVideo), 7, 0, 0, 0, 0x01, 0x00, 0x00, 0x08, 1, 2, 0xc5, 0x01, 0x00, 0x00, 0x08, 3, 4, 0xc5, 0x01, 0x00, 0x00, 0x08, 5} if !bytes.Equal(out.Bytes(), want) { t.Fatalf("written=%x want=%x", out.Bytes(), want) } if err := p.WriteMessage(ctx, (&message{})); err != nil { t.Fatalf("empty WriteMessage err=%v", err) } canceled, cancel := context.WithCancel(ctx) cancel() if err := p.WriteMessage(canceled, m); err != context.Canceled { t.Fatalf("canceled WriteMessage err=%v", err) } p = NewProtocol(struct { io.Reader io.Writer }{bytes.NewReader(nil), errWriter{}}).(*protocol) if err := p.WriteMessage(ctx, m); err == nil { t.Fatal("WriteMessage to bad writer should fail") } } func TestProtocolDecodeMessageAndControls(t *testing.T) { ctx := context.Background() p := NewProtocol(&bytes.Buffer{}).(*protocol) if _, err := p.DecodeMessage((&message{})); err == nil || !strings.Contains(err.Error(), "Empty packet") { t.Fatalf("empty decode err=%v", err) } unknown := &message{} unknown.messageHeader.MessageType = MessageTypeAudio unknown.payload = []byte{1} if _, err := p.DecodeMessage(unknown); err == nil || !strings.Contains(err.Error(), "Unknown message") { t.Fatalf("unknown err=%v", err) } bad := &message{} bad.messageHeader.MessageType = MessageTypeSetChunkSize bad.payload = []byte{1, 2} if _, err := p.DecodeMessage(bad); err == nil || !strings.Contains(err.Error(), "Unmarshal") { t.Fatalf("bad control err=%v", err) } for _, pkt := range []Packet{ &SetChunkSize{ChunkSize: 4096}, &WindowAcknowledgementSize{AckSize: 2500000}, &SetPeerBandwidth{Bandwidth: 1000, LimitType: LimitTypeSoft}, &UserControl{EventType: EventTypePingRequest, EventData: 123}, } { data, err := pkt.MarshalBinary() if err != nil { t.Fatalf("marshal %T err=%v", pkt, err) } m := &message{payload: data} m.messageHeader.MessageType = pkt.Type() got, err := p.DecodeMessage(m) if err != nil { t.Fatalf("DecodeMessage %T err=%v", pkt, err) } if reflect.TypeOf(got) != reflect.TypeOf(pkt) { t.Fatalf("got %T want %T", got, pkt) } } chunk := &SetChunkSize{ChunkSize: 3} m := &message{} m.messageHeader.MessageType = chunk.Type() m.payload, _ = chunk.MarshalBinary() if err := p.onMessageArrivated(m); err != nil || p.input.opt.chunkSize != 3 { t.Fatalf("onMessageArrivated err=%v chunk=%v", err, p.input.opt.chunkSize) } if err := p.onMessageArrivated(nil); err != nil { t.Fatalf("nil onMessageArrivated err=%v", err) } bad.Payload()[0] = 1 if err := p.onMessageArrivated(bad); err == nil { t.Fatal("bad onMessageArrivated should fail") } if _, err := p.ExpectMessage(ctx); err == nil { t.Fatal("ExpectMessage on empty reader should fail") } } func TestProtocolPacketsAndTransactions(t *testing.T) { ctx := context.Background() var wire bytes.Buffer writer := NewProtocol(&wire).(*protocol) connect := NewConnectAppPacket() connect.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live")) if connect.Size() == 0 || connect.BetterCid() != chunkIDOverConnection || connect.Type() != MessageTypeAMF0Command || connect.TcUrl() == "" { t.Fatalf("connect metadata invalid") } if err := writer.WritePacket(ctx, connect, 0); err != nil { t.Fatalf("WritePacket connect err=%v", err) } if _, ok := writer.input.transactions[connect.TransactionID]; !ok { t.Fatal("connect transaction not tracked") } create := NewCreateStreamPacket() if err := writer.WritePacket(ctx, create, 0); err != nil { t.Fatalf("WritePacket create err=%v", err) } call := NewCallPacket() call.CommandName = commandReleaseStream call.TransactionID = 3 call.CommandObject = NewAmf0Null() if err := writer.WritePacket(ctx, call, 0); err != nil { t.Fatalf("WritePacket call err=%v", err) } reader := NewProtocol(&wire) var gotConnect *ConnectAppPacket if _, err := ExpectPacket(ctx, reader, &gotConnect); err != nil || gotConnect.TcUrl() != "rtmp://host/live" { t.Fatalf("gotConnect=%v err=%v", gotConnect, err) } var gotCreate *CallPacket if _, err := ExpectPacket(ctx, reader, &gotCreate); err != nil || gotCreate.CommandName != commandCreateStream { t.Fatalf("gotCreate=%v err=%v", gotCreate, err) } decoder := NewProtocol(&bytes.Buffer{}).(*protocol) decoder.input.transactions[1] = commandConnect if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, NewConnectAppResPacket(1))); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(NewConnectAppResPacket(1)) { t.Fatalf("connect res pkt=%T err=%v", pkt, err) } decoder.input.transactions[2] = commandCreateStream csr := NewCreateStreamResPacket(2) csr.SetStreamID(99) if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, csr)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(csr) { t.Fatalf("create res pkt=%T err=%v", pkt, err) } decoder.input.transactions[3] = commandReleaseStream res := NewCallPacket() res.CommandName = commandResult res.TransactionID = 3 res.CommandObject = NewAmf0Null() if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, res)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(res) { t.Fatalf("call res pkt=%T err=%v", pkt, err) } for _, name := range []amf0String{commandPublish, commandPlay, commandOnStatus} { pkt := NewCallPacket() pkt.CommandName = name pkt.TransactionID = 0 pkt.CommandObject = NewAmf0Null() if name == commandPublish { pub := NewPublishPacket() pub.TransactionID = 0 pub.StreamName = NewAmf0String("s") pub.StreamType = NewAmf0String("live") if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pub)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pub) { t.Fatalf("publish decoded=%T err=%v", decoded, err) } continue } if name == commandPlay { play := NewPlayPacket() play.TransactionID = 0 play.StreamName = NewAmf0String("s") if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, play)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(play) { t.Fatalf("play decoded=%T err=%v", decoded, err) } continue } if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pkt)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pkt) { t.Fatalf("call decoded=%T err=%v", decoded, err) } } decoder.input.transactions[9] = commandPause errPkt := NewCallPacket() errPkt.CommandName = commandError errPkt.TransactionID = 9 errPkt.CommandObject = NewAmf0Null() if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No request") { t.Fatalf("unknown request err=%v", err) } if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No matched request") { t.Fatalf("missing transaction err=%v", err) } if _, err := decoder.parseAMFObject([]byte{byte(amf0MarkerString), 0, 8, 'c'}); err == nil { t.Fatal("bad AMF parse should fail") } cctx, cancel := context.WithCancel(ctx) cancel() if err := writer.WritePacket(cctx, connect, 0); err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { t.Fatalf("WritePacket canceled err=%v", err) } } func TestDeprecatedExpectPacketPanics(t *testing.T) { defer func() { if recover() == nil { t.Fatal("Expected panic") } }() NewProtocol(&bytes.Buffer{}).ExpectPacket(context.Background(), nil) } func TestPacketRoundTripsAndErrors(t *testing.T) { packets := []Packet{ NewConnectAppPacket(), NewConnectAppResPacket(7), NewCallPacket(), NewCreateStreamPacket(), func() Packet { p := NewCreateStreamResPacket(2); p.SetStreamID(1); return p }(), func() Packet { p := NewPublishPacket() p.TransactionID = 0 p.StreamName = NewAmf0String("s") return p }(), func() Packet { p := NewPlayPacket(); p.TransactionID = 0; p.StreamName = NewAmf0String("s"); return p }(), &SetChunkSize{ChunkSize: 1}, &WindowAcknowledgementSize{AckSize: 2}, &SetPeerBandwidth{Bandwidth: 3, LimitType: LimitTypeDynamic}, &UserControl{EventType: EventTypeFmsEvent0, EventData: 1}, &UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 2}, } // Initialize the generic call packet so it is marshalable. packets[2].(*CallPacket).CommandName = commandOnStatus packets[2].(*CallPacket).TransactionID = 0 packets[2].(*CallPacket).CommandObject = NewAmf0Null() packets[2].(*CallPacket).Args = NewAmf0Object().Set("code", NewAmf0String("ok")) packets[1].(*ConnectAppResPacket).Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid"))) for _, pkt := range packets { t.Run(reflect.TypeOf(pkt).String(), func(t *testing.T) { data, err := pkt.MarshalBinary() if err != nil { t.Fatalf("MarshalBinary err=%v", err) } if len(data) != pkt.Size() { t.Fatalf("len=%v Size=%v", len(data), pkt.Size()) } fresh := reflect.New(reflect.TypeOf(pkt).Elem()).Interface().(Packet) switch v := fresh.(type) { case *ConnectAppPacket: *v = *NewConnectAppPacket() case *ConnectAppResPacket: *v = *NewConnectAppResPacket(0) case *CreateStreamPacket: *v = *NewCreateStreamPacket() case *CreateStreamResPacket: *v = *NewCreateStreamResPacket(0) case *PublishPacket: *v = *NewPublishPacket() case *PlayPacket: *v = *NewPlayPacket() } if err := fresh.UnmarshalBinary(data); err != nil { t.Fatalf("UnmarshalBinary err=%v", err) } }) } if packets[1].(*ConnectAppResPacket).SrsID() != "sid" || packets[2].(*CallPacket).ArgsCode() != "ok" { t.Fatalf("packet helpers failed") } if NewConnectAppResPacket(1).SrsID() != "" || NewCallPacket().ArgsCode() != "" || NewConnectAppPacket().TcUrl() != "" { t.Fatalf("empty helpers failed") } badConnect := NewConnectAppPacket() badConnect.CommandName = commandPlay if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil { t.Fatal("bad connect name should fail") } badConnect = NewConnectAppPacket() badConnect.TransactionID = 2 if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil { t.Fatal("bad connect tid should fail") } badRes := NewConnectAppResPacket(1) badRes.CommandName = commandPlay if err := badRes.UnmarshalBinary(mustPacketBytes(t, badRes)); err == nil { t.Fatal("bad connect response name should fail") } for _, pkt := range []Packet{NewConnectAppPacket(), NewCallPacket(), NewCreateStreamResPacket(1), NewPublishPacket(), NewPlayPacket()} { if err := pkt.UnmarshalBinary([]byte{byte(amf0MarkerString)}); err == nil { t.Fatalf("%T short unmarshal should fail", pkt) } } for _, pkt := range []Packet{&SetChunkSize{}, &WindowAcknowledgementSize{}, &SetPeerBandwidth{}, &UserControl{}} { if err := pkt.UnmarshalBinary([]byte{0, 1}); err == nil { t.Fatalf("%T short unmarshal should fail", pkt) } } uc := &UserControl{} if err := uc.UnmarshalBinary([]byte{0, byte(EventTypeSetBufferLength), 1, 2, 3, 4, 5}); err == nil { t.Fatal("short set-buffer-length should fail") } } func mustPacketBytes(t *testing.T, pkt Packet) []byte { t.Helper() data, err := pkt.MarshalBinary() if err != nil { t.Fatalf("MarshalBinary %T err=%v", pkt, err) } return data } type failPacket struct{} func (failPacket) Size() int { return 0 } func (failPacket) UnmarshalBinary([]byte) error { return io.ErrUnexpectedEOF } func (failPacket) MarshalBinary() ([]byte, error) { return nil, io.ErrClosedPipe } func (failPacket) BetterCid() chunkID { return chunkIDOverConnection } func (failPacket) Type() MessageType { return MessageTypeAMF0Command } type stepWriter struct { writes int failAt int } func (w *stepWriter) Write(p []byte) (int, error) { w.writes++ if w.writes == w.failAt { return 0, io.ErrClosedPipe } return len(p), nil } func TestProtocolAdditionalBranches(t *testing.T) { ctx := context.Background() p := NewProtocol(&bytes.Buffer{}).(*protocol) if NewSetPeerBandwidth().BetterCid() != chunkIDProtocolControl || NewUserControl().BetterCid() != chunkIDProtocolControl { t.Fatal("control better cid failed") } if err := p.WritePacket(ctx, failPacket{}, 0); err == nil || !strings.Contains(err.Error(), "marshal payload") { t.Fatalf("WritePacket marshal err=%v", err) } payloadWriter := &stepWriter{failAt: 1} p = NewProtocol(struct { io.Reader io.Writer }{bytes.NewReader(nil), payloadWriter}).(*protocol) m := NewStreamMessage(1).asMessage() m.messageHeader.MessageType = MessageTypeVideo m.payload = bytes.Repeat([]byte{1}, 5000) if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "write chunk payload") { t.Fatalf("WriteMessage payload err=%v", err) } flushWriter := &stepWriter{failAt: 1} p = NewProtocol(struct { io.Reader io.Writer }{bytes.NewReader(nil), flushWriter}).(*protocol) m.payload = []byte{1} if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "flush writer") { t.Fatalf("WriteMessage flush err=%v writes=%v", err, flushWriter.writes) } // Zero-length payload returns a complete message without reading chunk bytes. in := bytes.NewBuffer([]byte{0x05, 0, 0, 1, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0}) p = NewProtocol(in).(*protocol) if msg, err := p.ReadMessage(ctx); err != nil || msg.MessageType() != MessageTypeAudio || len(msg.Payload()) != 0 { t.Fatalf("zero payload msg=%v err=%v", msg, err) } // ExpectMessage skips unwanted message types before returning the desired one. var wire bytes.Buffer writer := NewProtocol(&wire).(*protocol) am := NewStreamMessage(1).asMessage() am.messageHeader.MessageType = MessageTypeAudio am.payload = []byte{1} vm := NewStreamMessage(1).asMessage() vm.messageHeader.MessageType = MessageTypeVideo vm.payload = []byte{2} if err := writer.WriteMessage(ctx, am); err != nil { t.Fatal(err) } if err := writer.WriteMessage(ctx, vm); err != nil { t.Fatal(err) } reader := NewProtocol(&wire) if got, err := reader.ExpectMessage(ctx, MessageTypeVideo); err != nil || got.MessageType() != MessageTypeVideo { t.Fatalf("ExpectMessage got=%v err=%v", got, err) } // Generic ExpectPacket skips non-matching packets, then returns matching; it also reports decode/read errors. wire.Reset() writer = NewProtocol(&wire).(*protocol) if err := writer.WritePacket(ctx, &WindowAcknowledgementSize{AckSize: 1}, 0); err != nil { t.Fatal(err) } if err := writer.WritePacket(ctx, &SetChunkSize{ChunkSize: 2}, 0); err != nil { t.Fatal(err) } reader = NewProtocol(&wire) var chunk *SetChunkSize if _, err := ExpectPacket(ctx, reader, &chunk); err != nil || chunk.ChunkSize != 2 { t.Fatalf("ExpectPacket chunk=%v err=%v", chunk, err) } reader = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 1, byte(MessageTypeSetChunkSize), 1, 0, 0, 0, 0})) if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "decode message") { t.Fatalf("ExpectPacket decode err=%v", err) } reader = NewProtocol(bytes.NewBuffer(nil)) if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "read message") { t.Fatalf("ExpectPacket read err=%v", err) } // AMF3 strips the leading byte before AMF0 decoding. pub := NewPublishPacket() pub.TransactionID = 0 pub.StreamName = NewAmf0String("stream") data := append([]byte{0}, mustPacketBytes(t, pub)...) msg := &message{payload: data} msg.messageHeader.MessageType = MessageTypeAMF3Command if pkt, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(msg); err != nil || pkt.(*PublishPacket).StreamName.String() != "stream" { t.Fatalf("AMF3 decode pkt=%T err=%v", pkt, err) } } func TestProtocolErrorBranchesForCoverage(t *testing.T) { ctx := context.Background() // ExpectMessage without requested types returns the first message. var wire bytes.Buffer w := NewProtocol(&wire).(*protocol) msg := NewStreamMessage(1).asMessage() msg.messageHeader.MessageType = MessageTypeAudio msg.payload = []byte{1} if err := w.WriteMessage(ctx, msg); err != nil { t.Fatal(err) } if got, err := NewProtocol(&wire).ExpectMessage(ctx); err != nil || got.MessageType() != MessageTypeAudio { t.Fatalf("ExpectMessage any got=%v err=%v", got, err) } cctx, cancel := context.WithCancel(ctx) cancel() if _, err := NewProtocol(bytes.NewBuffer(nil)).ReadMessage(cctx); err != context.Canceled { t.Fatalf("ReadMessage canceled err=%v", err) } if err := w.WriteMessage(cctx, (&message{})); err != context.Canceled { t.Fatalf("WriteMessage empty canceled err=%v", err) } badAMF := &message{payload: []byte{0xff}} badAMF.messageHeader.MessageType = MessageTypeAMF0Command if _, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(badAMF); err == nil || !strings.Contains(err.Error(), "Parse AMF") { t.Fatalf("bad AMF decode err=%v", err) } rn := commandResult resultName, _ := (&rn).MarshalBinary() if _, err := NewProtocol(&bytes.Buffer{}).(*protocol).parseAMFObject(append(resultName, 0)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") { t.Fatalf("bad result tid err=%v", err) } } func TestPacketUnmarshalErrorBranchesForCoverage(t *testing.T) { cn := commandConnect name, _ := (&cn).MarshalBinary() tn := amf0Number(1) tid, _ := (&tn).MarshalBinary() obj, _ := NewAmf0Object().MarshalBinary() base := append(append([]byte{}, name...), tid...) oc := &objectCallPacket{CommandObject: NewAmf0Object()} if err := oc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") { t.Fatalf("object tid err=%v", err) } if err := oc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal command") { t.Fatalf("object command err=%v", err) } withObj := append(append([]byte{}, base...), obj...) if err := oc.UnmarshalBinary(append(withObj, 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal args") { t.Fatalf("object args err=%v", err) } vc := &variantCallPacket{} if err := vc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") { t.Fatalf("variant tid err=%v", err) } if err := vc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "discovery command object") { t.Fatalf("variant discovery err=%v", err) } if err := vc.UnmarshalBinary(append(append([]byte{}, base...), byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal command object") { t.Fatalf("variant command object err=%v", err) } call := NewCallPacket() call.CommandName = commandOnStatus call.TransactionID = 0 call.CommandObject = NewAmf0Null() callBase := mustPacketBytes(t, call) if err := NewCallPacket().UnmarshalBinary(append(callBase, 0xff)); err == nil || !strings.Contains(err.Error(), "discovery args") { t.Fatalf("call discovery args err=%v", err) } if err := NewCallPacket().UnmarshalBinary(append(callBase, byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal args") { t.Fatalf("call unmarshal args err=%v", err) } csr := NewCreateStreamResPacket(2) if err := NewCreateStreamResPacket(0).UnmarshalBinary(mustPacketBytes(t, &csr.variantCallPacket)); err == nil || !strings.Contains(err.Error(), "unmarshal sid") { t.Fatalf("create stream sid err=%v", err) } pub := NewPublishPacket() pub.TransactionID = 0 pubPrefix, _ := pub.variantCallPacket.MarshalBinary() if err := NewPublishPacket().UnmarshalBinary(append(pubPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") { t.Fatalf("publish stream name err=%v", err) } streamName, _ := NewAmf0String("s").MarshalBinary() if err := NewPublishPacket().UnmarshalBinary(append(append(pubPrefix, streamName...), 0xff)); err == nil || !strings.Contains(err.Error(), "stream type") { t.Fatalf("publish stream type err=%v", err) } play := NewPlayPacket() play.TransactionID = 0 playPrefix, _ := play.variantCallPacket.MarshalBinary() if err := NewPlayPacket().UnmarshalBinary(append(playPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") { t.Fatalf("play stream name err=%v", err) } }