Refine code and add tests for #4289. v7.0.45 (#4412)

Use AI to understand, add comments, add utests, refactor code for PR
#4289
This commit is contained in:
Winlin 2025-07-04 17:26:12 -04:00 committed by GitHub
parent c5b6b72876
commit b2a827f8cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
202 changed files with 6554 additions and 2166 deletions

View File

@ -1,226 +1,50 @@
// Package h265reader implements a H265 Annex-B Reader
// Package gb28181 provides GB28181 protocol support
package gb28181
import (
"bytes"
"errors"
"io"
"github.com/pion/webrtc/v4/pkg/media/h265reader"
)
type NalUnitType uint8
// Type aliases for compatibility with existing code
type H265Reader = h265reader.H265Reader
type NAL = h265reader.NAL
type NalUnitType = h265reader.NalUnitType
// Enums for NalUnitTypes
// NAL unit type constants for compatibility with existing code
const (
NaluTypeSliceTrailN NalUnitType = 0 // 0x0
NaluTypeSliceTrailR NalUnitType = 1 // 0x01
NaluTypeSliceTsaN NalUnitType = 2 // 0x02
NaluTypeSliceTsaR NalUnitType = 3 // 0x03
NaluTypeSliceStsaN NalUnitType = 4 // 0x04
NaluTypeSliceStsaR NalUnitType = 5 // 0x05
NaluTypeSliceRadlN NalUnitType = 6 // 0x06
NaluTypeSliceRadlR NalUnitType = 7 // 0x07
NaluTypeSliceRaslN NalUnitType = 8 // 0x06
NaluTypeSliceRaslR NalUnitType = 9 // 0x09
NaluTypeSliceTrailN = h265reader.NalUnitTypeTrailN
NaluTypeSliceTrailR = h265reader.NalUnitTypeTrailR
NaluTypeSliceTsaN = h265reader.NalUnitTypeTsaN
NaluTypeSliceTsaR = h265reader.NalUnitTypeTsaR
NaluTypeSliceStsaN = h265reader.NalUnitTypeStsaN
NaluTypeSliceStsaR = h265reader.NalUnitTypeStsaR
NaluTypeSliceRadlN = h265reader.NalUnitTypeRadlN
NaluTypeSliceRadlR = h265reader.NalUnitTypeRadlR
NaluTypeSliceRaslN = h265reader.NalUnitTypeRaslN
NaluTypeSliceRaslR = h265reader.NalUnitTypeRaslR
NaluTypeSliceBlaWlp NalUnitType = 16 // 0x10
NaluTypeSliceBlaWradl NalUnitType = 17 // 0x11
NaluTypeSliceBlaNlp NalUnitType = 18 // 0x12
NaluTypeSliceIdr NalUnitType = 19 // 0x13
NaluTypeSliceIdrNlp NalUnitType = 20 // 0x14
NaluTypeSliceCranut NalUnitType = 21 // 0x15
NaluTypeSliceRsvIrapVcl22 NalUnitType = 22 // 0x16
NaluTypeSliceRsvIrapVcl23 NalUnitType = 23 // 0x17
NaluTypeSliceBlaWlp = h265reader.NalUnitTypeBlaWLp
NaluTypeSliceBlaWradl = h265reader.NalUnitTypeBlaWRadl
NaluTypeSliceBlaNlp = h265reader.NalUnitTypeBlaNLp
NaluTypeSliceIdr = h265reader.NalUnitTypeIdrWRadl
NaluTypeSliceIdrNlp = h265reader.NalUnitTypeIdrNLp
NaluTypeSliceCranut = h265reader.NalUnitTypeCraNut
NaluTypeSliceRsvIrapVcl22 = h265reader.NalUnitTypeReserved41 // Approximate mapping
NaluTypeSliceRsvIrapVcl23 = h265reader.NalUnitTypeReserved47 // Approximate mapping
NaluTypeVps NalUnitType = 32 // 0x20
NaluTypeSps NalUnitType = 33 // 0x21
NaluTypePps NalUnitType = 34 // 0x22
NaluTypeAud NalUnitType = 35 // 0x23
NaluTypeSei NalUnitType = 39 // 0x27
NaluTypeSeiSuffix NalUnitType = 40 // 0x28
NaluTypeVps = h265reader.NalUnitTypeVps
NaluTypeSps = h265reader.NalUnitTypeSps
NaluTypePps = h265reader.NalUnitTypePps
NaluTypeAud = h265reader.NalUnitTypeAud
NaluTypeSei = h265reader.NalUnitTypePrefixSei
NaluTypeSeiSuffix = h265reader.NalUnitTypeSuffixSei
NaluTypeUnspecified NalUnitType = 48 // 0x30
NaluTypeUnspecified = h265reader.NalUnitTypeUnspec48
)
// H265Reader reads data from stream and constructs h265 nal units
type H265Reader struct {
stream io.Reader
nalBuffer []byte
countOfConsecutiveZeroBytes int
nalPrefixParsed bool
readBuffer []byte
}
var (
errNilReader = errors.New("stream is nil")
errDataIsNotH265Stream = errors.New("data is not a H265 bitstream")
)
// NewReader creates new H265Reader
// NewReader creates new H265Reader using Pion's implementation
func NewReader(in io.Reader) (*H265Reader, error) {
if in == nil {
return nil, errNilReader
}
reader := &H265Reader{
stream: in,
nalBuffer: make([]byte, 0),
nalPrefixParsed: false,
readBuffer: make([]byte, 0),
}
return reader, nil
}
// NAL H.265 Network Abstraction Layer
type NAL struct {
PictureOrderCount uint32
// NAL header
ForbiddenZeroBit bool
UnitType NalUnitType
NuhLayerId uint8
NuhTemporalIdPlus1 uint8
Data []byte // header byte + rbsp
}
func (reader *H265Reader) read(numToRead int) (data []byte) {
for len(reader.readBuffer) < numToRead {
buf := make([]byte, 4096)
n, err := reader.stream.Read(buf)
if n == 0 || err != nil {
break
}
buf = buf[0:n]
reader.readBuffer = append(reader.readBuffer, buf...)
}
var numShouldRead int
if numToRead <= len(reader.readBuffer) {
numShouldRead = numToRead
} else {
numShouldRead = len(reader.readBuffer)
}
data = reader.readBuffer[0:numShouldRead]
reader.readBuffer = reader.readBuffer[numShouldRead:]
return data
}
func (reader *H265Reader) bitStreamStartsWithH265Prefix() (prefixLength int, e error) {
nalPrefix3Bytes := []byte{0, 0, 1}
nalPrefix4Bytes := []byte{0, 0, 0, 1}
prefixBuffer := reader.read(4)
n := len(prefixBuffer)
if n == 0 {
return 0, io.EOF
}
if n < 3 {
return 0, errDataIsNotH265Stream
}
nalPrefix3BytesFound := bytes.Equal(nalPrefix3Bytes, prefixBuffer[:3])
if n == 3 {
if nalPrefix3BytesFound {
return 0, io.EOF
}
return 0, errDataIsNotH265Stream
}
// n == 4
if nalPrefix3BytesFound {
reader.nalBuffer = append(reader.nalBuffer, prefixBuffer[3])
return 3, nil
}
nalPrefix4BytesFound := bytes.Equal(nalPrefix4Bytes, prefixBuffer)
if nalPrefix4BytesFound {
return 4, nil
}
return 0, errDataIsNotH265Stream
}
// NextNAL reads from stream and returns then next NAL,
// and an error if there is incomplete frame data.
// Returns all nil values when no more NALs are available.
func (reader *H265Reader) NextNAL() (*NAL, error) {
if !reader.nalPrefixParsed {
_, err := reader.bitStreamStartsWithH265Prefix()
if err != nil {
return nil, err
}
reader.nalPrefixParsed = true
}
for {
buffer := reader.read(1)
n := len(buffer)
if n != 1 {
break
}
readByte := buffer[0]
nalFound := reader.processByte(readByte)
if nalFound {
nal := newNal(reader.nalBuffer)
nal.parseHeader()
if nal.UnitType == NaluTypeSeiSuffix || nal.UnitType == NaluTypeSei {
reader.nalBuffer = nil
continue
} else {
break
}
}
reader.nalBuffer = append(reader.nalBuffer, readByte)
}
if len(reader.nalBuffer) == 0 {
return nil, io.EOF
}
nal := newNal(reader.nalBuffer)
reader.nalBuffer = nil
nal.parseHeader()
return nal, nil
}
func (reader *H265Reader) processByte(readByte byte) (nalFound bool) {
nalFound = false
switch readByte {
case 0:
reader.countOfConsecutiveZeroBytes++
case 1:
if reader.countOfConsecutiveZeroBytes >= 2 {
countOfConsecutiveZeroBytesInPrefix := 2
if reader.countOfConsecutiveZeroBytes > 2 {
countOfConsecutiveZeroBytesInPrefix = 3
}
nalUnitLength := len(reader.nalBuffer) - countOfConsecutiveZeroBytesInPrefix
reader.nalBuffer = reader.nalBuffer[0:nalUnitLength]
reader.countOfConsecutiveZeroBytes = 0
nalFound = true
} else {
reader.countOfConsecutiveZeroBytes = 0
}
default:
reader.countOfConsecutiveZeroBytes = 0
}
return nalFound
}
func newNal(data []byte) *NAL {
return &NAL{PictureOrderCount: 0, ForbiddenZeroBit: false, UnitType: NaluTypeUnspecified, Data: data}
}
func (h *NAL) parseHeader() {
firstByte := h.Data[0]
h.ForbiddenZeroBit = (((firstByte & 0x80) >> 7) == 1) // 0x80 = 0b10000000
h.UnitType = NalUnitType((firstByte & 0x7E) >> 1) // 0x1F = 0b01111110
return h265reader.NewReader(in)
}

View File

