From 9b6842da9a3d697ec3f637373aa477b91d846ed9 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 29 Apr 2026 11:26:40 -0400 Subject: [PATCH] Codex: Expose RTMP AMF0 interfaces. Add public AMF0 and RTMP protocol interfaces, update the RTMP proxy to use the accessor APIs, and cover AMF0 encoding/decoding with unit tests and examples. --- internal/rtmp/amf0.go | 254 ++++++++++++----- internal/rtmp/amf0_test.go | 509 ++++++++++++++++++++++++++++++++++ internal/rtmp/example_test.go | 62 +++++ internal/rtmp/rtmp.go | 247 ++++++++++------- internal/server/rtmp.go | 26 +- 5 files changed, 924 insertions(+), 174 deletions(-) create mode 100644 internal/rtmp/amf0_test.go create mode 100644 internal/rtmp/example_test.go diff --git a/internal/rtmp/amf0.go b/internal/rtmp/amf0.go index 86a476308..7fd2c7a3d 100644 --- a/internal/rtmp/amf0.go +++ b/internal/rtmp/amf0.go @@ -95,7 +95,7 @@ var createBuffer = func() amf0Buffer { } // All AMF0 things. -type amf0Any interface { +type Amf0Any interface { // Binary marshaler and unmarshaler. encoding.BinaryUnmarshaler encoding.BinaryMarshaler @@ -106,59 +106,83 @@ type amf0Any interface { amf0Marker() amf0Marker } -type amf0Converter struct { - from amf0Any +type Amf0Converter interface { + ToNumber() Amf0Number + ToBoolean() Amf0Boolean + ToString() Amf0String + ToObject() Amf0Object + ToNull() Amf0Null + ToUndefined() Amf0Undefined + ToEcmaArray() Amf0EcmaArray + ToStrictArray() Amf0StrictArray } -func NewAmf0Converter(from amf0Any) *amf0Converter { +type amf0Converter struct { + from Amf0Any +} + +func NewAmf0Converter(from Amf0Any) Amf0Converter { return &amf0Converter{from: from} } -func (v *amf0Converter) ToNumber() *amf0Number { - return amf0AnyTo[*amf0Number](v.from) -} - -func (v *amf0Converter) ToBoolean() *amf0Boolean { - return amf0AnyTo[*amf0Boolean](v.from) -} - -func (v *amf0Converter) ToString() *amf0String { - return amf0AnyTo[*amf0String](v.from) -} - -func (v *amf0Converter) ToObject() *amf0Object { - return amf0AnyTo[*amf0Object](v.from) -} - -func (v *amf0Converter) ToNull() *amf0Null { - return amf0AnyTo[*amf0Null](v.from) -} - -func (v *amf0Converter) ToUndefined() *amf0Undefined { - return amf0AnyTo[*amf0Undefined](v.from) -} - -func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { - return amf0AnyTo[*amf0EcmaArray](v.from) -} - -func (v *amf0Converter) ToStrictArray() *amf0StrictArray { - return amf0AnyTo[*amf0StrictArray](v.from) -} - -// Convert any to specified object. -func amf0AnyTo[T amf0Any](a amf0Any) T { - var to T - if a != nil { - if v, ok := a.(T); ok { - return v - } +func (v *amf0Converter) ToNumber() Amf0Number { + if r, ok := v.from.(Amf0Number); ok { + return r } - return to + return nil +} + +func (v *amf0Converter) ToBoolean() Amf0Boolean { + if r, ok := v.from.(Amf0Boolean); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToString() Amf0String { + if r, ok := v.from.(Amf0String); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToObject() Amf0Object { + if r, ok := v.from.(Amf0Object); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToNull() Amf0Null { + if r, ok := v.from.(Amf0Null); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToUndefined() Amf0Undefined { + if r, ok := v.from.(Amf0Undefined); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToEcmaArray() Amf0EcmaArray { + if r, ok := v.from.(Amf0EcmaArray); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToStrictArray() Amf0StrictArray { + if r, ok := v.from.(Amf0StrictArray); ok { + return r + } + return nil } // Discovery the amf0 object from the bytes b. -func Amf0Discovery(p []byte) (a amf0Any, err error) { +func Amf0Discovery(p []byte) (a Amf0Any, err error) { if len(p) < 1 { return nil, errors.Errorf("require 1 bytes only %v", len(p)) } @@ -228,14 +252,24 @@ func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { return } +// Amf0Number is the AMF0 number type. +type Amf0Number interface { + Amf0Any + Float64() float64 +} + // The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type type amf0Number float64 -func NewAmf0Number(f float64) *amf0Number { +func NewAmf0Number(f float64) Amf0Number { v := amf0Number(f) return &v } +func (v *amf0Number) Float64() float64 { + return float64(*v) +} + func (v *amf0Number) amf0Marker() amf0Marker { return amf0MarkerNumber } @@ -266,14 +300,28 @@ func (v *amf0Number) MarshalBinary() (data []byte, err error) { return } +// Amf0String is the AMF0 string type. +type Amf0String interface { + Amf0Any + String() string +} + // The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type type amf0String string -func NewAmf0String(s string) *amf0String { +func NewAmf0String(s string) Amf0String { + return newAmf0String(s) +} + +func newAmf0String(s string) *amf0String { v := amf0String(s) return &v } +func (v *amf0String) String() string { + return string(*v) +} + func (v *amf0String) amf0Marker() amf0Marker { return amf0MarkerString } @@ -344,7 +392,7 @@ func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { // Use array for object and ecma array, to keep the original order. type amf0Property struct { key amf0UTF8 - value amf0Any + value Amf0Any } // The object-like AMF0 structure, like object and ecma array and strict array. @@ -367,7 +415,7 @@ func (v *amf0ObjectBase) Size() int { return size } -func (v *amf0ObjectBase) Get(key string) amf0Any { +func (v *amf0ObjectBase) Get(key string) Amf0Any { v.lock.Lock() defer v.lock.Unlock() @@ -380,7 +428,7 @@ func (v *amf0ObjectBase) Get(key string) amf0Any { return nil } -func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { +func (v *amf0ObjectBase) Set(key string, value Amf0Any) *amf0ObjectBase { v.lock.Lock() defer v.lock.Unlock() @@ -411,21 +459,21 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) return errors.Errorf("maxElems=%v with eof", maxElems) } - readOne := func() (amf0UTF8, amf0Any, error) { + readOne := func() (amf0UTF8, Amf0Any, error) { var u amf0UTF8 if err = u.UnmarshalBinary(p); err != nil { return "", nil, errors.WithMessage(err, "prop name") } p = p[u.Size():] - var a amf0Any + var a Amf0Any if a, err = Amf0Discovery(p); err != nil { return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) } return u, a, nil } - pushOne := func(u amf0UTF8, a amf0Any) error { + pushOne := func(u amf0UTF8, a Amf0Any) error { // For object property, consume the whole bytes. if err = a.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) @@ -494,13 +542,24 @@ func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { return } +// Amf0Object is the AMF0 object type. +type Amf0Object interface { + Amf0Any + Get(key string) Amf0Any + Set(key string, value Amf0Any) Amf0Object +} + // The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type type amf0Object struct { amf0ObjectBase eof amf0ObjectEOF } -func NewAmf0Object() *amf0Object { +func NewAmf0Object() Amf0Object { + return newAmf0Object() +} + +func newAmf0Object() *amf0Object { v := &amf0Object{} v.properties = []*amf0Property{} return v @@ -510,6 +569,15 @@ func (v *amf0Object) amf0Marker() amf0Marker { return amf0MarkerObject } +func (v *amf0Object) Get(key string) Amf0Any { + return v.amf0ObjectBase.Get(key) +} + +func (v *amf0Object) Set(key string, value Amf0Any) Amf0Object { + v.amf0ObjectBase.Set(key, value) + return v +} + func (v *amf0Object) Size() int { return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() } @@ -542,17 +610,22 @@ func (v *amf0Object) MarshalBinary() (data []byte, err error) { return nil, errors.WithMessage(err, "marshal") } - var pb []byte - if pb, err = v.eof.MarshalBinary(); err != nil { + if pb, err := v.eof.MarshalBinary(); err != nil { return nil, errors.WithMessage(err, "marshal") - } - if _, err = b.Write(pb); err != nil { + } else if _, err = b.Write(pb); err != nil { return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil } +// Amf0EcmaArray is the AMF0 ECMA array type. +type Amf0EcmaArray interface { + Amf0Any + Get(key string) Amf0Any + Set(key string, value Amf0Any) Amf0EcmaArray +} + // The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type type amf0EcmaArray struct { amf0ObjectBase @@ -560,7 +633,11 @@ type amf0EcmaArray struct { eof amf0ObjectEOF } -func NewAmf0EcmaArray() *amf0EcmaArray { +func NewAmf0EcmaArray() Amf0EcmaArray { + return newAmf0EcmaArray() +} + +func newAmf0EcmaArray() *amf0EcmaArray { v := &amf0EcmaArray{} v.properties = []*amf0Property{} return v @@ -570,6 +647,15 @@ func (v *amf0EcmaArray) amf0Marker() amf0Marker { return amf0MarkerEcmaArray } +func (v *amf0EcmaArray) Get(key string) Amf0Any { + return v.amf0ObjectBase.Get(key) +} + +func (v *amf0EcmaArray) Set(key string, value Amf0Any) Amf0EcmaArray { + v.amf0ObjectBase.Set(key, value) + return v +} + func (v *amf0EcmaArray) Size() int { return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() } @@ -606,24 +692,29 @@ func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { return nil, errors.WithMessage(err, "marshal") } - var pb []byte - if pb, err = v.eof.MarshalBinary(); err != nil { + if pb, err := v.eof.MarshalBinary(); err != nil { return nil, errors.WithMessage(err, "marshal") - } - if _, err = b.Write(pb); err != nil { + } else if _, err = b.Write(pb); err != nil { return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil } +// Amf0StrictArray is the AMF0 strict array type. +type Amf0StrictArray interface { + Amf0Any + Get(key string) Amf0Any + Set(key string, value Amf0Any) Amf0StrictArray +} + // The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type type amf0StrictArray struct { amf0ObjectBase count uint32 } -func NewAmf0StrictArray() *amf0StrictArray { +func NewAmf0StrictArray() Amf0StrictArray { v := &amf0StrictArray{} v.properties = []*amf0Property{} return v @@ -633,6 +724,15 @@ func (v *amf0StrictArray) amf0Marker() amf0Marker { return amf0MarkerStrictArray } +func (v *amf0StrictArray) Get(key string) Amf0Any { + return v.amf0ObjectBase.Get(key) +} + +func (v *amf0StrictArray) Set(key string, value Amf0Any) Amf0StrictArray { + v.amf0ObjectBase.Set(key, value) + return v +} + func (v *amf0StrictArray) Size() int { return int(1) + 4 + v.amf0ObjectBase.Size() } @@ -708,36 +808,56 @@ func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { return []byte{byte(v.target)}, nil } +// Amf0Null is the AMF0 null type. +type Amf0Null interface { + Amf0Any +} + // The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type type amf0Null struct { amf0SingleMarkerObject } -func NewAmf0Null() *amf0Null { +func NewAmf0Null() Amf0Null { v := amf0Null{} v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) return &v } +// Amf0Undefined is the AMF0 undefined type. +type Amf0Undefined interface { + Amf0Any +} + // The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type type amf0Undefined struct { amf0SingleMarkerObject } -func NewAmf0Undefined() amf0Any { +func NewAmf0Undefined() Amf0Undefined { v := amf0Undefined{} v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) return &v } +// Amf0Boolean is the public typed view of an AMF0 boolean. +type Amf0Boolean interface { + Amf0Any + Bool() bool +} + // The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type type amf0Boolean bool -func NewAmf0Boolean(b bool) amf0Any { +func NewAmf0Boolean(b bool) Amf0Boolean { v := amf0Boolean(b) return &v } +func (v *amf0Boolean) Bool() bool { + return bool(*v) +} + func (v *amf0Boolean) amf0Marker() amf0Marker { return amf0MarkerBoolean } diff --git a/internal/rtmp/amf0_test.go b/internal/rtmp/amf0_test.go new file mode 100644 index 000000000..a2c240360 --- /dev/null +++ b/internal/rtmp/amf0_test.go @@ -0,0 +1,509 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "fmt" + "math" + "strings" + "testing" +) + +func TestAmf0MarkerString(t *testing.T) { + for _, tt := range []struct { + marker amf0Marker + want string + }{ + {amf0MarkerNumber, "Amf0Number"}, + {amf0MarkerBoolean, "amf0Boolean"}, + {amf0MarkerString, "Amf0String"}, + {amf0MarkerObject, "Amf0Object"}, + {amf0MarkerMovieClip, "MovieClip"}, + {amf0MarkerNull, "Null"}, + {amf0MarkerUndefined, "Undefined"}, + {amf0MarkerReference, "Reference"}, + {amf0MarkerEcmaArray, "EcmaArray"}, + {amf0MarkerObjectEnd, "ObjectEnd"}, + {amf0MarkerStrictArray, "StrictArray"}, + {amf0MarkerDate, "Date"}, + {amf0MarkerLongString, "LongString"}, + {amf0MarkerUnsupported, "Unsupported"}, + {amf0MarkerRecordSet, "RecordSet"}, + {amf0MarkerXmlDocument, "XmlDocument"}, + {amf0MarkerTypedObject, "TypedObject"}, + {amf0MarkerAvmPlusObject, "AvmPlusObject"}, + {amf0MarkerForbidden, "Forbidden"}, + {amf0Marker(0xee), "Forbidden"}, + } { + if got := tt.marker.String(); got != tt.want { + t.Fatalf("marker=%#x String()=%v, want %v", byte(tt.marker), got, tt.want) + } + } +} + +func TestAmf0Discovery(t *testing.T) { + for _, tt := range []struct { + name string + data []byte + ok func(Amf0Any) bool + }{ + {"number", []byte{byte(amf0MarkerNumber)}, func(v Amf0Any) bool { _, ok := v.(Amf0Number); return ok }}, + {"boolean", []byte{byte(amf0MarkerBoolean)}, func(v Amf0Any) bool { _, ok := v.(Amf0Boolean); return ok }}, + {"string", []byte{byte(amf0MarkerString)}, func(v Amf0Any) bool { _, ok := v.(Amf0String); return ok }}, + {"object", []byte{byte(amf0MarkerObject)}, func(v Amf0Any) bool { _, ok := v.(Amf0Object); return ok }}, + {"null", []byte{byte(amf0MarkerNull)}, func(v Amf0Any) bool { _, ok := v.(Amf0Null); return ok }}, + {"undefined", []byte{byte(amf0MarkerUndefined)}, func(v Amf0Any) bool { _, ok := v.(Amf0Undefined); return ok }}, + {"ecma-array", []byte{byte(amf0MarkerEcmaArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0EcmaArray); return ok }}, + {"object-end", []byte{byte(amf0MarkerObjectEnd)}, func(v Amf0Any) bool { _, ok := v.(*amf0ObjectEOF); return ok }}, + {"strict-array", []byte{byte(amf0MarkerStrictArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0StrictArray); return ok }}, + } { + t.Run(tt.name, func(t *testing.T) { + value, err := Amf0Discovery(tt.data) + if err != nil { + t.Fatalf("Amf0Discovery() err=%v", err) + } + if !tt.ok(value) { + t.Fatalf("Amf0Discovery()=%T", value) + } + }) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerReference)}, {byte(amf0MarkerDate)}, {byte(amf0MarkerForbidden)}} { + if value, err := Amf0Discovery(data); err == nil || value != nil { + t.Fatalf("Amf0Discovery(%v) value=%T, err=%v, want error", data, value, err) + } + } +} + +func TestAmf0Converter(t *testing.T) { + values := []struct { + name string + in Amf0Any + ok func(Amf0Converter) bool + }{ + {"number", NewAmf0Number(1), func(c Amf0Converter) bool { return c.ToNumber() != nil }}, + {"boolean", NewAmf0Boolean(true), func(c Amf0Converter) bool { return c.ToBoolean() != nil }}, + {"string", NewAmf0String("v"), func(c Amf0Converter) bool { return c.ToString() != nil }}, + {"object", NewAmf0Object(), func(c Amf0Converter) bool { return c.ToObject() != nil }}, + {"null", NewAmf0Null(), func(c Amf0Converter) bool { return c.ToNull() != nil }}, + {"undefined", NewAmf0Undefined(), func(c Amf0Converter) bool { return c.ToUndefined() != nil }}, + {"ecma-array", NewAmf0EcmaArray(), func(c Amf0Converter) bool { return c.ToEcmaArray() != nil }}, + {"strict-array", NewAmf0StrictArray(), func(c Amf0Converter) bool { return c.ToStrictArray() != nil }}, + } + + for _, tt := range values { + t.Run(tt.name, func(t *testing.T) { + converter := NewAmf0Converter(tt.in) + if !tt.ok(converter) { + t.Fatalf("expected successful conversion for %T", tt.in) + } + }) + } + + nilConverter := NewAmf0Converter(nil) + if nilConverter.ToNumber() != nil || nilConverter.ToBoolean() != nil || nilConverter.ToString() != nil || + nilConverter.ToObject() != nil || nilConverter.ToNull() != nil || nilConverter.ToUndefined() != nil || + nilConverter.ToEcmaArray() != nil || nilConverter.ToStrictArray() != nil { + t.Fatal("nil converter should not convert") + } +} + +func TestAmf0UTF8(t *testing.T) { + var value amf0UTF8 = "hello" + b, err := value.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + if value.Size() != len(b) { + t.Fatalf("Size()=%v, len=%v", value.Size(), len(b)) + } + + var decoded amf0UTF8 + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if decoded != value { + t.Fatalf("decoded=%v, want %v", decoded, value) + } + + for _, data := range [][]byte{{0x00}, {0x00, 0x05, 'h'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0Number(t *testing.T) { + number := NewAmf0Number(math.Pi) + if number.Size() != 9 || number.(*amf0Number).amf0Marker() != amf0MarkerNumber { + t.Fatalf("unexpected number metadata") + } + + b, err := number.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0Number(0) + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := decoded.Float64(); got != math.Pi { + t.Fatalf("Float64()=%v, want %v", got, math.Pi) + } + + for _, data := range [][]byte{{byte(amf0MarkerNumber)}, append([]byte{byte(amf0MarkerString)}, b[1:]...)} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0Boolean(t *testing.T) { + for _, want := range []bool{false, true} { + boolean := NewAmf0Boolean(want) + if boolean.Size() != 2 || boolean.(*amf0Boolean).amf0Marker() != amf0MarkerBoolean { + t.Fatalf("unexpected boolean metadata") + } + + b, err := boolean.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0Boolean(!want) + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := decoded.Bool(); got != want { + t.Fatalf("Bool()=%v, want %v", got, want) + } + } + + decoded := NewAmf0Boolean(false) + for _, data := range [][]byte{{byte(amf0MarkerBoolean)}, {byte(amf0MarkerNumber), 1}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0String(t *testing.T) { + value := NewAmf0String("hello") + if value.Size() != 8 || value.(*amf0String).amf0Marker() != amf0MarkerString { + t.Fatalf("unexpected string metadata") + } + + b, err := value.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0String("") + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := decoded.String(); got != "hello" { + t.Fatalf("String()=%v, want hello", got) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerNumber), 0, 0}, {byte(amf0MarkerString), 0, 5, 'h'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0ObjectEOF(t *testing.T) { + eof := &amf0ObjectEOF{} + if eof.Size() != 3 || eof.amf0Marker() != amf0MarkerObjectEnd { + t.Fatalf("unexpected eof metadata") + } + + b, err := eof.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + if !bytes.Equal(b, []byte{0, 0, 9}) { + t.Fatalf("MarshalBinary()=%v", b) + } + for _, data := range [][]byte{b, {0, 0, 9, 1}} { + if err := eof.UnmarshalBinary(data); err != nil { + t.Fatalf("UnmarshalBinary(%v) err=%v", data, err) + } + } + for _, data := range [][]byte{{0, 0}, {0, 1, 9}} { + if err := eof.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0Object(t *testing.T) { + object := NewAmf0Object(). + Set("name", NewAmf0String("stream")). + Set("code", NewAmf0Number(100)). + Set("ok", NewAmf0Boolean(true)) + object.Set("code", NewAmf0Number(200)) + + if object.(*amf0Object).amf0Marker() != amf0MarkerObject || object.Size() == 0 { + t.Fatalf("unexpected object metadata") + } + if object.Get("missing") != nil { + t.Fatal("missing property should be nil") + } + + b, err := object.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0Object() + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" { + t.Fatalf("name=%v", got) + } + if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 200 { + t.Fatalf("code=%v", got) + } + if got := NewAmf0Converter(decoded.Get("ok")).ToBoolean().Bool(); !got { + t.Fatalf("ok=%v", got) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerString)}, {byte(amf0MarkerObject), 0, 4, 'n'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } + + base := &amf0ObjectBase{} + if err := base.unmarshal(nil, false, -1); err == nil { + t.Fatal("unmarshal without eof and negative maxElems should fail") + } + if err := base.unmarshal(nil, true, 0); err == nil { + t.Fatal("unmarshal with eof and non-negative maxElems should fail") + } +} + +func TestAmf0EcmaArray(t *testing.T) { + array := NewAmf0EcmaArray(). + Set("name", NewAmf0String("stream")). + Set("code", NewAmf0Number(100)) + + if array.(*amf0EcmaArray).amf0Marker() != amf0MarkerEcmaArray || array.Size() == 0 { + t.Fatalf("unexpected ecma array metadata") + } + if array.Get("missing") != nil { + t.Fatal("missing property should be nil") + } + + b, err := array.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0EcmaArray() + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" { + t.Fatalf("name=%v", got) + } + if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 { + t.Fatalf("code=%v", got) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerEcmaArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerEcmaArray), 0, 0, 0, 0, 0, 4, 'n'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0StrictArray(t *testing.T) { + array := NewAmf0StrictArray(). + Set("name", NewAmf0String("stream")). + Set("code", NewAmf0Number(100)) + array.(*amf0StrictArray).count = 2 + + if array.(*amf0StrictArray).amf0Marker() != amf0MarkerStrictArray || array.Size() == 0 { + t.Fatalf("unexpected strict array metadata") + } + if array.Get("missing") != nil { + t.Fatal("missing property should be nil") + } + + b, err := array.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0StrictArray() + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" { + t.Fatalf("name=%v", got) + } + if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 { + t.Fatalf("code=%v", got) + } + + empty := append([]byte{byte(amf0MarkerStrictArray)}, 0, 0, 0, 0) + if err := decoded.UnmarshalBinary(empty); err != nil { + t.Fatalf("UnmarshalBinary(empty) err=%v", err) + } + for _, data := range [][]byte{{}, {byte(amf0MarkerStrictArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerStrictArray), 0, 0, 0, 1, 0, 4, 'n'}} { + if err := NewAmf0StrictArray().UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0SingleMarkerObjects(t *testing.T) { + for _, tt := range []struct { + name string + value Amf0Any + marker amf0Marker + }{ + {"null", NewAmf0Null(), amf0MarkerNull}, + {"undefined", NewAmf0Undefined(), amf0MarkerUndefined}, + } { + t.Run(tt.name, func(t *testing.T) { + if tt.value.Size() != 1 || tt.value.amf0Marker() != tt.marker { + t.Fatalf("unexpected metadata") + } + b, err := tt.value.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + if err := tt.value.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + for _, data := range [][]byte{{}, {byte(amf0MarkerString)}} { + if err := tt.value.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } + }) + } +} + +type errorAmf0Buffer struct { + writeByteErr bool + writeErr bool +} + +func (v *errorAmf0Buffer) Bytes() []byte { + return nil +} + +func (v *errorAmf0Buffer) WriteByte(byte) error { + if v.writeByteErr { + return fmt.Errorf("write byte") + } + return nil +} + +func (v *errorAmf0Buffer) Write([]byte) (int, error) { + if v.writeErr { + return 0, fmt.Errorf("write") + } + return 0, nil +} + +type errorAmf0Any struct { + Amf0Any +} + +func (v *errorAmf0Any) Size() int { + return 1 +} + +func (v *errorAmf0Any) MarshalBinary() ([]byte, error) { + return nil, fmt.Errorf("marshal") +} + +func (v *errorAmf0Any) UnmarshalBinary([]byte) error { + return nil +} + +func (v *errorAmf0Any) amf0Marker() amf0Marker { + return amf0MarkerNumber +} + +func TestAmf0MarshalErrors(t *testing.T) { + originalCreateBuffer := createBuffer + defer func() { createBuffer = originalCreateBuffer }() + + for _, tt := range []struct { + name string + make func() Amf0Any + }{ + {"object", func() Amf0Any { return NewAmf0Object() }}, + {"ecma-array", func() Amf0Any { return NewAmf0EcmaArray() }}, + {"strict-array", func() Amf0Any { return NewAmf0StrictArray() }}, + } { + t.Run(tt.name+" write-byte", func(t *testing.T) { + createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} } + if _, err := tt.make().MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() should fail") + } + }) + + t.Run(tt.name+" write-prop", func(t *testing.T) { + createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} } + value := tt.make() + switch v := value.(type) { + case Amf0Object: + v.Set("name", NewAmf0String("stream")) + case Amf0EcmaArray: + v.Set("name", NewAmf0String("stream")) + case Amf0StrictArray: + v.Set("name", NewAmf0String("stream")) + v.(*amf0StrictArray).count = 1 + } + if _, err := value.MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() should fail") + } + }) + } + + createBuffer = originalCreateBuffer + for _, tt := range []struct { + name string + make func() Amf0Any + }{ + {"object", func() Amf0Any { return NewAmf0Object().Set("bad", &errorAmf0Any{}) }}, + {"ecma-array", func() Amf0Any { return NewAmf0EcmaArray().Set("bad", &errorAmf0Any{}) }}, + {"strict-array", func() Amf0Any { + value := NewAmf0StrictArray().Set("bad", &errorAmf0Any{}) + value.(*amf0StrictArray).count = 1 + return value + }}, + } { + t.Run(tt.name+" marshal-value", func(t *testing.T) { + if _, err := tt.make().MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() should fail") + } + }) + } +} + +func TestAmf0UnmarshalNestedErrors(t *testing.T) { + // Object property with unsupported marker. + data := []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerDate)} + if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "discover prop bad") { + t.Fatalf("err=%v, want discover prop bad", err) + } + + // Object property with invalid payload size. + data = []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerNumber), 0} + if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "unmarshal prop bad") { + t.Fatalf("err=%v, want unmarshal prop bad", err) + } +} diff --git a/internal/rtmp/example_test.go b/internal/rtmp/example_test.go new file mode 100644 index 000000000..2b5a9f7d6 --- /dev/null +++ b/internal/rtmp/example_test.go @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp_test + +import ( + "fmt" + + "srsx/internal/rtmp" +) + +func ExampleAmf0Number() { + number := rtmp.NewAmf0Number(3.14) + b, err := number.MarshalBinary() + if err != nil { + panic(err) + } + + value, err := rtmp.Amf0Discovery(b) + if err != nil { + panic(err) + } + if err := value.UnmarshalBinary(b); err != nil { + panic(err) + } + + converter := rtmp.NewAmf0Converter(value) + fmt.Println("number:", converter.ToNumber().Float64()) + fmt.Println("is string:", converter.ToString() != nil) + + // Output: + // number: 3.14 + // is string: false +} + +func ExampleAmf0Object() { + object := rtmp.NewAmf0Object(). + Set("code", rtmp.NewAmf0Number(100)). + Set("level", rtmp.NewAmf0String("status")) + b, err := object.MarshalBinary() + if err != nil { + panic(err) + } + + value, err := rtmp.Amf0Discovery(b) + if err != nil { + panic(err) + } + if err := value.UnmarshalBinary(b); err != nil { + panic(err) + } + + converter := rtmp.NewAmf0Converter(value) + fmt.Println("code:", rtmp.NewAmf0Converter(converter.ToObject().Get("code")).ToNumber().Float64()) + fmt.Println("level:", rtmp.NewAmf0Converter(converter.ToObject().Get("level")).ToString().String()) + fmt.Println("is number:", converter.ToNumber() != nil) + + // Output: + // code: 100 + // level: status + // is number: false +} diff --git a/internal/rtmp/rtmp.go b/internal/rtmp/rtmp.go index b24a12de5..08b67b986 100644 --- a/internal/rtmp/rtmp.go +++ b/internal/rtmp/rtmp.go @@ -17,21 +17,31 @@ import ( "srsx/internal/errors" ) -// The handshake implements the RTMP handshake protocol. -type Handshake struct { +// Handshake implements the RTMP handshake protocol. +type Handshake interface { + C1S1() []byte + WriteC0S0(w io.Writer) error + ReadC0S0(r io.Reader) ([]byte, error) + WriteC1S1(w io.Writer) error + ReadC1S1(r io.Reader) ([]byte, error) + WriteC2S2(w io.Writer, s1c1 []byte) error + ReadC2S2(r io.Reader) ([]byte, error) +} + +type handshake struct { // The c1s1 cache. c1s1 []byte } -func NewHandshake() *Handshake { - return &Handshake{} +func NewHandshake() Handshake { + return &handshake{} } -func (v *Handshake) C1S1() []byte { +func (v *handshake) C1S1() []byte { return v.c1s1 } -func (v *Handshake) WriteC0S0(w io.Writer) (err error) { +func (v *handshake) WriteC0S0(w io.Writer) (err error) { r := bytes.NewReader([]byte{0x03}) if _, err = io.Copy(w, r); err != nil { return errors.Wrap(err, "write c0s0") @@ -40,7 +50,7 @@ func (v *Handshake) WriteC0S0(w io.Writer) (err error) { return } -func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { +func (v *handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1); err != nil { return nil, errors.Wrap(err, "read c0s0") @@ -51,7 +61,7 @@ func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { return } -func (v *Handshake) WriteC1S1(w io.Writer) (err error) { +func (v *handshake) WriteC1S1(w io.Writer) (err error) { p := make([]byte, 1536) // Use crypto/rand for thread-safe random generation @@ -67,7 +77,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { return } -func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { +func (v *handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { return nil, errors.Wrap(err, "read c1s1") @@ -79,7 +89,7 @@ func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { return } -func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { +func (v *handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { r := bytes.NewReader(s1c1[:]) if _, err = io.Copy(w, r); err != nil { return errors.Wrap(err, "write c2s2") @@ -88,7 +98,7 @@ func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { return } -func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { +func (v *handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { return nil, errors.Wrap(err, "read c2s2") @@ -129,7 +139,7 @@ type chunkStream struct { format formatType cid chunkID header messageHeader - message *Message + message *message count uint64 extendedTimestamp bool } @@ -138,8 +148,18 @@ func newChunkStream() *chunkStream { return &chunkStream{} } -// The protocol implements the RTMP command and chunk stack. -type Protocol struct { +// Protocol implements the RTMP command and chunk stack. +type Protocol interface { + // Deprecated: Please use rtmp.ExpectPacket instead. + ExpectPacket(ctx context.Context, ppkt any) (Message, error) + ExpectMessage(ctx context.Context, types ...MessageType) (Message, error) + DecodeMessage(m Message) (Packet, error) + ReadMessage(ctx context.Context) (Message, error) + WritePacket(ctx context.Context, pkt Packet, streamID int) error + WriteMessage(ctx context.Context, m Message) error +} + +type protocol struct { r *bufio.Reader w *bufio.Writer input struct { @@ -154,8 +174,8 @@ type Protocol struct { } } -func NewProtocol(rw io.ReadWriter) *Protocol { - v := &Protocol{ +func NewProtocol(rw io.ReadWriter) Protocol { + v := &protocol{ r: bufio.NewReader(rw), w: bufio.NewWriter(rw), } @@ -169,7 +189,7 @@ func NewProtocol(rw io.ReadWriter) *Protocol { return v } -func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { +func ExpectPacket[T Packet](ctx context.Context, v Protocol, ppkt *T) (m Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { return nil, errors.WithMessage(err, "read message") @@ -190,11 +210,11 @@ func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Messa } // Deprecated: Please use rtmp.ExpectPacket instead. -func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { +func (v *protocol) ExpectPacket(ctx context.Context, ppkt any) (m Message, err error) { panic("Please use rtmp.ExpectPacket instead") } -func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { +func (v *protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { return nil, errors.WithMessage(err, "read message") @@ -205,7 +225,7 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m * } for _, t := range types { - if m.MessageType == t { + if m.MessageType() == t { return } } @@ -214,7 +234,7 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m * return } -func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { +func (v *protocol) parseAMFObject(p []byte) (pkt Packet, err error) { var commandName amf0String if err = commandName.UnmarshalBinary(p); err != nil { return nil, errors.WithMessage(err, "unmarshal command name") @@ -266,18 +286,18 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { } } -func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { - p := m.Payload[:] +func (v *protocol) DecodeMessage(m Message) (pkt Packet, err error) { + p := m.Payload()[:] if len(p) == 0 { return nil, errors.New("Empty packet") } - switch m.MessageType { + switch m.MessageType() { case MessageTypeAMF3Command, MessageTypeAMF3Data: p = p[1:] } - switch m.MessageType { + switch m.MessageType() { case MessageTypeSetChunkSize: pkt = NewSetChunkSize() case MessageTypeWindowAcknowledgementSize: @@ -286,22 +306,22 @@ func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { pkt = NewSetPeerBandwidth() case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: if pkt, err = v.parseAMFObject(p); err != nil { - return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType())) } case MessageTypeUserControl: pkt = NewUserControl() default: - return nil, errors.Errorf("Unknown message %v", m.MessageType) + return nil, errors.Errorf("Unknown message %v", m.MessageType()) } if err = pkt.UnmarshalBinary(p); err != nil { - return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType())) } return } -func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { +func (v *protocol) ReadMessage(ctx context.Context) (m Message, err error) { for m == nil { // TODO: We should convert buffered io to async io, because we will be stuck in block io here, // TODO: but the risk is acceptable because we literally will set the underlay io timeout. @@ -331,15 +351,17 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { return nil, errors.WithMessage(err, "read message payload") } - if err = v.onMessageArrivated(m); err != nil { - return nil, errors.WithMessage(err, "on message") + if m != nil { + if err = v.onMessageArrivated(m.asMessage()); err != nil { + return nil, errors.WithMessage(err, "on message") + } } } return } -func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { +func (v *protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m Message, err error) { // Empty payload message. if chunk.message.payloadLength == 0 { m = chunk.message @@ -348,7 +370,7 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) ( } // Calculate the chunk payload size. - chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) + chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.payload) if chunkedPayloadSize > int(v.input.opt.chunkSize) { chunkedPayloadSize = int(v.input.opt.chunkSize) } @@ -357,10 +379,10 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) ( if _, err = io.ReadFull(v.r, b); err != nil { return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) } - chunk.message.Payload = append(chunk.message.Payload, b...) + chunk.message.payload = append(chunk.message.payload, b...) // Got entire RTMP message? - if int(chunk.message.payloadLength) == len(chunk.message.Payload) { + if int(chunk.message.payloadLength) == len(chunk.message.payload) { m = chunk.message chunk.message = nil } @@ -426,7 +448,7 @@ var messageHeaderSizes = []int{11, 7, 3, 0} // fmt=1, 0x4X // fmt=2, 0x8X // fmt=3, 0xCX -func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { +func (v *protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { // We should not assert anything about fmt, for the first packet. // (when first packet, the chunk.message is nil). // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. @@ -480,7 +502,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // Create msg when new chunk stream start if chunk.message == nil { - chunk.message = NewMessage() + chunk.message = newMessage() } // Read the message header. @@ -659,7 +681,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // // Chunk stream IDs with values 64-319 could be represented by both 2- // byte version and 3-byte version of this field. -func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { +func (v *protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { // 2-63, 1B chunk header var t uint8 if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { @@ -689,14 +711,14 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid return } -func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { - m := NewMessage() +func (v *protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { + m := newMessage() - if m.Payload, err = pkt.MarshalBinary(); err != nil { + if m.payload, err = pkt.MarshalBinary(); err != nil { return errors.WithMessage(err, "marshal payload") } - m.MessageType = pkt.Type() + m.messageHeader.MessageType = pkt.Type() m.streamID = uint32(streamID) m.betterCid = pkt.BetterCid() @@ -711,7 +733,7 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e return } -func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { +func (v *protocol) onPacketWriten(m *message, pkt Packet) (err error) { var tid amf0Number var name amf0String @@ -734,16 +756,16 @@ func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { return } -func (v *Protocol) onMessageArrivated(m *Message) (err error) { +func (v *protocol) onMessageArrivated(m *message) (err error) { if m == nil { return } var pkt Packet - switch m.MessageType { + switch m.MessageType() { case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: if pkt, err = v.DecodeMessage(m); err != nil { - return errors.Errorf("decode message %v", m.MessageType) + return errors.Errorf("decode message %v", m.MessageType()) } } @@ -755,19 +777,20 @@ func (v *Protocol) onMessageArrivated(m *Message) (err error) { return } -func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { - m.payloadLength = uint32(len(m.Payload)) +func (v *protocol) WriteMessage(ctx context.Context, m Message) (err error) { + msg := m.asMessage() + msg.payloadLength = uint32(len(msg.payload)) var c0h, c3h []byte - if c0h, err = m.generateC0Header(); err != nil { + if c0h, err = msg.generateC0Header(); err != nil { return errors.WithMessage(err, "generate c0 header") } - if c3h, err = m.generateC3Header(); err != nil { + if c3h, err = msg.generateC3Header(); err != nil { return errors.WithMessage(err, "generate c3 header") } var h []byte - p := m.Payload + p := msg.payload for len(p) > 0 { // TODO: We should convert buffered io to async io, because we will be stuck in block io here, // TODO: but the risk is acceptable because we literally will set the underlay io timeout. @@ -899,29 +922,56 @@ type messageHeader struct { Timestamp uint64 } -// The RTMP message, transport over chunk stream in RTMP. +// Message is an RTMP message transported over a chunk stream. // Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header -type Message struct { +type Message interface { + MessageType() MessageType + Timestamp() uint64 + Payload() []byte + asMessage() *message +} + +type message struct { messageHeader // The payload which carries the RTMP packet. - Payload []byte + payload []byte } -func NewMessage() *Message { - return &Message{} +func NewMessage() Message { + return newMessage() } -func NewStreamMessage(streamID int) *Message { - v := NewMessage() +func newMessage() *message { + return &message{} +} + +func NewStreamMessage(streamID int) Message { + v := newMessage() v.streamID = uint32(streamID) v.betterCid = chunkIDOverStream return v } -func (v *Message) generateC3Header() ([]byte, error) { +func (v *message) MessageType() MessageType { + return v.messageHeader.MessageType +} + +func (v *message) Timestamp() uint64 { + return v.messageHeader.Timestamp +} + +func (v *message) Payload() []byte { + return v.payload +} + +func (v *message) asMessage() *message { + return v +} + +func (v *message) generateC3Header() ([]byte, error) { var c3h []byte - if v.Timestamp < extendedTimestamp { + if v.messageHeader.Timestamp < extendedTimestamp { c3h = make([]byte, 1) } else { c3h = make([]byte, 1+4) @@ -935,19 +985,19 @@ func (v *Message) generateC3Header() ([]byte, error) { // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, // always carry a extended timestamp in C3 header. // @see: http://blog.csdn.net/win_lin/article/details/13363699 - if v.Timestamp >= extendedTimestamp { - p[0] = byte(v.Timestamp >> 24) - p[1] = byte(v.Timestamp >> 16) - p[2] = byte(v.Timestamp >> 8) - p[3] = byte(v.Timestamp) + if v.messageHeader.Timestamp >= extendedTimestamp { + p[0] = byte(v.messageHeader.Timestamp >> 24) + p[1] = byte(v.messageHeader.Timestamp >> 16) + p[2] = byte(v.messageHeader.Timestamp >> 8) + p[3] = byte(v.messageHeader.Timestamp) } return c3h, nil } -func (v *Message) generateC0Header() ([]byte, error) { +func (v *message) generateC0Header() ([]byte, error) { var c0h []byte - if v.Timestamp < extendedTimestamp { + if v.messageHeader.Timestamp < extendedTimestamp { c0h = make([]byte, 1+3+3+1+4) } else { c0h = make([]byte, 1+3+3+1+4+4) @@ -957,10 +1007,10 @@ func (v *Message) generateC0Header() ([]byte, error) { p[0] = byte(v.betterCid) & 0x3f p = p[1:] - if v.Timestamp < extendedTimestamp { - p[0] = byte(v.Timestamp >> 16) - p[1] = byte(v.Timestamp >> 8) - p[2] = byte(v.Timestamp) + if v.messageHeader.Timestamp < extendedTimestamp { + p[0] = byte(v.messageHeader.Timestamp >> 16) + p[1] = byte(v.messageHeader.Timestamp >> 8) + p[2] = byte(v.messageHeader.Timestamp) } else { p[0] = 0xff p[1] = 0xff @@ -973,7 +1023,7 @@ func (v *Message) generateC0Header() ([]byte, error) { p[2] = byte(v.payloadLength) p = p[3:] - p[0] = byte(v.MessageType) + p[0] = byte(v.messageHeader.MessageType) p = p[1:] p[0] = byte(v.streamID) @@ -982,11 +1032,11 @@ func (v *Message) generateC0Header() ([]byte, error) { p[3] = byte(v.streamID >> 24) p = p[4:] - if v.Timestamp >= extendedTimestamp { - p[0] = byte(v.Timestamp >> 24) - p[1] = byte(v.Timestamp >> 16) - p[2] = byte(v.Timestamp >> 8) - p[3] = byte(v.Timestamp) + if v.messageHeader.Timestamp >= extendedTimestamp { + p[0] = byte(v.messageHeader.Timestamp >> 24) + p[1] = byte(v.messageHeader.Timestamp >> 16) + p[2] = byte(v.messageHeader.Timestamp >> 8) + p[3] = byte(v.messageHeader.Timestamp) } return c0h, nil @@ -1039,8 +1089,8 @@ type Packet interface { type objectCallPacket struct { CommandName amf0String TransactionID amf0Number - CommandObject *amf0Object - Args *amf0Object + CommandObject Amf0Object + Args Amf0Object } func (v *objectCallPacket) BetterCid() chunkID { @@ -1081,7 +1131,7 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { return } - v.Args = NewAmf0Object() + v.Args = newAmf0Object() if err = v.Args.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal args") } @@ -1149,8 +1199,8 @@ func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { func (v *ConnectAppPacket) TcUrl() string { if v.CommandObject != nil { - if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { - return string(*v) + if v, ok := v.CommandObject.Get("tcUrl").(Amf0String); ok { + return v.String() } } return "" @@ -1172,9 +1222,9 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { func (v *ConnectAppResPacket) SrsID() string { if v.Args != nil { - if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { - if v, ok := v.Get("srs_id").(*amf0String); ok { - return string(*v) + if v, ok := v.Args.Get("data").(Amf0EcmaArray); ok { + if v, ok := v.Get("srs_id").(Amf0String); ok { + return v.String() } } } @@ -1197,7 +1247,7 @@ func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { type variantCallPacket struct { CommandName amf0String TransactionID amf0Number - CommandObject amf0Any // object or null + CommandObject Amf0Any // object or null } func (v *variantCallPacket) BetterCid() chunkID { @@ -1273,7 +1323,7 @@ func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { // @remark onStatus packet is a call packet. type CallPacket struct { variantCallPacket - Args amf0Any // optional or object or null + Args Amf0Any // optional or object or null } func NewCallPacket() *CallPacket { @@ -1282,9 +1332,9 @@ func NewCallPacket() *CallPacket { func (v *CallPacket) ArgsCode() string { if v.Args != nil { - if v, ok := v.Args.(*amf0Object); ok { - if code, ok := v.Get("code").(*amf0String); ok { - return string(*code) + if v, ok := v.Args.(Amf0Object); ok { + if code, ok := v.Get("code").(Amf0String); ok { + return code.String() } } } @@ -1370,6 +1420,10 @@ func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { return v } +func (v *CreateStreamResPacket) SetStreamID(streamID int) { + v.StreamID = amf0Number(streamID) +} + func (v *CreateStreamResPacket) Size() int { return v.variantCallPacket.Size() + v.StreamID.Size() } @@ -1407,15 +1461,16 @@ func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { // Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish type PublishPacket struct { variantCallPacket - StreamName amf0String - StreamType amf0String + StreamName Amf0String + StreamType Amf0String } func NewPublishPacket() *PublishPacket { v := &PublishPacket{} v.CommandName = commandPublish v.CommandObject = NewAmf0Null() - v.StreamType = "live" + v.StreamName = NewAmf0String("") + v.StreamType = NewAmf0String("live") return v } @@ -1431,11 +1486,13 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { } p = p[v.variantCallPacket.Size():] + v.StreamName = newAmf0String("") if err = v.StreamName.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal stream name") } p = p[v.StreamName.Size():] + v.StreamType = newAmf0String("") if err = v.StreamType.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal stream type") } @@ -1466,13 +1523,14 @@ func (v *PublishPacket) MarshalBinary() (data []byte, err error) { // Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play type PlayPacket struct { variantCallPacket - StreamName amf0String + StreamName Amf0String } func NewPlayPacket() *PlayPacket { v := &PlayPacket{} v.CommandName = commandPlay v.CommandObject = NewAmf0Null() + v.StreamName = NewAmf0String("") return v } @@ -1488,6 +1546,7 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { } p = p[v.variantCallPacket.Size():] + v.StreamName = newAmf0String("") if err = v.StreamName.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal stream name") } diff --git a/internal/server/rtmp.go b/internal/server/rtmp.go index 80be13ba6..33b5f7bd7 100644 --- a/internal/server/rtmp.go +++ b/internal/server/rtmp.go @@ -229,7 +229,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { response = identifyRes nextStreamID = 1 - identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) + identifyRes.SetStreamID(nextStreamID) } else if pkt.CommandName == "getStreamLength" { // Ignore and do not reply these packets. } else { @@ -243,7 +243,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { identifyRes.Args = rtmp.NewAmf0Undefined() } case *rtmp.PublishPacket: - streamName = string(pkt.StreamName) + streamName = pkt.StreamName.String() clientType = RTMPClientTypePublisher identifyRes := rtmp.NewCallPacket() @@ -257,7 +257,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) identifyRes.Args = data case *rtmp.PlayPacket: - streamName = string(pkt.StreamName) + streamName = pkt.StreamName.String() clientType = RTMPClientTypeViewer identifyRes := rtmp.NewCallPacket() @@ -352,7 +352,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { if err != nil { return errors.Wrapf(err, "read message") } - //logger.Debug(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + //logger.Debug(ctx, "client<- %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload())) // TODO: Update the stream ID if not the same. if err := client.WriteMessage(ctx, m); err != nil { @@ -375,7 +375,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { if err != nil { return errors.Wrapf(err, "read message") } - //logger.Debug(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + //logger.Debug(ctx, "client-> %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload())) // TODO: Update the stream ID if not the same. if err := backend.client.WriteMessage(ctx, m); err != nil { @@ -421,7 +421,7 @@ type RTMPClientToBackend struct { // The underlayer tcp client. tcpConn *net.TCPConn // The RTMP protocol client. - client *rtmp.Protocol + client rtmp.Protocol // The stream type. typ RTMPClientType } @@ -527,7 +527,7 @@ func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName str return v.publish(ctx, client, streamName) } -func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { +func (v *RTMPClientToBackend) publish(ctx context.Context, client rtmp.Protocol, streamName string) error { if true { identifyReq := rtmp.NewCallPacket() identifyReq.CommandName = "releaseStream" @@ -592,8 +592,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol publishStream := rtmp.NewPublishPacket() publishStream.TransactionID = 5 publishStream.CommandObject = rtmp.NewAmf0Null() - publishStream.StreamName = *rtmp.NewAmf0String(streamName) - publishStream.StreamType = *rtmp.NewAmf0String("live") + publishStream.StreamName = rtmp.NewAmf0String(streamName) + publishStream.StreamType = rtmp.NewAmf0String("live") if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { return errors.Wrapf(err, "publish") } @@ -609,8 +609,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol return errors.Errorf("onStatus args not object") } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { return errors.Errorf("onStatus code not string") - } else if *code != "NetStream.Publish.Start" { - return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) + } else if code.String() != "NetStream.Publish.Start" { + return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", code.String()) } break } @@ -620,7 +620,7 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol return nil } -func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { +func (v *RTMPClientToBackend) play(ctx context.Context, client rtmp.Protocol, streamName string) error { var currentStreamID int if true { createStream := rtmp.NewCreateStreamPacket() @@ -642,7 +642,7 @@ func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, s } playStream := rtmp.NewPlayPacket() - playStream.StreamName = *rtmp.NewAmf0String(streamName) + playStream.StreamName = rtmp.NewAmf0String(streamName) if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { return errors.Wrapf(err, "play") }