Codex: Expose RTMP AMF0 interfaces.

Add public AMF0 and RTMP protocol interfaces, update the RTMP proxy to use the accessor APIs, and cover AMF0 encoding/decoding with unit tests and examples.
This commit is contained in:
winlin 2026-04-29 11:26:40 -04:00
parent a76a982563
commit 9b6842da9a
5 changed files with 924 additions and 174 deletions

View File

@ -95,7 +95,7 @@ var createBuffer = func() amf0Buffer {
}
// All AMF0 things.
type amf0Any interface {
type Amf0Any interface {
// Binary marshaler and unmarshaler.
encoding.BinaryUnmarshaler
encoding.BinaryMarshaler
@ -106,59 +106,83 @@ type amf0Any interface {
amf0Marker() amf0Marker
}
type amf0Converter struct {
from amf0Any
type Amf0Converter interface {
ToNumber() Amf0Number
ToBoolean() Amf0Boolean
ToString() Amf0String
ToObject() Amf0Object
ToNull() Amf0Null
ToUndefined() Amf0Undefined
ToEcmaArray() Amf0EcmaArray
ToStrictArray() Amf0StrictArray
}
func NewAmf0Converter(from amf0Any) *amf0Converter {
type amf0Converter struct {
from Amf0Any
}
func NewAmf0Converter(from Amf0Any) Amf0Converter {
return &amf0Converter{from: from}
}
func (v *amf0Converter) ToNumber() *amf0Number {
return amf0AnyTo[*amf0Number](v.from)
}
func (v *amf0Converter) ToBoolean() *amf0Boolean {
return amf0AnyTo[*amf0Boolean](v.from)
}
func (v *amf0Converter) ToString() *amf0String {
return amf0AnyTo[*amf0String](v.from)
}
func (v *amf0Converter) ToObject() *amf0Object {
return amf0AnyTo[*amf0Object](v.from)
}
func (v *amf0Converter) ToNull() *amf0Null {
return amf0AnyTo[*amf0Null](v.from)
}
func (v *amf0Converter) ToUndefined() *amf0Undefined {
return amf0AnyTo[*amf0Undefined](v.from)
}
func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray {
return amf0AnyTo[*amf0EcmaArray](v.from)
}
func (v *amf0Converter) ToStrictArray() *amf0StrictArray {
return amf0AnyTo[*amf0StrictArray](v.from)
}
// Convert any to specified object.
func amf0AnyTo[T amf0Any](a amf0Any) T {
var to T
if a != nil {
if v, ok := a.(T); ok {
return v
}
func (v *amf0Converter) ToNumber() Amf0Number {
if r, ok := v.from.(Amf0Number); ok {
return r
}
return to
return nil
}
func (v *amf0Converter) ToBoolean() Amf0Boolean {
if r, ok := v.from.(Amf0Boolean); ok {
return r
}
return nil
}
func (v *amf0Converter) ToString() Amf0String {
if r, ok := v.from.(Amf0String); ok {
return r
}
return nil
}
func (v *amf0Converter) ToObject() Amf0Object {
if r, ok := v.from.(Amf0Object); ok {
return r
}
return nil
}
func (v *amf0Converter) ToNull() Amf0Null {
if r, ok := v.from.(Amf0Null); ok {
return r
}
return nil
}
func (v *amf0Converter) ToUndefined() Amf0Undefined {
if r, ok := v.from.(Amf0Undefined); ok {
return r
}
return nil
}
func (v *amf0Converter) ToEcmaArray() Amf0EcmaArray {
if r, ok := v.from.(Amf0EcmaArray); ok {
return r
}
return nil
}
func (v *amf0Converter) ToStrictArray() Amf0StrictArray {
if r, ok := v.from.(Amf0StrictArray); ok {
return r
}
return nil
}
// Discovery the amf0 object from the bytes b.
func Amf0Discovery(p []byte) (a amf0Any, err error) {
func Amf0Discovery(p []byte) (a Amf0Any, err error) {
if len(p) < 1 {
return nil, errors.Errorf("require 1 bytes only %v", len(p))
}
@ -228,14 +252,24 @@ func (v *amf0UTF8) MarshalBinary() (data []byte, err error) {
return
}
// Amf0Number is the AMF0 number type.
type Amf0Number interface {
Amf0Any
Float64() float64
}
// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type
type amf0Number float64
func NewAmf0Number(f float64) *amf0Number {
func NewAmf0Number(f float64) Amf0Number {
v := amf0Number(f)
return &v
}
func (v *amf0Number) Float64() float64 {
return float64(*v)
}
func (v *amf0Number) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
@ -266,14 +300,28 @@ func (v *amf0Number) MarshalBinary() (data []byte, err error) {
return
}
// Amf0String is the AMF0 string type.
type Amf0String interface {
Amf0Any
String() string
}
// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type
type amf0String string
func NewAmf0String(s string) *amf0String {
func NewAmf0String(s string) Amf0String {
return newAmf0String(s)
}
func newAmf0String(s string) *amf0String {
v := amf0String(s)
return &v
}
func (v *amf0String) String() string {
return string(*v)
}
func (v *amf0String) amf0Marker() amf0Marker {
return amf0MarkerString
}
@ -344,7 +392,7 @@ func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) {
// Use array for object and ecma array, to keep the original order.
type amf0Property struct {
key amf0UTF8
value amf0Any
value Amf0Any
}
// The object-like AMF0 structure, like object and ecma array and strict array.
@ -367,7 +415,7 @@ func (v *amf0ObjectBase) Size() int {
return size
}
func (v *amf0ObjectBase) Get(key string) amf0Any {
func (v *amf0ObjectBase) Get(key string) Amf0Any {
v.lock.Lock()
defer v.lock.Unlock()
@ -380,7 +428,7 @@ func (v *amf0ObjectBase) Get(key string) amf0Any {
return nil
}
func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase {
func (v *amf0ObjectBase) Set(key string, value Amf0Any) *amf0ObjectBase {
v.lock.Lock()
defer v.lock.Unlock()
@ -411,21 +459,21 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error)
return errors.Errorf("maxElems=%v with eof", maxElems)
}
readOne := func() (amf0UTF8, amf0Any, error) {
readOne := func() (amf0UTF8, Amf0Any, error) {
var u amf0UTF8
if err = u.UnmarshalBinary(p); err != nil {
return "", nil, errors.WithMessage(err, "prop name")
}
p = p[u.Size():]
var a amf0Any
var a Amf0Any
if a, err = Amf0Discovery(p); err != nil {
return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u)))
}
return u, a, nil
}
pushOne := func(u amf0UTF8, a amf0Any) error {
pushOne := func(u amf0UTF8, a Amf0Any) error {
// For object property, consume the whole bytes.
if err = a.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u)))
@ -494,13 +542,24 @@ func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) {
return
}
// Amf0Object is the AMF0 object type.
type Amf0Object interface {
Amf0Any
Get(key string) Amf0Any
Set(key string, value Amf0Any) Amf0Object
}
// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type
type amf0Object struct {
amf0ObjectBase
eof amf0ObjectEOF
}
func NewAmf0Object() *amf0Object {
func NewAmf0Object() Amf0Object {
return newAmf0Object()
}
func newAmf0Object() *amf0Object {
v := &amf0Object{}
v.properties = []*amf0Property{}
return v
@ -510,6 +569,15 @@ func (v *amf0Object) amf0Marker() amf0Marker {
return amf0MarkerObject
}
func (v *amf0Object) Get(key string) Amf0Any {
return v.amf0ObjectBase.Get(key)
}
func (v *amf0Object) Set(key string, value Amf0Any) Amf0Object {
v.amf0ObjectBase.Set(key, value)
return v
}
func (v *amf0Object) Size() int {
return int(1) + v.eof.Size() + v.amf0ObjectBase.Size()
}
@ -542,17 +610,22 @@ func (v *amf0Object) MarshalBinary() (data []byte, err error) {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
if pb, err := v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
} else if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// Amf0EcmaArray is the AMF0 ECMA array type.
type Amf0EcmaArray interface {
Amf0Any
Get(key string) Amf0Any
Set(key string, value Amf0Any) Amf0EcmaArray
}
// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type
type amf0EcmaArray struct {
amf0ObjectBase
@ -560,7 +633,11 @@ type amf0EcmaArray struct {
eof amf0ObjectEOF
}
func NewAmf0EcmaArray() *amf0EcmaArray {
func NewAmf0EcmaArray() Amf0EcmaArray {
return newAmf0EcmaArray()
}
func newAmf0EcmaArray() *amf0EcmaArray {
v := &amf0EcmaArray{}
v.properties = []*amf0Property{}
return v
@ -570,6 +647,15 @@ func (v *amf0EcmaArray) amf0Marker() amf0Marker {
return amf0MarkerEcmaArray
}
func (v *amf0EcmaArray) Get(key string) Amf0Any {
return v.amf0ObjectBase.Get(key)
}
func (v *amf0EcmaArray) Set(key string, value Amf0Any) Amf0EcmaArray {
v.amf0ObjectBase.Set(key, value)
return v
}
func (v *amf0EcmaArray) Size() int {
return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size()
}
@ -606,24 +692,29 @@ func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
if pb, err := v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
} else if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// Amf0StrictArray is the AMF0 strict array type.
type Amf0StrictArray interface {
Amf0Any
Get(key string) Amf0Any
Set(key string, value Amf0Any) Amf0StrictArray
}
// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type
type amf0StrictArray struct {
amf0ObjectBase
count uint32
}
func NewAmf0StrictArray() *amf0StrictArray {
func NewAmf0StrictArray() Amf0StrictArray {
v := &amf0StrictArray{}
v.properties = []*amf0Property{}
return v
@ -633,6 +724,15 @@ func (v *amf0StrictArray) amf0Marker() amf0Marker {
return amf0MarkerStrictArray
}
func (v *amf0StrictArray) Get(key string) Amf0Any {
return v.amf0ObjectBase.Get(key)
}
func (v *amf0StrictArray) Set(key string, value Amf0Any) Amf0StrictArray {
v.amf0ObjectBase.Set(key, value)
return v
}
func (v *amf0StrictArray) Size() int {
return int(1) + 4 + v.amf0ObjectBase.Size()
}
@ -708,36 +808,56 @@ func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) {
return []byte{byte(v.target)}, nil
}
// Amf0Null is the AMF0 null type.
type Amf0Null interface {
Amf0Any
}
// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type
type amf0Null struct {
amf0SingleMarkerObject
}
func NewAmf0Null() *amf0Null {
func NewAmf0Null() Amf0Null {
v := amf0Null{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull)
return &v
}
// Amf0Undefined is the AMF0 undefined type.
type Amf0Undefined interface {
Amf0Any
}
// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type
type amf0Undefined struct {
amf0SingleMarkerObject
}
func NewAmf0Undefined() amf0Any {
func NewAmf0Undefined() Amf0Undefined {
v := amf0Undefined{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined)
return &v
}
// Amf0Boolean is the public typed view of an AMF0 boolean.
type Amf0Boolean interface {
Amf0Any
Bool() bool
}
// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type
type amf0Boolean bool
func NewAmf0Boolean(b bool) amf0Any {
func NewAmf0Boolean(b bool) Amf0Boolean {
v := amf0Boolean(b)
return &v
}
func (v *amf0Boolean) Bool() bool {
return bool(*v)
}
func (v *amf0Boolean) amf0Marker() amf0Marker {
return amf0MarkerBoolean
}

509
internal/rtmp/amf0_test.go Normal file
View File

@ -0,0 +1,509 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
import (
"bytes"
"fmt"
"math"
"strings"
"testing"
)
func TestAmf0MarkerString(t *testing.T) {
for _, tt := range []struct {
marker amf0Marker
want string
}{
{amf0MarkerNumber, "Amf0Number"},
{amf0MarkerBoolean, "amf0Boolean"},
{amf0MarkerString, "Amf0String"},
{amf0MarkerObject, "Amf0Object"},
{amf0MarkerMovieClip, "MovieClip"},
{amf0MarkerNull, "Null"},
{amf0MarkerUndefined, "Undefined"},
{amf0MarkerReference, "Reference"},
{amf0MarkerEcmaArray, "EcmaArray"},
{amf0MarkerObjectEnd, "ObjectEnd"},
{amf0MarkerStrictArray, "StrictArray"},
{amf0MarkerDate, "Date"},
{amf0MarkerLongString, "LongString"},
{amf0MarkerUnsupported, "Unsupported"},
{amf0MarkerRecordSet, "RecordSet"},
{amf0MarkerXmlDocument, "XmlDocument"},
{amf0MarkerTypedObject, "TypedObject"},
{amf0MarkerAvmPlusObject, "AvmPlusObject"},
{amf0MarkerForbidden, "Forbidden"},
{amf0Marker(0xee), "Forbidden"},
} {
if got := tt.marker.String(); got != tt.want {
t.Fatalf("marker=%#x String()=%v, want %v", byte(tt.marker), got, tt.want)
}
}
}
func TestAmf0Discovery(t *testing.T) {
for _, tt := range []struct {
name string
data []byte
ok func(Amf0Any) bool
}{
{"number", []byte{byte(amf0MarkerNumber)}, func(v Amf0Any) bool { _, ok := v.(Amf0Number); return ok }},
{"boolean", []byte{byte(amf0MarkerBoolean)}, func(v Amf0Any) bool { _, ok := v.(Amf0Boolean); return ok }},
{"string", []byte{byte(amf0MarkerString)}, func(v Amf0Any) bool { _, ok := v.(Amf0String); return ok }},
{"object", []byte{byte(amf0MarkerObject)}, func(v Amf0Any) bool { _, ok := v.(Amf0Object); return ok }},
{"null", []byte{byte(amf0MarkerNull)}, func(v Amf0Any) bool { _, ok := v.(Amf0Null); return ok }},
{"undefined", []byte{byte(amf0MarkerUndefined)}, func(v Amf0Any) bool { _, ok := v.(Amf0Undefined); return ok }},
{"ecma-array", []byte{byte(amf0MarkerEcmaArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0EcmaArray); return ok }},
{"object-end", []byte{byte(amf0MarkerObjectEnd)}, func(v Amf0Any) bool { _, ok := v.(*amf0ObjectEOF); return ok }},
{"strict-array", []byte{byte(amf0MarkerStrictArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0StrictArray); return ok }},
} {
t.Run(tt.name, func(t *testing.T) {
value, err := Amf0Discovery(tt.data)
if err != nil {
t.Fatalf("Amf0Discovery() err=%v", err)
}
if !tt.ok(value) {
t.Fatalf("Amf0Discovery()=%T", value)
}
})
}
for _, data := range [][]byte{{}, {byte(amf0MarkerReference)}, {byte(amf0MarkerDate)}, {byte(amf0MarkerForbidden)}} {
if value, err := Amf0Discovery(data); err == nil || value != nil {
t.Fatalf("Amf0Discovery(%v) value=%T, err=%v, want error", data, value, err)
}
}
}
func TestAmf0Converter(t *testing.T) {
values := []struct {
name string
in Amf0Any
ok func(Amf0Converter) bool
}{
{"number", NewAmf0Number(1), func(c Amf0Converter) bool { return c.ToNumber() != nil }},
{"boolean", NewAmf0Boolean(true), func(c Amf0Converter) bool { return c.ToBoolean() != nil }},
{"string", NewAmf0String("v"), func(c Amf0Converter) bool { return c.ToString() != nil }},
{"object", NewAmf0Object(), func(c Amf0Converter) bool { return c.ToObject() != nil }},
{"null", NewAmf0Null(), func(c Amf0Converter) bool { return c.ToNull() != nil }},
{"undefined", NewAmf0Undefined(), func(c Amf0Converter) bool { return c.ToUndefined() != nil }},
{"ecma-array", NewAmf0EcmaArray(), func(c Amf0Converter) bool { return c.ToEcmaArray() != nil }},
{"strict-array", NewAmf0StrictArray(), func(c Amf0Converter) bool { return c.ToStrictArray() != nil }},
}
for _, tt := range values {
t.Run(tt.name, func(t *testing.T) {
converter := NewAmf0Converter(tt.in)
if !tt.ok(converter) {
t.Fatalf("expected successful conversion for %T", tt.in)
}
})
}
nilConverter := NewAmf0Converter(nil)
if nilConverter.ToNumber() != nil || nilConverter.ToBoolean() != nil || nilConverter.ToString() != nil ||
nilConverter.ToObject() != nil || nilConverter.ToNull() != nil || nilConverter.ToUndefined() != nil ||
nilConverter.ToEcmaArray() != nil || nilConverter.ToStrictArray() != nil {
t.Fatal("nil converter should not convert")
}
}
func TestAmf0UTF8(t *testing.T) {
var value amf0UTF8 = "hello"
b, err := value.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
if value.Size() != len(b) {
t.Fatalf("Size()=%v, len=%v", value.Size(), len(b))
}
var decoded amf0UTF8
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if decoded != value {
t.Fatalf("decoded=%v, want %v", decoded, value)
}
for _, data := range [][]byte{{0x00}, {0x00, 0x05, 'h'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0Number(t *testing.T) {
number := NewAmf0Number(math.Pi)
if number.Size() != 9 || number.(*amf0Number).amf0Marker() != amf0MarkerNumber {
t.Fatalf("unexpected number metadata")
}
b, err := number.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0Number(0)
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := decoded.Float64(); got != math.Pi {
t.Fatalf("Float64()=%v, want %v", got, math.Pi)
}
for _, data := range [][]byte{{byte(amf0MarkerNumber)}, append([]byte{byte(amf0MarkerString)}, b[1:]...)} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0Boolean(t *testing.T) {
for _, want := range []bool{false, true} {
boolean := NewAmf0Boolean(want)
if boolean.Size() != 2 || boolean.(*amf0Boolean).amf0Marker() != amf0MarkerBoolean {
t.Fatalf("unexpected boolean metadata")
}
b, err := boolean.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0Boolean(!want)
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := decoded.Bool(); got != want {
t.Fatalf("Bool()=%v, want %v", got, want)
}
}
decoded := NewAmf0Boolean(false)
for _, data := range [][]byte{{byte(amf0MarkerBoolean)}, {byte(amf0MarkerNumber), 1}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0String(t *testing.T) {
value := NewAmf0String("hello")
if value.Size() != 8 || value.(*amf0String).amf0Marker() != amf0MarkerString {
t.Fatalf("unexpected string metadata")
}
b, err := value.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0String("")
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := decoded.String(); got != "hello" {
t.Fatalf("String()=%v, want hello", got)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerNumber), 0, 0}, {byte(amf0MarkerString), 0, 5, 'h'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0ObjectEOF(t *testing.T) {
eof := &amf0ObjectEOF{}
if eof.Size() != 3 || eof.amf0Marker() != amf0MarkerObjectEnd {
t.Fatalf("unexpected eof metadata")
}
b, err := eof.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
if !bytes.Equal(b, []byte{0, 0, 9}) {
t.Fatalf("MarshalBinary()=%v", b)
}
for _, data := range [][]byte{b, {0, 0, 9, 1}} {
if err := eof.UnmarshalBinary(data); err != nil {
t.Fatalf("UnmarshalBinary(%v) err=%v", data, err)
}
}
for _, data := range [][]byte{{0, 0}, {0, 1, 9}} {
if err := eof.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0Object(t *testing.T) {
object := NewAmf0Object().
Set("name", NewAmf0String("stream")).
Set("code", NewAmf0Number(100)).
Set("ok", NewAmf0Boolean(true))
object.Set("code", NewAmf0Number(200))
if object.(*amf0Object).amf0Marker() != amf0MarkerObject || object.Size() == 0 {
t.Fatalf("unexpected object metadata")
}
if object.Get("missing") != nil {
t.Fatal("missing property should be nil")
}
b, err := object.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0Object()
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" {
t.Fatalf("name=%v", got)
}
if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 200 {
t.Fatalf("code=%v", got)
}
if got := NewAmf0Converter(decoded.Get("ok")).ToBoolean().Bool(); !got {
t.Fatalf("ok=%v", got)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerString)}, {byte(amf0MarkerObject), 0, 4, 'n'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
base := &amf0ObjectBase{}
if err := base.unmarshal(nil, false, -1); err == nil {
t.Fatal("unmarshal without eof and negative maxElems should fail")
}
if err := base.unmarshal(nil, true, 0); err == nil {
t.Fatal("unmarshal with eof and non-negative maxElems should fail")
}
}
func TestAmf0EcmaArray(t *testing.T) {
array := NewAmf0EcmaArray().
Set("name", NewAmf0String("stream")).
Set("code", NewAmf0Number(100))
if array.(*amf0EcmaArray).amf0Marker() != amf0MarkerEcmaArray || array.Size() == 0 {
t.Fatalf("unexpected ecma array metadata")
}
if array.Get("missing") != nil {
t.Fatal("missing property should be nil")
}
b, err := array.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0EcmaArray()
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" {
t.Fatalf("name=%v", got)
}
if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 {
t.Fatalf("code=%v", got)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerEcmaArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerEcmaArray), 0, 0, 0, 0, 0, 4, 'n'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0StrictArray(t *testing.T) {
array := NewAmf0StrictArray().
Set("name", NewAmf0String("stream")).
Set("code", NewAmf0Number(100))
array.(*amf0StrictArray).count = 2
if array.(*amf0StrictArray).amf0Marker() != amf0MarkerStrictArray || array.Size() == 0 {
t.Fatalf("unexpected strict array metadata")
}
if array.Get("missing") != nil {
t.Fatal("missing property should be nil")
}
b, err := array.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0StrictArray()
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" {
t.Fatalf("name=%v", got)
}
if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 {
t.Fatalf("code=%v", got)
}
empty := append([]byte{byte(amf0MarkerStrictArray)}, 0, 0, 0, 0)
if err := decoded.UnmarshalBinary(empty); err != nil {
t.Fatalf("UnmarshalBinary(empty) err=%v", err)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerStrictArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerStrictArray), 0, 0, 0, 1, 0, 4, 'n'}} {
if err := NewAmf0StrictArray().UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0SingleMarkerObjects(t *testing.T) {
for _, tt := range []struct {
name string
value Amf0Any
marker amf0Marker
}{
{"null", NewAmf0Null(), amf0MarkerNull},
{"undefined", NewAmf0Undefined(), amf0MarkerUndefined},
} {
t.Run(tt.name, func(t *testing.T) {
if tt.value.Size() != 1 || tt.value.amf0Marker() != tt.marker {
t.Fatalf("unexpected metadata")
}
b, err := tt.value.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
if err := tt.value.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerString)}} {
if err := tt.value.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
})
}
}
type errorAmf0Buffer struct {
writeByteErr bool
writeErr bool
}
func (v *errorAmf0Buffer) Bytes() []byte {
return nil
}
func (v *errorAmf0Buffer) WriteByte(byte) error {
if v.writeByteErr {
return fmt.Errorf("write byte")
}
return nil
}
func (v *errorAmf0Buffer) Write([]byte) (int, error) {
if v.writeErr {
return 0, fmt.Errorf("write")
}
return 0, nil
}
type errorAmf0Any struct {
Amf0Any
}
func (v *errorAmf0Any) Size() int {
return 1
}
func (v *errorAmf0Any) MarshalBinary() ([]byte, error) {
return nil, fmt.Errorf("marshal")
}
func (v *errorAmf0Any) UnmarshalBinary([]byte) error {
return nil
}
func (v *errorAmf0Any) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
func TestAmf0MarshalErrors(t *testing.T) {
originalCreateBuffer := createBuffer
defer func() { createBuffer = originalCreateBuffer }()
for _, tt := range []struct {
name string
make func() Amf0Any
}{
{"object", func() Amf0Any { return NewAmf0Object() }},
{"ecma-array", func() Amf0Any { return NewAmf0EcmaArray() }},
{"strict-array", func() Amf0Any { return NewAmf0StrictArray() }},
} {
t.Run(tt.name+" write-byte", func(t *testing.T) {
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} }
if _, err := tt.make().MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
t.Run(tt.name+" write-prop", func(t *testing.T) {
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} }
value := tt.make()
switch v := value.(type) {
case Amf0Object:
v.Set("name", NewAmf0String("stream"))
case Amf0EcmaArray:
v.Set("name", NewAmf0String("stream"))
case Amf0StrictArray:
v.Set("name", NewAmf0String("stream"))
v.(*amf0StrictArray).count = 1
}
if _, err := value.MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
}
createBuffer = originalCreateBuffer
for _, tt := range []struct {
name string
make func() Amf0Any
}{
{"object", func() Amf0Any { return NewAmf0Object().Set("bad", &errorAmf0Any{}) }},
{"ecma-array", func() Amf0Any { return NewAmf0EcmaArray().Set("bad", &errorAmf0Any{}) }},
{"strict-array", func() Amf0Any {
value := NewAmf0StrictArray().Set("bad", &errorAmf0Any{})
value.(*amf0StrictArray).count = 1
return value
}},
} {
t.Run(tt.name+" marshal-value", func(t *testing.T) {
if _, err := tt.make().MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
}
}
func TestAmf0UnmarshalNestedErrors(t *testing.T) {
// Object property with unsupported marker.
data := []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerDate)}
if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "discover prop bad") {
t.Fatalf("err=%v, want discover prop bad", err)
}
// Object property with invalid payload size.
data = []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerNumber), 0}
if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "unmarshal prop bad") {
t.Fatalf("err=%v, want unmarshal prop bad", err)
}
}

View File

@ -0,0 +1,62 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp_test
import (
"fmt"
"srsx/internal/rtmp"
)
func ExampleAmf0Number() {
number := rtmp.NewAmf0Number(3.14)
b, err := number.MarshalBinary()
if err != nil {
panic(err)
}
value, err := rtmp.Amf0Discovery(b)
if err != nil {
panic(err)
}
if err := value.UnmarshalBinary(b); err != nil {
panic(err)
}
converter := rtmp.NewAmf0Converter(value)
fmt.Println("number:", converter.ToNumber().Float64())
fmt.Println("is string:", converter.ToString() != nil)
// Output:
// number: 3.14
// is string: false
}
func ExampleAmf0Object() {
object := rtmp.NewAmf0Object().
Set("code", rtmp.NewAmf0Number(100)).
Set("level", rtmp.NewAmf0String("status"))
b, err := object.MarshalBinary()
if err != nil {
panic(err)
}
value, err := rtmp.Amf0Discovery(b)
if err != nil {
panic(err)
}
if err := value.UnmarshalBinary(b); err != nil {
panic(err)
}
converter := rtmp.NewAmf0Converter(value)
fmt.Println("code:", rtmp.NewAmf0Converter(converter.ToObject().Get("code")).ToNumber().Float64())
fmt.Println("level:", rtmp.NewAmf0Converter(converter.ToObject().Get("level")).ToString().String())
fmt.Println("is number:", converter.ToNumber() != nil)
// Output:
// code: 100
// level: status
// is number: false
}

View File

@ -17,21 +17,31 @@ import (
"srsx/internal/errors"
)
// The handshake implements the RTMP handshake protocol.
type Handshake struct {
// Handshake implements the RTMP handshake protocol.
type Handshake interface {
C1S1() []byte
WriteC0S0(w io.Writer) error
ReadC0S0(r io.Reader) ([]byte, error)
WriteC1S1(w io.Writer) error
ReadC1S1(r io.Reader) ([]byte, error)
WriteC2S2(w io.Writer, s1c1 []byte) error
ReadC2S2(r io.Reader) ([]byte, error)
}
type handshake struct {
// The c1s1 cache.
c1s1 []byte
}
func NewHandshake() *Handshake {
return &Handshake{}
func NewHandshake() Handshake {
return &handshake{}
}
func (v *Handshake) C1S1() []byte {
func (v *handshake) C1S1() []byte {
return v.c1s1
}
func (v *Handshake) WriteC0S0(w io.Writer) (err error) {
func (v *handshake) WriteC0S0(w io.Writer) (err error) {
r := bytes.NewReader([]byte{0x03})
if _, err = io.Copy(w, r); err != nil {
return errors.Wrap(err, "write c0s0")
@ -40,7 +50,7 @@ func (v *Handshake) WriteC0S0(w io.Writer) (err error) {
return
}
func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) {
func (v *handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) {
b := &bytes.Buffer{}
if _, err = io.CopyN(b, r, 1); err != nil {
return nil, errors.Wrap(err, "read c0s0")
@ -51,7 +61,7 @@ func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) {
return
}
func (v *Handshake) WriteC1S1(w io.Writer) (err error) {
func (v *handshake) WriteC1S1(w io.Writer) (err error) {
p := make([]byte, 1536)
// Use crypto/rand for thread-safe random generation
@ -67,7 +77,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) {
return
}
func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) {
func (v *handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) {
b := &bytes.Buffer{}
if _, err = io.CopyN(b, r, 1536); err != nil {
return nil, errors.Wrap(err, "read c1s1")
@ -79,7 +89,7 @@ func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) {
return
}
func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) {
func (v *handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) {
r := bytes.NewReader(s1c1[:])
if _, err = io.Copy(w, r); err != nil {
return errors.Wrap(err, "write c2s2")
@ -88,7 +98,7 @@ func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) {
return
}
func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) {
func (v *handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) {
b := &bytes.Buffer{}
if _, err = io.CopyN(b, r, 1536); err != nil {
return nil, errors.Wrap(err, "read c2s2")
@ -129,7 +139,7 @@ type chunkStream struct {
format formatType
cid chunkID
header messageHeader
message *Message
message *message
count uint64
extendedTimestamp bool
}
@ -138,8 +148,18 @@ func newChunkStream() *chunkStream {
return &chunkStream{}
}
// The protocol implements the RTMP command and chunk stack.
type Protocol struct {
// Protocol implements the RTMP command and chunk stack.
type Protocol interface {
// Deprecated: Please use rtmp.ExpectPacket instead.
ExpectPacket(ctx context.Context, ppkt any) (Message, error)
ExpectMessage(ctx context.Context, types ...MessageType) (Message, error)
DecodeMessage(m Message) (Packet, error)
ReadMessage(ctx context.Context) (Message, error)
WritePacket(ctx context.Context, pkt Packet, streamID int) error
WriteMessage(ctx context.Context, m Message) error
}
type protocol struct {
r *bufio.Reader
w *bufio.Writer
input struct {
@ -154,8 +174,8 @@ type Protocol struct {
}
}
func NewProtocol(rw io.ReadWriter) *Protocol {
v := &Protocol{
func NewProtocol(rw io.ReadWriter) Protocol {
v := &protocol{
r: bufio.NewReader(rw),
w: bufio.NewWriter(rw),
}
@ -169,7 +189,7 @@ func NewProtocol(rw io.ReadWriter) *Protocol {
return v
}
func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) {
func ExpectPacket[T Packet](ctx context.Context, v Protocol, ppkt *T) (m Message, err error) {
for {
if m, err = v.ReadMessage(ctx); err != nil {
return nil, errors.WithMessage(err, "read message")
@ -190,11 +210,11 @@ func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Messa
}
// Deprecated: Please use rtmp.ExpectPacket instead.
func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) {
func (v *protocol) ExpectPacket(ctx context.Context, ppkt any) (m Message, err error) {
panic("Please use rtmp.ExpectPacket instead")
}
func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) {
func (v *protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m Message, err error) {
for {
if m, err = v.ReadMessage(ctx); err != nil {
return nil, errors.WithMessage(err, "read message")
@ -205,7 +225,7 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *
}
for _, t := range types {
if m.MessageType == t {
if m.MessageType() == t {
return
}
}
@ -214,7 +234,7 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *
return
}
func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) {
func (v *protocol) parseAMFObject(p []byte) (pkt Packet, err error) {
var commandName amf0String
if err = commandName.UnmarshalBinary(p); err != nil {
return nil, errors.WithMessage(err, "unmarshal command name")
@ -266,18 +286,18 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) {
}
}
func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) {
p := m.Payload[:]
func (v *protocol) DecodeMessage(m Message) (pkt Packet, err error) {
p := m.Payload()[:]
if len(p) == 0 {
return nil, errors.New("Empty packet")
}
switch m.MessageType {
switch m.MessageType() {
case MessageTypeAMF3Command, MessageTypeAMF3Data:
p = p[1:]
}
switch m.MessageType {
switch m.MessageType() {
case MessageTypeSetChunkSize:
pkt = NewSetChunkSize()
case MessageTypeWindowAcknowledgementSize:
@ -286,22 +306,22 @@ func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) {
pkt = NewSetPeerBandwidth()
case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data:
if pkt, err = v.parseAMFObject(p); err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType))
return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType()))
}
case MessageTypeUserControl:
pkt = NewUserControl()
default:
return nil, errors.Errorf("Unknown message %v", m.MessageType)
return nil, errors.Errorf("Unknown message %v", m.MessageType())
}
if err = pkt.UnmarshalBinary(p); err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType))
return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType()))
}
return
}
func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) {
func (v *protocol) ReadMessage(ctx context.Context) (m Message, err error) {
for m == nil {
// TODO: We should convert buffered io to async io, because we will be stuck in block io here,
// TODO: but the risk is acceptable because we literally will set the underlay io timeout.
@ -331,15 +351,17 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) {
return nil, errors.WithMessage(err, "read message payload")
}
if err = v.onMessageArrivated(m); err != nil {
return nil, errors.WithMessage(err, "on message")
if m != nil {
if err = v.onMessageArrivated(m.asMessage()); err != nil {
return nil, errors.WithMessage(err, "on message")
}
}
}
return
}
func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) {
func (v *protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m Message, err error) {
// Empty payload message.
if chunk.message.payloadLength == 0 {
m = chunk.message
@ -348,7 +370,7 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (
}
// Calculate the chunk payload size.
chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload)
chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.payload)
if chunkedPayloadSize > int(v.input.opt.chunkSize) {
chunkedPayloadSize = int(v.input.opt.chunkSize)
}
@ -357,10 +379,10 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (
if _, err = io.ReadFull(v.r, b); err != nil {
return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize)
}
chunk.message.Payload = append(chunk.message.Payload, b...)
chunk.message.payload = append(chunk.message.payload, b...)
// Got entire RTMP message?
if int(chunk.message.payloadLength) == len(chunk.message.Payload) {
if int(chunk.message.payloadLength) == len(chunk.message.payload) {
m = chunk.message
chunk.message = nil
}
@ -426,7 +448,7 @@ var messageHeaderSizes = []int{11, 7, 3, 0}
// fmt=1, 0x4X
// fmt=2, 0x8X
// fmt=3, 0xCX
func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) {
func (v *protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) {
// We should not assert anything about fmt, for the first packet.
// (when first packet, the chunk.message is nil).
// the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet.
@ -480,7 +502,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo
// Create msg when new chunk stream start
if chunk.message == nil {
chunk.message = NewMessage()
chunk.message = newMessage()
}
// Read the message header.
@ -659,7 +681,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo
//
// Chunk stream IDs with values 64-319 could be represented by both 2-
// byte version and 3-byte version of this field.
func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) {
func (v *protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) {
// 2-63, 1B chunk header
var t uint8
if err = binary.Read(v.r, binary.BigEndian, &t); err != nil {
@ -689,14 +711,14 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid
return
}
func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) {
m := NewMessage()
func (v *protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) {
m := newMessage()
if m.Payload, err = pkt.MarshalBinary(); err != nil {
if m.payload, err = pkt.MarshalBinary(); err != nil {
return errors.WithMessage(err, "marshal payload")
}
m.MessageType = pkt.Type()
m.messageHeader.MessageType = pkt.Type()
m.streamID = uint32(streamID)
m.betterCid = pkt.BetterCid()
@ -711,7 +733,7 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e
return
}
func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) {
func (v *protocol) onPacketWriten(m *message, pkt Packet) (err error) {
var tid amf0Number
var name amf0String
@ -734,16 +756,16 @@ func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) {
return
}
func (v *Protocol) onMessageArrivated(m *Message) (err error) {
func (v *protocol) onMessageArrivated(m *message) (err error) {
if m == nil {
return
}
var pkt Packet
switch m.MessageType {
switch m.MessageType() {
case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize:
if pkt, err = v.DecodeMessage(m); err != nil {
return errors.Errorf("decode message %v", m.MessageType)
return errors.Errorf("decode message %v", m.MessageType())
}
}
@ -755,19 +777,20 @@ func (v *Protocol) onMessageArrivated(m *Message) (err error) {
return
}
func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) {
m.payloadLength = uint32(len(m.Payload))
func (v *protocol) WriteMessage(ctx context.Context, m Message) (err error) {
msg := m.asMessage()
msg.payloadLength = uint32(len(msg.payload))
var c0h, c3h []byte
if c0h, err = m.generateC0Header(); err != nil {
if c0h, err = msg.generateC0Header(); err != nil {
return errors.WithMessage(err, "generate c0 header")
}
if c3h, err = m.generateC3Header(); err != nil {
if c3h, err = msg.generateC3Header(); err != nil {
return errors.WithMessage(err, "generate c3 header")
}
var h []byte
p := m.Payload
p := msg.payload
for len(p) > 0 {
// TODO: We should convert buffered io to async io, because we will be stuck in block io here,
// TODO: but the risk is acceptable because we literally will set the underlay io timeout.
@ -899,29 +922,56 @@ type messageHeader struct {
Timestamp uint64
}
// The RTMP message, transport over chunk stream in RTMP.
// Message is an RTMP message transported over a chunk stream.
// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header
type Message struct {
type Message interface {
MessageType() MessageType
Timestamp() uint64
Payload() []byte
asMessage() *message
}
type message struct {
messageHeader
// The payload which carries the RTMP packet.
Payload []byte
payload []byte
}
func NewMessage() *Message {
return &Message{}
func NewMessage() Message {
return newMessage()
}
func NewStreamMessage(streamID int) *Message {
v := NewMessage()
func newMessage() *message {
return &message{}
}
func NewStreamMessage(streamID int) Message {
v := newMessage()
v.streamID = uint32(streamID)
v.betterCid = chunkIDOverStream
return v
}
func (v *Message) generateC3Header() ([]byte, error) {
func (v *message) MessageType() MessageType {
return v.messageHeader.MessageType
}
func (v *message) Timestamp() uint64 {
return v.messageHeader.Timestamp
}
func (v *message) Payload() []byte {
return v.payload
}
func (v *message) asMessage() *message {
return v
}
func (v *message) generateC3Header() ([]byte, error) {
var c3h []byte
if v.Timestamp < extendedTimestamp {
if v.messageHeader.Timestamp < extendedTimestamp {
c3h = make([]byte, 1)
} else {
c3h = make([]byte, 1+4)
@ -935,19 +985,19 @@ func (v *Message) generateC3Header() ([]byte, error) {
// but actually all products from adobe, such as FMS/AMS and Flash player and FMLE,
// always carry a extended timestamp in C3 header.
// @see: http://blog.csdn.net/win_lin/article/details/13363699
if v.Timestamp >= extendedTimestamp {
p[0] = byte(v.Timestamp >> 24)
p[1] = byte(v.Timestamp >> 16)
p[2] = byte(v.Timestamp >> 8)
p[3] = byte(v.Timestamp)
if v.messageHeader.Timestamp >= extendedTimestamp {
p[0] = byte(v.messageHeader.Timestamp >> 24)
p[1] = byte(v.messageHeader.Timestamp >> 16)
p[2] = byte(v.messageHeader.Timestamp >> 8)
p[3] = byte(v.messageHeader.Timestamp)
}
return c3h, nil
}
func (v *Message) generateC0Header() ([]byte, error) {
func (v *message) generateC0Header() ([]byte, error) {
var c0h []byte
if v.Timestamp < extendedTimestamp {
if v.messageHeader.Timestamp < extendedTimestamp {
c0h = make([]byte, 1+3+3+1+4)
} else {
c0h = make([]byte, 1+3+3+1+4+4)
@ -957,10 +1007,10 @@ func (v *Message) generateC0Header() ([]byte, error) {
p[0] = byte(v.betterCid) & 0x3f
p = p[1:]
if v.Timestamp < extendedTimestamp {
p[0] = byte(v.Timestamp >> 16)
p[1] = byte(v.Timestamp >> 8)
p[2] = byte(v.Timestamp)
if v.messageHeader.Timestamp < extendedTimestamp {
p[0] = byte(v.messageHeader.Timestamp >> 16)
p[1] = byte(v.messageHeader.Timestamp >> 8)
p[2] = byte(v.messageHeader.Timestamp)
} else {
p[0] = 0xff
p[1] = 0xff
@ -973,7 +1023,7 @@ func (v *Message) generateC0Header() ([]byte, error) {
p[2] = byte(v.payloadLength)
p = p[3:]
p[0] = byte(v.MessageType)
p[0] = byte(v.messageHeader.MessageType)
p = p[1:]
p[0] = byte(v.streamID)
@ -982,11 +1032,11 @@ func (v *Message) generateC0Header() ([]byte, error) {
p[3] = byte(v.streamID >> 24)
p = p[4:]
if v.Timestamp >= extendedTimestamp {
p[0] = byte(v.Timestamp >> 24)
p[1] = byte(v.Timestamp >> 16)
p[2] = byte(v.Timestamp >> 8)
p[3] = byte(v.Timestamp)
if v.messageHeader.Timestamp >= extendedTimestamp {
p[0] = byte(v.messageHeader.Timestamp >> 24)
p[1] = byte(v.messageHeader.Timestamp >> 16)
p[2] = byte(v.messageHeader.Timestamp >> 8)
p[3] = byte(v.messageHeader.Timestamp)
}
return c0h, nil
@ -1039,8 +1089,8 @@ type Packet interface {
type objectCallPacket struct {
CommandName amf0String
TransactionID amf0Number
CommandObject *amf0Object
Args *amf0Object
CommandObject Amf0Object
Args Amf0Object
}
func (v *objectCallPacket) BetterCid() chunkID {
@ -1081,7 +1131,7 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) {
return
}
v.Args = NewAmf0Object()
v.Args = newAmf0Object()
if err = v.Args.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal args")
}
@ -1149,8 +1199,8 @@ func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) {
func (v *ConnectAppPacket) TcUrl() string {
if v.CommandObject != nil {
if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok {
return string(*v)
if v, ok := v.CommandObject.Get("tcUrl").(Amf0String); ok {
return v.String()
}
}
return ""
@ -1172,9 +1222,9 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket {
func (v *ConnectAppResPacket) SrsID() string {
if v.Args != nil {
if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok {
if v, ok := v.Get("srs_id").(*amf0String); ok {
return string(*v)
if v, ok := v.Args.Get("data").(Amf0EcmaArray); ok {
if v, ok := v.Get("srs_id").(Amf0String); ok {
return v.String()
}
}
}
@ -1197,7 +1247,7 @@ func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) {
type variantCallPacket struct {
CommandName amf0String
TransactionID amf0Number
CommandObject amf0Any // object or null
CommandObject Amf0Any // object or null
}
func (v *variantCallPacket) BetterCid() chunkID {
@ -1273,7 +1323,7 @@ func (v *variantCallPacket) MarshalBinary() (data []byte, err error) {
// @remark onStatus packet is a call packet.
type CallPacket struct {
variantCallPacket
Args amf0Any // optional or object or null
Args Amf0Any // optional or object or null
}
func NewCallPacket() *CallPacket {
@ -1282,9 +1332,9 @@ func NewCallPacket() *CallPacket {
func (v *CallPacket) ArgsCode() string {
if v.Args != nil {
if v, ok := v.Args.(*amf0Object); ok {
if code, ok := v.Get("code").(*amf0String); ok {
return string(*code)
if v, ok := v.Args.(Amf0Object); ok {
if code, ok := v.Get("code").(Amf0String); ok {
return code.String()
}
}
}
@ -1370,6 +1420,10 @@ func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket {
return v
}
func (v *CreateStreamResPacket) SetStreamID(streamID int) {
v.StreamID = amf0Number(streamID)
}
func (v *CreateStreamResPacket) Size() int {
return v.variantCallPacket.Size() + v.StreamID.Size()
}
@ -1407,15 +1461,16 @@ func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) {
// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish
type PublishPacket struct {
variantCallPacket
StreamName amf0String
StreamType amf0String
StreamName Amf0String
StreamType Amf0String
}
func NewPublishPacket() *PublishPacket {
v := &PublishPacket{}
v.CommandName = commandPublish
v.CommandObject = NewAmf0Null()
v.StreamType = "live"
v.StreamName = NewAmf0String("")
v.StreamType = NewAmf0String("live")
return v
}
@ -1431,11 +1486,13 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) {
}
p = p[v.variantCallPacket.Size():]
v.StreamName = newAmf0String("")
if err = v.StreamName.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal stream name")
}
p = p[v.StreamName.Size():]
v.StreamType = newAmf0String("")
if err = v.StreamType.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal stream type")
}
@ -1466,13 +1523,14 @@ func (v *PublishPacket) MarshalBinary() (data []byte, err error) {
// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play
type PlayPacket struct {
variantCallPacket
StreamName amf0String
StreamName Amf0String
}
func NewPlayPacket() *PlayPacket {
v := &PlayPacket{}
v.CommandName = commandPlay
v.CommandObject = NewAmf0Null()
v.StreamName = NewAmf0String("")
return v
}
@ -1488,6 +1546,7 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) {
}
p = p[v.variantCallPacket.Size():]
v.StreamName = newAmf0String("")
if err = v.StreamName.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal stream name")
}

View File

@ -229,7 +229,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
response = identifyRes
nextStreamID = 1
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
identifyRes.SetStreamID(nextStreamID)
} else if pkt.CommandName == "getStreamLength" {
// Ignore and do not reply these packets.
} else {
@ -243,7 +243,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
identifyRes.Args = rtmp.NewAmf0Undefined()
}
case *rtmp.PublishPacket:
streamName = string(pkt.StreamName)
streamName = pkt.StreamName.String()
clientType = RTMPClientTypePublisher
identifyRes := rtmp.NewCallPacket()
@ -257,7 +257,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
identifyRes.Args = data
case *rtmp.PlayPacket:
streamName = string(pkt.StreamName)
streamName = pkt.StreamName.String()
clientType = RTMPClientTypeViewer
identifyRes := rtmp.NewCallPacket()
@ -352,7 +352,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Debug(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
//logger.Debug(ctx, "client<- %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload()))
// TODO: Update the stream ID if not the same.
if err := client.WriteMessage(ctx, m); err != nil {
@ -375,7 +375,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Debug(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
//logger.Debug(ctx, "client-> %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload()))
// TODO: Update the stream ID if not the same.
if err := backend.client.WriteMessage(ctx, m); err != nil {
@ -421,7 +421,7 @@ type RTMPClientToBackend struct {
// The underlayer tcp client.
tcpConn *net.TCPConn
// The RTMP protocol client.
client *rtmp.Protocol
client rtmp.Protocol
// The stream type.
typ RTMPClientType
}
@ -527,7 +527,7 @@ func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
return v.publish(ctx, client, streamName)
}
func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
func (v *RTMPClientToBackend) publish(ctx context.Context, client rtmp.Protocol, streamName string) error {
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "releaseStream"
@ -592,8 +592,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol
publishStream := rtmp.NewPublishPacket()
publishStream.TransactionID = 5
publishStream.CommandObject = rtmp.NewAmf0Null()
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
publishStream.StreamType = *rtmp.NewAmf0String("live")
publishStream.StreamName = rtmp.NewAmf0String(streamName)
publishStream.StreamType = rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
return errors.Wrapf(err, "publish")
}
@ -609,8 +609,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol
return errors.Errorf("onStatus args not object")
} else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
return errors.Errorf("onStatus code not string")
} else if *code != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
} else if code.String() != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", code.String())
}
break
}
@ -620,7 +620,7 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol
return nil
}
func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
func (v *RTMPClientToBackend) play(ctx context.Context, client rtmp.Protocol, streamName string) error {
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
@ -642,7 +642,7 @@ func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, s
}
playStream := rtmp.NewPlayPacket()
playStream.StreamName = *rtmp.NewAmf0String(streamName)
playStream.StreamName = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
return errors.Wrapf(err, "play")
}