@ -456,13 +456,13 @@ func (v *PSIngester) writeH265(ctx context.Context, pack *PSPackStream, h265 *H2
videoFrames = append(videoFrames, frame)
logger.If(ctx, "NALU %v PictureOrderCount=%v, ForbiddenZeroBit=%v, %v bytes",
frame.UnitType, frame.PictureOrderCount, frame.ForbiddenZeroBit, len(frame.Data))
frame.NalUnitType, frame.PictureOrderCount, frame.ForbiddenZeroBit, len(frame.Data))
if frame.UnitType == NaluTypeVps {
if frame.NalUnitType == NaluTypeVps {
vps = frame
} else if frame.UnitType == NaluTypeSps {
} else if frame.NalUnitType == NaluTypeSps {
sps = frame
} else if frame.UnitType == NaluTypePps {
} else if frame.NalUnitType == NaluTypePps {
pps = frame
} else {
break

View File

@ -1,6 +1,6 @@
module github.com/ossrs/srs-bench
go 1.21
go 1.23.0
require (
github.com/ghettovoice/gosip v0.0.0-20220929080231-de8ba881be83
@ -8,13 +8,13 @@ require (
github.com/haivision/srtgo v0.0.0-20230627061225-a70d53fcd618
github.com/ossrs/go-oryx-lib v0.0.9
github.com/pion/ice/v4 v4.0.10
github.com/pion/interceptor v0.1.37
github.com/pion/logging v0.2.3
github.com/pion/interceptor v0.1.40
github.com/pion/logging v0.2.4
github.com/pion/rtcp v1.2.15
github.com/pion/rtp v1.8.15
github.com/pion/sdp/v3 v3.0.11
github.com/pion/rtp v1.8.20
github.com/pion/sdp/v3 v3.0.14
github.com/pion/transport/v3 v3.0.7
github.com/pion/webrtc/v4 v4.1.1
github.com/pion/webrtc/v4 v4.1.3
github.com/pkg/errors v0.9.1
github.com/yapingcat/gomedia/codec v0.0.0-20220617074658-94762898dc25
github.com/yapingcat/gomedia/mpeg2 v0.0.0-20220617074658-94762898dc25
@ -35,16 +35,16 @@ require (
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/sctp v1.8.39 // indirect
github.com/pion/srtp/v3 v3.0.4 // indirect
github.com/pion/srtp/v3 v3.0.6 // indirect
github.com/pion/stun/v3 v3.0.0 // indirect
github.com/pion/turn/v4 v4.0.0 // indirect
github.com/pion/turn/v4 v4.0.2 // indirect
github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b // indirect
github.com/sirupsen/logrus v1.4.2 // indirect
github.com/tevino/abool v0.0.0-20170917061928-9b9efcf221b5 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect
golang.org/x/crypto v0.33.0 // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/term v0.29.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/term v0.32.0 // indirect
)

View File

@ -65,32 +65,32 @@ github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E=
github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU=
github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4=
github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI=
github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y=
github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4=
github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
github.com/pion/rtp v1.8.15 h1:MuhuGn1cxpVCPLNY1lI7F1tQ8Spntpgf12ob+pOYT8s=
github.com/pion/rtp v1.8.15/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
github.com/pion/rtp v1.8.20 h1:8zcyqohadZE8FCBeGdyEvHiclPIezcwRQH9zfapFyYI=
github.com/pion/rtp v1.8.20/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE=
github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
github.com/pion/sdp/v3 v3.0.11 h1:VhgVSopdsBKwhCFoyyPmT1fKMeV9nLMrEKxNOdy3IVI=
github.com/pion/sdp/v3 v3.0.11/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M=
github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ=
github.com/pion/sdp/v3 v3.0.14 h1:1h7gBr9FhOWH5GjWWY5lcw/U85MtdcibTyt/o6RxRUI=
github.com/pion/sdp/v3 v3.0.14/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4=
github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM=
github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA=
github.com/pion/webrtc/v4 v4.1.1 h1:PMFPtLg1kpD2pVtun+LGUzA3k54JdFl87WO0Z1+HKug=
github.com/pion/webrtc/v4 v4.1.1/go.mod h1:cgEGkcpxGkT6Di2ClBYO5lP9mFXbCfEOrkYUpjjCQO4=
github.com/pion/turn/v4 v4.0.2 h1:ZqgQ3+MjP32ug30xAbD6Mn+/K4Sxi3SdNOTFf+7mpps=
github.com/pion/turn/v4 v4.0.2/go.mod h1:pMMKP/ieNAG/fN5cZiN4SDuyKsXtNTr0ccN7IToA1zs=
github.com/pion/webrtc/v4 v4.1.3 h1:YZ67Boj9X/hk190jJZ8+HFGQ6DqSZ/fYP3sLAZv7c3c=
github.com/pion/webrtc/v4 v4.1.3/go.mod h1:rsq+zQ82ryfR9vbb0L1umPJ6Ogq7zm8mcn9fcGnxomM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@ -117,8 +117,8 @@ github.com/yapingcat/gomedia/mpeg2 v0.0.0-20220617074658-94762898dc25/go.mod h1:
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -126,8 +126,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -144,15 +144,15 @@ golang.org/x/sys v0.0.0-20200926100807-9d91bd62050c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201214095126-aec9a390925b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -102,7 +102,7 @@ func TestRtmpPublishPlay(t *testing.T) {
}
}
func TestRtmpPublish_RtcPlay(t *testing.T) {
func TestRtmpPublish_RtcPlay_AVC(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
@ -119,10 +119,11 @@ func TestRtmpPublish_RtcPlay(t *testing.T) {
return err
}
// Setup the RTC player.
// Setup the RTC player with AVC codec support.
var thePlayer *testPlayer
if thePlayer, err = newTestPlayer(registerMiniCodecs, func(play *testPlayer) error {
play.streamSuffix = streamSuffix
play.streamCodec = "h264"
var nnPlayReadRTP uint64
return play.Setup(*srsVnetClientIP, func(api *testWebRTCAPI) {
api.registry.Add(newRTPInterceptor(func(i *rtpInterceptor) {
@ -234,7 +235,7 @@ func TestRtmpPublish_MultipleSequences(t *testing.T) {
}
// Ingore the duplicated sps/pps.
if IsAvccrEquals(previousAvccr, avccr) {
if isAvccrEquals(previousAvccr, avccr) {
return nil
}
previousAvccr = avccr
@ -316,7 +317,7 @@ func TestRtmpPublish_MultipleSequences_RtcPlay(t *testing.T) {
return nn, attr, err
}
annexb, nalus, err := DemuxRtpSpsPps(payload[:nn])
annexb, nalus, err := demuxRtpSpsPps(payload[:nn])
if err != nil || len(nalus) == 0 ||
(nalus[0].NALUType != avc.NALUTypeSPS && nalus[0].NALUType != avc.NALUTypePPS) ||
bytes.Equal(annexb, previousSpsPps) {
@ -640,3 +641,95 @@ func TestRtmpPublish_HttpFlvPlayNoVideo(t *testing.T) {
t.Errorf("err %+v", err)
}
}
// TestRtmpPublish_RtcPlay_HEVC tests HEVC support in RTMP to RTC pipeline (PR 4289)
// This test publishes H.265 video via RTMP and plays it back via WebRTC with codec=hevc
func TestRtmpPublish_RtcPlay_HEVC(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
var r0, r1 error
err := func() (err error) {
streamSuffix := fmt.Sprintf("rtmp-hevc-regression-%v-%v", os.Getpid(), rand.Int())
rtmpUrl := fmt.Sprintf("%v://%v%v-%v", srsSchema, *srsServer, *srsStream, streamSuffix)
// Publisher connect to a RTMP stream.
publisher := NewRTMPPublisher()
defer publisher.Close()
if err := publisher.Publish(ctx, rtmpUrl); err != nil {
return err
}
// Setup the RTC player with HEVC codec support.
var thePlayer *testPlayer
if thePlayer, err = newTestPlayer(registerHEVCCodecs, func(play *testPlayer) error {
play.streamSuffix = streamSuffix
play.streamCodec = "hevc"
var nnPlayReadRTP uint64
return play.Setup(*srsVnetClientIP, func(api *testWebRTCAPI) {
api.registry.Add(newRTPInterceptor(func(i *rtpInterceptor) {
i.rtpReader = func(payload []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
nn, attr, err := i.nextRTPReader.Read(payload, attributes)
if err != nil {
return nn, attr, err
}
if nnPlayReadRTP++; nnPlayReadRTP >= uint64(*srsPlayOKPackets) {
cancel() // Completed.
}
logger.Tf(ctx, "Play RECV RTP #%v %vB", nnPlayReadRTP, nn)
return nn, attr, err
}
}))
})
}); err != nil {
return err
}
defer thePlayer.Close()
// Run publisher and players.
var wg sync.WaitGroup
defer wg.Wait()
var playerIceReady context.Context
playerIceReady, thePlayer.iceReadyCancel = context.WithCancel(ctx)
wg.Add(1)
go func() {
defer wg.Done()
if r1 = thePlayer.Run(logger.WithContext(ctx), cancel); r1 != nil {
cancel()
}
logger.Tf(ctx, "player done")
}()
wg.Add(1)
go func() {
defer wg.Done()
// Wait for player ready.
select {
case <-ctx.Done():
return
case <-playerIceReady.Done():
}
publisher.onSendPacket = func(m *rtmp.Message) error {
time.Sleep(100 * time.Microsecond)
return nil
}
// Use H.265 file directly without ffmpeg transcoding, implementing new h265 demuxer
// in RTMPPublisher using pion pkg/media/h265reader with enhanced-RTMP fourcc 'hvc1'
if r0 = publisher.Ingest(ctx, *srsPublishVideoH265); r0 != nil {
cancel()
}
logger.Tf(ctx, "publisher done")
}()
return nil
}()
if err := filterTestError(ctx.Err(), err, r0, r1); err != nil {
t.Errorf("err %+v", err)
}
}

View File

@ -61,6 +61,7 @@ import (
"github.com/pion/transport/v3/vnet"
"github.com/pion/webrtc/v4"
"github.com/pion/webrtc/v4/pkg/media/h264reader"
"github.com/pion/webrtc/v4/pkg/media/h265reader"
)
var srsHttps *bool
@ -80,6 +81,7 @@ var srsStream *string
var srsLiveStream *string
var srsPublishAudio *string
var srsPublishVideo *string
var srsPublishVideoH265 *string
var srsPublishAvatar *string
var srsPublishBBB *string
var srsVnetClientIP *string
@ -97,6 +99,7 @@ func prepareTest() (err error) {
srsPublishOKPackets = flag.Int("srs-publish-ok-packets", 3, "If send N RTP, recv N RTCP packets, it's ok, or fail")
srsPublishAudio = flag.String("srs-publish-audio", "avatar.ogg", "The audio file for publisher.")
srsPublishVideo = flag.String("srs-publish-video", "avatar.h264", "The video file for publisher.")
srsPublishVideoH265 = flag.String("srs-publish-video-h265", "avatar.h265", "The H.265 video file for publisher.")
srsPublishAvatar = flag.String("srs-publish-avatar", "avatar.flv", "The avatar file for publisher.")
srsPublishBBB = flag.String("srs-publish-bbb", "bbb.flv", "The bbb file for publisher.")
srsPublishVideoFps = flag.Int("srs-publish-video-fps", 25, "The video fps for publisher.")
@ -143,6 +146,10 @@ func prepareTest() (err error) {
return err
}
if *srsPublishVideoH265, err = tryOpenFile(*srsPublishVideoH265); err != nil {
return err
}
if *srsPublishAvatar, err = tryOpenFile(*srsPublishAvatar); err != nil {
return err
}
@ -633,16 +640,28 @@ func registerMiniCodecs(api *testWebRTCAPI) error {
v := api
if err := v.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{webrtc.MimeTypeOpus, 48000, 2, "minptime=10;useinbandfec=1", nil},
PayloadType: 111,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2,
SDPFmtpLine: "minptime=10;useinbandfec=1", RTCPFeedback: nil,
},
PayloadType: 111,
}, webrtc.RTPCodecTypeAudio); err != nil {
return err
}
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
videoRTCPFeedback := []webrtc.RTCPFeedback{
{Type: "goog-remb", Parameter: ""},
{Type: "ccm", Parameter: "fir"},
{Type: "nack", Parameter: ""},
{Type: "nack", Parameter: "pli"},
}
if err := v.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{webrtc.MimeTypeH264, 90000, 0, "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", videoRTCPFeedback},
PayloadType: 108,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264, ClockRate: 90000, Channels: 0,
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f",
RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 108,
}, webrtc.RTPCodecTypeVideo); err != nil {
return err
}
@ -656,16 +675,26 @@ func registerMiniCodecsWithoutNack(api *testWebRTCAPI) error {
v := api
if err := v.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{webrtc.MimeTypeOpus, 48000, 2, "minptime=10;useinbandfec=1", nil},
PayloadType: 111,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2,
SDPFmtpLine: "minptime=10;useinbandfec=1", RTCPFeedback: nil,
},
PayloadType: 111,
}, webrtc.RTPCodecTypeAudio); err != nil {
return err
}
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}}
videoRTCPFeedback := []webrtc.RTCPFeedback{
{Type: "goog-remb", Parameter: ""},
{Type: "ccm", Parameter: "fir"},
}
if err := v.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{webrtc.MimeTypeH264, 90000, 0, "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", videoRTCPFeedback},
PayloadType: 108,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264, ClockRate: 90000, Channels: 0,
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f",
RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 108,
}, webrtc.RTPCodecTypeVideo); err != nil {
return err
}
@ -680,17 +709,28 @@ func registerHEVCCodecs(api *testWebRTCAPI) error {
// Register Opus audio codec
if err := v.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{webrtc.MimeTypeOpus, 48000, 2, "minptime=10;useinbandfec=1", nil},
PayloadType: 111,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2,
SDPFmtpLine: "minptime=10;useinbandfec=1", RTCPFeedback: nil,
},
PayloadType: 111,
}, webrtc.RTPCodecTypeAudio); err != nil {
return err
}
// Register HEVC/H.265 video codec
videoRTCPFeedback := []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
videoRTCPFeedback := []webrtc.RTCPFeedback{
{Type: "goog-remb", Parameter: ""},
{Type: "ccm", Parameter: "fir"},
{Type: "nack", Parameter: ""},
{Type: "nack", Parameter: "pli"},
}
if err := v.mediaEngine.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{webrtc.MimeTypeH265, 90000, 0, "profile-id=1", videoRTCPFeedback},
PayloadType: 49, // Use payload type 49 for HEVC as mentioned in PR description
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH265, ClockRate: 90000, Channels: 0,
SDPFmtpLine: "level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST", RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 49, // Use payload type 49 for HEVC as mentioned in PR description
}, webrtc.RTPCodecTypeVideo); err != nil {
return err
}
@ -812,6 +852,8 @@ type testPlayer struct {
api *testWebRTCAPI
// Optional suffix for stream url.
streamSuffix string
// Optional codec for stream, e.g., "hevc", "h264".
streamCodec string
// Optional app/stream to play, use srsStream by default.
defaultStream string
}
@ -864,6 +906,13 @@ func (v *testPlayer) Run(ctx context.Context, cancel context.CancelFunc) error {
if v.streamSuffix != "" {
r = fmt.Sprintf("%v-%v", r, v.streamSuffix)
}
if v.streamCodec != "" {
if strings.Contains(r, "?") {
r = fmt.Sprintf("%v&codec=%v", r, v.streamCodec)
} else {
r = fmt.Sprintf("%v?codec=%v", r, v.streamCodec)
}
}
pli := time.Duration(*srsPlayPLI) * time.Millisecond
logger.Tf(ctx, "Run play url=%v", r)
@ -1634,7 +1683,7 @@ func (v *RTMPPublisher) Publish(ctx context.Context, rtmpUrl string) error {
return v.client.Publish(ctx, rtmpUrl)
}
func (v *RTMPPublisher) Ingest(ctx context.Context, flvInput string) error {
func (v *RTMPPublisher) Ingest(ctx context.Context, input string) error {
// If ctx is cancelled, close the RTMP transport.
var wg sync.WaitGroup
defer wg.Wait()
@ -1649,8 +1698,16 @@ func (v *RTMPPublisher) Ingest(ctx context.Context, flvInput string) error {
}()
// Consume all packets.
logger.Tf(ctx, "Start to ingest %v", flvInput)
err := v.ingest(ctx, flvInput)
logger.Tf(ctx, "Start to ingest %v", input)
// Check file extension to determine format
var err error
if strings.HasSuffix(strings.ToLower(input), ".h265") {
err = v.ingestH265(ctx, input)
} else {
// Default to FLV format for H.264
err = v.ingestFLV(ctx, input)
}
if err == io.EOF {
return nil
}
@ -1660,7 +1717,7 @@ func (v *RTMPPublisher) Ingest(ctx context.Context, flvInput string) error {
return err
}
func (v *RTMPPublisher) ingest(ctx context.Context, flvInput string) error {
func (v *RTMPPublisher) ingestFLV(ctx context.Context, flvInput string) error {
p := v.client
fs, err := os.Open(flvInput)
@ -1718,6 +1775,247 @@ func (v *RTMPPublisher) ingest(ctx context.Context, flvInput string) error {
return nil
}
func (v *RTMPPublisher) ingestH265(ctx context.Context, h265Input string) error {
p := v.client
fs, err := os.Open(h265Input)
if err != nil {
return err
}
defer fs.Close()
logger.Tf(ctx, "Open H.265 input %v", h265Input)
h265Reader, err := h265reader.NewReader(fs)
if err != nil {
return err
}
// Send sequence header first
var vps, sps, pps []byte
var timestamp uint64 = 0
// Read NALUs to find VPS, SPS, PPS
for {
nal, err := h265Reader.NextNAL()
if err != nil {
if err == io.EOF {
break
}
return err
}
if nal == nil {
break
}
// Extract parameter sets using pion constants
switch nal.NalUnitType {
case h265reader.NalUnitTypeVps: // VPS (32)
vps = nal.Data
case h265reader.NalUnitTypeSps: // SPS (33)
sps = nal.Data
case h265reader.NalUnitTypePps: // PPS (34)
pps = nal.Data
}
// Once we have all parameter sets, send sequence header
if len(vps) > 0 && len(sps) > 0 && len(pps) > 0 {
if err := v.sendH265SequenceHeader(p, vps, sps, pps, timestamp); err != nil {
return err
}
break
}
}
// Reset reader for actual frame data
fs.Seek(0, 0)
h265Reader, err = h265reader.NewReader(fs)
if err != nil {
return err
}
// Send video frames
frameCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
nal, err := h265Reader.NextNAL()
if err != nil {
if err == io.EOF {
return nil
}
return err
}
if nal == nil {
return nil
}
// Skip parameter sets as they were already sent
if nal.NalUnitType == h265reader.NalUnitTypeVps || nal.NalUnitType == h265reader.NalUnitTypeSps || nal.NalUnitType == h265reader.NalUnitTypePps {
continue
}
// Send video frame - check for IDR and CRA frames (key frames)
isKeyFrame := (nal.NalUnitType >= h265reader.NalUnitTypeBlaWLp && nal.NalUnitType <= h265reader.NalUnitTypeCraNut)
if err := v.sendH265Frame(p, nal.Data, timestamp, isKeyFrame); err != nil {
return err
}
frameCount++
timestamp += 40 // 25fps = 40ms per frame
if v.onSendPacket != nil {
m := rtmp.NewStreamMessage(p.streamID)
m.MessageType = rtmp.MessageTypeVideo
m.Timestamp = timestamp
m.Payload = nal.Data
if err = v.onSendPacket(m); err != nil {
return err
}
}
}
}
func (v *RTMPPublisher) sendH265SequenceHeader(p *RTMPClient, vps, sps, pps []byte, timestamp uint64) error {
// Create HEVC sequence header using enhanced-RTMP format
// Format: [IsExHeader | FrameType | PacketType] [fourcc 'hvc1'] [HEVCDecoderConfigurationRecord]
// @see: https://veovera.org/docs/enhanced/enhanced-rtmp-v1.pdf, page 9
const flvIsExHeader = 0x80
const videoAvcFrameTypeKeyFrame = 1
const videoHEVCFrameTraitPacketTypeSequenceStart = 0
// IsExHeader | FrameType | PacketType
frameTypeAndPacket := byte(flvIsExHeader | (videoAvcFrameTypeKeyFrame << 4) | videoHEVCFrameTraitPacketTypeSequenceStart)
// Enhanced-RTMP fourcc 'hvc1' for HEVC (0x68766331)
fourcc := []byte{'h', 'v', 'c', '1'}
// Create proper HEVCDecoderConfigurationRecord
hvcc := v.createHVCC(vps, sps, pps)
// Build enhanced-RTMP packet
var payload []byte
payload = append(payload, frameTypeAndPacket)
payload = append(payload, fourcc...)
payload = append(payload, hvcc...)
m := rtmp.NewStreamMessage(p.streamID)
m.MessageType = rtmp.MessageTypeVideo
m.Timestamp = timestamp
m.Payload = payload
return p.proto.WriteMessage(m)
}
func (v *RTMPPublisher) createHVCC(vps, sps, pps []byte) []byte {
// Create HEVCDecoderConfigurationRecord based on SRS format
// @see: trunk/src/protocol/srs_protocol_raw_avc.cpp mux_sequence_header
var hvcc []byte
// configuration_version (1 byte) - must be 1
hvcc = append(hvcc, 0x01)
// general_profile_space (2 bits) + general_tier_flag (1 bit) + general_profile_idc (5 bits)
hvcc = append(hvcc, 0x01) // simplified: profile_space=0, tier_flag=0, profile_idc=1
// general_profile_compatibility_flags (4 bytes)
hvcc = append(hvcc, 0x60, 0x00, 0x00, 0x00)
// general_constraint_indicator_flags (6 bytes)
hvcc = append(hvcc, 0x90, 0x00, 0x00, 0x00, 0x00, 0x00)
// general_level_idc (1 byte)
hvcc = append(hvcc, 0x5d) // Level 3.1
// min_spatial_segmentation_idc (12 bits) + reserved (4 bits)
hvcc = append(hvcc, 0xf0, 0x00)
// parallelismType (2 bits) + reserved (6 bits)
hvcc = append(hvcc, 0xfc)
// chromaFormat (2 bits) + reserved (6 bits)
hvcc = append(hvcc, 0xfd)
// bitDepthLumaMinus8 (3 bits) + reserved (5 bits)
hvcc = append(hvcc, 0xf8)
// bitDepthChromaMinus8 (3 bits) + reserved (5 bits)
hvcc = append(hvcc, 0xf8)
// avgFrameRate (2 bytes)
hvcc = append(hvcc, 0x00, 0x00)
// constantFrameRate (2 bits) + numTemporalLayers (3 bits) + temporalIdNested (1 bit) + lengthSizeMinusOne (2 bits)
hvcc = append(hvcc, 0x0f) // lengthSizeMinusOne = 3 (4-byte length)
// numOfArrays (1 byte) - we have 3 arrays: VPS, SPS, PPS
hvcc = append(hvcc, 0x03)
// VPS array
hvcc = append(hvcc, 0x20) // array_completeness=0, reserved=0, NAL_unit_type=32 (VPS)
hvcc = append(hvcc, 0x00, 0x01) // numNalus = 1
hvcc = append(hvcc, byte(len(vps)>>8), byte(len(vps))) // nalUnitLength
hvcc = append(hvcc, vps...)
// SPS array
hvcc = append(hvcc, 0x21) // array_completeness=0, reserved=0, NAL_unit_type=33 (SPS)
hvcc = append(hvcc, 0x00, 0x01) // numNalus = 1
hvcc = append(hvcc, byte(len(sps)>>8), byte(len(sps))) // nalUnitLength
hvcc = append(hvcc, sps...)
// PPS array
hvcc = append(hvcc, 0x22) // array_completeness=0, reserved=0, NAL_unit_type=34 (PPS)
hvcc = append(hvcc, 0x00, 0x01) // numNalus = 1
hvcc = append(hvcc, byte(len(pps)>>8), byte(len(pps))) // nalUnitLength
hvcc = append(hvcc, pps...)
return hvcc
}
func (v *RTMPPublisher) sendH265Frame(p *RTMPClient, nalData []byte, timestamp uint64, isKeyFrame bool) error {
// Create HEVC frame packet using enhanced-RTMP format
// Format: [IsExHeader | FrameType | PacketType] [fourcc 'hvc1'] [NALU data]
// @see: https://veovera.org/docs/enhanced/enhanced-rtmp-v1.pdf, page 9
const flvIsExHeader = 0x80
const videoAvcFrameTypeKeyFrame = 1
const videoAvcFrameTypeInterFrame = 2
const videoHEVCFrameTraitPacketTypeCodedFramesX = 3
var frameType byte = videoAvcFrameTypeInterFrame
if isKeyFrame {
frameType = videoAvcFrameTypeKeyFrame
}
// IsExHeader | FrameType | PacketType
frameTypeAndPacket := byte(flvIsExHeader | (frameType << 4) | videoHEVCFrameTraitPacketTypeCodedFramesX)
// Enhanced-RTMP fourcc 'hvc1' for HEVC (0x68766331)
fourcc := []byte{'h', 'v', 'c', '1'}
var payload []byte
payload = append(payload, frameTypeAndPacket)
payload = append(payload, fourcc...)
// Add NALU length and data (IBMF format)
payload = append(payload, byte(len(nalData)>>24), byte(len(nalData)>>16), byte(len(nalData)>>8), byte(len(nalData)))
payload = append(payload, nalData...)
m := rtmp.NewStreamMessage(p.streamID)
m.MessageType = rtmp.MessageTypeVideo
m.Timestamp = timestamp
m.Payload = payload
return p.proto.WriteMessage(m)
}
type RTMPPlayer struct {
// Transport.
client *RTMPClient
@ -1922,7 +2220,7 @@ func (v *FLVPlayer) consume(ctx context.Context) (err error) {
}
}
func IsAvccrEquals(a, b *avc.AVCDecoderConfigurationRecord) bool {
func isAvccrEquals(a, b *avc.AVCDecoderConfigurationRecord) bool {
if a == nil || b == nil {
return false
}
@ -1936,13 +2234,13 @@ func IsAvccrEquals(a, b *avc.AVCDecoderConfigurationRecord) bool {
}
for i := 0; i < len(a.SequenceParameterSetNALUnits); i++ {
if !IsNALUEquals(a.SequenceParameterSetNALUnits[i], b.SequenceParameterSetNALUnits[i]) {
if !isNALUEquals(a.SequenceParameterSetNALUnits[i], b.SequenceParameterSetNALUnits[i]) {
return false
}
}
for i := 0; i < len(a.PictureParameterSetNALUnits); i++ {
if !IsNALUEquals(a.PictureParameterSetNALUnits[i], b.PictureParameterSetNALUnits[i]) {
if !isNALUEquals(a.PictureParameterSetNALUnits[i], b.PictureParameterSetNALUnits[i]) {
return false
}
}
@ -1950,7 +2248,7 @@ func IsAvccrEquals(a, b *avc.AVCDecoderConfigurationRecord) bool {
return true
}
func IsNALUEquals(a, b *avc.NALU) bool {
func isNALUEquals(a, b *avc.NALU) bool {
if a == nil || b == nil {
return false
}
@ -1962,7 +2260,7 @@ func IsNALUEquals(a, b *avc.NALU) bool {
return bytes.Equal(a.Data, b.Data)
}
func DemuxRtpSpsPps(payload []byte) ([]byte, []*avc.NALU, error) {
func demuxRtpSpsPps(payload []byte) ([]byte, []*avc.NALU, error) {
// Parse RTP packet.
pkt := rtp.Packet{}
if err := pkt.Unmarshal(payload); err != nil {

View File

@ -19,23 +19,42 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@ -46,18 +65,17 @@ linters:
- forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
@ -65,9 +83,15 @@ linters:
- grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes
@ -75,28 +99,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- gochecknoinits # Checks that no init functions are present in Go code
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- mnd # An analyzer to detect magic numbers
- nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated
@ -104,8 +122,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- varnamelen # checks that the length of a variable's name matches its scope
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines!
@ -114,9 +131,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@ -3,10 +3,10 @@
Pion Interceptor
<br>
</h1>
<h4 align="center">RTCP and RTCP processors for building real time communications</h4>
<h4 align="center">RTP and RTCP processors for building real time communications</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-interceptor-gray.svg?longCache=true&colorB=brightgreen" alt="Pion Interceptor"></a>
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<a href="https://discord.gg/PngbdqpFbt"><img src="https://img.shields.io/badge/join-us%20on%20discord-gray.svg?longCache=true&logo=discord&colorB=brightblue" alt="join us on Discord"></a> <a href="https://bsky.app/profile/pion.ly"><img src="https://img.shields.io/badge/follow-us%20on%20bluesky-gray.svg?longCache=true&logo=bluesky&colorB=brightblue" alt="Follow us on Bluesky"></a>
<br>
<img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pion/interceptor/test.yaml">
<a href="https://pkg.go.dev/github.com/pion/interceptor"><img src="https://pkg.go.dev/badge/github.com/pion/interceptor.svg" alt="Go Reference"></a>
@ -36,12 +36,12 @@ by anyone. With the following tenets in mind.
* [Google Congestion Control](https://github.com/pion/interceptor/tree/master/pkg/gcc)
* [Stats](https://github.com/pion/interceptor/tree/master/pkg/stats) A [webrtc-stats](https://www.w3.org/TR/webrtc-stats/) compliant statistics generation
* [Interval PLI](https://github.com/pion/interceptor/tree/master/pkg/intervalpli) Generate PLI on a interval. Useful when no decoder is available.
* [FlexFec](https://github.com/pion/interceptor/tree/master/pkg/flexfec) [FlexFEC-03](https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03) encoder implementation
### Planned Interceptors
* Bandwidth Estimation
- [NADA](https://tools.ietf.org/html/rfc8698)
* JitterBuffer, re-order packets and wait for arrival
* [FlexFec](https://tools.ietf.org/html/draft-ietf-payload-flexible-fec-scheme-20)
* [RTCP Feedback for Congestion Control](https://datatracker.ietf.org/doc/html/rfc8888) the standardized alternative to TWCC.
### Interceptor Public API
@ -70,9 +70,9 @@ You should also look in [pion/webrtc](https://github.com/pion/webrtc) for real w
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Slack](https://pion.ly/slack).
Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt).
Follow the [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)

View File

@ -19,7 +19,7 @@ const (
var errInvalidType = errors.New("found value of invalid type in attributes map")
// Attributes are a generic key/value store used by interceptors
// Attributes are a generic key/value store used by interceptors.
type Attributes map[interface{}]interface{}
// Get returns the attribute associated with key.
@ -39,6 +39,7 @@ func (a Attributes) GetRTPHeader(raw []byte) (*rtp.Header, error) {
if header, ok := val.(*rtp.Header); ok {
return header, nil
}
return nil, errInvalidType
}
header := &rtp.Header{}
@ -46,6 +47,7 @@ func (a Attributes) GetRTPHeader(raw []byte) (*rtp.Header, error) {
return nil, err
}
a[rtpHeaderKey] = header
return header, nil
}
@ -57,6 +59,7 @@ func (a Attributes) GetRTCPPackets(raw []byte) ([]rtcp.Packet, error) {
if packets, ok := val.([]rtcp.Packet); ok {
return packets, nil
}
return nil, errInvalidType
}
pkts, err := rtcp.Unmarshal(raw)
@ -64,5 +67,6 @@ func (a Attributes) GetRTCPPackets(raw []byte) ([]rtcp.Packet, error) {
return nil, err
}
a[rtcpPacketsKey] = pkts
return pkts, nil
}

View File

@ -50,7 +50,8 @@ func (i *Chain) UnbindLocalStream(ctx *StreamInfo) {
}
}
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// BindRemoteStream lets you modify any incoming RTP packets.
// It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
func (i *Chain) BindRemoteStream(ctx *StreamInfo, reader RTPReader) RTPReader {
for _, interceptor := range i.interceptors {

View File

@ -18,6 +18,7 @@ func flattenErrs(errs []error) error {
if len(errs2) == 0 {
return nil
}
return multiError(errs2)
}
@ -50,5 +51,6 @@ func (me multiError) Is(err error) bool {
}
}
}
return false
}

View File

@ -12,7 +12,7 @@ import (
"github.com/pion/rtp"
)
// Factory provides an interface for constructing interceptors
// Factory provides an interface for constructing interceptors.
type Factory interface {
NewInterceptor(id string) (Interceptor, error)
}
@ -35,7 +35,8 @@ type Interceptor interface {
// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track.
UnbindLocalStream(info *StreamInfo)
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// BindRemoteStream lets you modify any incoming RTP packets.
// It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
BindRemoteStream(info *StreamInfo, reader RTPReader) RTPReader
@ -69,34 +70,34 @@ type RTCPReader interface {
Read([]byte, Attributes) (int, Attributes, error)
}
// RTPWriterFunc is an adapter for RTPWrite interface
// RTPWriterFunc is an adapter for RTPWrite interface.
type RTPWriterFunc func(header *rtp.Header, payload []byte, attributes Attributes) (int, error)
// RTPReaderFunc is an adapter for RTPReader interface
// RTPReaderFunc is an adapter for RTPReader interface.
type RTPReaderFunc func([]byte, Attributes) (int, Attributes, error)
// RTCPWriterFunc is an adapter for RTCPWriter interface
// RTCPWriterFunc is an adapter for RTCPWriter interface.
type RTCPWriterFunc func(pkts []rtcp.Packet, attributes Attributes) (int, error)
// RTCPReaderFunc is an adapter for RTCPReader interface
// RTCPReaderFunc is an adapter for RTCPReader interface.
type RTCPReaderFunc func([]byte, Attributes) (int, Attributes, error)
// Write a rtp packet
// Write a rtp packet.
func (f RTPWriterFunc) Write(header *rtp.Header, payload []byte, attributes Attributes) (int, error) {
return f(header, payload, attributes)
}
// Read a rtp packet
// Read a rtp packet.
func (f RTPReaderFunc) Read(b []byte, a Attributes) (int, Attributes, error) {
return f(b, a)
}
// Write a batch of rtcp packets
// Write a batch of rtcp packets.
func (f RTCPWriterFunc) Write(pkts []rtcp.Packet, attributes Attributes) (int, error) {
return f(pkts, attributes)
}
// Read a batch of rtcp packets
// Read a batch of rtcp packets.
func (f RTCPReaderFunc) Read(b []byte, a Attributes) (int, Attributes, error) {
return f(b, a)
}

View File

@ -9,7 +9,7 @@ import (
"time"
)
// ToNTP converts a time.Time oboject to an uint64 NTP timestamp
// ToNTP converts a time.Time oboject to an uint64 NTP timestamp.
func ToNTP(t time.Time) uint64 {
// seconds since 1st January 1900
s := (float64(t.UnixNano()) / 1000000000) + 2208988800
@ -17,14 +17,31 @@ func ToNTP(t time.Time) uint64 {
// higher 32 bits are the integer part, lower 32 bits are the fractional part
integerPart := uint32(s)
fractionalPart := uint32((s - float64(integerPart)) * 0xFFFFFFFF)
return uint64(integerPart)<<32 | uint64(fractionalPart)
return uint64(integerPart)<<32 | uint64(fractionalPart) //nolint:gosec // G115
}
// ToTime converts a uint64 NTP timestamps to a time.Time object
// ToNTP32 converts a time.Time object to a uint32 NTP timestamp.
func ToNTP32(t time.Time) uint32 {
return uint32(ToNTP(t) >> 16) //nolint:gosec // G115
}
// ToTime converts a uint64 NTP timestamps to a time.Time object.
func ToTime(t uint64) time.Time {
seconds := (t & 0xFFFFFFFF00000000) >> 32
fractional := float64(t&0x00000000FFFFFFFF) / float64(0xFFFFFFFF)
//nolint:gosec // G115
d := time.Duration(seconds)*time.Second + time.Duration(fractional*1e9)*time.Nanosecond
return time.Unix(0, 0).Add(-2208988800 * time.Second).Add(d)
}
// ToTime32 converts a uint32 NTP timestamp to a time.Time object, using the
// highest 16 bit of the reference to recover the lost bits. The low 16 bits are
// not recovered.
func ToTime32(t uint32, reference time.Time) time.Time {
referenceNTP := ToNTP(reference) & 0xFFFF000000000000
tu64 := ((uint64(t) << 16) & 0x0000FFFFFFFF0000) | referenceNTP
return ToTime(tu64)
}

View File

@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package rtpbuffer
import "errors"
// ErrInvalidSize is returned by newReceiveLog/newRTPBuffer, when an incorrect buffer size is supplied.
var ErrInvalidSize = errors.New("invalid buffer size")
var (
errPacketReleased = errors.New("could not retain packet, already released")
errFailedToCastHeaderPool = errors.New("could not access header pool, failed cast")
errFailedToCastPayloadPool = errors.New("could not access payload pool, failed cast")
errPaddingOverflow = errors.New("padding size exceeds payload size")
)

View File

@ -0,0 +1,149 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package rtpbuffer
import (
"encoding/binary"
"io"
"sync"
"github.com/pion/rtp"
)
const rtxSsrcByteLength = 2
// PacketFactory allows custom logic around the handle of RTP Packets before they added to the RTPBuffer.
// The NoOpPacketFactory doesn't copy packets, while the RetainablePacket will take a copy before adding.
type PacketFactory interface {
NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*RetainablePacket, error)
}
// PacketFactoryCopy is PacketFactory that takes a copy of packets when added to the RTPBuffer.
type PacketFactoryCopy struct {
headerPool *sync.Pool
payloadPool *sync.Pool
rtxSequencer rtp.Sequencer
}
// NewPacketFactoryCopy constructs a PacketFactory that takes a copy of packets when added to the RTPBuffer.
func NewPacketFactoryCopy() *PacketFactoryCopy {
return &PacketFactoryCopy{
headerPool: &sync.Pool{
New: func() interface{} {
return &rtp.Header{}
},
},
payloadPool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, maxPayloadLen)
return &buf
},
},
rtxSequencer: rtp.NewRandomSequencer(),
}
}
// NewPacket constructs a new RetainablePacket that can be added to the RTPBuffer.
//
//nolint:cyclop
func (m *PacketFactoryCopy) NewPacket(
header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8,
) (*RetainablePacket, error) {
if len(payload) > maxPayloadLen {
return nil, io.ErrShortBuffer
}
retainablePacket := &RetainablePacket{
onRelease: m.releasePacket,
sequenceNumber: header.SequenceNumber,
// new packets have retain count of 1
count: 1,
}
var ok bool
retainablePacket.header, ok = m.headerPool.Get().(*rtp.Header)
if !ok {
return nil, errFailedToCastHeaderPool
}
*retainablePacket.header = header.Clone()
if payload != nil {
retainablePacket.buffer, ok = m.payloadPool.Get().(*[]byte)
if !ok {
return nil, errFailedToCastPayloadPool
}
if rtxSsrc != 0 && rtxPayloadType != 0 {
size := copy((*retainablePacket.buffer)[rtxSsrcByteLength:], payload)
retainablePacket.payload = (*retainablePacket.buffer)[:size+rtxSsrcByteLength]
} else {
size := copy(*retainablePacket.buffer, payload)
retainablePacket.payload = (*retainablePacket.buffer)[:size]
}
}
if rtxSsrc != 0 && rtxPayloadType != 0 { //nolint:nestif
if payload == nil {
retainablePacket.buffer, ok = m.payloadPool.Get().(*[]byte)
if !ok {
return nil, errFailedToCastPayloadPool
}
retainablePacket.payload = (*retainablePacket.buffer)[:rtxSsrcByteLength]
}
// Write the original sequence number at the beginning of the payload.
binary.BigEndian.PutUint16(retainablePacket.payload, retainablePacket.header.SequenceNumber)
// Rewrite the SSRC.
retainablePacket.header.SSRC = rtxSsrc
// Rewrite the payload type.
retainablePacket.header.PayloadType = rtxPayloadType
// Rewrite the sequence number.
retainablePacket.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber()
// Remove padding if present.
if retainablePacket.header.Padding {
// Older versions of pion/rtp didn't have the Header.PaddingSize field and as a workaround
// users had to add padding to the payload. We need to handle this case here.
if retainablePacket.header.PaddingSize == 0 && len(retainablePacket.payload) > 0 {
paddingLength := int(retainablePacket.payload[len(retainablePacket.payload)-1])
if paddingLength > len(retainablePacket.payload) {
return nil, errPaddingOverflow
}
retainablePacket.payload = (*retainablePacket.buffer)[:len(retainablePacket.payload)-paddingLength]
}
retainablePacket.header.Padding = false
retainablePacket.header.PaddingSize = 0
}
}
return retainablePacket, nil
}
func (m *PacketFactoryCopy) releasePacket(header *rtp.Header, payload *[]byte) {
m.headerPool.Put(header)
if payload != nil {
m.payloadPool.Put(payload)
}
}
// PacketFactoryNoOp is a PacketFactory implementation that doesn't copy packets.
type PacketFactoryNoOp struct{}
// NewPacket constructs a new RetainablePacket that can be added to the RTPBuffer.
func (f *PacketFactoryNoOp) NewPacket(
header *rtp.Header, payload []byte, _ uint32, _ uint8,
) (*RetainablePacket, error) {
return &RetainablePacket{
onRelease: f.releasePacket,
count: 1,
header: header,
payload: payload,
sequenceNumber: header.SequenceNumber,
}, nil
}
func (f *PacketFactoryNoOp) releasePacket(_ *rtp.Header, _ *[]byte) {
// no-op
}

View File

@ -0,0 +1,62 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package rtpbuffer
import (
"sync"
"github.com/pion/rtp"
)
// RetainablePacket is a referenced counted RTP packet.
type RetainablePacket struct {
onRelease func(*rtp.Header, *[]byte)
countMu sync.Mutex
count int
header *rtp.Header
buffer *[]byte
payload []byte
sequenceNumber uint16
}
// Header returns the RTP Header of the RetainablePacket.
func (p *RetainablePacket) Header() *rtp.Header {
return p.header
}
// Payload returns the RTP Payload of the RetainablePacket.
func (p *RetainablePacket) Payload() []byte {
return p.payload
}
// Retain increases the reference count of the RetainablePacket.
func (p *RetainablePacket) Retain() error {
p.countMu.Lock()
defer p.countMu.Unlock()
if p.count == 0 {
// already released
return errPacketReleased
}
p.count++
return nil
}
// Release decreases the reference count of the RetainablePacket and frees if needed.
func (p *RetainablePacket) Release() {
p.countMu.Lock()
defer p.countMu.Unlock()
p.count--
if p.count == 0 {
// release back to pool
p.onRelease(p.header, p.buffer)
p.header = nil
p.buffer = nil
p.payload = nil
}
}

View File

@ -0,0 +1,107 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package rtpbuffer provides a buffer for storing RTP packets
package rtpbuffer
import (
"fmt"
)
const (
// Uint16SizeHalf is half of a math.Uint16.
Uint16SizeHalf = 1 << 15
maxPayloadLen = 1460
)
// RTPBuffer stores RTP packets and allows custom logic
// around the lifetime of them via the PacketFactory.
type RTPBuffer struct {
packets []*RetainablePacket
size uint16
highestAdded uint16
started bool
}
// NewRTPBuffer constructs a new RTPBuffer.
func NewRTPBuffer(size uint16) (*RTPBuffer, error) {
allowedSizes := make([]uint16, 0)
correctSize := false
for i := 0; i < 16; i++ {
if size == 1<<i {
correctSize = true
break
}
allowedSizes = append(allowedSizes, 1<<i)
}
if !correctSize {
return nil, fmt.Errorf("%w: %d is not a valid size, allowed sizes: %v", ErrInvalidSize, size, allowedSizes)
}
return &RTPBuffer{
packets: make([]*RetainablePacket, size),
size: size,
}, nil
}
// Add places the RetainablePacket in the RTPBuffer.
func (r *RTPBuffer) Add(packet *RetainablePacket) {
seq := packet.sequenceNumber
if !r.started {
r.packets[seq%r.size] = packet
r.highestAdded = seq
r.started = true
return
}
diff := seq - r.highestAdded
if diff == 0 {
return
} else if diff < Uint16SizeHalf {
for i := r.highestAdded + 1; i != seq; i++ {
idx := i % r.size
prevPacket := r.packets[idx]
if prevPacket != nil {
prevPacket.Release()
}
r.packets[idx] = nil
}
r.highestAdded = seq
}
idx := seq % r.size
prevPacket := r.packets[idx]
if prevPacket != nil {
prevPacket.Release()
}
r.packets[idx] = packet
}
// Get returns the RetainablePacket for the requested sequence number.
func (r *RTPBuffer) Get(seq uint16) *RetainablePacket {
diff := r.highestAdded - seq
if diff >= Uint16SizeHalf {
return nil
}
if diff >= r.size {
return nil
}
pkt := r.packets[seq%r.size]
if pkt != nil {
if pkt.sequenceNumber != seq {
return nil
}
// already released
if err := pkt.Retain(); err != nil {
return nil
}
}
return pkt
}

View File

@ -9,7 +9,7 @@ const (
breakpoint = 32768 // half of max uint16
)
// Unwrapper stores an unwrapped sequence number
// Unwrapper stores an unwrapped sequence number.
type Unwrapper struct {
init bool
lastUnwrapped int64
@ -19,18 +19,20 @@ func isNewer(value, previous uint16) bool {
if value-previous == breakpoint {
return value > previous
}
return value != previous && (value-previous) < breakpoint
}
// Unwrap unwraps the next sequencenumber
// Unwrap unwraps the next sequencenumber.
func (u *Unwrapper) Unwrap(i uint16) int64 {
if !u.init {
u.init = true
u.lastUnwrapped = int64(i)
return u.lastUnwrapped
}
lastWrapped := uint16(u.lastUnwrapped)
lastWrapped := uint16(u.lastUnwrapped) //nolint:gosec // G115
delta := int64(i - lastWrapped)
if isNewer(i, lastWrapped) {
if delta < 0 {
@ -41,5 +43,6 @@ func (u *Unwrapper) Unwrap(i uint16) int64 {
}
u.lastUnwrapped += delta
return u.lastUnwrapped
}

View File

@ -28,7 +28,8 @@ func (i *NoOp) BindLocalStream(_ *StreamInfo, writer RTPWriter) RTPWriter {
// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track.
func (i *NoOp) UnbindLocalStream(_ *StreamInfo) {}
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// BindRemoteStream lets you modify any incoming RTP packets.
// It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
func (i *NoOp) BindRemoteStream(_ *StreamInfo, reader RTPReader) RTPReader {
return reader

View File

@ -0,0 +1,128 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package flexfec
import (
"errors"
"sync"
"github.com/pion/interceptor"
"github.com/pion/rtp"
)
// streamState holds the state for a single stream.
type streamState struct {
mu sync.Mutex
flexFecEncoder FlexEncoder
packetBuffer []rtp.Packet
}
// FecInterceptor implements FlexFec.
type FecInterceptor struct {
interceptor.NoOp
mu sync.Mutex
streams map[uint32]*streamState
numMediaPackets uint32
numFecPackets uint32
encoderFactory EncoderFactory
}
// FecInterceptorFactory creates new FecInterceptors.
type FecInterceptorFactory struct {
opts []FecOption
}
// NewFecInterceptor returns a new Fec interceptor factory.
func NewFecInterceptor(opts ...FecOption) (*FecInterceptorFactory, error) {
return &FecInterceptorFactory{opts: opts}, nil
}
// NewInterceptor constructs a new FecInterceptor.
func (r *FecInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
interceptor := &FecInterceptor{
streams: make(map[uint32]*streamState),
numMediaPackets: 5,
numFecPackets: 2,
encoderFactory: FlexEncoder03Factory{},
}
for _, opt := range r.opts {
if err := opt(interceptor); err != nil {
return nil, err
}
}
return interceptor, nil
}
// UnbindLocalStream removes the stream state for a specific SSRC.
func (r *FecInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.streams, info.SSRC)
}
// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
// will be called once per rtp packet.
func (r *FecInterceptor) BindLocalStream(
info *interceptor.StreamInfo, writer interceptor.RTPWriter,
) interceptor.RTPWriter {
if info.PayloadTypeForwardErrorCorrection == 0 || info.SSRCForwardErrorCorrection == 0 {
return writer
}
mediaSSRC := info.SSRC
r.mu.Lock()
stream := &streamState{
// Chromium supports version flexfec-03 of existing draft, this is the one we will configure by default
// although we should support configuring the latest (flexfec-20) as well.
flexFecEncoder: r.encoderFactory.NewEncoder(info.PayloadTypeForwardErrorCorrection, info.SSRCForwardErrorCorrection),
packetBuffer: make([]rtp.Packet, 0),
}
r.streams[mediaSSRC] = stream
r.mu.Unlock()
return interceptor.RTPWriterFunc(
func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
// Ignore non-media packets
if header.SSRC != mediaSSRC {
return writer.Write(header, payload, attributes)
}
var fecPackets []rtp.Packet
stream.mu.Lock()
stream.packetBuffer = append(stream.packetBuffer, rtp.Packet{
Header: *header,
Payload: payload,
})
// Check if we have enough packets to generate FEC
if len(stream.packetBuffer) == int(r.numMediaPackets) {
fecPackets = stream.flexFecEncoder.EncodeFec(stream.packetBuffer, r.numFecPackets)
// Reset the packet buffer now that we've sent the corresponding FEC packets.
stream.packetBuffer = nil
}
stream.mu.Unlock()
var errs []error
result, err := writer.Write(header, payload, attributes)
if err != nil {
errs = append(errs, err)
}
for _, packet := range fecPackets {
header := packet.Header
_, err = writer.Write(&header, packet.Payload, attributes)
if err != nil {
errs = append(errs, err)
}
}
return result, errors.Join(errs...)
},
)
}

View File

@ -0,0 +1,177 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package flexfec
import (
"github.com/pion/interceptor/pkg/flexfec/util"
"github.com/pion/rtp"
)
// Maximum number of media packets that can be protected by a single FEC packet.
// We are not supporting the possibility of having an FEC packet protect multiple
// SSRC source packets for now.
// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1
const (
MaxMediaPackets uint32 = 110
MaxFecPackets uint32 = MaxMediaPackets
)
// ProtectionCoverage defines the map of RTP packets that individual Fec packets protect.
type ProtectionCoverage struct {
// Array of masks, each mask capable of covering up to maxMediaPkts = 110.
// A mask is represented as a grouping of bytes where each individual bit
// represents the coverage for the media packet at the corresponding index.
packetMasks [MaxFecPackets]util.BitArray
numFecPackets uint32
numMediaPackets uint32
mediaPackets []rtp.Packet
}
// NewCoverage returns a new ProtectionCoverage object. numFecPackets represents the number of
// Fec packets that we will be generating to cover the list of mediaPackets. This allows us to know
// how big the underlying map should be.
func NewCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) *ProtectionCoverage {
numMediaPackets := uint32(len(mediaPackets)) //nolint:gosec // G115
// Basic sanity checks
if numMediaPackets <= 0 || numMediaPackets > MaxMediaPackets {
return nil
}
// We allocate the biggest array of bitmasks that respects the max constraints.
var packetMasks [MaxFecPackets]util.BitArray
for i := 0; i < int(MaxFecPackets); i++ {
packetMasks[i] = util.BitArray{}
}
coverage := &ProtectionCoverage{
packetMasks: packetMasks,
numFecPackets: 0,
numMediaPackets: 0,
mediaPackets: nil,
}
coverage.UpdateCoverage(mediaPackets, numFecPackets)
return coverage
}
// UpdateCoverage updates the ProtectionCoverage object with new bitmasks accounting for the numFecPackets
// we want to use to protect the batch media packets.
func (p *ProtectionCoverage) UpdateCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) {
numMediaPackets := uint32(len(mediaPackets)) //nolint:gosec // G115
// Basic sanity checks
if numMediaPackets <= 0 || numMediaPackets > MaxMediaPackets {
return
}
p.mediaPackets = mediaPackets
if numFecPackets == p.numFecPackets && numMediaPackets == p.numMediaPackets {
// We have the same number of FEC packets covering the same number of media packets, we can simply
// reuse the previous coverage map with the updated media packets.
return
}
p.numFecPackets = numFecPackets
p.numMediaPackets = numMediaPackets
// The number of FEC packets and/or the number of packets has changed, we need to update the coverage map
// to reflect these new values.
p.resetCoverage()
// Generate FEC bit mask where numFecPackets FEC packets are covering numMediaPackets Media packets.
// In the packetMasks array, each FEC packet is represented by a single BitArray, each bit in a given BitArray
// corresponds to a specific Media packet.
// Ex: Row I, Col J is set to 1 -> FEC packet I will protect media packet J.
for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ {
// We use an interleaved method to determine coverage. Given N FEC packets, Media packet X will be
// covered by FEC packet X % N.
coveredMediaPacketIndex := fecPacketIndex
for coveredMediaPacketIndex < numMediaPackets {
p.packetMasks[fecPacketIndex].SetBit(coveredMediaPacketIndex)
coveredMediaPacketIndex += numFecPackets
}
}
}
// ResetCoverage clears the underlying map so that we can reuse it for new batches of RTP packets.
func (p *ProtectionCoverage) resetCoverage() {
for i := uint32(0); i < MaxFecPackets; i++ {
p.packetMasks[i].Reset()
}
}
// GetCoveredBy returns an iterator over RTP packets that are protected by the specified Fec packet index.
func (p *ProtectionCoverage) GetCoveredBy(fecPacketIndex uint32) *util.MediaPacketIterator {
coverage := make([]uint32, 0, p.numMediaPackets)
for mediaPacketIndex := uint32(0); mediaPacketIndex < p.numMediaPackets; mediaPacketIndex++ {
if p.packetMasks[fecPacketIndex].GetBit(mediaPacketIndex) == 1 {
coverage = append(coverage, mediaPacketIndex)
}
}
return util.NewMediaPacketIterator(p.mediaPackets, coverage)
}
// ExtractMask1 returns the first section of the bitmask as defined by the FEC header.
// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1
func (p *ProtectionCoverage) ExtractMask1(fecPacketIndex uint32) uint16 {
return extractMask1(p.packetMasks[fecPacketIndex])
}
// ExtractMask2 returns the second section of the bitmask as defined by the FEC header.
// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1
func (p *ProtectionCoverage) ExtractMask2(fecPacketIndex uint32) uint32 {
return extractMask2(p.packetMasks[fecPacketIndex])
}
// ExtractMask3 returns the third section of the bitmask as defined by the FEC header.
// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1
func (p *ProtectionCoverage) ExtractMask3(fecPacketIndex uint32) uint64 {
return extractMask3(p.packetMasks[fecPacketIndex])
}
// ExtractMask3_03 returns the third section of the bitmask as defined by the FEC header.
// https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03#section-4.2
func (p *ProtectionCoverage) ExtractMask3_03(fecPacketIndex uint32) uint64 {
return extractMask3_03(p.packetMasks[fecPacketIndex])
}
func extractMask1(mask util.BitArray) uint16 {
// We get the first 16 bits (64 - 16 -> shift by 48) and we shift once more for K field
mask1 := mask.Lo >> 49
return uint16(mask1) //nolint:gosec // G115
}
func extractMask2(mask util.BitArray) uint32 {
// We remove the first 15 bits
mask2 := mask.Lo << 15
// We get the first 31 bits (64 - 32 -> shift by 32) and we shift once more for K field
mask2 >>= 33
return uint32(mask2) //nolint:gosec
}
func extractMask3(mask util.BitArray) uint64 {
// We remove the first 46 bits
maskLo := mask.Lo << 46
maskHi := mask.Hi >> 18
mask3 := maskLo | maskHi
return mask3
}
func extractMask3_03(mask util.BitArray) uint64 {
// We remove the first 46 bits
maskLo := mask.Lo << 46
maskHi := mask.Hi >> 18
mask3 := maskLo | maskHi
// We shift once for the K bit.
mask3 >>= 1
return mask3
}

View File

@ -0,0 +1,445 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package flexfec implements FlexFEC-03 to recover missing RTP packets due to packet loss.
// https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03
package flexfec
import (
"encoding/binary"
"errors"
"fmt"
"sort"
"github.com/pion/logging"
"github.com/pion/rtp"
)
// Static errors for the flexfec package.
var (
errPacketTruncated = errors.New("packet truncated")
errRetransmissionBitSet = errors.New("packet with retransmission bit set not supported")
errInflexibleGeneratorMatrix = errors.New("packet with inflexible generator matrix not supported")
errMultipleSSRCProtection = errors.New("multiple ssrc protection not supported")
errLastOptionalMaskKBitSetToFalse = errors.New("k-bit of last optional mask is set to false")
)
// fecDecoder is a WIP implementation decoder used for testing purposes.
type fecDecoder struct {
logger logging.LeveledLogger
ssrc uint32
protectedStreamSSRC uint32
maxMediaPackets int
maxFECPackets int
recoveredPackets []rtp.Packet
receivedFECPackets []fecPacketState
}
func newFECDecoder(ssrc uint32, protectedStreamSSRC uint32) *fecDecoder {
return &fecDecoder{
logger: logging.NewDefaultLoggerFactory().NewLogger("fec_decoder"),
ssrc: ssrc,
protectedStreamSSRC: protectedStreamSSRC,
maxMediaPackets: 100,
maxFECPackets: 100,
recoveredPackets: make([]rtp.Packet, 0),
receivedFECPackets: make([]fecPacketState, 0),
}
}
func (d *fecDecoder) DecodeFec(receivedPacket rtp.Packet) []rtp.Packet {
if len(d.recoveredPackets) == d.maxMediaPackets {
backRecoveredPacket := d.recoveredPackets[len(d.recoveredPackets)-1]
if backRecoveredPacket.SSRC == receivedPacket.SSRC {
seqDiffVal := seqDiff(receivedPacket.SequenceNumber, backRecoveredPacket.SequenceNumber)
if seqDiffVal > uint16(d.maxMediaPackets) { //nolint:gosec
d.logger.Info("big gap in media sequence numbers - resetting buffers")
d.recoveredPackets = nil
d.receivedFECPackets = nil
}
}
}
d.insertPacket(receivedPacket)
return d.attemptRecovery()
}
func (d *fecDecoder) insertPacket(receivedPkt rtp.Packet) {
// Discard old FEC packets such that the sequence numbers in
// `received_fec_packets_` span at most 1/2 of the sequence number space.
// This is important for keeping `received_fec_packets_` sorted, and may
// also reduce the possibility of incorrect decoding due to sequence number
// wrap-around.
if len(d.receivedFECPackets) > 0 && receivedPkt.SSRC == d.ssrc {
toRemove := 0
for _, fecPkt := range d.receivedFECPackets {
if abs(int(receivedPkt.SequenceNumber)-int(fecPkt.packet.SequenceNumber)) > 0x3fff {
toRemove++
} else {
// No need to keep iterating, since |received_fec_packets_| is sorted.
break
}
}
}
switch receivedPkt.SSRC {
case d.ssrc:
d.insertFECPacket(receivedPkt)
case d.protectedStreamSSRC:
d.insertMediaPacket(receivedPkt)
}
d.discardOldRecoveredPackets()
}
func (d *fecDecoder) insertMediaPacket(receivedPkt rtp.Packet) {
for _, recoveredPacket := range d.recoveredPackets {
if recoveredPacket.SequenceNumber == receivedPkt.SequenceNumber {
return
}
}
d.recoveredPackets = append(d.recoveredPackets, receivedPkt)
sort.Slice(d.recoveredPackets, func(i, j int) bool {
return isNewerSeq(d.recoveredPackets[i].SequenceNumber, d.recoveredPackets[j].SequenceNumber)
})
d.updateCoveringFecPackets(receivedPkt)
}
func (d *fecDecoder) updateCoveringFecPackets(receivedPkt rtp.Packet) {
for _, fecPkt := range d.receivedFECPackets {
for _, protectedPacket := range fecPkt.protectedPackets {
if protectedPacket.seq == receivedPkt.SequenceNumber {
protectedPacket.packet = &receivedPkt
}
}
}
}
func (d *fecDecoder) insertFECPacket(fecPkt rtp.Packet) { //nolint:cyclop
for _, existingFECPacket := range d.receivedFECPackets {
if existingFECPacket.packet.SequenceNumber == fecPkt.SequenceNumber {
return
}
}
fec, err := parseFlexFEC03Header(fecPkt.Payload)
if err != nil {
d.logger.Errorf("failed to parse flexfec03 header: %v", err)
return
}
if fec.protectedSSRC != d.protectedStreamSSRC {
d.logger.Errorf("fec is protecting unknown ssrc, expected %d, got %d", fec.protectedSSRC, d.protectedStreamSSRC)
return
}
protectedSeqs := decodeMask(uint64(fec.mask0), 15, fec.seqNumBase)
if fec.mask1 != 0 {
protectedSeqs = append(protectedSeqs, decodeMask(uint64(fec.mask1), 31, fec.seqNumBase+15)...)
}
if fec.mask2 != 0 {
protectedSeqs = append(protectedSeqs, decodeMask(fec.mask2, 63, fec.seqNumBase+46)...)
}
if len(protectedSeqs) == 0 {
d.logger.Warn("empty fec packet mask")
return
}
protectedPackets := make([]*protectedPacket, 0, len(protectedSeqs))
protectedSeqIt := 0
recoveredPacketIt := 0
for protectedSeqIt < len(protectedSeqs) && recoveredPacketIt < len(d.recoveredPackets) {
switch {
case isNewerSeq(protectedSeqs[protectedSeqIt], d.recoveredPackets[recoveredPacketIt].SequenceNumber):
protectedPackets = append(protectedPackets, &protectedPacket{
seq: protectedSeqs[protectedSeqIt],
packet: nil,
})
protectedSeqIt++
case isNewerSeq(d.recoveredPackets[recoveredPacketIt].SequenceNumber, protectedSeqs[protectedSeqIt]):
recoveredPacketIt++
default:
protectedPackets = append(protectedPackets, &protectedPacket{
seq: protectedSeqs[protectedSeqIt],
packet: &d.recoveredPackets[recoveredPacketIt],
})
protectedSeqIt++
recoveredPacketIt++
}
}
for protectedSeqIt < len(protectedSeqs) {
protectedPackets = append(protectedPackets, &protectedPacket{
seq: protectedSeqs[protectedSeqIt],
packet: nil,
})
protectedSeqIt++
}
d.receivedFECPackets = append(d.receivedFECPackets, fecPacketState{
packet: fecPkt,
flexFec: fec,
protectedPackets: protectedPackets,
})
sort.Slice(d.receivedFECPackets, func(i, j int) bool {
return isNewerSeq(d.receivedFECPackets[i].packet.SequenceNumber, d.receivedFECPackets[j].packet.SequenceNumber)
})
if len(d.receivedFECPackets) > d.maxFECPackets {
d.receivedFECPackets = d.receivedFECPackets[1:]
}
}
func (d *fecDecoder) attemptRecovery() []rtp.Packet {
recoveredPackets := make([]rtp.Packet, 0)
for {
packetsRecovered := 0
for _, fecPkt := range d.receivedFECPackets {
packetsMissing := 0
for _, pkt := range fecPkt.protectedPackets {
if pkt.packet == nil {
packetsMissing++
if packetsMissing > 1 {
break
}
}
}
if packetsMissing != 1 {
continue
}
recovered, err := d.recoverPacket(&fecPkt) //nolint:gosec
if err != nil {
d.logger.Errorf("failed to recover packet: %v", err)
}
recoveredPackets = append(recoveredPackets, recovered)
d.recoveredPackets = append(d.recoveredPackets, recovered)
sort.Slice(d.recoveredPackets, func(i, j int) bool {
return isNewerSeq(d.recoveredPackets[i].SequenceNumber, d.recoveredPackets[j].SequenceNumber)
})
d.updateCoveringFecPackets(recovered)
d.discardOldRecoveredPackets()
packetsRecovered++
}
if packetsRecovered == 0 {
break
}
}
return recoveredPackets
}
func (d *fecDecoder) recoverPacket(fec *fecPacketState) (rtp.Packet, error) {
// https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03#section-6.3.2
// 2. For the repair packet in T, extract the FEC bit string as the
// first 80 bits of the FEC header.
headerRecovery := make([]byte, 12)
copy(headerRecovery, fec.packet.Payload[:10])
var seqnum uint16
for _, protectedPacket := range fec.protectedPackets {
if protectedPacket.packet != nil {
// 1. For each of the source packets that are successfully received in
// T, compute the 80-bit string by concatenating the first 64 bits
// of their RTP header and the unsigned network-ordered 16-bit
// representation of their length in bytes minus 12.
receivedHeader, err := protectedPacket.packet.Header.Marshal()
if err != nil {
return rtp.Packet{}, fmt.Errorf("marshal received header: %w", err)
}
binary.BigEndian.PutUint16(receivedHeader[2:4], uint16(protectedPacket.packet.MarshalSize()-12)) //nolint:gosec
for i := 0; i < 8; i++ {
headerRecovery[i] ^= receivedHeader[i]
}
} else {
seqnum = protectedPacket.seq
}
}
// set version to 2
headerRecovery[0] |= 0x80
headerRecovery[0] &= 0xbf
payloadLength := binary.BigEndian.Uint16(headerRecovery[2:4])
binary.BigEndian.PutUint16(headerRecovery[2:4], seqnum)
binary.BigEndian.PutUint32(headerRecovery[8:12], d.protectedStreamSSRC)
payloadRecovery := make([]byte, payloadLength)
copy(payloadRecovery, fec.flexFec.payload)
for _, protectedPacket := range fec.protectedPackets {
if protectedPacket.packet != nil {
packet, err := protectedPacket.packet.Marshal()
if err != nil {
return rtp.Packet{}, fmt.Errorf("marshal protected packet: %w", err)
}
for i := 0; i < minInt(int(payloadLength), len(packet)-12); i++ {
payloadRecovery[i] ^= packet[12+i]
}
}
}
headerRecovery = append(headerRecovery, payloadRecovery...) //nolint:makezero
var packet rtp.Packet
err := packet.Unmarshal(headerRecovery)
if err != nil {
return rtp.Packet{}, fmt.Errorf("unmarshal recovered: %w", err)
}
return packet, nil
}
func (d *fecDecoder) discardOldRecoveredPackets() {
const limit = 192
if len(d.recoveredPackets) > limit {
d.recoveredPackets = d.recoveredPackets[len(d.recoveredPackets)-192:]
}
}
func decodeMask(mask uint64, bitCount uint16, seqNumBase uint16) []uint16 {
res := make([]uint16, 0)
for i := uint16(0); i < bitCount; i++ {
if (mask>>(bitCount-1-i))&1 == 1 {
res = append(res, seqNumBase+i)
}
}
return res
}
type fecPacketState struct {
packet rtp.Packet
flexFec flexFec
protectedPackets []*protectedPacket
}
type flexFec struct {
protectedSSRC uint32
seqNumBase uint16
mask0 uint16
mask1 uint32
mask2 uint64
payload []byte
}
type protectedPacket struct {
seq uint16
packet *rtp.Packet
}
func parseFlexFEC03Header(data []byte) (flexFec, error) {
if len(data) < 20 {
return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len(data))
}
rBit := (data[0] & 0x80) != 0
if rBit {
return flexFec{}, errRetransmissionBitSet
}
fBit := (data[0] & 0x40) != 0
if fBit {
return flexFec{}, errInflexibleGeneratorMatrix
}
ssrcCount := data[8]
if ssrcCount != 1 {
return flexFec{}, fmt.Errorf("%w: count %d", errMultipleSSRCProtection, ssrcCount)
}
protectedSSRC := binary.BigEndian.Uint32(data[12:])
seqNumBase := binary.BigEndian.Uint16(data[16:])
rawPacketMask := data[18:]
var payload []byte
kBit0 := (rawPacketMask[0] & 0x80) != 0
maskPart0 := binary.BigEndian.Uint16(rawPacketMask[0:2]) & 0x7FFF
var maskPart1 uint32
var maskPart2 uint64
if kBit0 { //nolint:nestif
payload = rawPacketMask[2:]
} else {
if len(data) < 24 {
return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len(data))
}
kBit1 := (rawPacketMask[2] & 0x80) != 0
maskPart1 = binary.BigEndian.Uint32(rawPacketMask[2:]) & 0x7FFFFFFF
if kBit1 {
payload = rawPacketMask[6:]
} else {
if len(data) < 32 {
return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len(data))
}
kBit2 := (rawPacketMask[6] & 0x80) != 0
maskPart2 = binary.BigEndian.Uint64(rawPacketMask[6:]) & 0x7FFFFFFFFFFFFFFF
if kBit2 {
payload = rawPacketMask[14:]
} else {
return flexFec{}, errLastOptionalMaskKBitSetToFalse
}
}
}
return flexFec{
protectedSSRC: protectedSSRC,
seqNumBase: seqNumBase,
mask0: maskPart0,
mask1: maskPart1,
mask2: maskPart2,
payload: payload,
}, nil
}
func seqDiff(a, b uint16) uint16 {
return minUInt16(a-b, b-a)
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func minUInt16(a, b uint16) uint16 {
if a < b {
return a
}
return b
}
func abs(x int) int {
if x >= 0 {
return x
}
return -x
}
func isNewerSeq(prevValue, value uint16) bool {
// half-way mark
breakpoint := uint16(0x8000)
if value-prevValue == breakpoint {
return value > prevValue
}
return value != prevValue && (value-prevValue) < breakpoint
}

View File

@ -0,0 +1,209 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package flexfec implements FlexFEC to recover missing RTP packets due to packet loss.
// https://datatracker.ietf.org/doc/html/rfc8627
package flexfec
import (
"encoding/binary"
"github.com/pion/interceptor/pkg/flexfec/util"
"github.com/pion/rtp"
)
const (
// BaseRTPHeaderSize represents the minium RTP packet header size in bytes.
BaseRTPHeaderSize = 12
// BaseFecHeaderSize represents the minium FEC payload's header size including the
// required first mask.
BaseFecHeaderSize = 12
)
// EncoderFactory is an interface for generic FEC encoders.
type EncoderFactory interface {
NewEncoder(payloadType uint8, ssrc uint32) FlexEncoder
}
// FlexEncoder is the interface that FecInterceptor uses to encode Fec packets.
type FlexEncoder interface {
EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet
}
// FlexEncoder20 implementation is WIP, contains bugs and no tests. Check out FlexEncoder03.
type FlexEncoder20 struct {
fecBaseSn uint16
payloadType uint8
ssrc uint32
coverage *ProtectionCoverage
}
// NewFlexEncoder returns a new FlexEncoder20.
// FlexEncoder20 implementation is WIP, contains bugs and no tests. Check out FlexEncoder03.
func NewFlexEncoder(payloadType uint8, ssrc uint32) *FlexEncoder20 {
return &FlexEncoder20{
payloadType: payloadType,
ssrc: ssrc,
fecBaseSn: uint16(1000),
}
}
// EncodeFec returns a list of generated RTP packets with FEC payloads that protect the specified mediaPackets.
// This method does not account for missing RTP packets in the mediaPackets array nor does it account for
// them being passed out of order.
func (flex *FlexEncoder20) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet {
// Start by defining which FEC packets cover which media packets
if flex.coverage == nil {
flex.coverage = NewCoverage(mediaPackets, numFecPackets)
} else {
flex.coverage.UpdateCoverage(mediaPackets, numFecPackets)
}
if flex.coverage == nil {
return nil
}
// Generate FEC payloads
fecPackets := make([]rtp.Packet, numFecPackets)
for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ {
fecPackets[fecPacketIndex] = flex.encodeFlexFecPacket(fecPacketIndex, mediaPackets[0].SequenceNumber)
}
return fecPackets
}
func (flex *FlexEncoder20) encodeFlexFecPacket(fecPacketIndex uint32, mediaBaseSn uint16) rtp.Packet {
mediaPacketsIt := flex.coverage.GetCoveredBy(fecPacketIndex)
flexFecHeader := flex.encodeFlexFecHeader(
mediaPacketsIt,
flex.coverage.ExtractMask1(fecPacketIndex),
flex.coverage.ExtractMask2(fecPacketIndex),
flex.coverage.ExtractMask3(fecPacketIndex),
mediaBaseSn,
)
flexFecRepairPayload := flex.encodeFlexFecRepairPayload(mediaPacketsIt.Reset())
packet := rtp.Packet{
Header: rtp.Header{
Version: 2,
Padding: false,
Extension: false,
Marker: false,
PayloadType: flex.payloadType,
SequenceNumber: flex.fecBaseSn,
Timestamp: 54243243,
SSRC: flex.ssrc,
CSRC: []uint32{},
},
Payload: append(flexFecHeader, flexFecRepairPayload...),
}
flex.fecBaseSn++
return packet
}
func (flex *FlexEncoder20) encodeFlexFecHeader(
mediaPackets *util.MediaPacketIterator,
mask1 uint16,
optionalMask2 uint32,
optionalMask3 uint64,
mediaBaseSn uint16,
) []byte {
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0|0|P|X| CC |M| PT recovery | length recovery |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| TS recovery |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SN base_i |k| Mask [0-14] |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|k| Mask [15-45] (optional) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Mask [46-109] (optional) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| ... next SN base and Mask for CSRC_i in CSRC list ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
: Repair "Payload" follows FEC Header :
: :
*/
// Get header size - This depends on the size of the bitmask.
headerSize := BaseFecHeaderSize
if optionalMask2 > 0 {
headerSize += 4
}
if optionalMask3 > 0 {
headerSize += 8
}
// Allocate the FlexFec header
flexFecHeader := make([]byte, headerSize)
// XOR the relevant fields for the header
// TO DO - CHECK TO SEE IF THE MARSHALTO() call works with this.
tmpMediaPacketBuf := make([]byte, headerSize)
for mediaPackets.HasNext() {
mediaPacket := mediaPackets.Next()
n, err := mediaPacket.MarshalTo(tmpMediaPacketBuf)
if n == 0 || err != nil {
return nil
}
// XOR the first 2 bytes of the header: V, P, X, CC, M, PT fields
flexFecHeader[0] ^= tmpMediaPacketBuf[0]
flexFecHeader[1] ^= tmpMediaPacketBuf[1]
// XOR the length recovery field
lengthRecoveryVal := uint16(mediaPacket.MarshalSize() - BaseRTPHeaderSize) //nolint:gosec // G115
flexFecHeader[2] ^= uint8(lengthRecoveryVal >> 8) //nolint:gosec // G115
flexFecHeader[3] ^= uint8(lengthRecoveryVal) //nolint:gosec // G115
// XOR the 5th to 8th bytes of the header: the timestamp field
flexFecHeader[4] ^= flexFecHeader[4]
flexFecHeader[5] ^= flexFecHeader[5]
flexFecHeader[6] ^= flexFecHeader[6]
flexFecHeader[7] ^= flexFecHeader[7]
}
// Write the base SN for the batch of media packets
binary.BigEndian.PutUint16(flexFecHeader[8:10], mediaBaseSn)
// Write the bitmasks to the header
binary.BigEndian.PutUint16(flexFecHeader[10:12], mask1)
if optionalMask2 > 0 {
binary.BigEndian.PutUint32(flexFecHeader[12:16], optionalMask2)
flexFecHeader[10] |= 0b10000000
}
if optionalMask3 > 0 {
binary.BigEndian.PutUint64(flexFecHeader[16:24], optionalMask3)
flexFecHeader[12] |= 0b10000000
}
return flexFecHeader
}
func (flex *FlexEncoder20) encodeFlexFecRepairPayload(mediaPackets *util.MediaPacketIterator) []byte {
flexFecPayload := make([]byte, len(mediaPackets.First().Payload))
for mediaPackets.HasNext() {
mediaPacketPayload := mediaPackets.Next().Payload
if len(flexFecPayload) < len(mediaPacketPayload) {
// Expected FEC packet payload is bigger that what we can currently store,
// we need to resize.
flexFecPayloadTmp := make([]byte, len(mediaPacketPayload))
copy(flexFecPayloadTmp, flexFecPayload)
flexFecPayload = flexFecPayloadTmp
}
for byteIndex := 0; byteIndex < len(mediaPacketPayload); byteIndex++ {
flexFecPayload[byteIndex] ^= mediaPacketPayload[byteIndex]
}
}
return flexFecPayload
}

View File

@ -0,0 +1,255 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package flexfec implements FlexFEC to recover missing RTP packets due to packet loss.
// https://datatracker.ietf.org/doc/html/rfc8627
package flexfec
import (
"encoding/binary"
"github.com/pion/interceptor/pkg/flexfec/util"
"github.com/pion/rtp"
)
const (
// BaseFec03HeaderSize represents the minium FEC payload's header size including the
// required first mask.
BaseFec03HeaderSize = 20
)
// FlexEncoder03 implements the Fec encoding mechanism for the "Flex" variant of FlexFec.
type FlexEncoder03 struct {
fecBaseSn uint16
payloadType uint8
ssrc uint32
coverage *ProtectionCoverage
}
// FlexEncoder03Factory is a factory for FlexFEC-03 encoders.
type FlexEncoder03Factory struct{}
// NewEncoder creates new FlexFEC-03 encoder.
func (f FlexEncoder03Factory) NewEncoder(payloadType uint8, ssrc uint32) FlexEncoder {
return NewFlexEncoder03(payloadType, ssrc)
}
// NewFlexEncoder03 creates new FlexFEC-03 encoder.
func NewFlexEncoder03(payloadType uint8, ssrc uint32) *FlexEncoder03 {
return &FlexEncoder03{
payloadType: payloadType,
ssrc: ssrc,
fecBaseSn: uint16(1000),
}
}
// EncodeFec returns a list of generated RTP packets with FEC payloads that protect the specified mediaPackets.
// This method returns nil in case of missing RTP packets in the mediaPackets array or packets passed out of order.
func (flex *FlexEncoder03) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet {
// Check if mediaPackets is empty
if len(mediaPackets) == 0 {
return nil
}
// Check if RTP packets are in order by comparing sequence numbers
for i := 1; i < len(mediaPackets); i++ {
if mediaPackets[i].SequenceNumber != mediaPackets[i-1].SequenceNumber+1 {
// Packets are not in order or there are missing packets
return nil
}
}
// Start by defining which FEC packets cover which media packets
if flex.coverage == nil {
flex.coverage = NewCoverage(mediaPackets, numFecPackets)
} else {
flex.coverage.UpdateCoverage(mediaPackets, numFecPackets)
}
if flex.coverage == nil {
return nil
}
// Generate FEC payloads
fecPackets := make([]rtp.Packet, numFecPackets)
for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ {
fecPackets[fecPacketIndex] = flex.encodeFlexFecPacket(fecPacketIndex, mediaPackets[0].SequenceNumber)
}
return fecPackets
}
func (flex *FlexEncoder03) encodeFlexFecPacket(fecPacketIndex uint32, mediaBaseSn uint16) rtp.Packet {
mediaPacketsIt := flex.coverage.GetCoveredBy(fecPacketIndex)
flexFecHeader := flex.encodeFlexFecHeader(
mediaPacketsIt,
flex.coverage.ExtractMask1(fecPacketIndex),
flex.coverage.ExtractMask2(fecPacketIndex),
flex.coverage.ExtractMask3_03(fecPacketIndex),
mediaBaseSn,
)
flexFecRepairPayload := flex.encodeFlexFecRepairPayload(mediaPacketsIt.Reset())
packet := rtp.Packet{
Header: rtp.Header{
Version: 2,
Padding: false,
Extension: false,
Marker: false,
PayloadType: flex.payloadType,
SequenceNumber: flex.fecBaseSn,
Timestamp: 54243243,
SSRC: flex.ssrc,
CSRC: []uint32{},
},
Payload: append(flexFecHeader, flexFecRepairPayload...),
}
flex.fecBaseSn++
return packet
}
func (flex *FlexEncoder03) encodeFlexFecHeader( //nolint:cyclop
mediaPackets *util.MediaPacketIterator,
mask1 uint16,
optionalMask2 uint32,
optionalMask3 uint64,
mediaBaseSn uint16,
) []byte {
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0|0| P|X| CC |M| PT recovery | length recovery |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| TS recovery |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SSRCCount | reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SSRC_i |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SN base_i |k| Mask [0-14] |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|k| Mask [15-45] (optional) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|k| |
+-+ Mask [46-108] (optional) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| ... next in SSRC_i ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
// Get header size - This depends on the size of the bitmask.
headerSize := BaseFec03HeaderSize
if optionalMask2 > 0 || optionalMask3 > 0 {
headerSize += 4
}
if optionalMask3 > 0 {
headerSize += 8
}
// Allocate the FlexFec header
flexFecHeader := make([]byte, headerSize)
// We allocate a single temporary buffer to store the mediaPacket bytes. This reduces
// overall allocations.
tmpMediaPacketBuf := make([]byte, 0)
for mediaPackets.HasNext() {
mediaPacket := mediaPackets.Next()
if mediaPacket.MarshalSize() > len(tmpMediaPacketBuf) {
// The temporary buffer is too small, we need to resize.
tmpMediaPacketBuf = make([]byte, mediaPacket.MarshalSize())
}
n, err := mediaPacket.MarshalTo(tmpMediaPacketBuf)
if n == 0 || err != nil {
return nil
}
// XOR the first 2 bytes of the header: V, P, X, CC, M, PT fields
flexFecHeader[0] ^= tmpMediaPacketBuf[0]
flexFecHeader[1] ^= tmpMediaPacketBuf[1]
// Clear the first 2 bits
flexFecHeader[0] &= 0b00111111
// XOR the length recovery field
lengthRecoveryVal := uint16(mediaPacket.MarshalSize() - BaseRTPHeaderSize) //nolint:gosec // G115
flexFecHeader[2] ^= uint8(lengthRecoveryVal >> 8) //nolint:gosec // G115
flexFecHeader[3] ^= uint8(lengthRecoveryVal) //nolint:gosec // G115
// XOR the 5th to 8th bytes of the header: the timestamp field
flexFecHeader[4] ^= tmpMediaPacketBuf[4]
flexFecHeader[5] ^= tmpMediaPacketBuf[5]
flexFecHeader[6] ^= tmpMediaPacketBuf[6]
flexFecHeader[7] ^= tmpMediaPacketBuf[7]
}
// Write the SSRC count
flexFecHeader[8] = 1
// Write 0s in reserved
flexFecHeader[9] = 0
flexFecHeader[10] = 0
flexFecHeader[11] = 0
// Write the SSRC of media packets protected by this FEC packet
binary.BigEndian.PutUint32(flexFecHeader[12:16], mediaPackets.First().SSRC)
// Write the base SN for the batch of media packets
binary.BigEndian.PutUint16(flexFecHeader[16:18], mediaBaseSn)
// Write the bitmasks to the header
binary.BigEndian.PutUint16(flexFecHeader[18:20], mask1)
if optionalMask2 == 0 && optionalMask3 == 0 {
flexFecHeader[18] |= 0b10000000
return flexFecHeader
}
binary.BigEndian.PutUint32(flexFecHeader[20:24], optionalMask2)
if optionalMask3 == 0 {
flexFecHeader[20] |= 0b10000000
} else {
binary.BigEndian.PutUint64(flexFecHeader[24:32], optionalMask3)
flexFecHeader[24] |= 0b10000000
}
return flexFecHeader
}
func (flex *FlexEncoder03) encodeFlexFecRepairPayload(mediaPackets *util.MediaPacketIterator) []byte {
flexFecPayload := make([]byte, mediaPackets.First().MarshalSize()-BaseRTPHeaderSize)
tmpMediaPacketBuf := make([]byte, 0)
for mediaPackets.HasNext() {
mediaPacket := mediaPackets.Next()
if mediaPacket.MarshalSize() > len(tmpMediaPacketBuf) {
tmpMediaPacketBuf = make([]byte, mediaPacket.MarshalSize())
}
n, err := mediaPacket.MarshalTo(tmpMediaPacketBuf)
if n == 0 || err != nil {
return nil
}
if len(flexFecPayload) < mediaPacket.MarshalSize()-BaseRTPHeaderSize {
// Expected FEC packet payload is bigger that what we can currently store,
// we need to resize.
flexFecPayloadTmp := make([]byte, mediaPacket.MarshalSize()-BaseRTPHeaderSize)
copy(flexFecPayloadTmp, flexFecPayload)
flexFecPayload = flexFecPayloadTmp
}
for byteIndex := 0; byteIndex < mediaPacket.MarshalSize()-BaseRTPHeaderSize; byteIndex++ {
flexFecPayload[byteIndex] ^= tmpMediaPacketBuf[byteIndex+BaseRTPHeaderSize]
}
}
return flexFecPayload
}

View File

@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package flexfec
// FecOption can be used to set initial options on Fec encoder interceptors.
type FecOption func(d *FecInterceptor) error
// NumMediaPackets sets the number of media packets to accumulate before generating another FEC packets batch.
func NumMediaPackets(numMediaPackets uint32) FecOption {
return func(f *FecInterceptor) error {
f.numMediaPackets = numMediaPackets
return nil
}
}
// NumFECPackets sets the number of FEC packets to generate for each batch of media packets.
func NumFECPackets(numFecPackets uint32) FecOption {
return func(f *FecInterceptor) error {
f.numFecPackets = numFecPackets
return nil
}
}
// FECEncoderFactory sets the custom factory for constructing the FEC Encoders.
func FECEncoderFactory(factory EncoderFactory) FecOption {
return func(f *FecInterceptor) error {
f.encoderFactory = factory
return nil
}
}

View File

@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package util implements utilities to better support Fec decoding / encoding.
package util
// BitArray provides support for bitmask manipulations.
type BitArray struct {
Lo uint64 // leftmost 64 bits
Hi uint64 // rightmost 64 bits
}
// SetBit sets a bit to the specified bit value on the bitmask.
func (b *BitArray) SetBit(bitIndex uint32) {
if bitIndex < 64 {
b.Lo |= uint64(0b1) << (63 - bitIndex)
} else {
hiBitIndex := bitIndex - 64
b.Hi |= uint64(0b1) << (63 - hiBitIndex)
}
}
// Reset clears the bitmask.
func (b *BitArray) Reset() {
b.Lo = 0
b.Hi = 0
}
// GetBit returns the bit value at a specified index of the bitmask.
func (b *BitArray) GetBit(bitIndex uint32) uint8 {
if bitIndex < 64 {
result := (b.Lo & (uint64(0b1) << (63 - bitIndex)))
if result > 0 {
return 1
}
return 0
}
hiBitIndex := bitIndex - 64
result := (b.Hi & (uint64(0b1) << (63 - hiBitIndex)))
if result > 0 {
return 1
}
return 0
}

View File

@ -0,0 +1,56 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package util
import "github.com/pion/rtp"
// MediaPacketIterator supports iterating through a list of media packets protected by
// a specific Fec packet.
type MediaPacketIterator struct {
mediaPackets []rtp.Packet
coveredIndices []uint32
nextIndex int
}
// NewMediaPacketIterator returns a new MediaPacketIterator.
func NewMediaPacketIterator(mediaPackets []rtp.Packet, coveredIndices []uint32) *MediaPacketIterator {
return &MediaPacketIterator{
mediaPackets: mediaPackets,
coveredIndices: coveredIndices,
nextIndex: 0,
}
}
// Reset sets the starting iterating index back to 0.
func (m *MediaPacketIterator) Reset() *MediaPacketIterator {
m.nextIndex = 0
return m
}
// HasNext indicates whether or not there are more media packets
// that can be iterated through.
func (m *MediaPacketIterator) HasNext() bool {
return m.nextIndex < len(m.coveredIndices)
}
// Next returns the next media packet to iterate through.
func (m *MediaPacketIterator) Next() *rtp.Packet {
if m.nextIndex == len(m.coveredIndices) {
return nil
}
packet := m.mediaPackets[m.coveredIndices[m.nextIndex]]
m.nextIndex++
return &packet
}
// First returns the first media packet to iterate through.
func (m *MediaPacketIterator) First() *rtp.Packet {
if len(m.coveredIndices) == 0 {
return nil
}
return &m.mediaPackets[m.coveredIndices[0]]
}

View File

@ -3,13 +3,7 @@
package nack
import "errors"
import "github.com/pion/interceptor/internal/rtpbuffer"
// ErrInvalidSize is returned by newReceiveLog/newSendBuffer, when an incorrect buffer size is supplied.
var ErrInvalidSize = errors.New("invalid buffer size")
var (
errPacketReleased = errors.New("could not retain packet, already released")
errFailedToCastHeaderPool = errors.New("could not access header pool, failed cast")
errFailedToCastPayloadPool = errors.New("could not access payload pool, failed cast")
)
// ErrInvalidSize is returned by newReceiveLog/newRTPBuffer, when an incorrect buffer size is supplied.
var ErrInvalidSize = rtpbuffer.ErrInvalidSize

View File

@ -13,14 +13,14 @@ import (
"github.com/pion/rtcp"
)
// GeneratorInterceptorFactory is a interceptor.Factory for a GeneratorInterceptor
// GeneratorInterceptorFactory is a interceptor.Factory for a GeneratorInterceptor.
type GeneratorInterceptorFactory struct {
opts []GeneratorOption
}
// NewInterceptor constructs a new ReceiverInterceptor
// NewInterceptor constructs a new ReceiverInterceptor.
func (g *GeneratorInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &GeneratorInterceptor{
generatorInterceptor := &GeneratorInterceptor{
streamsFilter: streamSupportNack,
size: 512,
skipLastN: 0,
@ -33,16 +33,16 @@ func (g *GeneratorInterceptorFactory) NewInterceptor(_ string) (interceptor.Inte
}
for _, opt := range g.opts {
if err := opt(i); err != nil {
if err := opt(generatorInterceptor); err != nil {
return nil, err
}
}
if _, err := newReceiveLog(i.size); err != nil {
if _, err := newReceiveLog(generatorInterceptor.size); err != nil {
return nil, err
}
return i, nil
return generatorInterceptor, nil
}
// GeneratorInterceptor interceptor generates nack feedback messages.
@ -63,13 +63,13 @@ type GeneratorInterceptor struct {
receiveLogsMu sync.Mutex
}
// NewGeneratorInterceptor returns a new GeneratorInterceptorFactory
// NewGeneratorInterceptor returns a new GeneratorInterceptorFactory.
func NewGeneratorInterceptor(opts ...GeneratorOption) (*GeneratorInterceptorFactory, error) {
return &GeneratorInterceptorFactory{opts}, nil
}
// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method
// will be called once per packet batch.
// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection.
// The returned method will be called once per packet batch.
func (n *GeneratorInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
n.m.Lock()
defer n.m.Unlock()
@ -85,9 +85,11 @@ func (n *GeneratorInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) int
return writer
}
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
func (n *GeneratorInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream.
// The returned method will be called once per rtp packet.
func (n *GeneratorInterceptor) BindRemoteStream(
info *interceptor.StreamInfo, reader interceptor.RTPReader,
) interceptor.RTPReader {
if !n.streamsFilter(info) {
return reader
}
@ -124,7 +126,7 @@ func (n *GeneratorInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo)
n.receiveLogsMu.Unlock()
}
// Close closes the interceptor
// Close closes the interceptor.
func (n *GeneratorInterceptor) Close() error {
defer n.wg.Wait()
n.m.Lock()
@ -137,12 +139,15 @@ func (n *GeneratorInterceptor) Close() error {
return nil
}
// nolint:gocognit
// nolint:gocognit,cyclop
func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
defer n.wg.Done()
senderSSRC := rand.Uint32() // #nosec
missingPacketSeqNums := make([]uint16, n.size)
filteredMissingPacket := make([]uint16, n.size)
ticker := time.NewTicker(n.interval)
defer ticker.Stop()
for {
@ -153,7 +158,7 @@ func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
defer n.receiveLogsMu.Unlock()
for ssrc, receiveLog := range n.receiveLogs {
missing := receiveLog.missingSeqNumbers(n.skipLastN)
missing := receiveLog.missingSeqNumbers(n.skipLastN, missingPacketSeqNums)
if len(missing) == 0 || n.nackCountLogs[ssrc] == nil {
n.nackCountLogs[ssrc] = map[uint16]uint16{}
@ -162,22 +167,33 @@ func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
continue
}
filteredMissing := []uint16{}
nack := &rtcp.TransportLayerNack{} // nolint:ineffassign,wastedassign
c := 0 // nolint:varnamelen,
if n.maxNacksPerPacket > 0 {
for _, missingSeq := range missing {
if n.nackCountLogs[ssrc][missingSeq] < n.maxNacksPerPacket {
filteredMissing = append(filteredMissing, missingSeq)
filteredMissingPacket[c] = missingSeq
c++
}
n.nackCountLogs[ssrc][missingSeq]++
}
} else {
filteredMissing = missing
}
nack := &rtcp.TransportLayerNack{
SenderSSRC: senderSSRC,
MediaSSRC: ssrc,
Nacks: rtcp.NackPairsFromSequenceNumbers(filteredMissing),
if c == 0 {
continue
}
nack = &rtcp.TransportLayerNack{
SenderSSRC: senderSSRC,
MediaSSRC: ssrc,
Nacks: rtcp.NackPairsFromSequenceNumbers(filteredMissingPacket[:c]),
}
} else {
nack = &rtcp.TransportLayerNack{
SenderSSRC: senderSSRC,
MediaSSRC: ssrc,
Nacks: rtcp.NackPairsFromSequenceNumbers(missing),
}
}
for nackSeq := range n.nackCountLogs[ssrc] {
@ -185,6 +201,7 @@ func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
for _, missingSeq := range missing {
if missingSeq == nackSeq {
isMissing = true
break
}
}
@ -193,10 +210,6 @@ func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
}
}
if len(filteredMissing) == 0 {
continue
}
if _, err := rtcpWriter.Write([]rtcp.Packet{nack}, interceptor.Attributes{}); err != nil {
n.log.Warnf("failed sending nack: %+v", err)
}

View File

@ -10,56 +10,63 @@ import (
"github.com/pion/logging"
)
// GeneratorOption can be used to configure GeneratorInterceptor
// GeneratorOption can be used to configure GeneratorInterceptor.
type GeneratorOption func(r *GeneratorInterceptor) error
// GeneratorSize sets the size of the interceptor.
// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768
// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768.
func GeneratorSize(size uint16) GeneratorOption {
return func(r *GeneratorInterceptor) error {
r.size = size
return nil
}
}
// GeneratorSkipLastN sets the number of packets (n-1 packets before the last received packets) to ignore when generating
// nack requests.
// GeneratorSkipLastN sets the number of packets (n-1 packets before the last received packets)
//
// to ignore when generating nack requests.
func GeneratorSkipLastN(skipLastN uint16) GeneratorOption {
return func(r *GeneratorInterceptor) error {
r.skipLastN = skipLastN
return nil
}
}
// GeneratorMaxNacksPerPacket sets the maximum number of NACKs sent per missing packet, e.g. if set to 2, a missing
// packet will only be NACKed at most twice. If set to 0 (default), max number of NACKs is unlimited
// packet will only be NACKed at most twice. If set to 0 (default), max number of NACKs is unlimited.
func GeneratorMaxNacksPerPacket(maxNacks uint16) GeneratorOption {
return func(r *GeneratorInterceptor) error {
r.maxNacksPerPacket = maxNacks
return nil
}
}
// GeneratorLog sets a logger for the interceptor
// GeneratorLog sets a logger for the interceptor.
func GeneratorLog(log logging.LeveledLogger) GeneratorOption {
return func(r *GeneratorInterceptor) error {
r.log = log
return nil
}
}
// GeneratorInterval sets the nack send interval for the interceptor
// GeneratorInterval sets the nack send interval for the interceptor.
func GeneratorInterval(interval time.Duration) GeneratorOption {
return func(r *GeneratorInterceptor) error {
r.interval = interval
return nil
}
}
// GeneratorStreamsFilter sets filter for generator streams
// GeneratorStreamsFilter sets filter for generator streams.
func GeneratorStreamsFilter(filter func(info *interceptor.StreamInfo) bool) GeneratorOption {
return func(r *GeneratorInterceptor) error {
r.streamsFilter = filter
return nil
}
}

View File

@ -6,6 +6,8 @@ package nack
import (
"fmt"
"sync"
"github.com/pion/interceptor/internal/rtpbuffer"
)
type receiveLog struct {
@ -23,6 +25,7 @@ func newReceiveLog(size uint16) (*receiveLog, error) {
for i := 6; i < 16; i++ {
if size == 1<<i {
correctSize = true
break
}
allowedSizes = append(allowedSizes, 1<<i)
@ -47,6 +50,7 @@ func (s *receiveLog) add(seq uint16) {
s.end = seq
s.started = true
s.lastConsecutive = seq
return
}
@ -54,7 +58,7 @@ func (s *receiveLog) add(seq uint16) {
switch {
case diff == 0:
return
case diff < uint16SizeHalf:
case diff < rtpbuffer.Uint16SizeHalf:
// this means a positive diff, in other words seq > end (with counting for rollovers)
for i := s.end + 1; i != seq; i++ {
// clear packets between end and seq (these may contain packets from a "size" ago)
@ -82,7 +86,7 @@ func (s *receiveLog) get(seq uint16) bool {
defer s.m.RUnlock()
diff := s.end - seq
if diff >= uint16SizeHalf {
if diff >= rtpbuffer.Uint16SizeHalf {
return false
}
@ -93,24 +97,25 @@ func (s *receiveLog) get(seq uint16) bool {
return s.getReceived(seq)
}
func (s *receiveLog) missingSeqNumbers(skipLastN uint16) []uint16 {
func (s *receiveLog) missingSeqNumbers(skipLastN uint16, missingPacketSeqNums []uint16) []uint16 {
s.m.RLock()
defer s.m.RUnlock()
until := s.end - skipLastN
if until-s.lastConsecutive >= uint16SizeHalf {
if until-s.lastConsecutive >= rtpbuffer.Uint16SizeHalf {
// until < s.lastConsecutive (counting for rollover)
return nil
}
missingPacketSeqNums := make([]uint16, 0)
c := 0
for i := s.lastConsecutive + 1; i != until+1; i++ {
if !s.getReceived(i) {
missingPacketSeqNums = append(missingPacketSeqNums, i)
missingPacketSeqNums[c] = i
c++
}
}
return missingPacketSeqNums
return missingPacketSeqNums[:c]
}
func (s *receiveLog) setReceived(seq uint16) {
@ -125,6 +130,7 @@ func (s *receiveLog) delReceived(seq uint16) {
func (s *receiveLog) getReceived(seq uint16) bool {
pos := seq % s.size
return (s.packets[pos/64] & (1 << (pos % 64))) != 0
}

View File

@ -7,23 +7,20 @@ import (
"sync"
"github.com/pion/interceptor"
"github.com/pion/interceptor/internal/rtpbuffer"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/rtp"
)
// ResponderInterceptorFactory is a interceptor.Factory for a ResponderInterceptor
// ResponderInterceptorFactory is a interceptor.Factory for a ResponderInterceptor.
type ResponderInterceptorFactory struct {
opts []ResponderOption
}
type packetFactory interface {
NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error)
}
// NewInterceptor constructs a new ResponderInterceptor
// NewInterceptor constructs a new ResponderInterceptor.
func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &ResponderInterceptor{
responderInterceptor := &ResponderInterceptor{
streamsFilter: streamSupportNack,
size: 1024,
log: logging.NewDefaultLoggerFactory().NewLogger("nack_responder"),
@ -31,40 +28,41 @@ func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Inte
}
for _, opt := range r.opts {
if err := opt(i); err != nil {
if err := opt(responderInterceptor); err != nil {
return nil, err
}
}
if i.packetFactory == nil {
i.packetFactory = newPacketManager()
if responderInterceptor.packetFactory == nil {
responderInterceptor.packetFactory = rtpbuffer.NewPacketFactoryCopy()
}
if _, err := newSendBuffer(i.size); err != nil {
if _, err := rtpbuffer.NewRTPBuffer(responderInterceptor.size); err != nil {
return nil, err
}
return i, nil
return responderInterceptor, nil
}
// ResponderInterceptor responds to nack feedback messages
// ResponderInterceptor responds to nack feedback messages.
type ResponderInterceptor struct {
interceptor.NoOp
streamsFilter func(info *interceptor.StreamInfo) bool
size uint16
log logging.LeveledLogger
packetFactory packetFactory
packetFactory rtpbuffer.PacketFactory
streams map[uint32]*localStream
streamsMu sync.Mutex
}
type localStream struct {
sendBuffer *sendBuffer
rtpWriter interceptor.RTPWriter
rtpBuffer *rtpbuffer.RTPBuffer
rtpBufferMutex sync.RWMutex
rtpWriter interceptor.RTPWriter
}
// NewResponderInterceptor returns a new ResponderInterceptorFactor
// NewResponderInterceptor returns a new ResponderInterceptorFactor.
func NewResponderInterceptor(opts ...ResponderOption) (*ResponderInterceptorFactory, error) {
return &ResponderInterceptorFactory{opts}, nil
}
@ -98,30 +96,44 @@ func (n *ResponderInterceptor) BindRTCPReader(reader interceptor.RTCPReader) int
})
}
// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
// will be called once per rtp packet.
func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream.
// The returned method will be called once per rtp packet.
func (n *ResponderInterceptor) BindLocalStream(
info *interceptor.StreamInfo, writer interceptor.RTPWriter,
) interceptor.RTPWriter {
if !n.streamsFilter(info) {
return writer
}
// error is already checked in NewGeneratorInterceptor
sendBuffer, _ := newSendBuffer(n.size)
n.streamsMu.Lock()
n.streams[info.SSRC] = &localStream{
sendBuffer: sendBuffer,
rtpWriter: writer,
rtpBuffer, _ := rtpbuffer.NewRTPBuffer(n.size)
stream := &localStream{
rtpBuffer: rtpBuffer,
rtpWriter: writer,
}
n.streamsMu.Lock()
n.streams[info.SSRC] = stream
n.streamsMu.Unlock()
return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
pkt, err := n.packetFactory.NewPacket(header, payload, info.SSRCRetransmission, info.PayloadTypeRetransmission)
if err != nil {
return 0, err
}
sendBuffer.add(pkt)
return writer.Write(header, payload, attributes)
})
return interceptor.RTPWriterFunc(
func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
// If this packet doesn't belong to the main SSRC, do not add it to rtpBuffer
if header.SSRC != info.SSRC {
return writer.Write(header, payload, attributes)
}
pkt, err := n.packetFactory.NewPacket(header, payload, info.SSRCRetransmission, info.PayloadTypeRetransmission)
if err != nil {
return 0, err
}
stream.rtpBufferMutex.Lock()
defer stream.rtpBufferMutex.Unlock()
rtpBuffer.Add(pkt)
return writer.Write(header, payload, attributes)
},
)
}
// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track.
@ -141,7 +153,10 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) {
for i := range nack.Nacks {
nack.Nacks[i].Range(func(seq uint16) bool {
if p := stream.sendBuffer.get(seq); p != nil {
stream.rtpBufferMutex.Lock()
defer stream.rtpBufferMutex.Unlock()
if p := stream.rtpBuffer.Get(seq); p != nil {
if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil {
n.log.Warnf("failed resending nacked packet: %+v", err)
}

View File

@ -5,42 +5,47 @@ package nack
import (
"github.com/pion/interceptor"
"github.com/pion/interceptor/internal/rtpbuffer"
"github.com/pion/logging"
)
// ResponderOption can be used to configure ResponderInterceptor
// ResponderOption can be used to configure ResponderInterceptor.
type ResponderOption func(s *ResponderInterceptor) error
// ResponderSize sets the size of the interceptor.
// Size must be one of: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768
// Size must be one of: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768.
func ResponderSize(size uint16) ResponderOption {
return func(r *ResponderInterceptor) error {
r.size = size
return nil
}
}
// ResponderLog sets a logger for the interceptor
// ResponderLog sets a logger for the interceptor.
func ResponderLog(log logging.LeveledLogger) ResponderOption {
return func(r *ResponderInterceptor) error {
r.log = log
return nil
}
}
// DisableCopy bypasses copy of underlying packets. It should be used when
// you are not re-using underlying buffers of packets that have been written
// you are not re-using underlying buffers of packets that have been written.
func DisableCopy() ResponderOption {
return func(s *ResponderInterceptor) error {
s.packetFactory = &noOpPacketFactory{}
s.packetFactory = &rtpbuffer.PacketFactoryNoOp{}
return nil
}
}
// ResponderStreamsFilter sets filter for local streams
// ResponderStreamsFilter sets filter for local streams.
func ResponderStreamsFilter(filter func(info *interceptor.StreamInfo) bool) ResponderOption {
return func(r *ResponderInterceptor) error {
r.streamsFilter = filter
return nil
}
}

View File

@ -1,162 +0,0 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package nack
import (
"encoding/binary"
"io"
"sync"
"github.com/pion/rtp"
)
const maxPayloadLen = 1460
type packetManager struct {
headerPool *sync.Pool
payloadPool *sync.Pool
rtxSequencer rtp.Sequencer
}
func newPacketManager() *packetManager {
return &packetManager{
headerPool: &sync.Pool{
New: func() interface{} {
return &rtp.Header{}
},
},
payloadPool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, maxPayloadLen)
return &buf
},
},
rtxSequencer: rtp.NewRandomSequencer(),
}
}
func (m *packetManager) NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) {
if len(payload) > maxPayloadLen {
return nil, io.ErrShortBuffer
}
p := &retainablePacket{
onRelease: m.releasePacket,
sequenceNumber: header.SequenceNumber,
// new packets have retain count of 1
count: 1,
}
var ok bool
p.header, ok = m.headerPool.Get().(*rtp.Header)
if !ok {
return nil, errFailedToCastHeaderPool
}
*p.header = header.Clone()
if payload != nil {
p.buffer, ok = m.payloadPool.Get().(*[]byte)
if !ok {
return nil, errFailedToCastPayloadPool
}
size := copy(*p.buffer, payload)
p.payload = (*p.buffer)[:size]
}
if rtxSsrc != 0 && rtxPayloadType != 0 {
// Store the original sequence number and rewrite the sequence number.
originalSequenceNumber := p.header.SequenceNumber
p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber()
// Rewrite the SSRC.
p.header.SSRC = rtxSsrc
// Rewrite the payload type.
p.header.PayloadType = rtxPayloadType
// Remove padding if present.
paddingLength := 0
if p.header.Padding && p.payload != nil && len(p.payload) > 0 {
paddingLength = int(p.payload[len(p.payload)-1])
p.header.Padding = false
}
// Write the original sequence number at the beginning of the payload.
payload := make([]byte, 2)
binary.BigEndian.PutUint16(payload, originalSequenceNumber)
p.payload = append(payload, p.payload[:len(p.payload)-paddingLength]...)
}
return p, nil
}
func (m *packetManager) releasePacket(header *rtp.Header, payload *[]byte) {
m.headerPool.Put(header)
if payload != nil {
m.payloadPool.Put(payload)
}
}
type noOpPacketFactory struct{}
func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte, _ uint32, _ uint8) (*retainablePacket, error) {
return &retainablePacket{
onRelease: f.releasePacket,
count: 1,
header: header,
payload: payload,
sequenceNumber: header.SequenceNumber,
}, nil
}
func (f *noOpPacketFactory) releasePacket(_ *rtp.Header, _ *[]byte) {
// no-op
}
type retainablePacket struct {
onRelease func(*rtp.Header, *[]byte)
countMu sync.Mutex
count int
header *rtp.Header
buffer *[]byte
payload []byte
sequenceNumber uint16
}
func (p *retainablePacket) Header() *rtp.Header {
return p.header
}
func (p *retainablePacket) Payload() []byte {
return p.payload
}
func (p *retainablePacket) Retain() error {
p.countMu.Lock()
defer p.countMu.Unlock()
if p.count == 0 {
// already released
return errPacketReleased
}
p.count++
return nil
}
func (p *retainablePacket) Release() {
p.countMu.Lock()
defer p.countMu.Unlock()
p.count--
if p.count == 0 {
// release back to pool
p.onRelease(p.header, p.buffer)
p.header = nil
p.buffer = nil
p.payload = nil
}
}

View File

@ -1,104 +0,0 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package nack
import (
"fmt"
"sync"
)
const (
uint16SizeHalf = 1 << 15
)
type sendBuffer struct {
packets []*retainablePacket
size uint16
lastAdded uint16
started bool
m sync.RWMutex
}
func newSendBuffer(size uint16) (*sendBuffer, error) {
allowedSizes := make([]uint16, 0)
correctSize := false
for i := 0; i < 16; i++ {
if size == 1<<i {
correctSize = true
break
}
allowedSizes = append(allowedSizes, 1<<i)
}
if !correctSize {
return nil, fmt.Errorf("%w: %d is not a valid size, allowed sizes: %v", ErrInvalidSize, size, allowedSizes)
}
return &sendBuffer{
packets: make([]*retainablePacket, size),
size: size,
}, nil
}
func (s *sendBuffer) add(packet *retainablePacket) {
s.m.Lock()
defer s.m.Unlock()
seq := packet.sequenceNumber
if !s.started {
s.packets[seq%s.size] = packet
s.lastAdded = seq
s.started = true
return
}
diff := seq - s.lastAdded
if diff == 0 {
return
} else if diff < uint16SizeHalf {
for i := s.lastAdded + 1; i != seq; i++ {
idx := i % s.size
prevPacket := s.packets[idx]
if prevPacket != nil {
prevPacket.Release()
}
s.packets[idx] = nil
}
}
idx := seq % s.size
prevPacket := s.packets[idx]
if prevPacket != nil {
prevPacket.Release()
}
s.packets[idx] = packet
s.lastAdded = seq
}
func (s *sendBuffer) get(seq uint16) *retainablePacket {
s.m.RLock()
defer s.m.RUnlock()
diff := s.lastAdded - seq
if diff >= uint16SizeHalf {
return nil
}
if diff >= s.size {
return nil
}
pkt := s.packets[seq%s.size]
if pkt != nil {
if pkt.sequenceNumber != seq {
return nil
}
// already released
if err := pkt.Retain(); err != nil {
return nil
}
}
return pkt
}

View File

@ -12,14 +12,14 @@ import (
"github.com/pion/rtcp"
)
// ReceiverInterceptorFactory is a interceptor.Factory for a ReceiverInterceptor
// ReceiverInterceptorFactory is a interceptor.Factory for a ReceiverInterceptor.
type ReceiverInterceptorFactory struct {
opts []ReceiverOption
}
// NewInterceptor constructs a new ReceiverInterceptor
// NewInterceptor constructs a new ReceiverInterceptor.
func (r *ReceiverInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &ReceiverInterceptor{
receiverInterceptor := &ReceiverInterceptor{
interval: 1 * time.Second,
now: time.Now,
log: logging.NewDefaultLoggerFactory().NewLogger("receiver_interceptor"),
@ -27,15 +27,15 @@ func (r *ReceiverInterceptorFactory) NewInterceptor(_ string) (interceptor.Inter
}
for _, opt := range r.opts {
if err := opt(i); err != nil {
if err := opt(receiverInterceptor); err != nil {
return nil, err
}
}
return i, nil
return receiverInterceptor, nil
}
// NewReceiverInterceptor returns a new ReceiverInterceptorFactory
// NewReceiverInterceptor returns a new ReceiverInterceptorFactory.
func NewReceiverInterceptor(opts ...ReceiverOption) (*ReceiverInterceptorFactory, error) {
return &ReceiverInterceptorFactory{opts}, nil
}
@ -103,7 +103,9 @@ func (r *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
r.streams.Range(func(_, value interface{}) bool {
if stream, ok := value.(*receiverStream); !ok {
r.log.Warnf("failed to cast ReceiverInterceptor stream")
} else if _, err := rtcpWriter.Write([]rtcp.Packet{stream.generateReport(now)}, interceptor.Attributes{}); err != nil {
} else if _, err := rtcpWriter.Write(
[]rtcp.Packet{stream.generateReport(now)}, interceptor.Attributes{},
); err != nil {
r.log.Warnf("failed sending: %+v", err)
}
@ -116,9 +118,11 @@ func (r *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
}
}
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
func (r *ReceiverInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream.
// The returned method will be called once per rtp packet.
func (r *ReceiverInterceptor) BindRemoteStream(
info *interceptor.StreamInfo, reader interceptor.RTPReader,
) interceptor.RTPReader {
stream := newReceiverStream(info.SSRC, info.ClockRate)
r.streams.Store(info.SSRC, stream)

View File

@ -16,6 +16,7 @@ type ReceiverOption func(r *ReceiverInterceptor) error
func ReceiverLog(log logging.LeveledLogger) ReceiverOption {
return func(r *ReceiverInterceptor) error {
r.log = log
return nil
}
}
@ -24,6 +25,7 @@ func ReceiverLog(log logging.LeveledLogger) ReceiverOption {
func ReceiverInterval(interval time.Duration) ReceiverOption {
return func(r *ReceiverInterceptor) error {
r.interval = interval
return nil
}
}
@ -32,6 +34,7 @@ func ReceiverInterval(interval time.Duration) ReceiverOption {
func ReceiverNow(f func() time.Time) ReceiverOption {
return func(r *ReceiverInterceptor) error {
r.now = f
return nil
}
}

View File

@ -41,6 +41,7 @@ type receiverStream struct {
func newReceiverStream(ssrc uint32, clockRate uint32) *receiverStream {
receiverSSRC := rand.Uint32() // #nosec
return &receiverStream{
ssrc: ssrc,
receiverSSRC: receiverSSRC,
@ -54,6 +55,7 @@ func (stream *receiverStream) processRTP(now time.Time, pktHeader *rtp.Header) {
stream.m.Lock()
defer stream.m.Unlock()
//nolint:nestif
if !stream.started { // first frame
stream.started = true
stream.setReceived(pktHeader.SequenceNumber)
@ -104,6 +106,7 @@ func (stream *receiverStream) delReceived(seq uint16) {
func (stream *receiverStream) getReceived(seq uint16) bool {
pos := seq % (stream.size * packetsPerHistoryEntry)
return (stream.packets[pos/packetsPerHistoryEntry] & (1 << (pos % packetsPerHistoryEntry))) != 0
}
@ -111,7 +114,7 @@ func (stream *receiverStream) processSenderReport(now time.Time, sr *rtcp.Sender
stream.m.Lock()
defer stream.m.Unlock()
stream.lastSenderReport = uint32(sr.NTPTime >> 16)
stream.lastSenderReport = uint32(sr.NTPTime >> 16) //nolint:gosec // G115
stream.lastSenderReportTime = now
}
@ -131,6 +134,7 @@ func (stream *receiverStream) generateReport(now time.Time) *rtcp.ReceiverReport
ret++
}
}
return ret
}()
stream.totalLost += totalLostSinceReport
@ -143,7 +147,7 @@ func (stream *receiverStream) generateReport(now time.Time) *rtcp.ReceiverReport
stream.totalLost = 0xFFFFFF
}
r := &rtcp.ReceiverReport{
receiverReport := &rtcp.ReceiverReport{
SSRC: stream.receiverSSRC,
Reports: []rtcp.ReceptionReport{
{
@ -156,6 +160,7 @@ func (stream *receiverStream) generateReport(now time.Time) *rtcp.ReceiverReport
if stream.lastSenderReportTime.IsZero() {
return 0
}
return uint32(now.Sub(stream.lastSenderReportTime).Seconds() * 65536)
}(),
Jitter: uint32(stream.jitter),
@ -165,5 +170,5 @@ func (stream *receiverStream) generateReport(now time.Time) *rtcp.ReceiverReport
stream.lastReportSeqnum = stream.lastSeqnum
return r
return receiverReport
}

View File

@ -13,17 +13,17 @@ import (
"github.com/pion/rtp"
)
// TickerFactory is a factory to create new tickers
// TickerFactory is a factory to create new tickers.
type TickerFactory func(d time.Duration) Ticker
// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor
// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor.
type SenderInterceptorFactory struct {
opts []SenderOption
}
// NewInterceptor constructs a new SenderInterceptor
// NewInterceptor constructs a new SenderInterceptor.
func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &SenderInterceptor{
senderInterceptor := &SenderInterceptor{
interval: 1 * time.Second,
now: time.Now,
newTicker: func(d time.Duration) Ticker {
@ -34,15 +34,15 @@ func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interce
}
for _, opt := range s.opts {
if err := opt(i); err != nil {
if err := opt(senderInterceptor); err != nil {
return nil, err
}
}
return i, nil
return senderInterceptor, nil
}
// NewSenderInterceptor returns a new SenderInterceptorFactory
// NewSenderInterceptor returns a new SenderInterceptorFactory.
func NewSenderInterceptor(opts ...SenderOption) (*SenderInterceptorFactory, error) {
return &SenderInterceptorFactory{opts}, nil
}
@ -119,7 +119,9 @@ func (s *SenderInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
s.streams.Range(func(_, value interface{}) bool {
if stream, ok := value.(*senderStream); !ok {
s.log.Warnf("failed to cast SenderInterceptor stream")
} else if _, err := rtcpWriter.Write([]rtcp.Packet{stream.generateReport(now)}, interceptor.Attributes{}); err != nil {
} else if _, err := rtcpWriter.Write(
[]rtcp.Packet{stream.generateReport(now)}, interceptor.Attributes{},
); err != nil {
s.log.Warnf("failed sending: %+v", err)
}
@ -134,7 +136,9 @@ func (s *SenderInterceptor) loop(rtcpWriter interceptor.RTCPWriter) {
// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
// will be called once per rtp packet.
func (s *SenderInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
func (s *SenderInterceptor) BindLocalStream(
info *interceptor.StreamInfo, writer interceptor.RTPWriter,
) interceptor.RTPWriter {
stream := newSenderStream(info.SSRC, info.ClockRate, s.useLatestPacket)
s.streams.Store(info.SSRC, stream)

View File

@ -16,6 +16,7 @@ type SenderOption func(r *SenderInterceptor) error
func SenderLog(log logging.LeveledLogger) SenderOption {
return func(r *SenderInterceptor) error {
r.log = log
return nil
}
}
@ -24,6 +25,7 @@ func SenderLog(log logging.LeveledLogger) SenderOption {
func SenderInterval(interval time.Duration) SenderOption {
return func(r *SenderInterceptor) error {
r.interval = interval
return nil
}
}
@ -32,6 +34,7 @@ func SenderInterval(interval time.Duration) SenderOption {
func SenderNow(f func() time.Time) SenderOption {
return func(r *SenderInterceptor) error {
r.now = f
return nil
}
}
@ -40,6 +43,7 @@ func SenderNow(f func() time.Time) SenderOption {
func SenderTicker(f TickerFactory) SenderOption {
return func(r *SenderInterceptor) error {
r.newTicker = f
return nil
}
}
@ -49,6 +53,7 @@ func SenderTicker(f TickerFactory) SenderOption {
func SenderUseLatestPacket() SenderOption {
return func(r *SenderInterceptor) error {
r.useLatestPacket = true
return nil
}
}
@ -58,6 +63,7 @@ func SenderUseLatestPacket() SenderOption {
func enableStartTracking(startedCh chan struct{}) SenderOption {
return func(r *SenderInterceptor) error {
r.started = startedCh
return nil
}
}

View File

@ -48,7 +48,7 @@ func (stream *senderStream) processRTP(now time.Time, header *rtp.Header, payloa
}
stream.packetCount++
stream.octetCount += uint32(len(payload))
stream.octetCount += uint32(len(payload)) //nolint:gosec // G115
}
func (stream *senderStream) generateReport(now time.Time) *rtcp.SenderReport {

View File

@ -14,17 +14,17 @@ import (
"github.com/pion/rtcp"
)
// TickerFactory is a factory to create new tickers
// TickerFactory is a factory to create new tickers.
type TickerFactory func(d time.Duration) ticker
// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor
// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor.
type SenderInterceptorFactory struct {
opts []Option
}
// NewInterceptor constructs a new SenderInterceptor
// NewInterceptor constructs a new SenderInterceptor.
func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &SenderInterceptor{
senderInterceptor := &SenderInterceptor{
NoOp: interceptor.NoOp{},
log: logging.NewDefaultLoggerFactory().NewLogger("rfc8888_interceptor"),
lock: sync.Mutex{},
@ -40,12 +40,13 @@ func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interce
close: make(chan struct{}),
}
for _, opt := range s.opts {
err := opt(i)
err := opt(senderInterceptor)
if err != nil {
return nil, err
}
}
return i, nil
return senderInterceptor, nil
}
// NewSenderInterceptor returns a new SenderInterceptorFactory configured with the given options.
@ -91,9 +92,12 @@ func (s *SenderInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interc
return writer
}
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
func (s *SenderInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
// BindRemoteStream lets you modify any incoming RTP packets.
// It is called once for per RemoteStream. The returned method
// will be called once per rtp packet..
func (s *SenderInterceptor) BindRemoteStream(
_ *interceptor.StreamInfo, reader interceptor.RTPReader,
) interceptor.RTPReader {
return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
i, attr, err := reader.Read(b, a)
if err != nil {
@ -115,6 +119,7 @@ func (s *SenderInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader i
ecn: 0, // ECN is not supported (yet).
}
s.packetChan <- p
return i, attr, nil
})
}
@ -157,16 +162,19 @@ func (s *SenderInterceptor) loop(writer interceptor.RTCPWriter) {
select {
case <-s.close:
t.Stop()
return
case pkt := <-s.packetChan:
s.log.Tracef("got packet: %v", pkt)
s.recorder.AddPacket(pkt.arrival, pkt.ssrc, pkt.sequenceNumber, pkt.ecn)
case now := <-t.Ch():
case <-t.Ch():
now := s.now()
s.log.Tracef("report triggered at %v", now)
if writer == nil {
s.log.Trace("no writer added, continue")
continue
}
pkts := s.recorder.BuildReport(now, int(s.maxReportSize))

View File

@ -5,13 +5,14 @@ package rfc8888
import "time"
// An Option is a function that can be used to configure a SenderInterceptor
// An Option is a function that can be used to configure a SenderInterceptor.
type Option func(*SenderInterceptor) error
// SenderTicker sets an alternative for time.Ticker.
func SenderTicker(f TickerFactory) Option {
return func(i *SenderInterceptor) error {
i.newTicker = f
return nil
}
}
@ -20,14 +21,16 @@ func SenderTicker(f TickerFactory) Option {
func SenderNow(f func() time.Time) Option {
return func(i *SenderInterceptor) error {
i.now = f
return nil
}
}
// SendInterval sets the feedback send interval for the interceptor
// SendInterval sets the feedback send interval for the interceptor.
func SendInterval(interval time.Duration) Option {
return func(s *SenderInterceptor) error {
s.interval = interval
return nil
}
}

View File

@ -6,6 +6,7 @@ package rfc8888
import (
"time"
"github.com/pion/interceptor/internal/ntp"
"github.com/pion/rtcp"
)
@ -21,7 +22,7 @@ type Recorder struct {
streams map[uint32]*streamLog
}
// NewRecorder creates a new Recorder
// NewRecorder creates a new Recorder.
func NewRecorder() *Recorder {
return &Recorder{
streams: map[uint32]*streamLog{},
@ -44,7 +45,7 @@ func (r *Recorder) BuildReport(now time.Time, maxSize int) *rtcp.CCFeedbackRepor
report := &rtcp.CCFeedbackReport{
SenderSSRC: r.ssrc,
ReportBlocks: []rtcp.CCFeedbackReportBlock{},
ReportTimestamp: ntpTime32(now),
ReportTimestamp: ntp.ToNTP32(now),
}
maxReportBlocks := (maxSize - 12 - (8 * len(r.streams))) / 2
@ -65,14 +66,3 @@ func (r *Recorder) BuildReport(now time.Time, maxSize int) *rtcp.CCFeedbackRepor
return report
}
func ntpTime32(t time.Time) uint32 {
// seconds since 1st January 1900
s := (float64(t.UnixNano()) / 1000000000.0) + 2208988800
integerPart := uint32(s)
fractionalPart := uint32((s - float64(integerPart)) * 0xFFFFFFFF)
// higher 32 bits are the integer part, lower 32 bits are the fractional part
return uint32(((uint64(integerPart)<<32 | uint64(fractionalPart)) >> 16) & 0xFFFFFFFF)
}

View File

@ -6,6 +6,7 @@ package rfc8888
import (
"time"
"github.com/pion/interceptor/internal/sequencenumber"
"github.com/pion/rtcp"
)
@ -13,7 +14,7 @@ const maxReportsPerReportBlock = 16384
type streamLog struct {
ssrc uint32
sequence unwrapper
sequence sequencenumber.Unwrapper
init bool
nextSequenceNumberToReport int64 // next to report
lastSequenceNumberReceived int64 // highest received
@ -23,7 +24,7 @@ type streamLog struct {
func newStreamLog(ssrc uint32) *streamLog {
return &streamLog{
ssrc: ssrc,
sequence: unwrapper{},
sequence: sequencenumber.Unwrapper{},
init: false,
nextSequenceNumberToReport: 0,
lastSequenceNumberReceived: 0,
@ -32,7 +33,7 @@ func newStreamLog(ssrc uint32) *streamLog {
}
func (l *streamLog) add(ts time.Time, sequenceNumber uint16, ecn uint8) {
unwrappedSequenceNumber := l.sequence.unwrap(sequenceNumber)
unwrappedSequenceNumber := l.sequence.Unwrap(sequenceNumber)
if !l.init {
l.init = true
l.nextSequenceNumberToReport = unwrappedSequenceNumber
@ -52,7 +53,7 @@ func (l *streamLog) metricsAfter(reference time.Time, maxReportBlocks int64) rtc
if len(l.log) == 0 {
return rtcp.CCFeedbackReportBlock{
MediaSSRC: l.ssrc,
BeginSequence: uint16(l.nextSequenceNumberToReport),
BeginSequence: uint16(l.nextSequenceNumberToReport), //nolint:gosec // G115
MetricBlocks: []rtcp.CCFeedbackMetricBlock{},
}
}
@ -65,7 +66,7 @@ func (l *streamLog) metricsAfter(reference time.Time, maxReportBlocks int64) rtc
offset := l.nextSequenceNumberToReport
lastReceived := l.nextSequenceNumberToReport
gapDetected := false
for i := offset; i <= l.lastSequenceNumberReceived; i++ {
for i := offset; i <= l.lastSequenceNumberReceived; i++ { //nolint:varnamelen // i int64
received := false
ecn := uint8(0)
ato := uint16(0)
@ -91,9 +92,10 @@ func (l *streamLog) metricsAfter(reference time.Time, maxReportBlocks int64) rtc
}
}
}
return rtcp.CCFeedbackReportBlock{
MediaSSRC: l.ssrc,
BeginSequence: uint16(offset),
BeginSequence: uint16(offset), //nolint:gosec // G115
MetricBlocks: metricBlocks,
}
}
@ -106,5 +108,6 @@ func getArrivalTimeOffset(base time.Time, arrival time.Time) uint16 {
if ato > 0x1FFD {
return 0x1FFE
}
return ato
}

View File

@ -1,42 +0,0 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package rfc8888
const (
maxSequenceNumberPlusOne = int64(65536)
breakpoint = 32768 // half of max uint16
)
type unwrapper struct {
init bool
lastUnwrapped int64
}
func isNewer(value, previous uint16) bool {
if value-previous == breakpoint {
return value > previous
}
return value != previous && (value-previous) < breakpoint
}
func (u *unwrapper) unwrap(i uint16) int64 {
if !u.init {
u.init = true
u.lastUnwrapped = int64(i)
return u.lastUnwrapped
}
lastWrapped := uint16(u.lastUnwrapped)
delta := int64(i - lastWrapped)
if isNewer(i, lastWrapped) {
if delta < 0 {
delta += maxSequenceNumberPlusOne
}
} else if delta > 0 && u.lastUnwrapped+delta-maxSequenceNumberPlusOne >= 0 {
delta -= maxSequenceNumberPlusOne
}
u.lastUnwrapped += delta
return u.lastUnwrapped
}

View File

@ -12,6 +12,8 @@ const (
// of the arrival times of packets. It is used by the TWCC interceptor to build feedback
// packets.
// See https://source.chromium.org/chromium/chromium/src/+/refs/heads/main:third_party/webrtc/modules/remote_bitrate_estimator/packet_arrival_map.h;drc=b5cd13bb6d5d157a5fbe3628b2dd1c1e106203c6
//
//nolint:lll
type packetArrivalTimeMap struct {
// arrivalTimes is a circular buffer, where the packet with sequence number sn is stored
// in slot sn % len(arrivalTimes)
@ -31,12 +33,14 @@ func (m *packetArrivalTimeMap) AddPacket(sequenceNumber int64, arrivalTime int64
m.beginSequenceNumber = sequenceNumber
m.endSequenceNumber = sequenceNumber + 1
m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime
return
}
if sequenceNumber >= m.beginSequenceNumber && sequenceNumber < m.endSequenceNumber {
// The packet is within the buffer, no need to resize.
m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime
return
}
@ -53,6 +57,7 @@ func (m *packetArrivalTimeMap) AddPacket(sequenceNumber int64, arrivalTime int64
m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime
m.setNotReceived(sequenceNumber+1, m.beginSequenceNumber)
m.beginSequenceNumber = sequenceNumber
return
}
@ -64,6 +69,7 @@ func (m *packetArrivalTimeMap) AddPacket(sequenceNumber int64, arrivalTime int64
m.beginSequenceNumber = sequenceNumber
m.endSequenceNumber = newEndSequenceNumber
m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime
return
}
@ -99,12 +105,15 @@ func (m *packetArrivalTimeMap) EndSequenceNumber() int64 {
// FindNextAtOrAfter returns the sequence number and timestamp of the first received packet that has a sequence number
// greator or equal to sequenceNumber.
func (m *packetArrivalTimeMap) FindNextAtOrAfter(sequenceNumber int64) (foundSequenceNumber int64, arrivalTime int64, ok bool) {
func (m *packetArrivalTimeMap) FindNextAtOrAfter(sequenceNumber int64) (
foundSequenceNumber int64, arrivalTime int64, ok bool,
) {
for sequenceNumber = m.Clamp(sequenceNumber); sequenceNumber < m.endSequenceNumber; sequenceNumber++ {
if t := m.get(sequenceNumber); t >= 0 {
return sequenceNumber, t, true
}
}
return -1, -1, false
}
@ -116,6 +125,7 @@ func (m *packetArrivalTimeMap) EraseTo(sequenceNumber int64) {
if sequenceNumber >= m.endSequenceNumber {
// Erase all.
m.beginSequenceNumber = m.endSequenceNumber
return
}
// Remove some
@ -138,7 +148,7 @@ func (m *packetArrivalTimeMap) HasReceived(sequenceNumber int64) bool {
return m.get(sequenceNumber) >= 0
}
// Clamp returns sequenceNumber clamped to [beginSequenceNumber, endSequenceNumber]
// Clamp returns sequenceNumber clamped to [beginSequenceNumber, endSequenceNumber].
func (m *packetArrivalTimeMap) Clamp(sequenceNumber int64) int64 {
if sequenceNumber < m.beginSequenceNumber {
return m.beginSequenceNumber
@ -146,6 +156,7 @@ func (m *packetArrivalTimeMap) Clamp(sequenceNumber int64) int64 {
if m.endSequenceNumber < sequenceNumber {
return m.endSequenceNumber
}
return sequenceNumber
}
@ -153,6 +164,7 @@ func (m *packetArrivalTimeMap) get(sequenceNumber int64) int64 {
if sequenceNumber < m.beginSequenceNumber || sequenceNumber >= m.endSequenceNumber {
return -1
}
return m.arrivalTimes[m.index(sequenceNumber)]
}

View File

@ -13,20 +13,20 @@ import (
var errHeaderIsNil = errors.New("header is nil")
// HeaderExtensionInterceptorFactory is a interceptor.Factory for a HeaderExtensionInterceptor
// HeaderExtensionInterceptorFactory is a interceptor.Factory for a HeaderExtensionInterceptor.
type HeaderExtensionInterceptorFactory struct{}
// NewInterceptor constructs a new HeaderExtensionInterceptor
// NewInterceptor constructs a new HeaderExtensionInterceptor.
func (h *HeaderExtensionInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
return &HeaderExtensionInterceptor{}, nil
}
// NewHeaderExtensionInterceptor returns a HeaderExtensionInterceptorFactory
// NewHeaderExtensionInterceptor returns a HeaderExtensionInterceptorFactory.
func NewHeaderExtensionInterceptor() (*HeaderExtensionInterceptorFactory, error) {
return &HeaderExtensionInterceptorFactory{}, nil
}
// HeaderExtensionInterceptor adds transport wide sequence numbers as header extension to each RTP packet
// HeaderExtensionInterceptor adds transport wide sequence numbers as header extension to each RTP packet.
type HeaderExtensionInterceptor struct {
interceptor.NoOp
nextSequenceNr uint32
@ -36,31 +36,39 @@ const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide
// BindLocalStream returns a writer that adds a rtp.TransportCCExtension
// header with increasing sequence numbers to each outgoing packet.
func (h *HeaderExtensionInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
func (h *HeaderExtensionInterceptor) BindLocalStream(
info *interceptor.StreamInfo,
writer interceptor.RTPWriter,
) interceptor.RTPWriter {
var hdrExtID uint8
for _, e := range info.RTPHeaderExtensions {
if e.URI == transportCCURI {
hdrExtID = uint8(e.ID)
hdrExtID = uint8(e.ID) //nolint:gosec // G115
break
}
}
if hdrExtID == 0 { // Don't add header extension if ID is 0, because 0 is an invalid extension ID
return writer
}
return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
sequenceNumber := atomic.AddUint32(&h.nextSequenceNr, 1) - 1
tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(sequenceNumber)}).Marshal()
if err != nil {
return 0, err
}
if header == nil {
return 0, errHeaderIsNil
}
err = header.SetExtension(hdrExtID, tcc)
if err != nil {
return 0, err
}
return writer.Write(header, payload, attributes)
})
return interceptor.RTPWriterFunc(
func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
sequenceNumber := atomic.AddUint32(&h.nextSequenceNr, 1) - 1
//nolint:gosec // G115
tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(sequenceNumber)}).Marshal()
if err != nil {
return 0, err
}
if header == nil {
return 0, errHeaderIsNil
}
err = header.SetExtension(hdrExtID, tcc)
if err != nil {
return 0, err
}
return writer.Write(header, payload, attributes)
},
)
}

View File

@ -14,16 +14,16 @@ import (
"github.com/pion/rtp"
)
// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor
// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor.
type SenderInterceptorFactory struct {
opts []Option
}
var errClosed = errors.New("interceptor is closed")
// NewInterceptor constructs a new SenderInterceptor
// NewInterceptor constructs a new SenderInterceptor.
func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
i := &SenderInterceptor{
senderInterceptor := &SenderInterceptor{
log: logging.NewDefaultLoggerFactory().NewLogger("twcc_sender_interceptor"),
packetChan: make(chan packet),
close: make(chan struct{}),
@ -32,13 +32,13 @@ func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interce
}
for _, opt := range s.opts {
err := opt(i)
err := opt(senderInterceptor)
if err != nil {
return nil, err
}
}
return i, nil
return senderInterceptor, nil
}
// NewSenderInterceptor returns a new SenderInterceptorFactory configured with the given options.
@ -64,7 +64,7 @@ type SenderInterceptor struct {
packetChan chan packet
}
// An Option is a function that can be used to configure a SenderInterceptor
// An Option is a function that can be used to configure a SenderInterceptor.
type Option func(*SenderInterceptor) error
// SendInterval sets the interval at which the interceptor
@ -72,6 +72,7 @@ type Option func(*SenderInterceptor) error
func SendInterval(interval time.Duration) Option {
return func(s *SenderInterceptor) error {
s.interval = interval
return nil
}
}
@ -102,54 +103,63 @@ type packet struct {
ssrc uint32
}
// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
// BindRemoteStream lets you modify any incoming RTP packets.
// It is called once for per RemoteStream. The returned method
// will be called once per rtp packet.
func (s *SenderInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
//
//nolint:cyclop
func (s *SenderInterceptor) BindRemoteStream(
info *interceptor.StreamInfo, reader interceptor.RTPReader,
) interceptor.RTPReader {
var hdrExtID uint8
for _, e := range info.RTPHeaderExtensions {
if e.URI == transportCCURI {
hdrExtID = uint8(e.ID)
hdrExtID = uint8(e.ID) //nolint:gosec // G115
break
}
}
if hdrExtID == 0 { // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID
return reader
}
return interceptor.RTPReaderFunc(func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
i, attr, err := reader.Read(buf, attributes)
if err != nil {
return 0, nil, err
}
if attr == nil {
attr = make(interceptor.Attributes)
}
header, err := attr.GetRTPHeader(buf[:i])
if err != nil {
return 0, nil, err
}
var tccExt rtp.TransportCCExtension
if ext := header.GetExtension(hdrExtID); ext != nil {
err = tccExt.Unmarshal(ext)
return interceptor.RTPReaderFunc(
func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
i, attr, err := reader.Read(buf, attributes)
if err != nil {
return 0, nil, err
}
p := packet{
hdr: header,
sequenceNumber: tccExt.TransportSequence,
arrivalTime: time.Since(s.startTime).Microseconds(),
ssrc: info.SSRC,
if attr == nil {
attr = make(interceptor.Attributes)
}
select {
case <-s.close:
return 0, nil, errClosed
case s.packetChan <- p:
header, err := attr.GetRTPHeader(buf[:i])
if err != nil {
return 0, nil, err
}
}
var tccExt rtp.TransportCCExtension
if ext := header.GetExtension(hdrExtID); ext != nil {
err = tccExt.Unmarshal(ext)
if err != nil {
return 0, nil, err
}
return i, attr, nil
})
p := packet{
hdr: header,
sequenceNumber: tccExt.TransportSequence,
arrivalTime: time.Since(s.startTime).Microseconds(),
ssrc: info.SSRC,
}
select {
case <-s.close:
return 0, nil, errClosed
case s.packetChan <- p:
}
}
return i, attr, nil
},
)
}
// Close closes the interceptor.
@ -174,7 +184,7 @@ func (s *SenderInterceptor) isClosed() bool {
}
}
func (s *SenderInterceptor) loop(w interceptor.RTCPWriter) {
func (s *SenderInterceptor) loop(writer interceptor.RTCPWriter) {
defer s.wg.Done()
select {
@ -189,6 +199,7 @@ func (s *SenderInterceptor) loop(w interceptor.RTCPWriter) {
select {
case <-s.close:
ticker.Stop()
return
case p := <-s.packetChan:
s.recorder.Record(p.ssrc, p.sequenceNumber, p.arrivalTime)
@ -199,7 +210,7 @@ func (s *SenderInterceptor) loop(w interceptor.RTCPWriter) {
if len(pkts) == 0 {
continue
}
if _, err := w.Write(pkts, nil); err != nil {
if _, err := writer.Write(pkts, nil); err != nil {
s.log.Error(err.Error())
}
}

View File

@ -71,12 +71,13 @@ func (r *Recorder) Record(mediaSSRC uint32, sequenceNumber uint16, arrivalTime i
}
func (r *Recorder) maybeCullOldPackets(sequenceNumber int64, arrivalTime int64) {
if r.startSequenceNumber != nil && *r.startSequenceNumber >= r.arrivalTimeMap.EndSequenceNumber() && arrivalTime >= packetWindowMicroseconds {
if r.startSequenceNumber != nil && *r.startSequenceNumber >= r.arrivalTimeMap.EndSequenceNumber() &&
arrivalTime >= packetWindowMicroseconds {
r.arrivalTimeMap.RemoveOldPackets(sequenceNumber, arrivalTime-packetWindowMicroseconds)
}
}
// PacketsHeld returns the number of received packets currently held by the recorder
// PacketsHeld returns the number of received packets currently held by the recorder.
func (r *Recorder) PacketsHeld() int {
return r.packetsHeld
}
@ -101,6 +102,7 @@ func (r *Recorder) BuildFeedbackPacket() []rtcp.Packet {
// old.
}
r.packetsHeld = 0
return feedbacks
}
@ -109,6 +111,7 @@ func (r *Recorder) BuildFeedbackPacket() []rtcp.Packet {
func (r *Recorder) maybeBuildFeedbackPacket(beginSeqNumInclusive, endSeqNumExclusive int64) *feedback {
// NOTE: The logic of this method is inspired by the implementation in Chrome.
// See https://source.chromium.org/chromium/chromium/src/+/refs/heads/main:third_party/webrtc/modules/remote_bitrate_estimator/remote_estimator_proxy.cc;l=276;drc=b5cd13bb6d5d157a5fbe3628b2dd1c1e106203c6
//nolint:lll
startSNInclusive, endSNExclusive := r.arrivalTimeMap.Clamp(beginSeqNumInclusive), r.arrivalTimeMap.Clamp(endSeqNumExclusive)
// Create feedback on demand, as we don't yet know if there are packets in the range that have been
@ -136,18 +139,19 @@ func (r *Recorder) maybeBuildFeedbackPacket(beginSeqNumInclusive, endSeqNumExclu
// baseSequenceNumber is the expected first sequence number. This is known,
// but we may not have actually received it, so the base time should be the time
// of the first received packet in the feedback.
fb.setBase(uint16(baseSequenceNumber), arrivalTime)
fb.setBase(uint16(baseSequenceNumber), arrivalTime) //nolint:gosec // G115
if !fb.addReceived(uint16(seq), arrivalTime) {
if !fb.addReceived(uint16(seq), arrivalTime) { //nolint:gosec // G115
// Could not add a single received packet to the feedback.
// This is unexpected to actually occur, but if it does, we'll
// try again after skipping any missing packets.
// NOTE: It's fine that we already incremented fbPktCnt, as in essence
// we did actually "skip" a feedback (and this matches Chrome's behavior).
r.startSequenceNumber = &seq
return nil
}
} else if !fb.addReceived(uint16(seq), arrivalTime) {
} else if !fb.addReceived(uint16(seq), arrivalTime) { //nolint:gosec // G115
// Could not add timestamp. Packet may be full. Return
// and try again with a fresh packet.
break
@ -157,6 +161,7 @@ func (r *Recorder) maybeBuildFeedbackPacket(beginSeqNumInclusive, endSeqNumExclu
}
r.startSequenceNumber = &nextSequenceNumber
return fb
}
@ -192,7 +197,7 @@ func (f *feedback) setBase(sequenceNumber uint16, timeUS int64) {
func (f *feedback) getRTCP() *rtcp.TransportLayerCC {
f.rtcp.PacketStatusCount = f.sequenceNumberCount
f.rtcp.ReferenceTime = uint32(f.refTimestamp64MS)
f.rtcp.ReferenceTime = uint32(f.refTimestamp64MS) //nolint:gosec // G115
f.rtcp.BaseSequenceNumber = f.baseSequenceNumber
for len(f.lastChunk.deltas) > 0 {
f.chunks = append(f.chunks, f.lastChunk.encode())
@ -200,7 +205,8 @@ func (f *feedback) getRTCP() *rtcp.TransportLayerCC {
f.rtcp.PacketChunks = append(f.rtcp.PacketChunks, f.chunks...)
f.rtcp.RecvDeltas = f.deltas
padLen := 20 + len(f.rtcp.PacketChunks)*2 + f.len // 4 bytes header + 16 bytes twcc header + 2 bytes for each chunk + length of deltas
// 4 bytes header + 16 bytes twcc header + 2 bytes for each chunk + length of deltas
padLen := 20 + len(f.rtcp.PacketChunks)*2 + f.len
padding := padLen%4 != 0
for padLen%4 != 0 {
padLen++
@ -209,7 +215,7 @@ func (f *feedback) getRTCP() *rtcp.TransportLayerCC {
Count: rtcp.FormatTCC,
Type: rtcp.TypeTransportSpecificFeedback,
Padding: padding,
Length: uint16((padLen / 4) - 1),
Length: uint16((padLen / 4) - 1), //nolint:gosec // G115
}
return f.rtcp
@ -223,7 +229,8 @@ func (f *feedback) addReceived(sequenceNumber uint16, timestampUS int64) bool {
} else {
delta250US = (deltaUS - rtcp.TypeTCCDeltaScaleFactor/2) / rtcp.TypeTCCDeltaScaleFactor
}
if delta250US < math.MinInt16 || delta250US > math.MaxInt16 { // delta doesn't fit into 16 bit, need to create new packet
// delta doesn't fit into 16 bit, need to create new packet
if delta250US < math.MinInt16 || delta250US > math.MaxInt16 {
return false
}
deltaUSRounded := delta250US * rtcp.TypeTCCDeltaScaleFactor
@ -257,6 +264,7 @@ func (f *feedback) addReceived(sequenceNumber uint16, timestampUS int64) bool {
f.lastTimestampUS += deltaUSRounded
f.sequenceNumberCount++
f.nextSequenceNumber++
return true
}
@ -282,6 +290,7 @@ func (c *chunk) canAdd(delta uint16) bool {
if len(c.deltas) < maxRunLengthCap && !c.hasDifferentTypes && delta == c.deltas[0] {
return true
}
return false
}
@ -294,13 +303,15 @@ func (c *chunk) add(delta uint16) {
func (c *chunk) encode() rtcp.PacketStatusChunk {
if !c.hasDifferentTypes {
defer c.reset()
return &rtcp.RunLengthChunk{
PacketStatusSymbol: c.deltas[0],
RunLength: uint16(len(c.deltas)),
RunLength: uint16(len(c.deltas)), //nolint:gosec // G115
}
}
if len(c.deltas) == maxOneBitCap {
defer c.reset()
return &rtcp.StatusVectorChunk{
SymbolSize: rtcp.TypeTCCSymbolSizeOneBit,
SymbolList: c.deltas,
@ -341,6 +352,7 @@ func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
@ -348,6 +360,7 @@ func minInt(a, b int) int {
if a < b {
return a
}
return b
}
@ -355,6 +368,7 @@ func max64(a, b int64) int64 {
if a > b {
return a
}
return b
}
@ -362,5 +376,6 @@ func min64(a, b int64) int64 {
if a < b {
return a
}
return b
}

View File

@ -13,7 +13,7 @@ func (r *Registry) Add(f Factory) {
r.factories = append(r.factories, f)
}
// Build constructs a single Interceptor from a InterceptorRegistry
// Build constructs a single Interceptor from a InterceptorRegistry.
func (r *Registry) Build(id string) (Interceptor, error) {
if len(r.factories) == 0 {
return &NoOp{}, nil

View File

@ -9,7 +9,7 @@ type RTPHeaderExtension struct {
ID int
}
// StreamInfo is the Context passed when a StreamLocal or StreamRemote has been Binded or Unbinded
// StreamInfo is the Context passed when a StreamLocal or StreamRemote has been Binded or Unbinded.
type StreamInfo struct {
ID string
Attributes Attributes

View File

@ -19,12 +19,16 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
@ -37,6 +41,12 @@ linters-settings:
- w io.Writer
- r io.Reader
- b []byte
revive:
rules:
# Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility
- name: use-any
severity: warning
disabled: false
linters:
enable:
@ -59,7 +69,6 @@ linters:
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- forcetypeassert # finds forced type assertions
- funlen # Tool for detection of long functions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
@ -106,6 +115,7 @@ linters:
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- funlen # Tool for detection of long functions
- gochecknoinits # Checks that no init functions are present in Go code
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- interfacebloat # A linter that checks length of interface.
@ -127,9 +137,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@ -6,7 +6,7 @@
<h4 align="center">The Pion logging library</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-logging-gray.svg?longCache=true&colorB=brightgreen" alt="Pion transport"></a>
<a href="http://gophers.slack.com/messages/pion"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<a href="https://discord.gg/PngbdqpFbt"><img src="https://img.shields.io/badge/join-us%20on%20discord-gray.svg?longCache=true&logo=discord&colorB=brightblue" alt="join us on Discord"></a> <a href="https://bsky.app/profile/pion.ly"><img src="https://img.shields.io/badge/follow-us%20on%20bluesky-gray.svg?longCache=true&logo=bluesky&colorB=brightblue" alt="Follow us on Bluesky"></a>
<br>
<img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pion/logging/test.yaml">
<a href="https://pkg.go.dev/github.com/pion/logging"><img src="https://pkg.go.dev/badge/github.com/pion/logging.svg" alt="Go Reference"></a>
@ -20,9 +20,9 @@
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Slack](https://pion.ly/slack).
Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt).
Follow the [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)

View File

@ -93,7 +93,7 @@ func (ll *DefaultLeveledLogger) WithOutput(output io.Writer) *DefaultLeveledLogg
return ll
}
func (ll *DefaultLeveledLogger) logf(logger *log.Logger, level LogLevel, format string, args ...interface{}) {
func (ll *DefaultLeveledLogger) logf(logger *log.Logger, level LogLevel, format string, args ...any) {
if ll.level.Get() < level {
return
}
@ -116,7 +116,7 @@ func (ll *DefaultLeveledLogger) Trace(msg string) {
}
// Tracef formats and emits a message if the logger is at or below LogLevelTrace.
func (ll *DefaultLeveledLogger) Tracef(format string, args ...interface{}) {
func (ll *DefaultLeveledLogger) Tracef(format string, args ...any) {
ll.logf(ll.trace, LogLevelTrace, format, args...)
}
@ -126,7 +126,7 @@ func (ll *DefaultLeveledLogger) Debug(msg string) {
}
// Debugf formats and emits a message if the logger is at or below LogLevelDebug.
func (ll *DefaultLeveledLogger) Debugf(format string, args ...interface{}) {
func (ll *DefaultLeveledLogger) Debugf(format string, args ...any) {
ll.logf(ll.debug, LogLevelDebug, format, args...)
}
@ -136,7 +136,7 @@ func (ll *DefaultLeveledLogger) Info(msg string) {
}
// Infof formats and emits a message if the logger is at or below LogLevelInfo.
func (ll *DefaultLeveledLogger) Infof(format string, args ...interface{}) {
func (ll *DefaultLeveledLogger) Infof(format string, args ...any) {
ll.logf(ll.info, LogLevelInfo, format, args...)
}
@ -146,7 +146,7 @@ func (ll *DefaultLeveledLogger) Warn(msg string) {
}
// Warnf formats and emits a message if the logger is at or below LogLevelWarn.
func (ll *DefaultLeveledLogger) Warnf(format string, args ...interface{}) {
func (ll *DefaultLeveledLogger) Warnf(format string, args ...any) {
ll.logf(ll.warn, LogLevelWarn, format, args...)
}
@ -156,7 +156,7 @@ func (ll *DefaultLeveledLogger) Error(msg string) {
}
// Errorf formats and emits a message if the logger is at or below LogLevelError.
func (ll *DefaultLeveledLogger) Errorf(format string, args ...interface{}) {
func (ll *DefaultLeveledLogger) Errorf(format string, args ...any) {
ll.logf(ll.err, LogLevelError, format, args...)
}

View File

@ -58,15 +58,15 @@ const (
// LeveledLogger is the basic pion Logger interface.
type LeveledLogger interface {
Trace(msg string)
Tracef(format string, args ...interface{})
Tracef(format string, args ...any)
Debug(msg string)
Debugf(format string, args ...interface{})
Debugf(format string, args ...any)
Info(msg string)
Infof(format string, args ...interface{})
Infof(format string, args ...any)
Warn(msg string)
Warnf(format string, args ...interface{})
Warnf(format string, args ...any)
Error(msg string)
Errorf(format string, args ...interface{})
Errorf(format string, args ...any)
}
// LoggerFactory is the basic pion LoggerFactory interface.

View File

@ -41,6 +41,12 @@ linters-settings:
- w io.Writer
- r io.Reader
- b []byte
revive:
rules:
# Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility
- name: use-any
severity: warning
disabled: false
linters:
enable:

View File

@ -15,8 +15,9 @@ import (
//
var (
errH265CorruptedPacket = errors.New("corrupted h265 packet")
errInvalidH265PacketType = errors.New("invalid h265 packet type")
errH265CorruptedPacket = errors.New("corrupted h265 packet")
errInvalidH265PacketType = errors.New("invalid h265 packet type")
errExpectFragmentationStartUnit = errors.New("expecting a fragmentation start unit")
)
//
@ -192,6 +193,15 @@ func (p *H265SingleNALUnitPacket) Payload() []byte {
func (p *H265SingleNALUnitPacket) isH265Packet() {}
func (p *H265SingleNALUnitPacket) doPackaging(buf []byte) []byte {
buf = append(buf, annexbNALUStartCode...)
buf = append(buf, byte(p.payloadHeader>>8), byte(p.payloadHeader&0xFF))
buf = append(buf, p.payload...)
return buf
}
//
// Aggregation Packets implementation
//
@ -399,6 +409,21 @@ func (p *H265AggregationPacket) OtherUnits() []H265AggregationUnit {
func (p *H265AggregationPacket) isH265Packet() {}
func (p *H265AggregationPacket) doPackaging(buf []byte) []byte {
if p.firstUnit == nil {
return buf
}
buf = append(buf, annexbNALUStartCode...)
buf = append(buf, p.firstUnit.nalUnit...)
for _, unit := range p.otherUnits {
buf = append(buf, annexbNALUStartCode...)
buf = append(buf, unit.nalUnit...)
}
return buf
}
//
// Fragmentation Unit implementation
//
@ -536,6 +561,64 @@ func (p *H265FragmentationUnitPacket) Payload() []byte {
func (p *H265FragmentationUnitPacket) isH265Packet() {}
// H265FragmentationPacket represents a Fragmentation packet, which contains one or more Fragmentation Units.
type H265FragmentationPacket struct {
payloadHeader H265NALUHeader
donl *uint16
units []*H265FragmentationUnitPacket
payload []byte
}
func NewH265FragmentationPacket(startUnit *H265FragmentationUnitPacket) *H265FragmentationPacket {
return &H265FragmentationPacket{
payloadHeader: (startUnit.payloadHeader & 0x81FF) | (H265NALUHeader(startUnit.FuHeader().FuType()) << 9),
donl: startUnit.donl,
units: []*H265FragmentationUnitPacket{startUnit},
}
}
// PayloadHeader returns the NALU header of the packet.
func (p *H265FragmentationPacket) PayloadHeader() H265NALUHeader {
return p.payloadHeader
}
// DONL returns the DONL of the packet.
func (p *H265FragmentationPacket) DONL() *uint16 {
return p.donl
}
// Payload returns the Fragmentation packet payload.
func (p *H265FragmentationPacket) Payload() []byte {
return p.payload
}
func (p *H265FragmentationPacket) isH265Packet() {}
func (p *H265FragmentationPacket) doPackaging(buf []byte) []byte {
if len(p.payload) == 0 {
return buf
}
buf = append(buf, annexbNALUStartCode...)
buf = append(buf, byte(p.payloadHeader>>8), byte(p.payloadHeader&0xFF))
buf = append(buf, p.payload...)
return buf
}
func (p *H265FragmentationPacket) appendUnit(unit *H265FragmentationUnitPacket) {
if len(p.payload) > 0 {
// already have end unit
return
}
p.units = append(p.units, unit)
if unit.FuHeader().E() {
for _, u := range p.units {
p.payload = append(p.payload, u.payload...)
}
}
}
//
// PACI implementation
//
@ -691,6 +774,21 @@ func (p *H265PACIPacket) Unmarshal(payload []byte) ([]byte, error) {
func (p *H265PACIPacket) isH265Packet() {}
func (p *H265PACIPacket) doPackaging(buf []byte) []byte {
buf = append(buf, annexbNALUStartCode...)
buf = append(buf, byte(p.payloadHeader>>8), byte(p.payloadHeader&0xFF))
buf = binary.BigEndian.AppendUint16(buf, p.paciHeaderFields)
if len(p.phes) > 0 {
buf = append(buf, p.phes...)
}
buf = append(buf, p.payload...)
return buf
}
//
// Temporal Scalability Control Information
//
@ -745,10 +843,11 @@ func (h H265TSCI) RES() uint8 {
type isH265Packet interface {
isH265Packet()
doPackaging([]byte) []byte
}
var (
_ isH265Packet = (*H265FragmentationUnitPacket)(nil)
_ isH265Packet = (*H265FragmentationPacket)(nil)
_ isH265Packet = (*H265PACIPacket)(nil)
_ isH265Packet = (*H265SingleNALUnitPacket)(nil)
_ isH265Packet = (*H265AggregationPacket)(nil)
@ -802,7 +901,15 @@ func (p *H265Packet) Unmarshal(payload []byte) ([]byte, error) { // nolint:cyclo
return nil, err
}
p.packet = decoded
if decoded.FuHeader().S() {
p.packet = NewH265FragmentationPacket(decoded)
} else {
if fu, ok := p.packet.(*H265FragmentationPacket); !ok {
return nil, errExpectFragmentationStartUnit
} else {
fu.appendUnit(decoded)
}
}
case payloadHeader.IsAggregationPacket():
decoded := &H265AggregationPacket{}
@ -825,7 +932,7 @@ func (p *H265Packet) Unmarshal(payload []byte) ([]byte, error) { // nolint:cyclo
p.packet = decoded
}
return nil, nil
return p.packet.doPackaging(nil), nil
}
// Packet returns the populated packet.

View File

@ -10,9 +10,7 @@ import (
)
const (
headerExtensionProfileOneByte = 0xBEDE
headerExtensionProfileTwoByte = 0x1000
headerExtensionIDReserved = 0xF
headerExtensionIDReserved = 0xF
)
// HeaderExtension represents an RTP extension header.
@ -140,7 +138,7 @@ func (e *OneByteHeaderExtension) Del(id uint8) error {
// Unmarshal parses the extension payload.
func (e *OneByteHeaderExtension) Unmarshal(buf []byte) (int, error) {
profile := binary.BigEndian.Uint16(buf[0:2])
if profile != headerExtensionProfileOneByte {
if profile != ExtensionProfileOneByte {
return 0, fmt.Errorf("%w actual(%x)", errHeaderExtensionNotFound, buf[0:2])
}
e.payload = buf
@ -283,7 +281,7 @@ func (e *TwoByteHeaderExtension) Del(id uint8) error {
// Unmarshal parses the extension payload.
func (e *TwoByteHeaderExtension) Unmarshal(buf []byte) (int, error) {
profile := binary.BigEndian.Uint16(buf[0:2])
if profile != headerExtensionProfileTwoByte {
if profile != ExtensionProfileTwoByte {
return 0, fmt.Errorf("%w actual(%x)", errHeaderExtensionNotFound, buf[0:2])
}
e.payload = buf
@ -354,7 +352,7 @@ func (e *RawExtension) Del(id uint8) error {
// Unmarshal parses the extension from the given buffer.
func (e *RawExtension) Unmarshal(buf []byte) (int, error) {
profile := binary.BigEndian.Uint16(buf[0:2])
if profile == headerExtensionProfileOneByte || profile == headerExtensionProfileTwoByte {
if profile == ExtensionProfileOneByte || profile == ExtensionProfileTwoByte {
return 0, fmt.Errorf("%w actual(%x)", errHeaderExtensionNotFound, buf[0:2])
}
e.payload = buf

View File

@ -29,6 +29,10 @@ type Header struct {
ExtensionProfile uint16
Extensions []Extension
// PaddingLength is the length of the padding in bytes. It is not part of the RTP header
// (it is sent in the last byte of RTP packet padding), but logically it belongs here.
PaddingSize byte
// Deprecated: will be removed in a future version.
PayloadOffset int
}
@ -36,36 +40,50 @@ type Header struct {
// Packet represents an RTP Packet.
type Packet struct {
Header
Payload []byte
PaddingSize byte
Payload []byte
PaddingSize byte // Deprecated: will be removed in a future version. Use Header.PaddingSize instead.
// Deprecated: will be removed in a future version.
Raw []byte
// Please do not add any new field directly to Packet struct unless you know that it is safe.
// pion internally passes Header and Payload separately, what causes bugs like
// https://github.com/pion/webrtc/issues/2403 .
}
const (
headerLength = 4
versionShift = 6
versionMask = 0x3
paddingShift = 5
paddingMask = 0x1
extensionShift = 4
extensionMask = 0x1
extensionProfileOneByte = 0xBEDE
extensionProfileTwoByte = 0x1000
extensionIDReserved = 0xF
ccMask = 0xF
markerShift = 7
markerMask = 0x1
ptMask = 0x7F
seqNumOffset = 2
seqNumLength = 2
timestampOffset = 4
timestampLength = 4
ssrcOffset = 8
ssrcLength = 4
csrcOffset = 12
csrcLength = 4
// ExtensionProfileOneByte is the RTP One Byte Header Extension Profile, defined in RFC 8285.
ExtensionProfileOneByte = 0xBEDE
// ExtensionProfileTwoByte is the RTP Two Byte Header Extension Profile, defined in RFC 8285.
ExtensionProfileTwoByte = 0x1000
// CryptexProfileOneByte is the Cryptex One Byte Header Extension Profile, defined in RFC 9335.
CryptexProfileOneByte = 0xC0DE
// CryptexProfileTwoByte is the Cryptex Two Byte Header Extension Profile, defined in RFC 9335.
CryptexProfileTwoByte = 0xC2DE
)
const (
headerLength = 4
versionShift = 6
versionMask = 0x3
paddingShift = 5
paddingMask = 0x1
extensionShift = 4
extensionMask = 0x1
extensionIDReserved = 0xF
ccMask = 0xF
markerShift = 7
markerMask = 0x1
ptMask = 0x7F
seqNumOffset = 2
seqNumLength = 2
timestampOffset = 4
timestampLength = 4
ssrcOffset = 8
ssrcLength = 4
csrcOffset = 12
csrcLength = 4
)
// String helps with debugging by printing packet information in a readable way.
@ -155,7 +173,7 @@ func (h *Header) Unmarshal(buf []byte) (n int, err error) { //nolint:gocognit,cy
return n, fmt.Errorf("size %d < %d: %w", len(buf), extensionEnd, errHeaderSizeInsufficientForExtension)
}
if h.ExtensionProfile == extensionProfileOneByte || h.ExtensionProfile == extensionProfileTwoByte {
if h.ExtensionProfile == ExtensionProfileOneByte || h.ExtensionProfile == ExtensionProfileTwoByte {
var (
extid uint8
payloadLen int
@ -168,7 +186,7 @@ func (h *Header) Unmarshal(buf []byte) (n int, err error) { //nolint:gocognit,cy
continue
}
if h.ExtensionProfile == extensionProfileOneByte {
if h.ExtensionProfile == ExtensionProfileOneByte {
extid = buf[n] >> 4
payloadLen = int(buf[n]&^0xF0 + 1)
n++
@ -219,11 +237,12 @@ func (p *Packet) Unmarshal(buf []byte) error {
if end <= n {
return errTooSmall
}
p.PaddingSize = buf[end-1]
end -= int(p.PaddingSize)
p.Header.PaddingSize = buf[end-1]
end -= int(p.Header.PaddingSize)
} else {
p.PaddingSize = 0
p.Header.PaddingSize = 0
}
p.PaddingSize = p.Header.PaddingSize
if end < n {
return errTooSmall
}
@ -302,14 +321,14 @@ func (h Header) MarshalTo(buf []byte) (n int, err error) { //nolint:cyclop
switch h.ExtensionProfile {
// RFC 8285 RTP One Byte Header Extension
case extensionProfileOneByte:
case ExtensionProfileOneByte:
for _, extension := range h.Extensions {
buf[n] = extension.id<<4 | (uint8(len(extension.payload)) - 1) // nolint: gosec // G115
n++
n += copy(buf[n:], extension.payload)
}
// RFC 8285 RTP Two Byte Header Extension
case extensionProfileTwoByte:
case ExtensionProfileTwoByte:
for _, extension := range h.Extensions {
buf[n] = extension.id
n++
@ -353,12 +372,12 @@ func (h Header) MarshalSize() int {
switch h.ExtensionProfile {
// RFC 8285 RTP One Byte Header Extension
case extensionProfileOneByte:
case ExtensionProfileOneByte:
for _, extension := range h.Extensions {
extSize += 1 + len(extension.payload)
}
// RFC 8285 RTP Two Byte Header Extension
case extensionProfileTwoByte:
case ExtensionProfileTwoByte:
for _, extension := range h.Extensions {
extSize += 2 + len(extension.payload)
}
@ -378,7 +397,7 @@ func (h *Header) SetExtension(id uint8, payload []byte) error { //nolint:gocogni
if h.Extension { // nolint: nestif
switch h.ExtensionProfile {
// RFC 8285 RTP One Byte Header Extension
case extensionProfileOneByte:
case ExtensionProfileOneByte:
if id < 1 || id > 14 {
return fmt.Errorf("%w actual(%d)", errRFC8285OneByteHeaderIDRange, id)
}
@ -386,7 +405,7 @@ func (h *Header) SetExtension(id uint8, payload []byte) error { //nolint:gocogni
return fmt.Errorf("%w actual(%d)", errRFC8285OneByteHeaderSize, len(payload))
}
// RFC 8285 RTP Two Byte Header Extension
case extensionProfileTwoByte:
case ExtensionProfileTwoByte:
if id < 1 {
return fmt.Errorf("%w actual(%d)", errRFC8285TwoByteHeaderIDRange, id)
}
@ -418,9 +437,9 @@ func (h *Header) SetExtension(id uint8, payload []byte) error { //nolint:gocogni
switch payloadLen := len(payload); {
case payloadLen <= 16:
h.ExtensionProfile = extensionProfileOneByte
h.ExtensionProfile = ExtensionProfileOneByte
case payloadLen > 16 && payloadLen < 256:
h.ExtensionProfile = extensionProfileTwoByte
h.ExtensionProfile = ExtensionProfileTwoByte
}
h.Extensions = append(h.Extensions, Extension{id: id, payload: payload})
@ -490,7 +509,7 @@ func (p Packet) Marshal() (buf []byte, err error) {
// MarshalTo serializes the packet and writes to the buffer.
func (p *Packet) MarshalTo(buf []byte) (n int, err error) {
if p.Header.Padding && p.PaddingSize == 0 {
if p.Header.Padding && p.paddingSize() == 0 {
return 0, errInvalidRTPPadding
}
@ -499,23 +518,28 @@ func (p *Packet) MarshalTo(buf []byte) (n int, err error) {
return 0, err
}
return marshalPayloadAndPaddingTo(buf, n, &p.Header, p.Payload, p.paddingSize())
}
func marshalPayloadAndPaddingTo(buf []byte, offset int, header *Header, payload []byte, paddingSize byte,
) (n int, err error) {
// Make sure the buffer is large enough to hold the packet.
if n+len(p.Payload)+int(p.PaddingSize) > len(buf) {
if offset+len(payload)+int(paddingSize) > len(buf) {
return 0, io.ErrShortBuffer
}
m := copy(buf[n:], p.Payload)
m := copy(buf[offset:], payload)
if p.Header.Padding {
buf[n+m+int(p.PaddingSize-1)] = p.PaddingSize
if header.Padding {
buf[offset+m+int(paddingSize-1)] = paddingSize
}
return n + m + int(p.PaddingSize), nil
return offset + m + int(paddingSize), nil
}
// MarshalSize returns the size of the packet once marshaled.
func (p Packet) MarshalSize() int {
return p.Header.MarshalSize() + len(p.Payload) + int(p.PaddingSize)
return p.Header.MarshalSize() + len(p.Payload) + int(p.paddingSize())
}
// Clone returns a deep copy of p.
@ -552,3 +576,45 @@ func (h Header) Clone() Header {
return clone
}
func (p *Packet) paddingSize() byte {
if p.Header.PaddingSize > 0 {
return p.Header.PaddingSize
}
return p.PaddingSize
}
// MarshalPacketTo serializes the header and payload into bytes.
// Parts of pion code passes RTP header and payload separately, so this function
// is provided to help with that.
//
// Deprecated: this function is a temporary workaround and will be removed in pion/webrtc v5.
func MarshalPacketTo(buf []byte, header *Header, payload []byte) (int, error) {
n, err := header.MarshalTo(buf)
if err != nil {
return 0, err
}
return marshalPayloadAndPaddingTo(buf, n, header, payload, header.PaddingSize)
}
// PacketMarshalSize returns the size of the header and payload once marshaled.
// Parts of pion code passes RTP header and payload separately, so this function
// is provided to help with that.
//
// Deprecated: this function is a temporary workaround and will be removed in pion/webrtc v5.
func PacketMarshalSize(header *Header, payload []byte) int {
return header.MarshalSize() + len(payload) + int(header.PaddingSize)
}
// HeaderAndPacketMarshalSize returns the size of the header and full packet once marshaled.
// Parts of pion code passes RTP header and payload separately, so this function
// is provided to help with that.
//
// Deprecated: this function is a temporary workaround and will be removed in pion/webrtc v5.
func HeaderAndPacketMarshalSize(header *Header, payload []byte) (headerSize int, packetSize int) {
headerSize = header.MarshalSize()
return headerSize, headerSize + len(payload) + int(header.PaddingSize)
}

View File

@ -165,9 +165,6 @@ func (p *packetizer) GeneratePadding(samples uint32) []*Packet {
packets := make([]*Packet, samples)
for i := 0; i < int(samples); i++ {
pp := make([]byte, 255)
pp[254] = 255
packets[i] = &Packet{
Header: Header{
Version: 2,
@ -179,8 +176,8 @@ func (p *packetizer) GeneratePadding(samples uint32) []*Packet {
Timestamp: p.Timestamp, // Use latest timestamp
SSRC: p.SSRC,
CSRC: []uint32{},
PaddingSize: 255,
},
Payload: pp,
}
}

View File

@ -19,12 +19,16 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
@ -127,9 +131,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@ -7,10 +7,10 @@
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-sdp-gray.svg?longCache=true&colorB=brightgreen" alt="Pion SDP"></a>
<a href="https://sourcegraph.com/github.com/pion/sdp?badge"><img src="https://sourcegraph.com/github.com/pion/sdp/-/badge.svg" alt="Sourcegraph Widget"></a>
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<a href="https://discord.gg/PngbdqpFbt"><img src="https://img.shields.io/badge/join-us%20on%20discord-gray.svg?longCache=true&logo=discord&colorB=brightblue" alt="join us on Discord"></a> <a href="https://bsky.app/profile/pion.ly"><img src="https://img.shields.io/badge/follow-us%20on%20bluesky-gray.svg?longCache=true&logo=bluesky&colorB=brightblue" alt="Follow us on Bluesky"></a>
<br>
<img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pion/sdp/test.yaml">
<a href="https://pkg.go.dev/github.com/pion/sdp/v2"><img src="https://pkg.go.dev/badge/github.com/pion/sdp/v2.svg" alt="Go Reference"></a>
<a href="https://pkg.go.dev/github.com/pion/sdp/v3"><img src="https://pkg.go.dev/badge/github.com/pion/sdp/v3.svg" alt="Go Reference"></a>
<a href="https://codecov.io/gh/pion/sdp"><img src="https://codecov.io/gh/pion/sdp/branch/master/graph/badge.svg" alt="Coverage Status"></a>
<a href="https://goreportcard.com/report/github.com/pion/sdp"><img src="https://goreportcard.com/badge/github.com/pion/sdp" alt="Go Report Card"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
@ -21,9 +21,9 @@
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Slack](https://pion.ly/slack).
Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt).
Follow the [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)
@ -32,4 +32,4 @@ If you need commercial support or don't want to use public methods you can conta
Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible
### License
MIT License - see [LICENSE](LICENSE) for full text
MIT License - see [LICENSE](LICENSE) for full text

View File

@ -23,6 +23,7 @@ const (
AttrKeyConnectionSetup = "setup"
AttrKeyMID = "mid"
AttrKeyICELite = "ice-lite"
AttrKeyICEOptions = "ice-options"
AttrKeyRTCPMux = "rtcp-mux"
AttrKeyRTCPRsize = "rtcp-rsize"
AttrKeyInactive = "inactive"
@ -31,6 +32,7 @@ const (
AttrKeySendRecv = "sendrecv"
AttrKeyExtMap = "extmap"
AttrKeyExtMapAllowMixed = "extmap-allow-mixed"
AttrKeyCryptex = "cryptex"
)
// Constants for semantic tokens used in JSEP.
@ -38,7 +40,9 @@ const (
SemanticTokenLipSynchronization = "LS"
SemanticTokenFlowIdentification = "FID"
SemanticTokenForwardErrorCorrection = "FEC"
SemanticTokenWebRTCMediaStreams = "WMS"
// https://datatracker.ietf.org/doc/html/rfc5956#section-4.1
SemanticTokenForwardErrorCorrectionFramework = "FEC-FR"
SemanticTokenWebRTCMediaStreams = "WMS"
)
// Constants for extmap key.
@ -113,6 +117,12 @@ func (s *SessionDescription) WithValueAttribute(key, value string) *SessionDescr
return s
}
// WithICETrickleAdvertised advertises ICE trickle support in the session description.
// See https://datatracker.ietf.org/doc/html/rfc9429#section-5.2.1
func (s *SessionDescription) WithICETrickleAdvertised() *SessionDescription {
return s.WithValueAttribute(AttrKeyICEOptions, "trickle")
}
// WithFingerprint adds a fingerprint to the session description.
func (s *SessionDescription) WithFingerprint(algorithm, value string) *SessionDescription {
return s.WithValueAttribute("fingerprint", algorithm+" "+value)

View File

@ -986,10 +986,10 @@ func timeShorthand(b byte) int64 {
func parsePort(value string) (int, error) {
port, err := strconv.Atoi(value)
if err != nil {
return 0, fmt.Errorf("%w `%v`", errSDPInvalidPortValue, port)
return 0, fmt.Errorf("%w `%v`", errSDPInvalidPortValue, value)
}
if port < 0 || port > 65536 {
if port < 0 || port > 65535 {
return 0, fmt.Errorf("%w -- out of range `%v`", errSDPInvalidPortValue, port)
}

View File

@ -19,16 +19,12 @@ const (
)
var (
errExtractCodecRtpmap = errors.New("could not extract codec from rtpmap")
errExtractCodecFmtp = errors.New("could not extract codec from fmtp")
errExtractCodecRtcpFb = errors.New("could not extract codec from rtcp-fb")
errMultipleName = errors.New("codec has multiple names defined")
errMultipleClockRate = errors.New("codec has multiple clock rates")
errMultipleEncodingParameters = errors.New("codec has multiple encoding parameters")
errMultipleFmtp = errors.New("codec has multiple fmtp values")
errPayloadTypeNotFound = errors.New("payload type not found")
errCodecNotFound = errors.New("codec not found")
errSyntaxError = errors.New("SyntaxError")
errExtractCodecRtpmap = errors.New("could not extract codec from rtpmap")
errExtractCodecFmtp = errors.New("could not extract codec from fmtp")
errExtractCodecRtcpFb = errors.New("could not extract codec from rtcp-fb")
errPayloadTypeNotFound = errors.New("payload type not found")
errCodecNotFound = errors.New("codec not found")
errSyntaxError = errors.New("SyntaxError")
)
// ConnectionRole indicates which of the end points should initiate the connection establishment.
@ -207,49 +203,30 @@ func parseRtcpFb(rtcpFb string) (codec Codec, isWildcard bool, err error) {
return codec, isWildcard, nil
}
func mergeCodecs(codec Codec, codecs map[uint8]Codec) error { // nolint: cyclop
func mergeCodecs(codec Codec, codecs map[uint8]Codec) {
savedCodec := codecs[codec.PayloadType]
savedCodec.PayloadType = codec.PayloadType
if codec.Name != "" {
if savedCodec.Name != "" && savedCodec.Name != codec.Name {
return errMultipleName
}
if savedCodec.PayloadType == 0 {
savedCodec.PayloadType = codec.PayloadType
}
if savedCodec.Name == "" {
savedCodec.Name = codec.Name
}
if codec.ClockRate != 0 {
if savedCodec.ClockRate != 0 && savedCodec.ClockRate != codec.ClockRate {
return errMultipleClockRate
}
if savedCodec.ClockRate == 0 {
savedCodec.ClockRate = codec.ClockRate
}
if codec.EncodingParameters != "" {
if savedCodec.EncodingParameters != "" && savedCodec.EncodingParameters != codec.EncodingParameters {
return errMultipleEncodingParameters
}
if savedCodec.EncodingParameters == "" {
savedCodec.EncodingParameters = codec.EncodingParameters
}
if codec.Fmtp != "" {
if savedCodec.Fmtp != "" && savedCodec.Fmtp != codec.Fmtp {
return errMultipleFmtp
}
if savedCodec.Fmtp == "" {
savedCodec.Fmtp = codec.Fmtp
}
savedCodec.RTCPFeedback = append(savedCodec.RTCPFeedback, codec.RTCPFeedback...)
codecs[savedCodec.PayloadType] = savedCodec
return nil
codecs[savedCodec.PayloadType] = savedCodec
}
func (s *SessionDescription) buildCodecMap() (map[uint8]Codec, error) { //nolint:cyclop, gocognit
func (s *SessionDescription) buildCodecMap() map[uint8]Codec { //nolint:cyclop
codecs := map[uint8]Codec{
// static codecs that do not require a rtpmap
0: {
@ -272,16 +249,12 @@ func (s *SessionDescription) buildCodecMap() (map[uint8]Codec, error) { //nolint
case strings.HasPrefix(attr, "rtpmap:"):
codec, err := parseRtpmap(attr)
if err == nil {
if err = mergeCodecs(codec, codecs); err != nil {
return nil, err
}
mergeCodecs(codec, codecs)
}
case strings.HasPrefix(attr, "fmtp:"):
codec, err := parseFmtp(attr)
if err == nil {
if err = mergeCodecs(codec, codecs); err != nil {
return nil, err
}
mergeCodecs(codec, codecs)
}
case strings.HasPrefix(attr, "rtcp-fb:"):
codec, isWildcard, err := parseRtcpFb(attr)
@ -290,9 +263,7 @@ func (s *SessionDescription) buildCodecMap() (map[uint8]Codec, error) { //nolint
case isWildcard:
wildcardRTCPFeedback = append(wildcardRTCPFeedback, codec.RTCPFeedback...)
default:
if err = mergeCodecs(codec, codecs); err != nil {
return nil, err
}
mergeCodecs(codec, codecs)
}
}
}
@ -306,7 +277,7 @@ func (s *SessionDescription) buildCodecMap() (map[uint8]Codec, error) { //nolint
codecs[i] = codec
}
return codecs, nil
return codecs
}
func equivalentFmtp(want, got string) bool {
@ -350,10 +321,7 @@ func codecsMatch(wanted, got Codec) bool {
// GetCodecForPayloadType scans the SessionDescription for the given payload type and returns the codec.
func (s *SessionDescription) GetCodecForPayloadType(payloadType uint8) (Codec, error) {
codecs, err := s.buildCodecMap()
if err != nil {
return Codec{}, err
}
codecs := s.buildCodecMap()
codec, ok := codecs[payloadType]
if ok {
@ -366,10 +334,7 @@ func (s *SessionDescription) GetCodecForPayloadType(payloadType uint8) (Codec, e
// GetPayloadTypeForCodec scans the SessionDescription for a codec that matches the provided codec
// as closely as possible and returns its payload type.
func (s *SessionDescription) GetPayloadTypeForCodec(wanted Codec) (uint8, error) {
codecs, err := s.buildCodecMap()
if err != nil {
return 0, err
}
codecs := s.buildCodecMap()
for payloadType, codec := range codecs {
if codecsMatch(wanted, codec) {

View File

@ -19,23 +19,42 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@ -46,18 +65,17 @@ linters:
- forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
@ -65,9 +83,15 @@ linters:
- grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes
@ -75,28 +99,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- gochecknoinits # Checks that no init functions are present in Go code
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- mnd # An analyzer to detect magic numbers
- nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated
@ -104,8 +122,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- varnamelen # checks that the length of a variable's name matches its scope
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines!
@ -114,9 +131,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@ -7,7 +7,7 @@
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-srtp-gray.svg?longCache=true&colorB=brightgreen" alt="Pion SRTP"></a>
<a href="https://sourcegraph.com/github.com/pion/srtp?badge"><img src="https://sourcegraph.com/github.com/pion/srtp/-/badge.svg" alt="Sourcegraph Widget"></a>
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<a href="https://discord.gg/PngbdqpFbt"><img src="https://img.shields.io/badge/join-us%20on%20discord-gray.svg?longCache=true&logo=discord&colorB=brightblue" alt="join us on Discord"></a> <a href="https://bsky.app/profile/pion.ly"><img src="https://img.shields.io/badge/follow-us%20on%20bluesky-gray.svg?longCache=true&logo=bluesky&colorB=brightblue" alt="Follow us on Bluesky"></a>
<br>
<img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pion/srtp/test.yaml">
<a href="https://pkg.go.dev/github.com/pion/srtp"><img src="https://pkg.go.dev/badge/github.com/pion/srtp.svg" alt="Go Reference"></a>
@ -21,9 +21,9 @@
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Slack](https://pion.ly/slack).
Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt).
Follow the [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)

View File

@ -24,11 +24,9 @@ const (
seqNumMedian = 1 << 15
seqNumMax = 1 << 16
srtcpIndexSize = 4
)
// Encrypt/Decrypt state for a single SRTP SSRC
// Encrypt/Decrypt state for a single SRTP SSRC.
type srtpSSRCState struct {
ssrc uint32
rolloverHasProcessed bool
@ -36,13 +34,30 @@ type srtpSSRCState struct {
replayDetector replaydetector.ReplayDetector
}
// Encrypt/Decrypt state for a single SRTCP SSRC
// Encrypt/Decrypt state for a single SRTCP SSRC.
type srtcpSSRCState struct {
srtcpIndex uint32
ssrc uint32
replayDetector replaydetector.ReplayDetector
}
// RCCMode is the mode of Roll-over Counter Carrying Transform from RFC 4771.
type RCCMode int
const (
// RCCModeNone is the default mode.
RCCModeNone RCCMode = iota
// RCCMode1 is RCCm1 mode from RFC 4771. In this mode ROC and truncated auth tag is sent every R-th packet,
// and no auth tag in other ones. This mode is not supported by pion/srtp.
RCCMode1
// RCCMode2 is RCCm2 mode from RFC 4771. In this mode ROC and truncated auth tag is sent every R-th packet,
// and full auth tag in other ones. This mode is supported for AES-CM and NULL profiles only.
RCCMode2
// RCCMode3 is RCCm3 mode from RFC 4771. In this mode ROC is sent every R-th packet (without truncated auth tag),
// and no auth tag in other ones. This mode is supported for AES-GCM profiles only.
RCCMode3
)
// Context represents a SRTP cryptographic context.
// Context can only be used for one-way operations.
// it must either used ONLY for encryption or ONLY for decryption.
@ -60,11 +75,18 @@ type Context struct {
profile ProtectionProfile
sendMKI []byte // Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled.
mkis map[string]srtpCipher // Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled.
// Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled.
sendMKI []byte
// Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled.
mkis map[string]srtpCipher
encryptSRTP bool
encryptSRTCP bool
rccMode RCCMode
rocTransmitRate uint16
authTagRTPLen *int
}
// CreateContext creates a new SRTP Context.
@ -74,7 +96,11 @@ type Context struct {
// Following example create SRTP Context with replay protection with window size of 256.
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
func CreateContext(
masterKey, masterSalt []byte,
profile ProtectionProfile,
opts ...ContextOption,
) (c *Context, err error) {
c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
@ -96,6 +122,21 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts
}
}
if err = c.checkRCCMode(); err != nil {
return nil, err
}
if c.authTagRTPLen != nil {
var authKeyLen int
authKeyLen, err = c.profile.AuthKeyLen()
if err != nil {
return nil, err
}
if *c.authTagRTPLen > authKeyLen {
return nil, errTooLongSRTPAuthTag
}
}
c.cipher, err = c.createCipher(c.sendMKI, masterKey, masterSalt, c.encryptSRTP, c.encryptSRTCP)
if err != nil {
return nil, err
@ -107,7 +148,8 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts
return c, nil
}
// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option
// AddCipherForMKI adds new MKI with associated masker key and salt.
// Context must be created with MasterKeyIndicator option
// to enable MKI support. MKI must be unique and have the same length as the one used for creating Context.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error {
@ -126,6 +168,7 @@ func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error {
return err
}
c.mkis[string(mki)] = cipher
return nil
}
@ -141,18 +184,26 @@ func (c *Context) createCipher(mki, masterKey, masterSalt []byte, encryptSRTP, e
}
if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, keyLen, masterKey)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
}
profileWithArgs := protectionProfileWithArgs{
ProtectionProfile: c.profile,
authTagRTPLen: c.authTagRTPLen,
}
switch c.profile {
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
return newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP)
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80:
return newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP)
return newSrtpCipherAeadAesGcm(profileWithArgs, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP)
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAes256CmHmacSha1_32,
ProtectionProfileAes256CmHmacSha1_80:
return newSrtpCipherAesCmHmacSha1(profileWithArgs, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP)
case ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
return newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki, false, false)
return newSrtpCipherAesCmHmacSha1(profileWithArgs, masterKey, masterSalt, mki, false, false)
default:
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile)
}
@ -168,6 +219,7 @@ func (c *Context) RemoveMKI(mki []byte) error {
return errMKIAlreadyInUse
}
delete(c.mkis, string(mki))
return nil
}
@ -180,19 +232,20 @@ func (c *Context) SetSendMKI(mki []byte) error {
}
c.sendMKI = mki
c.cipher = cipher
return nil
}
// https://tools.ietf.org/html/rfc3550#appendix-A.1
func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) {
func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int64, overflow bool) {
seq := int32(sequenceNumber)
localRoc := uint32(s.index >> 16)
localSeq := int32(s.index & (seqNumMax - 1))
localRoc := uint32(s.index >> 16) //nolint:gosec // G115
localSeq := int32(s.index & (seqNumMax - 1)) //nolint:gosec // G115
guessRoc := localRoc
var difference int32
if s.rolloverHasProcessed {
if s.rolloverHasProcessed { //nolint:nestif
// When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
// judgment, it will cause guessRoc calculation error
if s.index > seqNumMedian {
@ -219,16 +272,20 @@ func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, di
}
}
return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC)
return guessRoc, int64(difference), (guessRoc == 0 && localRoc == maxROC)
}
func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) {
if !s.rolloverHasProcessed {
func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int64, hasRemoteRoc bool,
remoteRoc uint32,
) {
switch {
case hasRemoteRoc:
s.index = (uint64(remoteRoc) << 16) | uint64(sequenceNumber)
s.rolloverHasProcessed = true
case !s.rolloverHasProcessed:
s.index |= uint64(sequenceNumber)
s.rolloverHasProcessed = true
return
}
if difference > 0 {
case difference > 0:
s.index += uint64(difference)
}
}
@ -244,6 +301,7 @@ func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState {
replayDetector: c.newSRTPReplayDetector(),
}
c.srtpSSRCStates[ssrc] = s
return s
}
@ -258,6 +316,7 @@ func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState {
replayDetector: c.newSRTCPReplayDetector(),
}
c.srtcpSSRCStates[ssrc] = s
return s
}
@ -267,7 +326,8 @@ func (c *Context) ROC(ssrc uint32) (uint32, bool) {
if !ok {
return 0, false
}
return uint32(s.index >> 16), true
return uint32(s.index >> 16), true //nolint:gosec // G115
}
// SetROC sets SRTP rollover counter value of specified SSRC.
@ -283,6 +343,7 @@ func (c *Context) Index(ssrc uint32) (uint32, bool) {
if !ok {
return 0, false
}
return s.srtcpIndex, true
}
@ -291,3 +352,49 @@ func (c *Context) SetIndex(ssrc uint32, index uint32) {
s := c.getSRTCPSSRCState(ssrc)
s.srtcpIndex = index % (maxSRTCPIndex + 1)
}
//nolint:cyclop
func (c *Context) checkRCCMode() error {
if c.rccMode == RCCModeNone {
return nil
}
if c.rocTransmitRate == 0 {
return errZeroRocTransmitRate
}
switch c.profile {
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
// AEAD profiles support RCCMode3 only
if c.rccMode != RCCMode3 {
return errUnsupportedRccMode
}
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes256CmHmacSha1_32,
ProtectionProfileNullHmacSha1_32:
if c.authTagRTPLen == nil {
// ROC completely replaces auth tag for _32 profiles. If you really want to use 4-byte
// SRTP auth tag with RCC, use SRTPAuthenticationTagLength(4) option.
return errTooShortSRTPAuthTag
}
fallthrough // Checks below are common for _32 and _80 profiles.
case ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAes256CmHmacSha1_80,
ProtectionProfileNullHmacSha1_80:
// AES-CM and NULL profiles support RCCMode2 only
if c.rccMode != RCCMode2 {
return errUnsupportedRccMode
}
if c.authTagRTPLen != nil && *c.authTagRTPLen < 4 {
return errTooShortSRTPAuthTag
}
default:
return errUnsupportedRccMode
}
return nil
}

View File

@ -55,5 +55,6 @@ func xorBytesCTR(block cipher.Block, iv []byte, dst, src []byte) error {
}
i += n
}
return nil
}

View File

@ -9,9 +9,9 @@ import (
)
var (
// ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag
// ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag.
ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
// ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet
// ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet.
ErrMKINotFound = errors.New("MKI not found")
errDuplicated = errors.New("duplicated packet")
@ -31,11 +31,16 @@ var (
errMKIAlreadyInUse = errors.New("MKI already in use")
errMKIIsNotEnabled = errors.New("MKI is not enabled")
errInvalidMKILength = errors.New("invalid MKI length")
errTooLongSRTPAuthTag = errors.New("SRTP auth tag is too long")
errTooShortSRTPAuthTag = errors.New("SRTP auth tag is too short")
errStreamNotInited = errors.New("stream has not been inited, unable to close")
errStreamAlreadyClosed = errors.New("stream is already closed")
errStreamAlreadyInited = errors.New("stream is already inited")
errFailedTypeAssertion = errors.New("failed to cast child")
errZeroRocTransmitRate = errors.New("ROC transmit rate is zero")
errUnsupportedRccMode = errors.New("unsupported RCC mode")
)
type duplicatedError struct {

View File

@ -40,6 +40,7 @@ func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr i
block.Encrypt(out[n:n+nBlockSize], prfIn)
i++
}
return out[:outLen], nil
}
@ -50,8 +51,12 @@ func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr i
// - times the 16-bit RTP sequence number has been reset to zero after
// - passing through 65,535
// i = 2^16 * ROC + SEQ
// IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16)
func generateCounter(sequenceNumber uint16, rolloverCounter uint32, ssrc uint32, sessionSalt []byte) (counter [16]byte) {
// IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16).
func generateCounter(
sequenceNumber uint16,
rolloverCounter uint32,
ssrc uint32, sessionSalt []byte,
) (counter [16]byte) {
copy(counter[:], sessionSalt)
counter[4] ^= byte(ssrc >> 24)

View File

@ -5,7 +5,7 @@ package srtp
const labelExtractorDtlsSrtp = "EXTRACTOR-dtls_srtp"
// KeyingMaterialExporter allows package SRTP to extract keying material
// KeyingMaterialExporter allows package SRTP to extract keying material.
type KeyingMaterialExporter interface {
ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error)
}
@ -46,6 +46,7 @@ func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isC
c.Keys.LocalMasterSalt = clientWriteKey[keyLen:]
c.Keys.RemoteMasterKey = serverWriteKey[0:keyLen]
c.Keys.RemoteMasterSalt = serverWriteKey[keyLen:]
return nil
}
@ -53,5 +54,6 @@ func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isC
c.Keys.LocalMasterSalt = serverWriteKey[keyLen:]
c.Keys.RemoteMasterKey = clientWriteKey[0:keyLen]
c.Keys.RemoteMasterSalt = clientWriteKey[keyLen:]
return nil
}

View File

@ -16,6 +16,7 @@ func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive
c.newSRTPReplayDetector = func() replaydetector.ReplayDetector {
return replaydetector.New(windowSize, maxROC<<16|maxSequenceNumber)
}
return nil
}
}
@ -26,6 +27,7 @@ func SRTCPReplayProtection(windowSize uint) ContextOption {
c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector {
return replaydetector.New(windowSize, maxSRTCPIndex)
}
return nil
}
}
@ -36,6 +38,7 @@ func SRTPNoReplayProtection() ContextOption { // nolint:revive
c.newSRTPReplayDetector = func() replaydetector.ReplayDetector {
return &nopReplayDetector{}
}
return nil
}
}
@ -46,6 +49,7 @@ func SRTCPNoReplayProtection() ContextOption {
c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector {
return &nopReplayDetector{}
}
return nil
}
}
@ -54,6 +58,7 @@ func SRTCPNoReplayProtection() ContextOption {
func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { // nolint:revive
return func(c *Context) error {
c.newSRTPReplayDetector = fn
return nil
}
}
@ -62,6 +67,7 @@ func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextO
func SRTCPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption {
return func(c *Context) error {
c.newSRTCPReplayDetector = fn
return nil
}
}
@ -81,6 +87,7 @@ func MasterKeyIndicator(mki []byte) ContextOption {
c.sendMKI = make([]byte, len(mki))
copy(c.sendMKI, mki)
}
return nil
}
}
@ -89,15 +96,20 @@ func MasterKeyIndicator(mki []byte) ContextOption {
func SRTPEncryption() ContextOption { // nolint:revive
return func(c *Context) error {
c.encryptSRTP = true
return nil
}
}
// SRTPNoEncryption disables SRTP encryption. This option is useful when you want to use NullCipher for SRTP and keep authentication only.
// SRTPNoEncryption disables SRTP encryption.
// This option is useful when you want to use NullCipher for SRTP and keep authentication only.
// It simplifies debugging and testing, but it is not recommended for production use.
//
// Note: you can also use SRTPAuthenticationTagLength(0) to disable authentication tag too.
func SRTPNoEncryption() ContextOption { // nolint:revive
return func(c *Context) error {
c.encryptSRTP = false
return nil
}
}
@ -106,15 +118,58 @@ func SRTPNoEncryption() ContextOption { // nolint:revive
func SRTCPEncryption() ContextOption {
return func(c *Context) error {
c.encryptSRTCP = true
return nil
}
}
// SRTCPNoEncryption disables SRTCP encryption. This option is useful when you want to use NullCipher for SRTCP and keep authentication only.
// SRTCPNoEncryption disables SRTCP encryption.
// This option is useful when you want to use NullCipher for SRTCP and keep authentication only.
// It simplifies debugging and testing, but it is not recommended for production use.
func SRTCPNoEncryption() ContextOption {
return func(c *Context) error {
c.encryptSRTCP = false
return nil
}
}
// RolloverCounterCarryingTransform enables Rollover Counter Carrying Transform from RFC 4771.
// ROC value is sent in Authentication Tag of SRTP packets every rocTransmitRate packets.
//
// RFC 4771 defines 3 RCC modes. pion/srtp supports mode RCCm2 for AES-CM and NULL profiles,
// and mode RCCm3 for AES-GCM (AEAD) profiles.
//
// From RFC 4771: "[For modes RCCm1 and and RCCm3] the length of the MAC is shorter than the length
// of the authentication tag. To achieve the same (or less) MAC forgery success probability on all
// packets when using RCCm1 or RCCm2, as with the default integrity transform in RFC 3711,
// the tag-length must be set to 14 octets, which means that the length of MAC_tr is 10 octets."
//
// Protection profiles ProtectionProfile*CmHmacSha1_32 uses 4-byte SRTP auth tag, so in RCCm2 mode
// SRTP packets with ROC will not be integrity protected.
//
// You can increase the length of the authentication tag using SRTPAuthenticationTagLength option
// to mitigate this issue.
func RolloverCounterCarryingTransform(mode RCCMode, rocTransmitRate uint16) ContextOption {
return func(c *Context) error {
c.rccMode = mode
c.rocTransmitRate = rocTransmitRate
return nil
}
}
// SRTPAuthenticationTagLength sets length of SRTP authentication tag in bytes for AES-CM protection
// profiles. Decreasing the length of the authentication tag is not recommended for production use,
// as it decreases integrity protection.
//
// Zero value means that there is no authentication tag, what may be useful for debugging and testing.
//
// This option is ignored for AEAD profiles.
func SRTPAuthenticationTagLength(authTagRTPLen int) ContextOption { // nolint:revive
return func(c *Context) error {
c.authTagRTPLen = &authTagRTPLen
return nil
}
}

View File

@ -5,19 +5,25 @@ package srtp
import "fmt"
// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite
// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite.
type ProtectionProfile uint16
// Supported protection profiles
// See https://www.iana.org/assignments/srtp-protection/srtp-protection.xhtml
//
// AES128_CM_HMAC_SHA1_80 and AES128_CM_HMAC_SHA1_32 are valid SRTP profiles, but they do not have an DTLS-SRTP Protection Profiles ID assigned
// in RFC 5764. They were in earlier draft of this RFC: https://datatracker.ietf.org/doc/html/draft-ietf-avt-dtls-srtp-03#section-4.1.2
// AES128_CM_HMAC_SHA1_80 and AES128_CM_HMAC_SHA1_32 are valid SRTP profiles,
// but they do not have an DTLS-SRTP Protection Profiles ID assigned
// in RFC 5764. They were in earlier draft of this RFC:
// https://datatracker.ietf.org/doc/html/draft-ietf-avt-dtls-srtp-03#section-4.1.2
// Their IDs are now marked as reserved in the IANA registry. Despite this Chrome supports them:
// https://chromium.googlesource.com/chromium/deps/libsrtp/+/84122798bb16927b1e676bd4f938a6e48e5bf2fe/srtp/include/srtp.h#694
//
// Null profiles disable encryption, they are used for debugging and testing. They are not recommended for production use.
// Use of them is equivalent to using ProtectionProfileAes128CmHmacSha1_NN profile with SRTPNoEncryption and SRTCPNoEncryption options.
// Null profiles disable encryption, they are used for debugging and testing.
// They are not recommended for production use.
// Use of them is equivalent to using ProtectionProfileAes128CmHmacSha1_NN
// profile with SRTPNoEncryption and SRTCPNoEncryption options.
//
//nolint:lll
const (
ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001
ProtectionProfileAes128CmHmacSha1_32 ProtectionProfile = 0x0002
@ -29,10 +35,16 @@ const (
ProtectionProfileAeadAes256Gcm ProtectionProfile = 0x0008
)
// KeyLen returns length of encryption key in bytes. For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is is also the length of the session key.
// KeyLen returns length of encryption key in bytes.
// For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is
// also the length of the session key.
func (p ProtectionProfile) KeyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAeadAes128Gcm, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAeadAes128Gcm,
ProtectionProfileNullHmacSha1_32,
ProtectionProfileNullHmacSha1_80:
return 16, nil
case ProtectionProfileAeadAes256Gcm, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80:
return 32, nil
@ -41,10 +53,17 @@ func (p ProtectionProfile) KeyLen() (int, error) {
}
}
// SaltLen returns length of salt key in bytes. For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is is also the length of the session salt.
// SaltLen returns length of salt key in bytes.
// For all profiles except NullHmacSha1_32 and NullHmacSha1_80
// is also the length of the session salt.
func (p ProtectionProfile) SaltLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAes256CmHmacSha1_32,
ProtectionProfileAes256CmHmacSha1_80,
ProtectionProfileNullHmacSha1_32,
ProtectionProfileNullHmacSha1_80:
return 14, nil
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
return 12, nil
@ -53,7 +72,8 @@ func (p ProtectionProfile) SaltLen() (int, error) {
}
}
// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero.
// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles.
// For AEAD ones it returns zero.
func (p ProtectionProfile) AuthTagRTPLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_80:
@ -67,10 +87,17 @@ func (p ProtectionProfile) AuthTagRTPLen() (int, error) {
}
}
// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero.
// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles.
//
// For AEAD ones it returns zero.
func (p ProtectionProfile) AuthTagRTCPLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAes256CmHmacSha1_32,
ProtectionProfileAes256CmHmacSha1_80,
ProtectionProfileNullHmacSha1_32,
ProtectionProfileNullHmacSha1_80:
return 10, nil
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
return 0, nil
@ -79,10 +106,16 @@ func (p ProtectionProfile) AuthTagRTCPLen() (int, error) {
}
}
// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles. For AES ones it returns zero.
// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles.
// For AES ones it returns zero.
func (p ProtectionProfile) AEADAuthTagLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAes256CmHmacSha1_32,
ProtectionProfileAes256CmHmacSha1_80,
ProtectionProfileNullHmacSha1_32,
ProtectionProfileNullHmacSha1_80:
return 0, nil
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
return 16, nil
@ -91,10 +124,16 @@ func (p ProtectionProfile) AEADAuthTagLen() (int, error) {
}
}
// AuthKeyLen returns length of authentication key in bytes for AES protection profiles. For AEAD ones it returns zero.
// AuthKeyLen returns length of authentication key in bytes for AES protection profiles.
// For AEAD ones it returns zero.
func (p ProtectionProfile) AuthKeyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80:
case ProtectionProfileAes128CmHmacSha1_32,
ProtectionProfileAes128CmHmacSha1_80,
ProtectionProfileAes256CmHmacSha1_32,
ProtectionProfileAes256CmHmacSha1_80,
ProtectionProfileNullHmacSha1_32,
ProtectionProfileNullHmacSha1_80:
return 20, nil
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
return 0, nil

View File

@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package srtp
// protectionProfileWithArgs is a wrapper around ProtectionProfile that allows to
// specify additional arguments for the profile.
type protectionProfileWithArgs struct {
ProtectionProfile
authTagRTPLen *int
}
// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles.
// For AEAD ones it returns zero.
func (p protectionProfileWithArgs) AuthTagRTPLen() (int, error) {
if p.authTagRTPLen != nil {
return *p.authTagRTPLen, nil
}
return p.ProtectionProfile.AuthTagRTPLen()
}

View File

@ -58,7 +58,7 @@ type Config struct {
LocalOptions, RemoteOptions []ContextOption
}
// SessionKeys bundles the keys required to setup an SRTP session
// SessionKeys bundles the keys required to setup an SRTP session.
type SessionKeys struct {
LocalMasterKey []byte
LocalMasterSalt []byte
@ -74,20 +74,21 @@ func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto
return nil, false
}
r, ok := s.readStreams[ssrc]
rStream, ok := s.readStreams[ssrc]
if ok {
return r, false
return rStream, false
}
// Create the readStream.
r = proto()
rStream = proto()
if err := r.init(child, ssrc); err != nil {
if err := rStream.init(child, ssrc); err != nil {
return nil, false
}
s.readStreams[ssrc] = r
return r, true
s.readStreams[ssrc] = rStream
return rStream, true
}
func (s *session) removeReadStream(ssrc uint32) {
@ -109,10 +110,15 @@ func (s *session) close() error {
}
<-s.closed
return nil
}
func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
func (s *session) start(
localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte,
profile ProtectionProfile,
child streamSession,
) error {
var err error
s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
if err != nil {
@ -146,6 +152,7 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote
if !errors.Is(err, io.EOF) {
s.log.Error(err.Error())
}
return
}

View File

@ -16,7 +16,7 @@ const defaultSessionSRTCPReplayProtectionWindow = 64
// SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session
// SRTCP itself does not have a design like this, but it is common in most applications
// for local/remote to each have their own keying material. This provides those patterns
// instead of making everyone re-implement
// instead of making everyone re-implement.
type SessionSRTCP struct {
session
writeStream *WriteStreamSRTCP
@ -47,7 +47,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n
config.RemoteOptions...,
)
s := &SessionSRTCP{
srtcpSession := &SessionSRTCP{
session: session{
nextConn: conn,
localOptions: localOpts,
@ -61,37 +61,39 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n
log: loggerFactory.NewLogger("srtp"),
},
}
s.writeStream = &WriteStreamSRTCP{s}
srtcpSession.writeStream = &WriteStreamSRTCP{srtcpSession}
err := s.session.start(
err := srtcpSession.session.start(
config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
config.Profile,
s,
srtcpSession,
)
if err != nil {
return nil, err
}
return s, nil
return srtcpSession, nil
}
// OpenWriteStream returns the global write stream for the Session
// OpenWriteStream returns the global write stream for the Session.
func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) {
return s.writeStream, nil
}
// OpenReadStream opens a read stream for the given SSRC, it can be used
// if you want a certain SSRC, but don't want to wait for AcceptStream
// if you want a certain SSRC, but don't want to wait for AcceptStream.
func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) {
r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
if readStream, ok := r.(*ReadStreamSRTCP); ok {
return readStream, nil
}
return nil, errFailedTypeAssertion
}
// AcceptStream returns a stream to handle RTCP for a single SSRC
// AcceptStream returns a stream to handle RTCP for a single SSRC.
func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
stream, ok := <-s.newStream
if !ok {
@ -106,7 +108,7 @@ func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
return readStream, stream.GetSSRC(), nil
}
// Close ends the session
// Close ends the session.
func (s *SessionSRTCP) Close() error {
return s.session.close()
}
@ -122,12 +124,13 @@ func (s *SessionSRTCP) write(buf []byte) (int, error) {
defer bufferpool.Put(ibuf)
s.session.localContextMutex.Lock()
encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil)
encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil) //nolint:forcetypeassert
s.session.localContextMutex.Unlock()
if err != nil {
return 0, err
}
return s.session.nextConn.Write(encrypted)
}

View File

@ -17,7 +17,7 @@ const defaultSessionSRTPReplayProtectionWindow = 64
// SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session
// SRTP itself does not have a design like this, but it is common in most applications
// for local/remote to each have their own keying material. This provides those patterns
// instead of making everyone re-implement
// instead of making everyone re-implement.
type SessionSRTP struct {
session
writeStream *WriteStreamSRTP
@ -48,7 +48,7 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol
config.RemoteOptions...,
)
s := &SessionSRTP{
srtpSession := &SessionSRTP{
session: session{
nextConn: conn,
localOptions: localOpts,
@ -62,27 +62,28 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol
log: loggerFactory.NewLogger("srtp"),
},
}
s.writeStream = &WriteStreamSRTP{s}
srtpSession.writeStream = &WriteStreamSRTP{srtpSession}
err := s.session.start(
err := srtpSession.session.start(
config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
config.Profile,
s,
srtpSession,
)
if err != nil {
return nil, err
}
return s, nil
return srtpSession, nil
}
// OpenWriteStream returns the global write stream for the Session
// OpenWriteStream returns the global write stream for the Session.
func (s *SessionSRTP) OpenWriteStream() (*WriteStreamSRTP, error) {
return s.writeStream, nil
}
// OpenReadStream opens a read stream for the given SSRC, it can be used
// if you want a certain SSRC, but don't want to wait for AcceptStream
// if you want a certain SSRC, but don't want to wait for AcceptStream.
func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) {
r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTP)
@ -93,7 +94,7 @@ func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) {
return nil, errFailedTypeAssertion
}
// AcceptStream returns a stream to handle RTCP for a single SSRC
// AcceptStream returns a stream to handle RTCP for a single SSRC.
func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) {
stream, ok := <-s.newStream
if !ok {
@ -108,7 +109,7 @@ func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) {
return readStream, stream.GetSSRC(), nil
}
// Close ends the session
// Close ends the session.
func (s *SessionSRTP) Close() error {
return s.session.close()
}
@ -149,8 +150,20 @@ func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error)
ibuf := bufferpool.Get()
defer bufferpool.Put(ibuf)
buf := ibuf.([]byte) // nolint:forcetypeassert
headerLen, marshalSize := rtp.HeaderAndPacketMarshalSize(header, payload) // nolint:staticcheck
if len(buf) < marshalSize+20 {
// The buffer is too small, so we need to allocate a new one. Add 20 bytes for auth tag like
// for bufferpool above.
buf = make([]byte, marshalSize+20)
}
_, err := rtp.MarshalPacketTo(buf, header, payload) // nolint:staticcheck
if err != nil {
return 0, err
}
s.session.localContextMutex.Lock()
encrypted, err := s.localContext.encryptRTP(ibuf.([]byte), header, payload)
encrypted, err := s.localContext.encryptRTP(buf, header, headerLen, buf[:marshalSize])
s.session.localContextMutex.Unlock()
if err != nil {
@ -165,13 +178,13 @@ func (s *SessionSRTP) setWriteDeadline(t time.Time) error {
}
func (s *SessionSRTP) decrypt(buf []byte) error {
h := &rtp.Header{}
headerLen, err := h.Unmarshal(buf)
header := &rtp.Header{}
headerLen, err := header.Unmarshal(buf)
if err != nil {
return err
}
r, isNew := s.session.getOrCreateReadStream(h.SSRC, s, newReadStreamSRTP)
r, isNew := s.session.getOrCreateReadStream(header.SSRC, s, newReadStreamSRTP)
if r == nil {
return nil // Session has been closed
} else if isNew {
@ -186,7 +199,7 @@ func (s *SessionSRTP) decrypt(buf []byte) error {
return errFailedTypeAssertion
}
decrypted, err := s.remoteContext.decryptRTP(buf, buf, h, headerLen)
decrypted, err := s.remoteContext.decryptRTP(buf, buf, header, headerLen)
if err != nil {
return err
}

View File

@ -10,9 +10,23 @@ import (
"github.com/pion/rtcp"
)
const maxSRTCPIndex = 0x7FFFFFFF
/*
Simplified structure of SRTCP Packets:
- RTCP Header
- Payload
- AEAD Auth Tag - used by AEAD profiles only
- E flag and SRTCP Index
- MKI (optional)
- Auth Tag - used by non-AEAD profiles only
*/
const srtcpHeaderSize = 8
const (
maxSRTCPIndex = 0x7FFFFFFF
srtcpHeaderSize = 8
srtcpIndexSize = 4
srtcpEncryptionFlag = 0x80
)
func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
authTagLen, err := c.cipher.AuthTagRTCPLen()
@ -42,25 +56,24 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
cipher := c.cipher
if len(c.mkis) > 0 {
// Find cipher for MKI
actualMKI := c.cipher.getMKI(encrypted, false)
actualMKI := encrypted[len(encrypted)-mkiLen-authTagLen : len(encrypted)-authTagLen]
cipher, ok = c.mkis[string(actualMKI)]
if !ok {
return nil, ErrMKINotFound
}
}
out := allocateIfMismatch(dst, encrypted)
out, err = cipher.decryptRTCP(out, encrypted, index, ssrc)
out, err := cipher.decryptRTCP(dst, encrypted, index, ssrc)
if err != nil {
return nil, err
}
markAsValid()
return out, nil
}
// DecryptRTCP decrypts a buffer that contains a RTCP packet
// DecryptRTCP decrypts a buffer that contains a RTCP packet.
func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byte, error) {
if header == nil {
header = &rtcp.Header{}
@ -79,9 +92,9 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
}
ssrc := binary.BigEndian.Uint32(decrypted[4:])
s := c.getSRTCPSSRCState(ssrc)
ssrcState := c.getSRTCPSSRCState(ssrc)
if s.srtcpIndex >= maxSRTCPIndex {
if ssrcState.srtcpIndex >= maxSRTCPIndex {
// ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key
// (whichever occurs before), the key management MUST be called to provide new master key(s)
// (previously stored and used keys MUST NOT be used again), or the session MUST be terminated.
@ -90,12 +103,12 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
}
// We roll over early because MSB is used for marking as encrypted
s.srtcpIndex++
ssrcState.srtcpIndex++
return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc)
return c.cipher.encryptRTCP(dst, decrypted, ssrcState.srtcpIndex, ssrc)
}
// EncryptRTCP Encrypts a RTCP packet
// EncryptRTCP Encrypts a RTCP packet.
func (c *Context) EncryptRTCP(dst, decrypted []byte, header *rtcp.Header) ([]byte, error) {
if header == nil {
header = &rtcp.Header{}

View File

@ -5,11 +5,21 @@
package srtp
import (
"encoding/binary"
"fmt"
"github.com/pion/rtp"
)
/*
Simplified structure of SRTP Packets:
- RTP Header (with optional RTP Header Extension)
- Payload (with optional padding)
- AEAD Auth Tag - used by AEAD profiles only
- MKI (optional)
- Auth Tag - used by non-AEAD profiles only. When RCC is used with AEAD profiles, the ROC is sent here.
*/
func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) {
authTagLen, err := c.cipher.AuthTagRTPLen()
if err != nil {
@ -21,17 +31,31 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL
}
mkiLen := len(c.sendMKI)
var hasRocInPacket bool
hasRocInPacket, authTagLen = c.hasROCInPacket(header, authTagLen)
// Verify that encrypted packet is long enough
if len(ciphertext) < (headerLen + aeadAuthTagLen + mkiLen + authTagLen) {
return nil, fmt.Errorf("%w: %d", errTooShortRTP, len(ciphertext))
}
s := c.getSRTPSSRCState(header.SSRC)
ssrcState := c.getSRTPSSRCState(header.SSRC)
roc, diff, _ := s.nextRolloverCount(header.SequenceNumber)
markAsValid, ok := s.replayDetector.Check(
(uint64(roc) << 16) | uint64(header.SequenceNumber),
)
var roc uint32
var diff int64
var index uint64
if !hasRocInPacket {
// The ROC is not sent in the packet. We need to guess it.
roc, diff, _ = ssrcState.nextRolloverCount(header.SequenceNumber)
index = (uint64(roc) << 16) | uint64(header.SequenceNumber)
} else {
// Extract ROC from the packet. The ROC is sent in the first 4 bytes of the auth tag.
roc = binary.BigEndian.Uint32(ciphertext[len(ciphertext)-authTagLen:])
index = (uint64(roc) << 16) | uint64(header.SequenceNumber)
diff = int64(ssrcState.index) - int64(index) //nolint:gosec
}
markAsValid, ok := ssrcState.replayDetector.Check(index)
if !ok {
return nil, &duplicatedError{
Proto: "srtp", SSRC: header.SSRC, Index: uint32(header.SequenceNumber),
@ -41,7 +65,7 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL
cipher := c.cipher
if len(c.mkis) > 0 {
// Find cipher for MKI
actualMKI := c.cipher.getMKI(ciphertext, true)
actualMKI := ciphertext[len(ciphertext)-mkiLen-authTagLen : len(ciphertext)-authTagLen]
cipher, ok = c.mkis[string(actualMKI)]
if !ok {
return nil, ErrMKINotFound
@ -50,17 +74,18 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL
dst = growBufferSize(dst, len(ciphertext)-authTagLen-len(c.sendMKI))
dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc)
dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc, hasRocInPacket)
if err != nil {
return nil, err
}
markAsValid()
s.updateRolloverCount(header.SequenceNumber, diff)
ssrcState.updateRolloverCount(header.SequenceNumber, diff, hasRocInPacket, roc)
return dst, nil
}
// DecryptRTP decrypts a RTP packet with an encrypted payload
// DecryptRTP decrypts a RTP packet with an encrypted payload.
func (c *Context) DecryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte, error) {
if header == nil {
header = &rtp.Header{}
@ -75,7 +100,8 @@ func (c *Context) DecryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte,
}
// EncryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided.
// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, a new one will be allocated and returned.
// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes,
// a new one will be allocated and returned.
// If a rtp.Header is provided, it will be Unmarshaled using the plaintext.
func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ([]byte, error) {
if header == nil {
@ -87,13 +113,14 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) (
return nil, err
}
return c.encryptRTP(dst, header, plaintext[headerLen:])
return c.encryptRTP(dst, header, headerLen, plaintext)
}
// encryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided.
// If the dst buffer does not have the capacity, a new one will be allocated and returned.
// Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload.
func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) {
func (c *Context) encryptRTP(dst []byte, header *rtp.Header, headerLen int, plaintext []byte,
) (ciphertext []byte, err error) {
s := c.getSRTPSSRCState(header.SSRC)
roc, diff, ovf := s.nextRolloverCount(header.SequenceNumber)
if ovf {
@ -103,7 +130,30 @@ func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ci
// https://www.rfc-editor.org/rfc/rfc3711#section-9.2
return nil, errExceededMaxPackets
}
s.updateRolloverCount(header.SequenceNumber, diff)
s.updateRolloverCount(header.SequenceNumber, diff, false, 0)
return c.cipher.encryptRTP(dst, header, payload, roc)
rocInPacket := false
if c.rccMode != RCCModeNone && header.SequenceNumber%c.rocTransmitRate == 0 {
rocInPacket = true
}
return c.cipher.encryptRTP(dst, header, headerLen, plaintext, roc, rocInPacket)
}
func (c *Context) hasROCInPacket(header *rtp.Header, authTagLen int) (bool, int) {
hasRocInPacket := false
switch c.rccMode {
case RCCMode2:
// This mode is supported for AES-CM and NULL profiles only. The ROC is sent in the first 4 bytes of the auth tag.
hasRocInPacket = header.SequenceNumber%c.rocTransmitRate == 0
case RCCMode3:
// This mode is supported for AES-GCM only. The ROC is sent as 4-byte auth tag.
hasRocInPacket = header.SequenceNumber%c.rocTransmitRate == 0
if hasRocInPacket {
authTagLen = 4
}
default:
}
return hasRocInPacket, authTagLen
}

View File

@ -6,7 +6,7 @@ package srtp
import "github.com/pion/rtp"
// cipher represents a implementation of one
// of the SRTP Specific ciphers
// of the SRTP Specific ciphers.
type srtpCipher interface {
// AuthTagRTPLen/AuthTagRTCPLen return auth key length of the cipher.
// See the note below.
@ -16,12 +16,11 @@ type srtpCipher interface {
// See the note below.
AEADAuthTagLen() (int, error)
getRTCPIndex([]byte) uint32
getMKI([]byte, bool) []byte
encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error)
encryptRTP([]byte, *rtp.Header, int, []byte, uint32, bool) ([]byte, error)
encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error)
decryptRTP([]byte, []byte, *rtp.Header, int, uint32) ([]byte, error)
decryptRTP([]byte, []byte, *rtp.Header, int, uint32, bool) ([]byte, error)
decryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error)
}

View File

@ -12,12 +12,8 @@ import (
"github.com/pion/rtp"
)
const (
rtcpEncryptionFlag = 0x80
)
type srtpCipherAeadAesGcm struct {
ProtectionProfile
protectionProfileWithArgs
srtpCipher, srtcpCipher cipher.AEAD
@ -28,11 +24,15 @@ type srtpCipherAeadAesGcm struct {
srtpEncrypted, srtcpEncrypted bool
}
func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, mki []byte, encryptSRTP, encryptSRTCP bool) (*srtpCipherAeadAesGcm, error) {
s := &srtpCipherAeadAesGcm{
ProtectionProfile: profile,
srtpEncrypted: encryptSRTP,
srtcpEncrypted: encryptSRTCP,
func newSrtpCipherAeadAesGcm(
profile protectionProfileWithArgs,
masterKey, masterSalt, mki []byte,
encryptSRTP, encryptSRTCP bool,
) (*srtpCipherAeadAesGcm, error) {
srtpCipher := &srtpCipherAeadAesGcm{
protectionProfileWithArgs: profile,
srtpEncrypted: encryptSRTP,
srtcpEncrypted: encryptSRTCP,
}
srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey))
@ -45,7 +45,7 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, m
return nil, err
}
s.srtpCipher, err = cipher.NewGCM(srtpBlock)
srtpCipher.srtpCipher, err = cipher.NewGCM(srtpBlock)
if err != nil {
return nil, err
}
@ -60,72 +60,109 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, m
return nil, err
}
s.srtcpCipher, err = cipher.NewGCM(srtcpBlock)
srtpCipher.srtcpCipher, err = cipher.NewGCM(srtcpBlock)
if err != nil {
return nil, err
}
if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil {
if srtpCipher.srtpSessionSalt, err = aesCmKeyDerivation(
labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt),
); err != nil {
return nil, err
} else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil {
} else if srtpCipher.srtcpSessionSalt, err = aesCmKeyDerivation(
labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt),
); err != nil {
return nil, err
}
mkiLen := len(mki)
if mkiLen > 0 {
s.mki = make([]byte, mkiLen)
copy(s.mki, mki)
srtpCipher.mki = make([]byte, mkiLen)
copy(srtpCipher.mki, mki)
}
return s, nil
return srtpCipher, nil
}
func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) {
func (s *srtpCipherAeadAesGcm) encryptRTP(
dst []byte,
header *rtp.Header,
headerLen int,
plaintext []byte,
roc uint32,
rocInAuthTag bool,
) (ciphertext []byte, err error) {
payload := plaintext[headerLen:]
payloadLen := len(payload)
// Grow the given buffer to fit the output.
authTagLen, err := s.AEADAuthTagLen()
if err != nil {
return nil, err
}
dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen+len(s.mki))
authPartLen := header.MarshalSize() + len(payload) + authTagLen
dstLen := authPartLen + len(s.mki)
if rocInAuthTag {
dstLen += 4
}
dst = growBufferSize(dst, dstLen)
sameBuffer := isSameBuffer(dst, plaintext)
n, err := header.MarshalTo(dst)
if err != nil {
return nil, err
// Copy the header unencrypted.
if !sameBuffer {
copy(dst, plaintext[:headerLen])
}
iv := s.rtpInitializationVector(header, roc)
if s.srtpEncrypted {
s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n])
s.srtpCipher.Seal(dst[headerLen:headerLen], iv[:], payload, dst[:headerLen])
} else {
clearLen := n + len(payload)
copy(dst[n:], payload)
clearLen := headerLen + payloadLen
if !sameBuffer {
copy(dst[headerLen:], payload)
}
s.srtpCipher.Seal(dst[clearLen:clearLen], iv[:], nil, dst[:clearLen])
}
// Add MKI after the encrypted payload
if len(s.mki) > 0 {
copy(dst[len(dst)-len(s.mki):], s.mki)
copy(dst[authPartLen:], s.mki)
}
if rocInAuthTag {
binary.BigEndian.PutUint32(dst[len(dst)-4:], roc)
}
return dst, nil
}
func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) {
func (s *srtpCipherAeadAesGcm) decryptRTP(
dst, ciphertext []byte,
header *rtp.Header,
headerLen int,
roc uint32,
rocInAuthTag bool,
) ([]byte, error) {
// Grow the given buffer to fit the output.
authTagLen, err := s.AEADAuthTagLen()
if err != nil {
return nil, err
}
nDst := len(ciphertext) - authTagLen - len(s.mki)
rocLen := 0
if rocInAuthTag {
rocLen = 4
}
nDst := len(ciphertext) - authTagLen - len(s.mki) - rocLen
if nDst < headerLen {
// Size of ciphertext is shorter than AEAD auth tag len.
return nil, ErrFailedToVerifyAuthTag
}
dst = growBufferSize(dst, nDst)
sameBuffer := isSameBuffer(dst, ciphertext)
iv := s.rtpInitializationVector(header, roc)
nEnd := len(ciphertext) - len(s.mki)
nEnd := len(ciphertext) - len(s.mki) - rocLen
if s.srtpEncrypted {
if _, err := s.srtpCipher.Open(
dst[headerLen:headerLen], iv[:], ciphertext[headerLen:nEnd], ciphertext[:headerLen],
@ -139,10 +176,16 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He
); err != nil {
return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err)
}
copy(dst[headerLen:], ciphertext[headerLen:nDataEnd])
if !sameBuffer {
copy(dst[headerLen:], ciphertext[headerLen:nDataEnd])
}
}
// Copy the header unencrypted.
if !sameBuffer {
copy(dst[:headerLen], ciphertext[:headerLen])
}
copy(dst[:headerLen], ciphertext[:headerLen])
return dst, nil
}
@ -154,28 +197,36 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin
aadPos := len(decrypted) + authTagLen
// Grow the given buffer to fit the output.
dst = growBufferSize(dst, aadPos+srtcpIndexSize+len(s.mki))
sameBuffer := isSameBuffer(dst, decrypted)
iv := s.rtcpInitializationVector(srtcpIndex, ssrc)
if s.srtcpEncrypted {
aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex)
copy(dst[:8], decrypted[:8])
copy(dst[aadPos:aadPos+4], aad[8:12])
s.srtcpCipher.Seal(dst[8:8], iv[:], decrypted[8:], aad[:])
if !sameBuffer {
// Copy the header unencrypted.
copy(dst[:srtcpHeaderSize], decrypted[:srtcpHeaderSize])
}
// Copy index to the proper place.
copy(dst[aadPos:aadPos+srtcpIndexSize], aad[8:12])
s.srtcpCipher.Seal(dst[srtcpHeaderSize:srtcpHeaderSize], iv[:], decrypted[srtcpHeaderSize:], aad[:])
} else {
// Copy the packet unencrypted.
copy(dst, decrypted)
if !sameBuffer {
copy(dst, decrypted)
}
// Append the SRTCP index to the end of the packet - this will form the AAD.
binary.BigEndian.PutUint32(dst[len(decrypted):], srtcpIndex)
// Generate the authentication tag.
tag := make([]byte, authTagLen)
s.srtcpCipher.Seal(tag[0:0], iv[:], nil, dst[:len(decrypted)+4])
s.srtcpCipher.Seal(tag[0:0], iv[:], nil, dst[:len(decrypted)+srtcpIndexSize])
// Copy index to the proper place.
copy(dst[aadPos:], dst[len(decrypted):len(decrypted)+4])
copy(dst[aadPos:], dst[len(decrypted):len(decrypted)+srtcpIndexSize])
// Copy the auth tag after RTCP payload.
copy(dst[len(decrypted):], tag)
}
copy(dst[aadPos+4:], s.mki)
copy(dst[aadPos+srtcpIndexSize:], s.mki)
return dst, nil
}
@ -192,12 +243,14 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss
return nil, ErrFailedToVerifyAuthTag
}
dst = growBufferSize(dst, nDst)
sameBuffer := isSameBuffer(dst, encrypted)
isEncrypted := encrypted[aadPos]>>7 != 0
isEncrypted := encrypted[aadPos]&srtcpEncryptionFlag != 0
iv := s.rtcpInitializationVector(srtcpIndex, ssrc)
if isEncrypted {
aad := s.rtcpAdditionalAuthenticatedData(encrypted, srtcpIndex)
if _, err := s.srtcpCipher.Open(dst[8:8], iv[:], encrypted[8:aadPos], aad[:]); err != nil {
if _, err := s.srtcpCipher.Open(dst[srtcpHeaderSize:srtcpHeaderSize], iv[:], encrypted[srtcpHeaderSize:aadPos],
aad[:]); err != nil {
return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err)
}
} else {
@ -211,10 +264,16 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss
return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err)
}
// Copy the unencrypted payload.
copy(dst[8:], encrypted[8:dataEnd])
if !sameBuffer {
copy(dst[srtcpHeaderSize:], encrypted[srtcpHeaderSize:dataEnd])
}
}
// Copy the header unencrypted.
if !sameBuffer {
copy(dst[:srtcpHeaderSize], encrypted[:srtcpHeaderSize])
}
copy(dst[:8], encrypted[:8])
return dst, nil
}
@ -233,6 +292,7 @@ func (s *srtpCipherAeadAesGcm) rtpInitializationVector(header *rtp.Header, roc u
for i := range iv {
iv[i] ^= s.srtpSessionSalt[i]
}
return iv
}
@ -252,6 +312,7 @@ func (s *srtpCipherAeadAesGcm) rtcpInitializationVector(srtcpIndex uint32, ssrc
for i := range iv {
iv[i] ^= s.srtcpSessionSalt[i]
}
return iv
}
@ -265,21 +326,11 @@ func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte
copy(aad[:], rtcpPacket[:8])
binary.BigEndian.PutUint32(aad[8:], srtcpIndex)
aad[8] |= rtcpEncryptionFlag
aad[8] |= srtcpEncryptionFlag
return aad
}
func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 {
return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-4:]) &^ (rtcpEncryptionFlag << 24)
}
func (s *srtpCipherAeadAesGcm) getMKI(in []byte, _ bool) []byte {
mkiLen := len(s.mki)
if mkiLen == 0 {
return nil
}
tailOffset := len(in) - mkiLen
return in[tailOffset:]
return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-srtcpIndexSize:]) &^ (srtcpEncryptionFlag << 24)
}

View File

@ -16,7 +16,7 @@ import ( //nolint:gci
)
type srtpCipherAesCmHmacSha1 struct {
ProtectionProfile
protectionProfileWithArgs
srtpSessionSalt []byte
srtpSessionAuth hash.Hash
@ -31,35 +31,46 @@ type srtpCipherAesCmHmacSha1 struct {
mki []byte
}
func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt, mki []byte, encryptSRTP, encryptSRTCP bool) (*srtpCipherAesCmHmacSha1, error) {
if profile == ProtectionProfileNullHmacSha1_80 || profile == ProtectionProfileNullHmacSha1_32 {
//nolint:cyclop
func newSrtpCipherAesCmHmacSha1(
profile protectionProfileWithArgs,
masterKey, masterSalt, mki []byte,
encryptSRTP, encryptSRTCP bool,
) (*srtpCipherAesCmHmacSha1, error) {
switch profile.ProtectionProfile {
case ProtectionProfileNullHmacSha1_80, ProtectionProfileNullHmacSha1_32:
encryptSRTP = false
encryptSRTCP = false
default:
}
s := &srtpCipherAesCmHmacSha1{
ProtectionProfile: profile,
srtpEncrypted: encryptSRTP,
srtcpEncrypted: encryptSRTCP,
srtpCipher := &srtpCipherAesCmHmacSha1{
protectionProfileWithArgs: profile,
srtpEncrypted: encryptSRTP,
srtcpEncrypted: encryptSRTCP,
}
srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey))
if err != nil {
return nil, err
} else if s.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil {
} else if srtpCipher.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil {
return nil, err
}
srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey))
if err != nil {
return nil, err
} else if s.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil {
} else if srtpCipher.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil {
return nil, err
}
if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil {
if srtpCipher.srtpSessionSalt, err = aesCmKeyDerivation(
labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt),
); err != nil {
return nil, err
} else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil {
} else if srtpCipher.srtcpSessionSalt, err = aesCmKeyDerivation(
labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt),
); err != nil {
return nil, err
}
@ -78,45 +89,55 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt
return nil, err
}
s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag)
s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag)
srtpCipher.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag)
srtpCipher.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag)
mkiLen := len(mki)
if mkiLen > 0 {
s.mki = make([]byte, mkiLen)
copy(s.mki, mki)
srtpCipher.mki = make([]byte, mkiLen)
copy(srtpCipher.mki, mki)
}
return s, nil
return srtpCipher, nil
}
func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) {
func (s *srtpCipherAesCmHmacSha1) encryptRTP(
dst []byte,
header *rtp.Header,
headerLen int,
plaintext []byte,
roc uint32,
rocInAuthTag bool,
) (ciphertext []byte, err error) {
payload := plaintext[headerLen:]
payloadLen := len(payload)
// Grow the given buffer to fit the output.
authTagLen, err := s.AuthTagRTPLen()
if err != nil {
return nil, err
}
dst = growBufferSize(dst, header.MarshalSize()+len(payload)+len(s.mki)+authTagLen)
dst = growBufferSize(dst, headerLen+payloadLen+len(s.mki)+authTagLen)
sameBuffer := isSameBuffer(dst, plaintext)
// Copy the header unencrypted.
n, err := header.MarshalTo(dst)
if err != nil {
return nil, err
if !sameBuffer {
copy(dst, plaintext[:headerLen])
}
// Encrypt the payload
if s.srtpEncrypted {
counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt)
if err = xorBytesCTR(s.srtpBlock, counter[:], dst[n:], payload); err != nil {
if err = xorBytesCTR(s.srtpBlock, counter[:], dst[headerLen:], payload); err != nil {
return nil, err
}
} else {
copy(dst[n:], payload)
} else if !sameBuffer {
copy(dst[headerLen:], payload)
}
n += len(payload)
n := headerLen + payloadLen
// Generate the auth tag.
authTag, err := s.generateSrtpAuthTag(dst[:n], roc)
authTag, err := s.generateSrtpAuthTag(dst[:n], roc, rocInAuthTag)
if err != nil {
return nil, err
}
@ -133,7 +154,13 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay
return dst, nil
}
func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) {
func (s *srtpCipherAesCmHmacSha1) decryptRTP(
dst, ciphertext []byte,
header *rtp.Header,
headerLen int,
roc uint32,
rocInAuthTag bool,
) ([]byte, error) {
// Split the auth tag and the cipher text into two parts.
authTagLen, err := s.AuthTagRTPLen()
if err != nil {
@ -145,7 +172,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp
ciphertext = ciphertext[:len(ciphertext)-len(s.mki)-authTagLen]
// Generate the auth tag we expect to see from the ciphertext.
expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc)
expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc, rocInAuthTag)
if err != nil {
return nil, err
}
@ -156,8 +183,12 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp
return nil, ErrFailedToVerifyAuthTag
}
sameBuffer := isSameBuffer(dst, ciphertext)
// Write the plaintext header to the destination buffer.
copy(dst, ciphertext[:headerLen])
if !sameBuffer {
copy(dst, ciphertext[:headerLen])
}
// Decrypt the ciphertext for the payload.
if s.srtpEncrypted {
@ -168,83 +199,110 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp
if err != nil {
return nil, err
}
} else {
} else if !sameBuffer {
copy(dst[headerLen:], ciphertext[headerLen:])
}
return dst, nil
}
func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) {
dst = allocateIfMismatch(dst, decrypted)
authTagLen, err := s.AuthTagRTCPLen()
if err != nil {
return nil, err
}
mkiLen := len(s.mki)
decryptedLen := len(decrypted)
encryptedLen := decryptedLen + authTagLen + mkiLen + srtcpIndexSize
dst = growBufferSize(dst, encryptedLen)
sameBuffer := isSameBuffer(dst, decrypted)
if !sameBuffer {
copy(dst, decrypted[:srtcpHeaderSize]) // Copy the first 8 bytes (RTCP header)
}
// Encrypt everything after header
if s.srtcpEncrypted {
counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt)
if err := xorBytesCTR(s.srtcpBlock, counter[:], dst[8:], dst[8:]); err != nil {
counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt) //nolint:gosec // G115
if err = xorBytesCTR(s.srtcpBlock, counter[:], dst[srtcpHeaderSize:], decrypted[srtcpHeaderSize:]); err != nil {
return nil, err
}
// Add SRTCP Index and set Encryption bit
dst = append(dst, make([]byte, 4)...)
binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex)
dst[len(dst)-4] |= 0x80
binary.BigEndian.PutUint32(dst[decryptedLen:], srtcpIndex)
dst[decryptedLen] |= srtcpEncryptionFlag
} else {
// Copy the decrypted payload as is
copy(dst[8:], decrypted[8:])
if !sameBuffer {
copy(dst[srtcpHeaderSize:], decrypted[srtcpHeaderSize:])
}
// Add SRTCP Index with Encryption bit cleared
dst = append(dst, make([]byte, 4)...)
binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex)
binary.BigEndian.PutUint32(dst[decryptedLen:], srtcpIndex)
}
n := decryptedLen + srtcpIndexSize
// Generate the authentication tag
authTag, err := s.generateSrtcpAuthTag(dst)
authTag, err := s.generateSrtcpAuthTag(dst[:n])
if err != nil {
return nil, err
}
// Include the MKI if provided
if len(s.mki) > 0 {
dst = append(dst, s.mki...)
copy(dst[n:], s.mki)
n += mkiLen
}
// Append the auth tag at the end of the buffer
return append(dst, authTag...), nil
copy(dst[n:], authTag)
return dst, nil
}
func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc uint32) ([]byte, error) {
func (s *srtpCipherAesCmHmacSha1) decryptRTCP(dst, encrypted []byte, index, ssrc uint32) ([]byte, error) {
authTagLen, err := s.AuthTagRTCPLen()
if err != nil {
return nil, err
}
tailOffset := len(encrypted) - (authTagLen + len(s.mki) + srtcpIndexSize)
if tailOffset < 8 {
mkiLen := len(s.mki)
encryptedLen := len(encrypted)
decryptedLen := encryptedLen - (authTagLen + mkiLen + srtcpIndexSize)
if decryptedLen < 8 {
return nil, errTooShortRTCP
}
out = out[0:tailOffset]
expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-len(s.mki)-authTagLen])
expectedTag, err := s.generateSrtcpAuthTag(encrypted[:encryptedLen-mkiLen-authTagLen])
if err != nil {
return nil, err
}
actualTag := encrypted[len(encrypted)-authTagLen:]
actualTag := encrypted[encryptedLen-authTagLen:]
if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 {
return nil, ErrFailedToVerifyAuthTag
}
isEncrypted := encrypted[tailOffset]>>7 != 0
if isEncrypted {
counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt)
err = xorBytesCTR(s.srtcpBlock, counter[:], out[8:], out[8:])
} else {
copy(out[8:], encrypted[8:])
dst = growBufferSize(dst, decryptedLen)
sameBuffer := isSameBuffer(dst, encrypted)
if !sameBuffer {
copy(dst, encrypted[:srtcpHeaderSize]) // Copy the first 8 bytes (RTCP header)
}
return out, err
isEncrypted := encrypted[decryptedLen]&srtcpEncryptionFlag != 0
if isEncrypted {
counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) //nolint:gosec // G115
err = xorBytesCTR(s.srtcpBlock, counter[:], dst[srtcpHeaderSize:], encrypted[srtcpHeaderSize:decryptedLen])
} else if !sameBuffer {
copy(dst[srtcpHeaderSize:], encrypted[srtcpHeaderSize:])
}
return dst, err
}
func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([]byte, error) {
func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32, rocInAuthTag bool) ([]byte, error) {
// https://tools.ietf.org/html/rfc3711#section-4.2
// In the case of SRTP, M SHALL consist of the Authenticated
// Portion of the packet (as specified in Figure 1) concatenated with
@ -279,7 +337,13 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([
if err != nil {
return nil, err
}
return s.srtpSessionAuth.Sum(nil)[0:authTagLen], nil
var authTag []byte
if rocInAuthTag {
authTag = append(authTag, rocRaw[:]...)
}
return s.srtpSessionAuth.Sum(authTag)[0:authTagLen], nil
}
func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, error) {
@ -311,21 +375,6 @@ func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 {
authTagLen, _ := s.AuthTagRTCPLen()
tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki))
srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize]
return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31)
}
func (s *srtpCipherAesCmHmacSha1) getMKI(in []byte, rtp bool) []byte {
mkiLen := len(s.mki)
if mkiLen == 0 {
return nil
}
var authTagLen int
if rtp {
authTagLen, _ = s.AuthTagRTPLen()
} else {
authTagLen, _ = s.AuthTagRTCPLen()
}
tailOffset := len(in) - (authTagLen + mkiLen)
return in[tailOffset : tailOffset+mkiLen]
}

View File

@ -13,10 +13,10 @@ import (
"github.com/pion/transport/v3/packetio"
)
// Limit the buffer size to 100KB
// Limit the buffer size to 100KB.
const srtcpBufferSize = 100 * 1000
// ReadStreamSRTCP handles decryption for a single RTCP SSRC
// ReadStreamSRTCP handles decryption for a single RTCP SSRC.
type ReadStreamSRTCP struct {
mu sync.Mutex
@ -40,12 +40,12 @@ func (r *ReadStreamSRTCP) write(buf []byte) (n int, err error) {
return n, err
}
// Used by getOrCreateReadStream
// Used by getOrCreateReadStream.
func newReadStreamSRTCP() readStream {
return &ReadStreamSRTCP{}
}
// ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn
// ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn.
func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) {
n, err := r.Read(buf)
if err != nil {
@ -61,7 +61,7 @@ func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) {
return n, header, nil
}
// Read reads and decrypts full RTCP packet from the nextConn
// Read reads and decrypts full RTCP packet from the nextConn.
func (r *ReadStreamSRTCP) Read(buf []byte) (int, error) {
return r.buffer.Read(buf)
}
@ -74,10 +74,11 @@ func (r *ReadStreamSRTCP) SetReadDeadline(t time.Time) error {
}); ok {
return b.SetReadDeadline(t)
}
return nil
}
// Close removes the ReadStream from the session and cleans up any associated state
// Close removes the ReadStream from the session and cleans up any associated state.
func (r *ReadStreamSRTCP) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
@ -96,6 +97,7 @@ func (r *ReadStreamSRTCP) Close() error {
}
r.session.removeReadStream(r.ssrc)
return nil
}
}
@ -128,17 +130,17 @@ func (r *ReadStreamSRTCP) init(child streamSession, ssrc uint32) error {
return nil
}
// GetSSRC returns the SSRC we are demuxing for
// GetSSRC returns the SSRC we are demuxing for.
func (r *ReadStreamSRTCP) GetSSRC() uint32 {
return r.ssrc
}
// WriteStreamSRTCP is stream for a single Session that is used to encrypt RTCP
// WriteStreamSRTCP is stream for a single Session that is used to encrypt RTCP.
type WriteStreamSRTCP struct {
session *SessionSRTCP
}
// WriteRTCP encrypts a RTCP header and its payload to the nextConn
// WriteRTCP encrypts a RTCP header and its payload to the nextConn.
func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int, error) {
headerRaw, err := header.Marshal()
if err != nil {
@ -148,7 +150,7 @@ func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int,
return w.session.write(append(headerRaw, payload...))
}
// Write encrypts and writes a full RTCP packets to the nextConn
// Write encrypts and writes a full RTCP packets to the nextConn.
func (w *WriteStreamSRTCP) Write(b []byte) (int, error) {
return w.session.write(b)
}

View File

@ -13,10 +13,10 @@ import (
"github.com/pion/transport/v3/packetio"
)
// Limit the buffer size to 1MB
// Limit the buffer size to 1MB.
const srtpBufferSize = 1000 * 1000
// ReadStreamSRTP handles decryption for a single RTP SSRC
// ReadStreamSRTP handles decryption for a single RTP SSRC.
type ReadStreamSRTP struct {
mu sync.Mutex
@ -29,7 +29,7 @@ type ReadStreamSRTP struct {
buffer io.ReadWriteCloser
}
// Used by getOrCreateReadStream
// Used by getOrCreateReadStream.
func newReadStreamSRTP() readStream {
return &ReadStreamSRTP{}
}
@ -74,12 +74,12 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) {
return n, err
}
// Read reads and decrypts full RTP packet from the nextConn
// Read reads and decrypts full RTP packet from the nextConn.
func (r *ReadStreamSRTP) Read(buf []byte) (int, error) {
return r.buffer.Read(buf)
}
// ReadRTP reads and decrypts full RTP packet and its header from the nextConn
// ReadRTP reads and decrypts full RTP packet and its header from the nextConn.
func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) {
n, err := r.Read(buf)
if err != nil {
@ -104,10 +104,11 @@ func (r *ReadStreamSRTP) SetReadDeadline(t time.Time) error {
}); ok {
return b.SetReadDeadline(t)
}
return nil
}
// Close removes the ReadStream from the session and cleans up any associated state
// Close removes the ReadStream from the session and cleans up any associated state.
func (r *ReadStreamSRTP) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
@ -126,26 +127,27 @@ func (r *ReadStreamSRTP) Close() error {
}
r.session.removeReadStream(r.ssrc)
return nil
}
}
// GetSSRC returns the SSRC we are demuxing for
// GetSSRC returns the SSRC we are demuxing for.
func (r *ReadStreamSRTP) GetSSRC() uint32 {
return r.ssrc
}
// WriteStreamSRTP is stream for a single Session that is used to encrypt RTP
// WriteStreamSRTP is stream for a single Session that is used to encrypt RTP.
type WriteStreamSRTP struct {
session *SessionSRTP
}
// WriteRTP encrypts a RTP packet and writes to the connection
// WriteRTP encrypts a RTP packet and writes to the connection.
func (w *WriteStreamSRTP) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
return w.session.writeRTP(header, payload)
}
// Write encrypts and writes a full RTP packets to the nextConn
// Write encrypts and writes a full RTP packets to the nextConn.
func (w *WriteStreamSRTP) Write(b []byte) (int, error) {
return w.session.write(b)
}

View File

@ -3,9 +3,11 @@
package srtp
import "bytes"
import (
"unsafe"
)
// Grow the buffer size to the given number of bytes.
// growBufferSize grows the buffer size to the given number of bytes.
func growBufferSize(buf []byte, size int) []byte {
if size <= cap(buf) {
return buf[:size]
@ -13,24 +15,25 @@ func growBufferSize(buf []byte, size int) []byte {
buf2 := make([]byte, size)
copy(buf2, buf)
return buf2
}
// Check if buffers match, if not allocate a new buffer and return it
func allocateIfMismatch(dst, src []byte) []byte {
if dst == nil {
dst = make([]byte, len(src))
copy(dst, src)
} else if !bytes.Equal(dst, src) { // bytes.Equal returns on ref equality, no optimization needed
extraNeeded := len(src) - len(dst)
if extraNeeded > 0 {
dst = append(dst, make([]byte, extraNeeded)...)
} else if extraNeeded < 0 {
dst = dst[:len(dst)+extraNeeded]
}
copy(dst, src)
// isSameBuffer returns true if slices a and b share the same underlying buffer.
func isSameBuffer(a, b []byte) bool {
// If both are nil, they are technically the same (no buffer)
if a == nil && b == nil {
return true
}
return dst
// If either is nil, or both have 0 capacity, they can't share backing buffer
if cap(a) == 0 || cap(b) == 0 {
return false
}
// Create a slice of length 1 from each if possible
aPtr := unsafe.Pointer(&a[:1][0]) // nolint:gosec
bPtr := unsafe.Pointer(&b[:1][0]) // nolint:gosec
return aPtr == bPtr
}

View File

@ -1,6 +1,9 @@
# SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
# SPDX-License-Identifier: MIT
run:
timeout: 5m
linters-settings:
govet:
enable:
@ -16,23 +19,42 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@ -43,18 +65,17 @@ linters:
- forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords
- goerr113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
@ -62,9 +83,15 @@ linters:
- grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes
@ -72,31 +99,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- exhaustivestruct # Checks if all struct's fields are initialized
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- ifshort # Checks that your code uses short syntax for if-statements whenever possible
- gochecknoinits # Checks that no init functions are present in Go code
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- maligned # Tool to detect Go structs that would take less memory if their fields were sorted
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- mnd # An analyzer to detect magic numbers
- nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated
@ -104,8 +122,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- varnamelen # checks that the length of a variable's name matches its scope
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines!
@ -114,9 +131,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@ -7,13 +7,13 @@
<h4 align="center">A toolkit for building TURN clients and servers in Go</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-turn-gray.svg?longCache=true&colorB=brightgreen" alt="Pion TURN"></a>
<a href="http://gophers.slack.com/messages/pion"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<a href="https://discord.gg/PngbdqpFbt"><img src="https://img.shields.io/badge/join-us%20on%20discord-gray.svg?longCache=true&logo=discord&colorB=brightblue" alt="join us on Discord"></a> <a href="https://bsky.app/profile/pion.ly"><img src="https://img.shields.io/badge/follow-us%20on%20bluesky-gray.svg?longCache=true&logo=bluesky&colorB=brightblue" alt="Follow us on Bluesky"></a>
<a href="https://github.com/pion/awesome-pion" alt="Awesome Pion"><img src="https://cdn.rawgit.com/sindresorhus/awesome/d7305f38d29fed78fa85652e3a63e154dd8e8829/media/badge.svg"></a>
<br>
<img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pion/turn/test.yaml">
<a href="https://pkg.go.dev/github.com/pion/turn/v3"><img src="https://pkg.go.dev/badge/github.com/pion/turn/v3.svg" alt="Go Reference"></a>
<a href="https://pkg.go.dev/github.com/pion/turn/v4"><img src="https://pkg.go.dev/badge/github.com/pion/turn/v4.svg" alt="Go Reference"></a>
<a href="https://codecov.io/gh/pion/turn"><img src="https://codecov.io/gh/pion/turn/branch/master/graph/badge.svg" alt="Coverage Status"></a>
<a href="https://goreportcard.com/report/github.com/pion/turn/v3"><img src="https://goreportcard.com/badge/github.com/pion/turn/v3" alt="Go Report Card"></a>
<a href="https://goreportcard.com/report/github.com/pion/turn/v4"><img src="https://goreportcard.com/badge/github.com/pion/turn/v4" alt="Go Report Card"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
</p>
<br>
@ -79,9 +79,9 @@ Yes.
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Slack](https://pion.ly/slack).
Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt).
Follow the [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news.
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)

View File

@ -49,7 +49,7 @@ type ClientConfig struct {
LoggerFactory logging.LoggerFactory
}
// Client is a STUN server client
// Client is a STUN server client.
type Client struct {
conn net.PacketConn // Read-only
net transport.Net // Read-only
@ -72,7 +72,8 @@ type Client struct {
log logging.LeveledLogger // Read-only
}
// NewClient returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0"
// NewClient returns a new Client instance. listeningAddress is the address and port to listen on,
// default "0.0.0.0:0".
func NewClient(config *ClientConfig) (*Client, error) {
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
@ -119,7 +120,7 @@ func NewClient(config *ClientConfig) (*Client, error) {
log.Debugf("Resolved TURN server %s to %s", config.TURNServerAddr, turnServ)
}
c := &Client{
client := &Client{
conn: config.Conn,
stunServerAddr: stunServ,
turnServerAddr: turnServ,
@ -133,25 +134,25 @@ func NewClient(config *ClientConfig) (*Client, error) {
log: log,
}
return c, nil
return client, nil
}
// TURNServerAddr return the TURN server address
// TURNServerAddr return the TURN server address.
func (c *Client) TURNServerAddr() net.Addr {
return c.turnServerAddr
}
// STUNServerAddr return the STUN server address
// STUNServerAddr return the STUN server address.
func (c *Client) STUNServerAddr() net.Addr {
return c.stunServerAddr
}
// Username returns username
// Username returns username.
func (c *Client) Username() stun.Username {
return c.username
}
// Realm return realm
// Realm return realm.
func (c *Client) Realm() stun.Realm {
return c.realm
}
@ -175,12 +176,14 @@ func (c *Client) Listen() error {
n, from, err := c.conn.ReadFrom(buf)
if err != nil {
c.log.Debugf("Failed to read: %s. Exiting loop", err)
break
}
_, err = c.HandleInbound(buf[:n], from)
if err != nil {
c.log.Debugf("Failed to handle inbound message: %s. Exiting loop", err)
break
}
}
@ -191,7 +194,7 @@ func (c *Client) Listen() error {
return nil
}
// Close closes this client
// Close closes this client.
func (c *Client) Close() {
c.mutexTrMap.Lock()
defer c.mutexTrMap.Unlock()
@ -201,7 +204,7 @@ func (c *Client) Close() {
// TransactionID & Base64: https://play.golang.org/p/EEgmJDI971P
// SendBindingRequestTo sends a new STUN request to the given transport address
// SendBindingRequestTo sends a new STUN request to the given transport address.
func (c *Client) SendBindingRequestTo(to net.Addr) (net.Addr, error) {
attrs := []stun.Setter{stun.TransactionID, stun.BindingRequest}
if len(c.software) > 0 {
@ -228,15 +231,21 @@ func (c *Client) SendBindingRequestTo(to net.Addr) (net.Addr, error) {
}, nil
}
// SendBindingRequest sends a new STUN request to the STUN server
// SendBindingRequest sends a new STUN request to the STUN server.
func (c *Client) SendBindingRequest() (net.Addr, error) {
if c.stunServerAddr == nil {
return nil, errSTUNServerAddressNotSet
}
return c.SendBindingRequestTo(c.stunServerAddr)
}
func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) {
func (c *Client) sendAllocateRequest(protocol proto.Protocol) ( //nolint:cyclop
proto.RelayedAddress,
proto.Lifetime,
stun.Nonce,
error,
) {
var relayed proto.RelayedAddress
var lifetime proto.Lifetime
var nonce stun.Nonce
@ -295,6 +304,7 @@ func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddr
if err = code.GetFrom(res); err == nil {
return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113
}
@ -307,10 +317,11 @@ func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddr
if err := lifetime.GetFrom(res); err != nil {
return relayed, lifetime, nonce, err
}
return relayed, lifetime, nonce, nil
}
// Allocate sends a TURN allocation request to the given transport address
// Allocate sends a TURN allocation request to the given transport address.
func (c *Client) Allocate() (net.PacketConn, error) {
if err := c.allocTryLock.Lock(); err != nil {
return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error())
@ -403,10 +414,11 @@ func (c *Client) CreatePermission(addrs ...net.Addr) error {
return err
}
}
return nil
}
// PerformTransaction performs STUN transaction
// PerformTransaction performs STUN transaction.
func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult bool) (client.TransactionResult,
error,
) {
@ -442,11 +454,12 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult
if res.Err != nil {
return res, res.Err
}
return res, nil
}
// OnDeallocated is called when de-allocation of relay address has been complete.
// (Called by UDPConn)
// (Called by UDPConn).
func (c *Client) OnDeallocated(net.Addr) {
c.setRelayedUDPConn(nil)
c.setTCPAllocation(nil)
@ -494,7 +507,7 @@ func (c *Client) HandleInbound(data []byte, from net.Addr) (bool, error) {
return false, nil
}
func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { //nolint:cyclop
raw := make([]byte, len(data))
copy(raw, data)
@ -507,7 +520,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
return fmt.Errorf("%w : %s", errUnexpectedSTUNRequestMessage, msg.String())
}
if msg.Type.Class == stun.ClassIndication {
if msg.Type.Class == stun.ClassIndication { // nolint:nestif
switch msg.Type.Method {
case stun.MethodData:
var peerAddr proto.PeerAddress
@ -529,6 +542,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
relayedConn := c.relayedUDPConn()
if relayedConn == nil {
c.log.Debug("No relayed conn allocated")
return nil // Silently discard
}
relayedConn.HandleInbound(data, from)
@ -553,6 +567,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
allocation := c.getTCPAllocation()
if allocation == nil {
c.log.Debug("No TCP allocation exists")
return nil // Silently discard
}
@ -560,6 +575,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
default:
c.log.Debug("Received unsupported STUN method")
}
return nil
}
@ -576,6 +592,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
c.mutexTrMap.Unlock()
// Silently discard
c.log.Debugf("No transaction for %s", msg)
return nil
}
@ -607,6 +624,7 @@ func (c *Client) handleChannelData(data []byte) error {
relayedConn := c.relayedUDPConn()
if relayedConn == nil {
c.log.Debug("No relayed conn allocated")
return nil // Silently discard
}
@ -618,6 +636,7 @@ func (c *Client) handleChannelData(data []byte) error {
c.log.Tracef("Channel data received from %s (ch=%d)", addr.String(), int(chData.Number))
relayedConn.HandleInbound(chData.Data, addr)
return nil
}
@ -638,11 +657,12 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int) {
}) {
c.log.Debug("No listener for transaction")
}
return
}
c.log.Tracef("Retransmitting transaction %s to %s (nRtx=%d)",
trKey, tr.To.String(), nRtx)
trKey, tr.To, nRtx)
_, err := c.conn.WriteTo(tr.Raw, tr.To)
if err != nil {
c.trMap.Delete(trKey)
@ -651,6 +671,7 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int) {
}) {
c.log.Debug("No listener for transaction")
}
return
}
tr.StartRtxTimer(c.onRtxTimeout)

View File

@ -22,7 +22,7 @@ type allocationResponse struct {
}
// Allocation is tied to a FiveTuple and relays traffic
// use CreateAllocation and GetAllocation to operate
// use CreateAllocation and GetAllocation to operate.
type Allocation struct {
RelayAddr net.Addr
Protocol Protocol
@ -55,7 +55,7 @@ func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, log logging.
}
}
// GetPermission gets the Permission from the allocation
// GetPermission gets the Permission from the allocation.
func (a *Allocation) GetPermission(addr net.Addr) *Permission {
a.permissionsLock.RLock()
defer a.permissionsLock.RUnlock()
@ -63,9 +63,9 @@ func (a *Allocation) GetPermission(addr net.Addr) *Permission {
return a.permissions[ipnet.FingerprintAddr(addr)]
}
// AddPermission adds a new permission to the allocation
func (a *Allocation) AddPermission(p *Permission) {
fingerprint := ipnet.FingerprintAddr(p.Addr)
// AddPermission adds a new permission to the allocation.
func (a *Allocation) AddPermission(perms *Permission) {
fingerprint := ipnet.FingerprintAddr(perms.Addr)
a.permissionsLock.RLock()
existedPermission, ok := a.permissions[fingerprint]
@ -73,18 +73,19 @@ func (a *Allocation) AddPermission(p *Permission) {
if ok {
existedPermission.refresh(permissionTimeout)
return
}
p.allocation = a
perms.allocation = a
a.permissionsLock.Lock()
a.permissions[fingerprint] = p
a.permissions[fingerprint] = perms
a.permissionsLock.Unlock()
p.start(permissionTimeout)
perms.start(permissionTimeout)
}
// RemovePermission removes the net.Addr's fingerprint from the allocation's permissions
// RemovePermission removes the net.Addr's fingerprint from the allocation's permissions.
func (a *Allocation) RemovePermission(addr net.Addr) {
a.permissionsLock.Lock()
defer a.permissionsLock.Unlock()
@ -92,13 +93,13 @@ func (a *Allocation) RemovePermission(addr net.Addr) {
}
// AddChannelBind adds a new ChannelBind to the allocation, it also updates the
// permissions needed for this ChannelBind
func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) error {
// permissions needed for this ChannelBind.
func (a *Allocation) AddChannelBind(chanBind *ChannelBind, lifetime time.Duration) error {
// Check that this channel id isn't bound to another transport address, and
// that this transport address isn't bound to another channel number.
channelByNumber := a.GetChannelByNumber(c.Number)
channelByNumber := a.GetChannelByNumber(chanBind.Number)
if channelByNumber != a.GetChannelByAddr(c.Peer) {
if channelByNumber != a.GetChannelByAddr(chanBind.Peer) {
return errSameChannelDifferentPeer
}
@ -107,12 +108,12 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro
a.channelBindingsLock.Lock()
defer a.channelBindingsLock.Unlock()
c.allocation = a
a.channelBindings = append(a.channelBindings, c)
c.start(lifetime)
chanBind.allocation = a
a.channelBindings = append(a.channelBindings, chanBind)
chanBind.start(lifetime)
// Channel binds also refresh permissions.
a.AddPermission(NewPermission(c.Peer, a.log))
a.AddPermission(NewPermission(chanBind.Peer, a.log))
} else {
channelByNumber.refresh(lifetime)
@ -123,7 +124,7 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro
return nil
}
// RemoveChannelBind removes the ChannelBind from this allocation by id
// RemoveChannelBind removes the ChannelBind from this allocation by id.
func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool {
a.channelBindingsLock.Lock()
defer a.channelBindingsLock.Unlock()
@ -131,6 +132,7 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool {
for i := len(a.channelBindings) - 1; i >= 0; i-- {
if a.channelBindings[i].Number == number {
a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...)
return true
}
}
@ -138,7 +140,7 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool {
return false
}
// GetChannelByNumber gets the ChannelBind from this allocation by id
// GetChannelByNumber gets the ChannelBind from this allocation by id.
func (a *Allocation) GetChannelByNumber(number proto.ChannelNumber) *ChannelBind {
a.channelBindingsLock.RLock()
defer a.channelBindingsLock.RUnlock()
@ -147,10 +149,11 @@ func (a *Allocation) GetChannelByNumber(number proto.ChannelNumber) *ChannelBind
return cb
}
}
return nil
}
// GetChannelByAddr gets the ChannelBind from this allocation by net.Addr
// GetChannelByAddr gets the ChannelBind from this allocation by net.Addr.
func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind {
a.channelBindingsLock.RLock()
defer a.channelBindingsLock.RUnlock()
@ -159,17 +162,18 @@ func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind {
return cb
}
}
return nil
}
// Refresh updates the allocations lifetime
// Refresh updates the allocations lifetime.
func (a *Allocation) Refresh(lifetime time.Duration) {
if !a.lifetimeTimer.Reset(lifetime) {
a.log.Errorf("Failed to reset allocation timer for %v", a.fiveTuple)
}
}
// SetResponseCache cache allocation response for retransmit allocation request
// SetResponseCache cache allocation response for retransmit allocation request.
func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte, attrs []stun.Setter) {
a.responseCache.Store(&allocationResponse{
transactionID: transactionID,
@ -177,15 +181,16 @@ func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte
})
}
// GetResponseCache return response cache for retransmit allocation request
// GetResponseCache return response cache for retransmit allocation request.
func (a *Allocation) GetResponseCache() (id [stun.TransactionIDSize]byte, attrs []stun.Setter) {
if res, ok := a.responseCache.Load().(*allocationResponse); ok && res != nil {
id, attrs = res.transactionID, res.responseAttrs
}
return
}
// Close closes the allocation
// Close closes the allocation.
func (a *Allocation) Close() error {
select {
case <-a.closed:
@ -233,13 +238,14 @@ func (a *Allocation) Close() error {
const rtpMTU = 1600
func (a *Allocation) packetHandler(m *Manager) {
func (a *Allocation) packetHandler(manager *Manager) {
buffer := make([]byte, rtpMTU)
for {
n, srcAddr, err := a.RelaySocket.ReadFrom(buffer)
if err != nil {
m.DeleteAllocation(a.fiveTuple)
manager.DeleteAllocation(a.fiveTuple)
return
}
@ -248,7 +254,7 @@ func (a *Allocation) packetHandler(m *Manager) {
n,
srcAddr)
if channel := a.GetChannelByAddr(srcAddr); channel != nil {
if channel := a.GetChannelByAddr(srcAddr); channel != nil { // nolint:nestif
channelData := &proto.ChannelData{
Data: buffer[:n],
Number: channel.Number,
@ -262,15 +268,22 @@ func (a *Allocation) packetHandler(m *Manager) {
udpAddr, ok := srcAddr.(*net.UDPAddr)
if !ok {
a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err)
return
}
peerAddressAttr := proto.PeerAddress{IP: udpAddr.IP, Port: udpAddr.Port}
dataAttr := proto.Data(buffer[:n])
msg, err := stun.Build(stun.TransactionID, stun.NewType(stun.MethodData, stun.ClassIndication), peerAddressAttr, dataAttr)
msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodData, stun.ClassIndication),
peerAddressAttr,
dataAttr,
)
if err != nil {
a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err)
return
}
a.log.Debugf("Relaying message from %s to client at %s",

View File

@ -25,7 +25,7 @@ type reservation struct {
port int
}
// Manager is used to hold active allocations
// Manager is used to hold active allocations.
type Manager struct {
lock sync.RWMutex
log logging.LeveledLogger
@ -58,21 +58,23 @@ func NewManager(config ManagerConfig) (*Manager, error) {
}, nil
}
// GetAllocation fetches the allocation matching the passed FiveTuple
// GetAllocation fetches the allocation matching the passed FiveTuple.
func (m *Manager) GetAllocation(fiveTuple *FiveTuple) *Allocation {
m.lock.RLock()
defer m.lock.RUnlock()
return m.allocations[fiveTuple.Fingerprint()]
}
// AllocationCount returns the number of existing allocations
// AllocationCount returns the number of existing allocations.
func (m *Manager) AllocationCount() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.allocations)
}
// Close closes the manager and closes all allocations it manages
// Close closes the manager and closes all allocations it manages.
func (m *Manager) Close() error {
m.lock.Lock()
defer m.lock.Unlock()
@ -82,11 +84,17 @@ func (m *Manager) Close() error {
return err
}
}
return nil
}
// CreateAllocation creates a new allocation and starts relaying
func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) {
// CreateAllocation creates a new allocation and starts relaying.
func (m *Manager) CreateAllocation(
fiveTuple *FiveTuple,
turnSocket net.PacketConn,
requestedPort int,
lifetime time.Duration,
) (*Allocation, error) {
switch {
case fiveTuple == nil:
return nil, errNilFiveTuple
@ -100,34 +108,35 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo
return nil, errLifetimeZero
}
if a := m.GetAllocation(fiveTuple); a != nil {
if alloc := m.GetAllocation(fiveTuple); alloc != nil {
return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple)
}
a := NewAllocation(turnSocket, fiveTuple, m.log)
alloc := NewAllocation(turnSocket, fiveTuple, m.log)
conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort)
if err != nil {
return nil, err
}
a.RelaySocket = conn
a.RelayAddr = relayAddr
alloc.RelaySocket = conn
alloc.RelayAddr = relayAddr
m.log.Debugf("Listening on relay address: %s", a.RelayAddr)
m.log.Debugf("Listening on relay address: %s", alloc.RelayAddr)
a.lifetimeTimer = time.AfterFunc(lifetime, func() {
m.DeleteAllocation(a.fiveTuple)
alloc.lifetimeTimer = time.AfterFunc(lifetime, func() {
m.DeleteAllocation(alloc.fiveTuple)
})
m.lock.Lock()
m.allocations[fiveTuple.Fingerprint()] = a
m.allocations[fiveTuple.Fingerprint()] = alloc
m.lock.Unlock()
go a.packetHandler(m)
return a, nil
go alloc.packetHandler(m)
return alloc, nil
}
// DeleteAllocation removes an allocation
// DeleteAllocation removes an allocation.
func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) {
fingerprint := fiveTuple.Fingerprint()
@ -145,7 +154,7 @@ func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) {
}
}
// CreateReservation stores the reservation for the token+port
// CreateReservation stores the reservation for the token+port.
func (m *Manager) CreateReservation(reservationToken string, port int) {
time.AfterFunc(30*time.Second, func() {
m.lock.Lock()
@ -153,6 +162,7 @@ func (m *Manager) CreateReservation(reservationToken string, port int) {
for i := len(m.reservations) - 1; i >= 0; i-- {
if m.reservations[i].token == reservationToken {
m.reservations = append(m.reservations[:i], m.reservations[i+1:]...)
return
}
}
@ -166,7 +176,7 @@ func (m *Manager) CreateReservation(reservationToken string, port int) {
m.lock.Unlock()
}
// GetReservation returns the port for a given reservation if it exists
// GetReservation returns the port for a given reservation if it exists.
func (m *Manager) GetReservation(reservationToken string) (int, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
@ -176,10 +186,11 @@ func (m *Manager) GetReservation(reservationToken string) (int, bool) {
return r.port, true
}
}
return 0, false
}
// GetRandomEvenPort returns a random un-allocated udp4 port
// GetRandomEvenPort returns a random un-allocated udp4 port.
func (m *Manager) GetRandomEvenPort() (int, error) {
for i := 0; i < 128; i++ {
conn, addr, err := m.allocatePacketConn("udp4", 0)
@ -199,11 +210,12 @@ func (m *Manager) GetRandomEvenPort() (int, error) {
return udpAddr.Port, nil
}
}
return 0, errFailedToAllocateEvenPort
}
// GrantPermission handles permission requests by calling the permission handler callback
// associated with the TURN server listener socket
// associated with the TURN server listener socket.
func (m *Manager) GrantPermission(sourceAddr net.Addr, peerIP net.IP) error {
// No permission handler: open
if m.permissionHandler == nil {

View File

@ -22,7 +22,7 @@ type ChannelBind struct {
log logging.LeveledLogger
}
// NewChannelBind creates a new ChannelBind
// NewChannelBind creates a new ChannelBind.
func NewChannelBind(number proto.ChannelNumber, peer net.Addr, log logging.LeveledLogger) *ChannelBind {
return &ChannelBind{
Number: number,

View File

@ -7,10 +7,10 @@ import (
"net"
)
// Protocol is an enum for relay protocol
// Protocol is an enum for relay protocol.
type Protocol uint8
// Network protocols for relay
// Network protocols for relay.
const (
UDP Protocol = iota
TCP
@ -27,19 +27,19 @@ type FiveTuple struct {
SrcAddr, DstAddr net.Addr
}
// Equal asserts if two FiveTuples are equal
// Equal asserts if two FiveTuples are equal.
func (f *FiveTuple) Equal(b *FiveTuple) bool {
return f.Fingerprint() == b.Fingerprint()
}
// FiveTupleFingerprint is a comparable representation of a FiveTuple
// FiveTupleFingerprint is a comparable representation of a FiveTuple.
type FiveTupleFingerprint struct {
srcIP, dstIP [16]byte
srcPort, dstPort uint16
protocol Protocol
}
// Fingerprint is the identity of a FiveTuple
// Fingerprint is the identity of a FiveTuple.
func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) {
srcIP, srcPort := netAddrIPAndPort(f.SrcAddr)
copy(fp.srcIP[:], srcIP)
@ -48,15 +48,16 @@ func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) {
copy(fp.dstIP[:], dstIP)
fp.dstPort = dstPort
fp.protocol = f.Protocol
return
}
func netAddrIPAndPort(addr net.Addr) (net.IP, uint16) {
switch a := addr.(type) {
case *net.UDPAddr:
return a.IP.To16(), uint16(a.Port)
return a.IP.To16(), uint16(a.Port) // nolint:gosec // G115
case *net.TCPAddr:
return a.IP.To16(), uint16(a.Port)
return a.IP.To16(), uint16(a.Port) // nolint:gosec // G115
default:
return nil, 0
}

View File

@ -22,7 +22,7 @@ type Permission struct {
log logging.LeveledLogger
}
// NewPermission create a new Permission
// NewPermission create a new Permission.
func NewPermission(addr net.Addr, log logging.LeveledLogger) *Permission {
return &Permission{
Addr: addr,

View File

@ -16,7 +16,7 @@ import (
"github.com/pion/turn/v4/internal/proto"
)
// AllocationConfig is a set of configuration params use by NewUDPConn and NewTCPAllocation
// AllocationConfig is a set of configuration params use by NewUDPConn and NewTCPAllocation.
type AllocationConfig struct {
Client Client
RelayedAddr net.Addr
@ -82,6 +82,7 @@ func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) er
if dontWait {
a.log.Debug("Refresh request sent")
return nil
}
@ -93,10 +94,13 @@ func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) er
if err = code.GetFrom(res); err == nil {
if code.Code == stun.CodeStaleNonce {
a.setNonceFromMsg(res)
return errTryAgain
}
return err
}
return fmt.Errorf("%s", res.Type) //nolint:goerr113
}
@ -108,6 +112,7 @@ func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) er
a.setLifetime(updatedLifetime.Duration)
a.log.Debugf("Updated lifetime: %d seconds", int(a.lifetime().Seconds()))
return nil
}
@ -115,6 +120,7 @@ func (a *allocation) refreshPermissions() error {
addrs := a.permMap.addrs()
if len(addrs) == 0 {
a.log.Debug("No permission to refresh")
return nil
}
if err := a.CreatePermissions(addrs...); err != nil {
@ -122,9 +128,11 @@ func (a *allocation) refreshPermissions() error {
return errTryAgain
}
a.log.Errorf("Fail to refresh permissions: %s", err)
return err
}
a.log.Debug("Refresh permissions successful")
return nil
}

View File

@ -61,7 +61,7 @@ func (b *binding) refreshedAt() time.Time {
return b._refreshedAt
}
// Thread-safe binding map
// Thread-safe binding map.
type bindingManager struct {
chanMap map[uint16]*binding
addrMap map[string]*binding
@ -84,6 +84,7 @@ func (mgr *bindingManager) assignChannelNumber() uint16 {
} else {
mgr.next++
}
return n
}
@ -100,6 +101,7 @@ func (mgr *bindingManager) create(addr net.Addr) *binding {
mgr.chanMap[b.number] = b
mgr.addrMap[b.addr.String()] = b
return b
}
@ -108,6 +110,7 @@ func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) {
defer mgr.mutex.RUnlock()
b, ok := mgr.addrMap[addr.String()]
return b, ok
}
@ -116,6 +119,7 @@ func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) {
defer mgr.mutex.RUnlock()
b, ok := mgr.chanMap[number]
return b, ok
}
@ -130,6 +134,7 @@ func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool {
delete(mgr.addrMap, addr.String())
delete(mgr.chanMap, b.number)
return true
}
@ -144,6 +149,7 @@ func (mgr *bindingManager) deleteByNumber(number uint16) bool {
delete(mgr.addrMap, b.addr.String())
delete(mgr.chanMap, number)
return true
}

View File

@ -10,7 +10,7 @@ import (
"github.com/pion/stun/v3"
)
// Client is an interface for the public turn.Client in order to break cyclic dependencies
// Client is an interface for the public turn.Client in order to break cyclic dependencies.
type Client interface {
WriteTo(data []byte, to net.Addr) (int, error)
PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error)

View File

@ -8,10 +8,10 @@ import (
"time"
)
// PeriodicTimerTimeoutHandler is a handler called on timeout
// PeriodicTimerTimeoutHandler is a handler called on timeout.
type PeriodicTimerTimeoutHandler func(timerID int)
// PeriodicTimer is a periodic timer
// PeriodicTimer is a periodic timer.
type PeriodicTimer struct {
id int
interval time.Duration
@ -20,7 +20,7 @@ type PeriodicTimer struct {
mutex sync.RWMutex
}
// NewPeriodicTimer create a new timer
// NewPeriodicTimer create a new timer.
func NewPeriodicTimer(id int, timeoutHandler PeriodicTimerTimeoutHandler, interval time.Duration) *PeriodicTimer {
return &PeriodicTimer{
id: id,
@ -76,7 +76,7 @@ func (t *PeriodicTimer) Stop() {
}
// IsRunning tests if the timer is running.
// Debug purpose only
// Debug purpose only.
func (t *PeriodicTimer) IsRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()

View File

@ -32,7 +32,7 @@ func (p *permission) state() permState {
return permState(atomic.LoadInt32((*int32)(&p.st)))
}
// Thread-safe permission map
// Thread-safe permission map.
type permissionMap struct {
permMap map[string]*permission
mutex sync.RWMutex
@ -43,6 +43,7 @@ func (m *permissionMap) insert(addr net.Addr, p *permission) bool {
defer m.mutex.Unlock()
p.addr = addr
m.permMap[ipnet.FingerprintAddr(addr)] = p
return true
}
@ -50,6 +51,7 @@ func (m *permissionMap) find(addr net.Addr) (*permission, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
p, ok := m.permMap[ipnet.FingerprintAddr(addr)]
return p, ok
}
@ -67,6 +69,7 @@ func (m *permissionMap) addrs() []net.Addr {
for _, p := range m.permMap {
addrs = append(addrs, p.addr)
}
return addrs
}

Some files were not shown because too many files have changed in this diff Show More