Proxy: Refactor server APIs and expand RTMP test coverage. v7.0.147 (#4672)

This PR refactors the Go proxy server internals and significantly
expands RTMP/proxy verification coverage.

- Rename internal/protocol to internal/server to better describe the
package responsibility.
- Refactor proxy server constructors and types toward cleaner exported
interfaces:
      - NewRTMPServer
      - NewWebRTCServer
      - NewHTTPAPIServer
      - NewHTTPStreamServer
      - NewSystemAPI
  - Expose RTMP protocol interfaces for better testability:
      - Handshake
      - Protocol
      - Message
- AMF0 public interfaces such as Amf0Any, Amf0Number, Amf0String,
Amf0Object, etc.
- Add RTMP unit tests covering AMF0, handshake, protocol messages,
packet encoding/decoding, and API examples.
  - Add generated RTMP fakes for interface-based tests.
  - Add proxy E2E scripts for:
      - multi-origin memory load-balancer routing
      - Redis multi-proxy routing
- RTMP transmuxing verification across RTMP, HTTP-FLV, HLS, and optional
WebRTC WHEP
- Update OpenClaw/SRSBot development docs and memory to reflect the new
package layout, new verification scripts, and unsupported origin/edge
development scope.

---------

Co-authored-by: chatgpt-codex-connector[bot] <199175422+chatgpt-codex-connector[bot]@users.noreply.github.com>
This commit is contained in:
Winlin 2026-05-02 09:36:55 -04:00 committed by GitHub
parent d8696434cb
commit 3663a8e38f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 4074 additions and 292 deletions

View File

@ -215,7 +215,7 @@ The next-generation server (`cmd/` + `internal/`) is written in Go and maintaine
`internal/bootstrap` — Server startup and lifecycle orchestration. Sets up logging context, signal handlers, loads environment, installs force-quit timer, optionally starts pprof, initializes the load balancer (memory or Redis based on `PROXY_LOAD_BALANCER_TYPE`), then starts all six servers sequentially (RTMP, WebRTC, HTTP API, SRT, System API, HTTP Stream) and blocks until context is cancelled. Deferred `Close()` on each server ensures graceful shutdown.
`internal/protocol` — Protocol proxy servers. Each server accepts client connections, parses just enough of the protocol to extract the stream URL, picks a backend via the load balancer, and proxies traffic bidirectionally. Contains five proxy servers: (1) **RTMP proxy** (`rtmp.go`) — TCP listener, simple handshake, parses connect/publish/play to get stream URL, bidirectional RTMP message copying, stateless. (2) **HTTP stream proxy** (`http.go`) — serves static files, proxies HTTP-FLV/TS via reverse-proxy, proxies HLS m3u8 with `spbhid` rewriting so TS segment requests route to the same backend. (3) **WebRTC proxy** (`rtc.go`) — two-phase: WHIP/WHEP signaling (SDP rewrite to replace backend UDP port with proxy's) + UDP media transport (identifies connections by STUN ufrag, supports address migration), stateful. (4) **SRT proxy** (`srt.go`) — intercepts SRT 4-step handshake locally, parses stream ID on handshake 2, replays full handshake with backend, then proxies UDP bidirectionally, stateful per-connection. (5) **HTTP API + System API** (`api.go`) — HTTP API delegates WHIP/WHEP to WebRTC server; System API provides `/api/v1/srs/register` where backend SRS C++ servers register themselves so the load balancer knows about them.
`internal/server` — Proxy server implementations. Each server accepts client connections, parses just enough of the protocol to extract the stream URL, picks a backend via the load balancer, and proxies traffic bidirectionally. Contains five proxy servers: (1) **RTMP proxy** (`rtmp.go`) — TCP listener, simple handshake, parses connect/publish/play to get stream URL, bidirectional RTMP message copying, stateless. (2) **HTTP stream proxy** (`http.go`) — serves static files, proxies HTTP-FLV/TS via reverse-proxy, proxies HLS m3u8 with `spbhid` rewriting so TS segment requests route to the same backend. (3) **WebRTC proxy** (`rtc.go`) — two-phase: WHIP/WHEP signaling (SDP rewrite to replace backend UDP port with proxy's) + UDP media transport (identifies connections by STUN ufrag, supports address migration), stateful. (4) **SRT proxy** (`srt.go`) — intercepts SRT 4-step handshake locally, parses stream ID on handshake 2, replays full handshake with backend, then proxies UDP bidirectionally, stateful per-connection. (5) **HTTP API + System API** (`api.go`) — HTTP API delegates WHIP/WHEP to WebRTC server; System API provides `/api/v1/srs/register` where backend SRS C++ servers register themselves so the load balancer knows about them.
`internal/rtmp` — RTMP protocol implementation (parsing, not proxying). Full RTMP chunk stream and message protocol: simple handshake (C0/C1/C2), chunk stream reader/writer with all four format types, extended timestamp, message reassembly from chunks. Defines all RTMP message types, chunk stream IDs, and command names. Packet types include ConnectApp, CreateStream, Publish, Play, Call, SetChunkSize, WindowAcknowledgementSize, SetPeerBandwidth, UserControl. Uses Go generics (`ExpectPacket[T]`) to read until a specific packet type arrives. Also includes full AMF0 encoder/decoder supporting Number, Boolean, String, Object, Null, Undefined, EcmaArray, StrictArray, Date, LongString — with ordered key-value maps, auto-type-discovery, and safe type converters.
@ -301,6 +301,9 @@ The knowledge base (`memory/srs-*.md`) captures William's knowledge about SRS
- `proxy-load-balancer.md` — Load balancer design: memory vs Redis implementations, stream-to-server mapping, server health via heartbeats, protocol-specific state
- `proxy-origin-cluster.md` — Origin cluster tutorial: build proxy + SRS, configure multi-origin with proxy, stream publishing and playback verification
**Next-Generation Server API Examples** — Executable API documentation:
- `internal/rtmp/example_test.go` — RTMP API examples: AMF0, handshake, and protocol workflow
## Testing and Verification Structure
How to verify SRS works correctly.
@ -343,6 +346,13 @@ How to verify SRS works correctly.
- Reconnecting Load Test
- Janus
`.openclaw/skills/srs-develop/scripts/` — Go proxy verification scripts:
- `proxy-utest.sh` — Runs Go proxy unit tests with optional coverage.
- `proxy-e2e-test.sh` — Single-origin RTMP proxy E2E test.
- `proxy-e2e-cluster-test.sh` — Multi-origin memory load-balancer E2E test.
- `proxy-e2e-redis-test.sh` — Multi-proxy Redis load-balancer E2E test.
- `proxy-e2e-transmux-test.sh` — RTMP publish through proxy, then verify RTMP, HTTP-FLV, HLS, and WebRTC playback.
**Summary: The Key Differences**
| | Unit Tests | Black-box | E2E | Benchmark |

View File

@ -127,6 +127,12 @@ Which underlying transport (TCP or UDP) each protocol uses in SRS:
- **RTSP** — TCP. SRS only supports TCP transport (no UDP/RTP interleaved).
- **GB28181** — TCP. PS stream over TCP.
Related transport protocols that are important in the media industry but **not supported by SRS or Oryx**:
- **RIST** — UDP. Reliable Internet Stream Transport. Similar to SRT: a reliable, low-latency media transport over UDP, with retransmission and encryption options.
- **MoQ** — UDP/QUIC. Media over QUIC. An IETF effort for low-latency media ingest and delivery over QUIC, usually over UDP and optionally through WebTransport.
- **WebTransport** — UDP/QUIC, HTTP/3-based. A browser and network transport API/protocol that can carry media data, and one possible substrate for MoQ.
## Most Common Usage
The simplest way to use SRS: publish an RTMP stream and play it.
@ -292,4 +298,3 @@ Config files are in the `conf/` folder. Key files:
- Other files exist for specific features like clustering, DVR, or different protocols.
SRS also supports configuration via environment variables. This is especially useful for Docker and cloud-native deployments — you can set environment variables in YAML files or other platforms without needing a separate config file. It's convenient to copy and paste, making documentation clearer. In the SRS docs, environment variables are often used to show how to run SRS with different configurations.

View File

@ -1,6 +1,6 @@
---
name: srs-develop
description: Develop, modify, debug, and maintain the next-generation SRS media server written in Go — including the proxy, origin, and edge servers. This is the AI-maintained successor to the first-generation C++ SRS server. Use for all development tasks, for example, adding features, fixing bugs, refactoring code, understanding code architecture, reviewing changes, and writing tests for the Go codebase. NOT for end-user support, usage questions, configuration help, or learning how to use SRS — use the srs-support skill for those. Only activate when the task is explicitly about developing or modifying the Go SRS codebase.
description: Develop, modify, debug, and maintain the next-generation SRS media server written in Go. This is the AI-maintained successor to the first-generation C++ SRS server. Currently, planned changes are supported for the Go proxy server only; the next-generation Go origin and edge server workflows are not yet supported. Use for all development tasks, for example, adding features, fixing bugs, refactoring code, understanding code architecture, reviewing changes, and writing tests for the Go codebase. NOT for end-user support, usage questions, configuration help, or learning how to use SRS — use the srs-support skill for those. Only activate when the task is explicitly about developing or modifying the Go SRS codebase.
---
# SRS Development
@ -100,7 +100,7 @@ Do NOT attempt unsupported tasks.
| Service | Route To | Status |
|---|---|---|
| **Proxy server** | → [Proxy Server](#proxy-server) | ✅ Supported |
| **Origin server** | → [Origin Server](#origin-server) | ✅ Supported |
| **Origin server** | → [Origin Server](#origin-server) | ❌ Not yet supported |
| **Edge server** | → [Edge Server](#edge-server) | ❌ Not yet supported |
**If the routed service is not yet supported**, stop and tell the user:
@ -143,17 +143,30 @@ Only after the user confirms the routing do you proceed to Step 2.
```
bash scripts/proxy-utest.sh --coverage
```
4. Run the proxy E2E test (starts proxy + SRS origin, publishes RTMP, verifies playback):
4. Run the proxy E2E tests:
- Single-origin RTMP proxy test (starts proxy + one SRS origin, publishes RTMP, verifies playback):
```
bash scripts/proxy-e2e-test.sh
```
- Multi-origin cluster routing test (starts proxy + two SRS origins, publishes multiple streams, verifies streams are assigned to different origins):
```
bash scripts/proxy-e2e-cluster-test.sh
```
- Redis multi-proxy routing test (requires local Redis; starts two proxy instances with Redis LB, publishes through one proxy, verifies playback through the other):
```
bash scripts/proxy-e2e-redis-test.sh
```
- RTMP transmuxing test (starts proxy + one SRS origin, publishes RTMP, verifies RTMP/HTTP-FLV/HLS playback, and verifies WebRTC WHEP playback when `PROXY_TRANSMUX_TEST_RTC=on`):
```
bash scripts/proxy-e2e-transmux-test.sh
```
5. If any tests fail, fix the issues and re-run until all tests pass.
All script paths are relative to this skill's directory.
### Origin Server
*(workflow steps to be defined)*
**Not yet supported.** This refers to the next-generation Go origin server workflow. The first-generation C++ origin server still exists, but it is in maintenance mode and only bug fixes are accepted there.
### Edge Server

View File

@ -0,0 +1,341 @@
#!/bin/bash
# E2E test for RTMP proxy origin-cluster routing: starts one proxy + two SRS
# origins, publishes multiple RTMP streams, verifies playback via proxy, and
# verifies that different streams are assigned to different origin servers.
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
exit 1
fi
# Ports — use high ports to avoid conflicts with running services.
# The proxy starts ALL servers, so we must assign unique ports for each.
PROXY_RTMP_PORT=11935
PROXY_HTTP_API_PORT=11985
PROXY_HTTP_SERVER_PORT=18080
PROXY_WEBRTC_PORT=18000
PROXY_SRT_PORT=20080
PROXY_SYSTEM_API_PORT=12025
SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv"
SRS_BINARY="$WORKSPACE/trunk/objs/srs"
# Origin ports from origin1-for-proxy.conf and origin2-for-proxy.conf.
ORIGIN1_RTMP_PORT=19351
ORIGIN1_HTTP_PORT=8081
ORIGIN1_API_PORT=19851
ORIGIN1_RTC_PORT=8001
ORIGIN1_SRT_PORT=10081
ORIGIN2_RTMP_PORT=19352
ORIGIN2_HTTP_PORT=8082
ORIGIN2_API_PORT=19853
ORIGIN2_RTC_PORT=8002
ORIGIN2_SRT_PORT=10082
# PIDs to clean up on exit.
PROXY_PID=""
ORIGIN_PIDS=()
FFMPEG_PIDS=()
cleanup() {
echo ""
echo "=== Cleaning up ==="
for pid in $PROXY_PID "${ORIGIN_PIDS[@]}" "${FFMPEG_PIDS[@]}"; do
if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then
kill "$pid" 2>/dev/null || true
fi
done
sleep 1
for pid in $PROXY_PID "${ORIGIN_PIDS[@]}" "${FFMPEG_PIDS[@]}"; do
if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then
kill -9 "$pid" 2>/dev/null || true
fi
done
echo "Cleanup done."
}
trap cleanup EXIT
wait_for_http() {
local url=$1
local name=$2
local i
for i in $(seq 1 30); do
if curl -fsS --max-time 2 "$url" >/dev/null 2>&1; then
echo "$name is ready."
return 0
fi
sleep 1
done
echo "Error: $name is not ready after 30s: $url" >&2
return 1
}
origin_has_stream() {
local api_port=$1
local stream=$2
curl -fsS --max-time 3 "http://127.0.0.1:$api_port/api/v1/streams/" 2>/dev/null | grep -q "$stream"
}
detect_origin_for_stream() {
local stream=$1
local i
for i in $(seq 1 10); do
local on_origin1=0
local on_origin2=0
if origin_has_stream "$ORIGIN1_API_PORT" "$stream"; then
on_origin1=1
fi
if origin_has_stream "$ORIGIN2_API_PORT" "$stream"; then
on_origin2=1
fi
if [[ $on_origin1 -eq 1 && $on_origin2 -eq 0 ]]; then
echo "origin1"
return 0
fi
if [[ $on_origin1 -eq 0 && $on_origin2 -eq 1 ]]; then
echo "origin2"
return 0
fi
if [[ $on_origin1 -eq 1 && $on_origin2 -eq 1 ]]; then
echo "Error: stream $stream exists on both origins; expected exactly one owner" >&2
return 1
fi
sleep 1
done
echo "Error: stream $stream was not found on either origin" >&2
return 1
}
verify_probe_has_av() {
local url=$1
local label=$2
local probe_output
probe_output=$(ffprobe -v error -rw_timeout 5000000 -show_streams "$url" 2>&1 || true)
if ! echo "$probe_output" | grep -q "codec_type=video"; then
echo "FAIL: No video stream detected for $label." >&2
echo "ffprobe output:" >&2
echo "$probe_output" >&2
exit 1
fi
if ! echo "$probe_output" | grep -q "codec_type=audio"; then
echo "FAIL: No audio stream detected for $label." >&2
echo "ffprobe output:" >&2
echo "$probe_output" >&2
exit 1
fi
echo "PASS: Audio/video detected for $label."
}
echo "=== E2E RTMP Proxy Origin Cluster Test ==="
echo "Workspace: $WORKSPACE"
echo ""
# --- Pre-checks ---
if [[ ! -f "$SOURCE_FLV" ]]; then
echo "Error: test source not found: $SOURCE_FLV" >&2
exit 1
fi
if ! command -v ffmpeg &>/dev/null; then
echo "Error: ffmpeg not found in PATH" >&2
exit 1
fi
if ! command -v ffprobe &>/dev/null; then
echo "Error: ffprobe not found in PATH" >&2
exit 1
fi
if ! command -v curl &>/dev/null; then
echo "Error: curl not found in PATH" >&2
exit 1
fi
# --- Step 0: Clean up stale state ---
# Remove stale SRS PID files that prevent restart.
rm -f "$WORKSPACE/trunk/objs/origin1.pid" "$WORKSPACE/trunk/objs/origin2.pid"
# Kill any leftover processes on our ports (proxy + origins).
ALL_PORTS="$PROXY_RTMP_PORT $PROXY_HTTP_API_PORT $PROXY_HTTP_SERVER_PORT $PROXY_WEBRTC_PORT $PROXY_SRT_PORT $PROXY_SYSTEM_API_PORT"
ALL_PORTS="$ALL_PORTS $ORIGIN1_RTMP_PORT $ORIGIN1_HTTP_PORT $ORIGIN1_API_PORT $ORIGIN1_RTC_PORT $ORIGIN1_SRT_PORT"
ALL_PORTS="$ALL_PORTS $ORIGIN2_RTMP_PORT $ORIGIN2_HTTP_PORT $ORIGIN2_API_PORT $ORIGIN2_RTC_PORT $ORIGIN2_SRT_PORT"
for port in $ALL_PORTS; do
lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true
done
sleep 1
# --- Step 1: Build proxy ---
echo "=== Step 1: Building proxy ==="
cd "$WORKSPACE"
make -s 2>&1
echo "Proxy built: $WORKSPACE/bin/srs-proxy"
# --- Step 2: Build SRS origins (if not already built) ---
if [[ ! -f "$SRS_BINARY" ]]; then
echo "=== Step 2: Building SRS origins ==="
cd "$WORKSPACE/trunk"
./configure && make 2>&1 | tail -3
echo "SRS origins built: $SRS_BINARY"
else
echo "=== Step 2: SRS origins already built ==="
fi
# --- Step 3: Start proxy ---
echo "=== Step 3: Starting proxy (RTMP :$PROXY_RTMP_PORT, System API :$PROXY_SYSTEM_API_PORT) ==="
cd "$WORKSPACE"
env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \
PROXY_HTTP_API=$PROXY_HTTP_API_PORT \
PROXY_HTTP_SERVER=$PROXY_HTTP_SERVER_PORT \
PROXY_WEBRTC_SERVER=$PROXY_WEBRTC_PORT \
PROXY_SRT_SERVER=$PROXY_SRT_PORT \
PROXY_SYSTEM_API=$PROXY_SYSTEM_API_PORT \
PROXY_LOAD_BALANCER_TYPE=memory \
./bin/srs-proxy >/tmp/srs-proxy-cluster-e2e.log 2>&1 &
PROXY_PID=$!
echo "Proxy PID: $PROXY_PID"
wait_for_http "http://127.0.0.1:$PROXY_SYSTEM_API_PORT/api/v1/versions" "Proxy System API"
if ! kill -0 "$PROXY_PID" 2>/dev/null; then
echo "Error: proxy failed to start. Logs:" >&2
cat /tmp/srs-proxy-cluster-e2e.log >&2
exit 1
fi
echo "Proxy started."
# --- Step 4: Start two SRS origins ---
echo "=== Step 4: Starting two SRS origins ==="
ulimit -n 10000 2>/dev/null || true
cd "$WORKSPACE/trunk"
./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin1-cluster-e2e.log 2>&1 &
ORIGIN_PIDS+=($!)
echo "SRS origin1 PID: ${ORIGIN_PIDS[0]}"
./objs/srs -c conf/origin2-for-proxy.conf >/tmp/srs-origin2-cluster-e2e.log 2>&1 &
ORIGIN_PIDS+=($!)
echo "SRS origin2 PID: ${ORIGIN_PIDS[1]}"
wait_for_http "http://127.0.0.1:$ORIGIN1_API_PORT/api/v1/versions" "SRS origin1 HTTP API"
wait_for_http "http://127.0.0.1:$ORIGIN2_API_PORT/api/v1/versions" "SRS origin2 HTTP API"
# Wait for both SRS origins to register with proxy (heartbeat interval is 9s).
echo "Waiting for both SRS origins to register with proxy (up to 20s)..."
for i in $(seq 1 20); do
registered=$(grep -c "Register SRS media server" /tmp/srs-proxy-cluster-e2e.log 2>/dev/null || true)
if [[ $registered -ge 2 ]]; then
echo "Both origins registered."
break
fi
sleep 1
done
registered=$(grep -c "Register SRS media server" /tmp/srs-proxy-cluster-e2e.log 2>/dev/null || true)
if [[ $registered -lt 2 ]]; then
echo "Error: expected two origin registrations, got $registered. Proxy logs:" >&2
cat /tmp/srs-proxy-cluster-e2e.log >&2
exit 1
fi
for pid in "${ORIGIN_PIDS[@]}"; do
if ! kill -0 "$pid" 2>/dev/null; then
echo "Error: SRS origin failed to start. Origin1 logs:" >&2
cat /tmp/srs-origin1-cluster-e2e.log >&2
echo "Origin2 logs:" >&2
cat /tmp/srs-origin2-cluster-e2e.log >&2
exit 1
fi
done
echo "Two SRS origins started and registered."
# --- Step 5: Publish RTMP streams until both origins own at least one stream ---
echo "=== Step 5: Publishing multiple RTMP streams to proxy ==="
STREAM_PREFIX="cluster$(date +%s)"
STREAMS=()
STREAM_ORIGINS=()
origin1_count=0
origin2_count=0
# The memory load balancer picks a random healthy origin for each new stream and
# keeps that stream sticky. Publish several unique streams to verify distribution
# while keeping the test extremely unlikely to fail from random selection alone.
for i in $(seq 1 20); do
stream="${STREAM_PREFIX}_$i"
STREAMS+=("$stream")
ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \
"rtmp://localhost:$PROXY_RTMP_PORT/live/$stream" >/tmp/srs-ffmpeg-cluster-e2e-$i.log 2>&1 &
FFMPEG_PIDS+=($!)
echo "Started publisher for live/$stream, PID: ${FFMPEG_PIDS[$((${#FFMPEG_PIDS[@]} - 1))]}"
sleep 3
if ! kill -0 "${FFMPEG_PIDS[$((${#FFMPEG_PIDS[@]} - 1))]}" 2>/dev/null; then
echo "Error: FFmpeg publisher failed for $stream. Logs:" >&2
cat "/tmp/srs-ffmpeg-cluster-e2e-$i.log" >&2
exit 1
fi
owner=$(detect_origin_for_stream "$stream")
STREAM_ORIGINS+=("$owner")
echo "Stream live/$stream is owned by $owner."
if [[ "$owner" == "origin1" ]]; then
origin1_count=$((origin1_count + 1))
else
origin2_count=$((origin2_count + 1))
fi
if [[ $origin1_count -gt 0 && $origin2_count -gt 0 ]]; then
break
fi
done
if [[ $origin1_count -eq 0 || $origin2_count -eq 0 ]]; then
echo "FAIL: streams were not distributed to both origins." >&2
echo "origin1_count=$origin1_count, origin2_count=$origin2_count" >&2
exit 1
fi
echo "PASS: stream distribution detected: origin1=$origin1_count, origin2=$origin2_count."
# --- Step 6: Verify playback via proxy and direct owning origins ---
echo "=== Step 6: Verifying playback and sticky origin ownership ==="
for i in "${!STREAMS[@]}"; do
stream="${STREAMS[$i]}"
owner="${STREAM_ORIGINS[$i]}"
# Verify proxy playback works for every published stream.
verify_probe_has_av "rtmp://localhost:$PROXY_RTMP_PORT/live/$stream" "proxy live/$stream"
# Verify the stream is on exactly one origin, and verify direct playback from
# the owning origin to prove the proxy really published to that backend.
current_owner=$(detect_origin_for_stream "$stream")
if [[ "$current_owner" != "$owner" ]]; then
echo "FAIL: stream live/$stream moved from $owner to $current_owner; expected sticky routing." >&2
exit 1
fi
if [[ "$owner" == "origin1" ]]; then
verify_probe_has_av "rtmp://localhost:$ORIGIN1_RTMP_PORT/live/$stream" "origin1 live/$stream"
else
verify_probe_has_av "rtmp://localhost:$ORIGIN2_RTMP_PORT/live/$stream" "origin2 live/$stream"
fi
done
echo ""
echo "=== E2E RTMP Proxy Origin Cluster Test PASSED ==="

View File

@ -0,0 +1,324 @@
#!/bin/bash
# E2E test for RTMP proxy Redis load balancer: starts two proxy instances with
# Redis-backed shared state + one SRS origin. The origin registers through proxy A,
# publishing goes through proxy A, and playback goes through proxy B. Playback must
# succeed because proxy B resolves the stream-to-origin mapping from Redis.
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
exit 1
fi
# Ports — use high ports to avoid conflicts with running services.
# Each proxy starts ALL servers, so each proxy needs a unique full port set.
PROXY_A_RTMP_PORT=11935
PROXY_A_HTTP_API_PORT=11985
PROXY_A_HTTP_SERVER_PORT=18080
PROXY_A_WEBRTC_PORT=18000
PROXY_A_SRT_PORT=20080
PROXY_A_SYSTEM_API_PORT=12025
PROXY_B_RTMP_PORT=11936
PROXY_B_HTTP_API_PORT=11986
PROXY_B_HTTP_SERVER_PORT=18081
PROXY_B_WEBRTC_PORT=18001
PROXY_B_SRT_PORT=20081
PROXY_B_SYSTEM_API_PORT=12026
REDIS_HOST="${PROXY_REDIS_HOST:-127.0.0.1}"
REDIS_PORT="${PROXY_REDIS_PORT:-6379}"
REDIS_PASSWORD="${PROXY_REDIS_PASSWORD:-}"
REDIS_DB="${PROXY_REDIS_DB:-0}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv"
SRS_BINARY="$WORKSPACE/trunk/objs/srs"
TEST_STREAM_URL="__defaultVhost__/live/livestream"
# PIDs to clean up on exit.
PROXY_A_PID=""
PROXY_B_PID=""
ORIGIN_PID=""
FFMPEG_PID=""
redis_cli() {
if [[ -n "$REDIS_PASSWORD" ]]; then
redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" -a "$REDIS_PASSWORD" -n "$REDIS_DB" "$@"
else
redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" -n "$REDIS_DB" "$@"
fi
}
cleanup_redis_state() {
# Remove only the Redis records created by this E2E test. Never flush the DB,
# and never delete every srs-proxy-* key because the same Redis DB may be used
# by another proxy/origin-cluster test or by a developer's local proxy.
if ! command -v redis-cli &>/dev/null; then
return
fi
if ! command -v "$PYTHON_BIN" &>/dev/null; then
echo "Skip Redis cleanup: $PYTHON_BIN is not available"
return
fi
if ! redis_cli ping 2>/dev/null | grep -q "PONG"; then
echo "Skip Redis cleanup: Redis is not available at $REDIS_HOST:$REDIS_PORT db=$REDIS_DB"
return
fi
local count=0
local stream_key="srs-proxy-url:$TEST_STREAM_URL"
if [[ "$(redis_cli exists "$stream_key" 2>/dev/null || echo 0)" != "0" ]]; then
redis_cli del "$stream_key" >/dev/null 2>&1 || true
count=$((count + 1))
fi
# The origin server generates its server key from runtime IDs, so discover only
# server records that match this test origin's identity and configured ports.
local server_keys=()
local key value
while IFS= read -r key; do
[[ -z "$key" ]] && continue
value="$(redis_cli get "$key" 2>/dev/null || true)"
if [[ "$value" == *'"device_id":"origin1"'* && \
"$value" == *'"rtmp":["19351"]'* && \
"$value" == *'"http":["8081"]'* && \
"$value" == *'"api":["19851"]'* ]]; then
server_keys+=("$key")
redis_cli del "$key" >/dev/null 2>&1 || true
count=$((count + 1))
fi
done < <(redis_cli --scan --pattern 'srs-proxy-server:*' 2>/dev/null || true)
# Keep the shared server index, but remove only the test origin server keys.
if [[ ${#server_keys[@]} -gt 0 ]]; then
local servers_json updated_json
servers_json="$(redis_cli get srs-proxy-all-servers 2>/dev/null || true)"
if [[ -n "$servers_json" ]]; then
updated_json="$($PYTHON_BIN - "$servers_json" "${server_keys[@]}" <<'PY'
import json, sys
servers = json.loads(sys.argv[1]) if sys.argv[1] else []
remove = set(sys.argv[2:])
servers = [server for server in servers if server not in remove]
print(json.dumps(servers, separators=(",", ":")))
PY
)"
if [[ "$updated_json" == "[]" ]]; then
redis_cli del srs-proxy-all-servers >/dev/null 2>&1 || true
else
redis_cli set srs-proxy-all-servers "$updated_json" >/dev/null 2>&1 || true
fi
fi
fi
echo "Cleaned $count Redis proxy test key(s)."
}
cleanup() {
echo ""
echo "=== Cleaning up ==="
for pid in $PROXY_A_PID $PROXY_B_PID $ORIGIN_PID $FFMPEG_PID; do
if kill -0 "$pid" 2>/dev/null; then
kill "$pid" 2>/dev/null || true
fi
done
sleep 1
for pid in $PROXY_A_PID $PROXY_B_PID $ORIGIN_PID $FFMPEG_PID; do
if kill -0 "$pid" 2>/dev/null; then
kill -9 "$pid" 2>/dev/null || true
fi
done
cleanup_redis_state
echo "Cleanup done."
}
trap cleanup EXIT
echo "=== E2E RTMP Proxy Redis Load Balancer Test ==="
echo "Workspace: $WORKSPACE"
echo "Redis: $REDIS_HOST:$REDIS_PORT db=$REDIS_DB"
echo ""
# --- Pre-checks ---
if [[ ! -f "$SOURCE_FLV" ]]; then
echo "Error: test source not found: $SOURCE_FLV" >&2
exit 1
fi
if ! command -v ffmpeg &>/dev/null; then
echo "Error: ffmpeg not found in PATH" >&2
exit 1
fi
if ! command -v ffprobe &>/dev/null; then
echo "Error: ffprobe not found in PATH" >&2
exit 1
fi
if ! command -v redis-cli &>/dev/null; then
echo "Error: redis-cli not found in PATH" >&2
echo "Install Redis on macOS with: brew install redis" >&2
exit 1
fi
if ! redis_cli ping 2>/dev/null | grep -q "PONG"; then
echo "Error: Redis is not available at $REDIS_HOST:$REDIS_PORT db=$REDIS_DB" >&2
echo "Start Redis on macOS with: brew services start redis" >&2
echo "Or run a foreground Redis with: redis-server" >&2
exit 1
fi
# Origin ports (from origin1-for-proxy.conf).
ORIGIN_RTMP_PORT=19351
ORIGIN_HTTP_PORT=8081
ORIGIN_API_PORT=19851
ORIGIN_RTC_PORT=8001
ORIGIN_SRT_PORT=10081
# --- Step 0: Clean up stale state ---
# Remove stale SRS PID file that prevents restart.
rm -f "$WORKSPACE/trunk/objs/origin1.pid"
cleanup_redis_state
# Kill any leftover processes on our ports (proxy A + proxy B + origin).
ALL_PORTS="$PROXY_A_RTMP_PORT $PROXY_A_HTTP_API_PORT $PROXY_A_HTTP_SERVER_PORT $PROXY_A_WEBRTC_PORT $PROXY_A_SRT_PORT $PROXY_A_SYSTEM_API_PORT $PROXY_B_RTMP_PORT $PROXY_B_HTTP_API_PORT $PROXY_B_HTTP_SERVER_PORT $PROXY_B_WEBRTC_PORT $PROXY_B_SRT_PORT $PROXY_B_SYSTEM_API_PORT $ORIGIN_RTMP_PORT $ORIGIN_HTTP_PORT $ORIGIN_API_PORT $ORIGIN_RTC_PORT $ORIGIN_SRT_PORT"
for port in $ALL_PORTS; do
lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true
done
sleep 1
# --- Step 1: Build proxy ---
echo "=== Step 1: Building proxy ==="
cd "$WORKSPACE"
make -s 2>&1
echo "Proxy built: $WORKSPACE/bin/srs-proxy"
# --- Step 2: Build SRS origin (if not already built) ---
if [[ ! -f "$SRS_BINARY" ]]; then
echo "=== Step 2: Building SRS origin ==="
cd "$WORKSPACE/trunk"
./configure && make 2>&1 | tail -3
echo "SRS origin built: $SRS_BINARY"
else
echo "=== Step 2: SRS origin already built ==="
fi
# --- Step 3: Start proxy A ---
echo "=== Step 3: Starting proxy A (RTMP :$PROXY_A_RTMP_PORT, System API :$PROXY_A_SYSTEM_API_PORT) ==="
cd "$WORKSPACE"
env PROXY_RTMP_SERVER=$PROXY_A_RTMP_PORT \
PROXY_HTTP_API=$PROXY_A_HTTP_API_PORT \
PROXY_HTTP_SERVER=$PROXY_A_HTTP_SERVER_PORT \
PROXY_WEBRTC_SERVER=$PROXY_A_WEBRTC_PORT \
PROXY_SRT_SERVER=$PROXY_A_SRT_PORT \
PROXY_SYSTEM_API=$PROXY_A_SYSTEM_API_PORT \
PROXY_LOAD_BALANCER_TYPE=redis \
PROXY_REDIS_HOST="$REDIS_HOST" \
PROXY_REDIS_PORT="$REDIS_PORT" \
PROXY_REDIS_PASSWORD="$REDIS_PASSWORD" \
PROXY_REDIS_DB="$REDIS_DB" \
./bin/srs-proxy >/tmp/srs-proxy-redis-a-e2e.log 2>&1 &
PROXY_A_PID=$!
echo "Proxy A PID: $PROXY_A_PID"
sleep 1
if ! kill -0 "$PROXY_A_PID" 2>/dev/null; then
echo "Error: proxy A failed to start. Logs:" >&2
cat /tmp/srs-proxy-redis-a-e2e.log >&2
exit 1
fi
echo "Proxy A started."
# --- Step 4: Start proxy B ---
echo "=== Step 4: Starting proxy B (RTMP :$PROXY_B_RTMP_PORT, System API :$PROXY_B_SYSTEM_API_PORT) ==="
cd "$WORKSPACE"
env PROXY_RTMP_SERVER=$PROXY_B_RTMP_PORT \
PROXY_HTTP_API=$PROXY_B_HTTP_API_PORT \
PROXY_HTTP_SERVER=$PROXY_B_HTTP_SERVER_PORT \
PROXY_WEBRTC_SERVER=$PROXY_B_WEBRTC_PORT \
PROXY_SRT_SERVER=$PROXY_B_SRT_PORT \
PROXY_SYSTEM_API=$PROXY_B_SYSTEM_API_PORT \
PROXY_LOAD_BALANCER_TYPE=redis \
PROXY_REDIS_HOST="$REDIS_HOST" \
PROXY_REDIS_PORT="$REDIS_PORT" \
PROXY_REDIS_PASSWORD="$REDIS_PASSWORD" \
PROXY_REDIS_DB="$REDIS_DB" \
./bin/srs-proxy >/tmp/srs-proxy-redis-b-e2e.log 2>&1 &
PROXY_B_PID=$!
echo "Proxy B PID: $PROXY_B_PID"
sleep 1
if ! kill -0 "$PROXY_B_PID" 2>/dev/null; then
echo "Error: proxy B failed to start. Logs:" >&2
cat /tmp/srs-proxy-redis-b-e2e.log >&2
exit 1
fi
echo "Proxy B started."
# --- Step 5: Start SRS origin ---
echo "=== Step 5: Starting SRS origin ==="
ulimit -n 10000 2>/dev/null || true
cd "$WORKSPACE/trunk"
./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-redis-e2e.log 2>&1 &
ORIGIN_PID=$!
echo "SRS origin PID: $ORIGIN_PID"
# Wait for SRS to start and register with proxy A (heartbeat interval is 9s).
echo "Waiting for SRS origin to register with proxy A and Redis (up to 15s)..."
sleep 12
if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then
echo "Error: SRS origin failed to start. Logs:" >&2
cat /tmp/srs-origin-redis-e2e.log >&2
exit 1
fi
if ! redis_cli --scan --pattern 'srs-proxy-server:*' | grep -q 'srs-proxy-server:'; then
echo "Error: SRS origin did not register in Redis. Proxy A logs:" >&2
cat /tmp/srs-proxy-redis-a-e2e.log >&2
exit 1
fi
echo "SRS origin started and registered in Redis."
# --- Step 6: Publish RTMP stream to proxy A ---
echo "=== Step 6: Publishing RTMP stream to proxy A ==="
ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \
"rtmp://localhost:$PROXY_A_RTMP_PORT/live/livestream" >/tmp/srs-ffmpeg-redis-e2e.log 2>&1 &
FFMPEG_PID=$!
echo "FFmpeg publisher PID: $FFMPEG_PID"
# Wait for stream to stabilize.
sleep 5
if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then
echo "Error: FFmpeg publisher failed. Logs:" >&2
cat /tmp/srs-ffmpeg-redis-e2e.log >&2
exit 1
fi
echo "Stream publishing through proxy A."
# --- Step 7: Verify RTMP playback through proxy B ---
echo "=== Step 7: Verifying RTMP playback through proxy B ==="
PROBE_OUTPUT=$(ffprobe -v error -show_streams \
"rtmp://localhost:$PROXY_B_RTMP_PORT/live/livestream" 2>&1 || true)
if echo "$PROBE_OUTPUT" | grep -q "codec_type=video"; then
echo "PASS: Video stream detected through proxy B."
else
echo "FAIL: No video stream detected through proxy B." >&2
echo "ffprobe output:" >&2
echo "$PROBE_OUTPUT" >&2
echo "Proxy B logs:" >&2
cat /tmp/srs-proxy-redis-b-e2e.log >&2
exit 1
fi
if echo "$PROBE_OUTPUT" | grep -q "codec_type=audio"; then
echo "PASS: Audio stream detected through proxy B."
else
echo "FAIL: No audio stream detected through proxy B." >&2
echo "ffprobe output:" >&2
echo "$PROBE_OUTPUT" >&2
echo "Proxy B logs:" >&2
cat /tmp/srs-proxy-redis-b-e2e.log >&2
exit 1
fi
echo ""
echo "=== E2E RTMP Proxy Redis Load Balancer Test PASSED ==="

View File

@ -0,0 +1,261 @@
#!/bin/bash
# E2E test for RTMP-to-multiple-protocol transmuxing through the proxy:
# starts one proxy with memory load balancer + one SRS origin, publishes one
# RTMP stream, then verifies RTMP, HTTP-FLV, HLS, and optional WebRTC WHEP
# playback through the proxy.
set -e
SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)"
# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs
WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)"
if [[ ! -f "$WORKSPACE/go.mod" ]]; then
echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2
exit 1
fi
# Ports — use the same high ports as proxy-e2e-test.sh.
# The proxy starts ALL servers, so we must assign unique ports for each.
PROXY_RTMP_PORT=11935
PROXY_HTTP_API_PORT=11985
PROXY_HTTP_SERVER_PORT=18080
PROXY_WEBRTC_PORT=18000
PROXY_SRT_PORT=20080
PROXY_SYSTEM_API_PORT=12025
# Origin ports (from origin1-for-proxy.conf).
ORIGIN_RTMP_PORT=19351
ORIGIN_HTTP_PORT=8081
ORIGIN_API_PORT=19851
ORIGIN_RTC_PORT=8001
ORIGIN_SRT_PORT=10081
SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv"
SRS_BINARY="$WORKSPACE/trunk/objs/srs"
SRS_TEST_BINARY="$WORKSPACE/trunk/3rdparty/srs-bench/objs/srs_test"
STREAM_URL="live/livestream"
# WebRTC requires the srs-bench regression test binary. Keep it enabled by
# default because it exercises the proxy WHEP API path; set to "off" to skip.
PROXY_TRANSMUX_TEST_RTC="${PROXY_TRANSMUX_TEST_RTC:-on}"
# PIDs to clean up on exit.
PROXY_PID=""
ORIGIN_PID=""
FFMPEG_PID=""
cleanup() {
echo ""
echo "=== Cleaning up ==="
for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do
if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then
kill "$pid" 2>/dev/null || true
fi
done
sleep 1
for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do
if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then
kill -9 "$pid" 2>/dev/null || true
fi
done
echo "Cleanup done."
}
trap cleanup EXIT
probe_has_audio_video() {
local name="$1"
local url="$2"
echo "Verifying $name playback: $url"
local output
output=$(ffprobe -v error -show_streams "$url" 2>&1 || true)
if echo "$output" | grep -q "codec_type=video"; then
echo "PASS: $name video stream detected."
else
echo "FAIL: $name no video stream detected." >&2
echo "ffprobe output:" >&2
echo "$output" >&2
exit 1
fi
if echo "$output" | grep -q "codec_type=audio"; then
echo "PASS: $name audio stream detected."
else
echo "FAIL: $name no audio stream detected." >&2
echo "ffprobe output:" >&2
echo "$output" >&2
exit 1
fi
}
wait_for_hls_playlist() {
local url="$1"
local deadline=45
echo "Waiting for HLS playlist to be generated (up to ${deadline}s): $url"
for ((i = 1; i <= deadline; i++)); do
if curl -fsS "$url" 2>/dev/null | grep -q "#EXTM3U"; then
echo "HLS playlist is ready."
return
fi
sleep 1
done
echo "FAIL: HLS playlist was not generated in ${deadline}s." >&2
echo "Last HLS response:" >&2
curl -v "$url" 2>&1 || true
exit 1
}
echo "=== E2E RTMP Transmux Proxy Test ==="
echo "Workspace: $WORKSPACE"
echo "Stream: $STREAM_URL"
echo ""
# --- Pre-checks ---
if [[ ! -f "$SOURCE_FLV" ]]; then
echo "Error: test source not found: $SOURCE_FLV" >&2
exit 1
fi
if ! command -v ffmpeg &>/dev/null; then
echo "Error: ffmpeg not found in PATH" >&2
exit 1
fi
if ! command -v ffprobe &>/dev/null; then
echo "Error: ffprobe not found in PATH" >&2
exit 1
fi
if ! command -v curl &>/dev/null; then
echo "Error: curl not found in PATH" >&2
exit 1
fi
# --- Step 0: Clean up stale state ---
rm -f "$WORKSPACE/trunk/objs/origin1.pid"
ALL_PORTS="$PROXY_RTMP_PORT $PROXY_HTTP_API_PORT $PROXY_HTTP_SERVER_PORT $PROXY_WEBRTC_PORT $PROXY_SRT_PORT $PROXY_SYSTEM_API_PORT $ORIGIN_RTMP_PORT $ORIGIN_HTTP_PORT $ORIGIN_API_PORT $ORIGIN_RTC_PORT $ORIGIN_SRT_PORT"
for port in $ALL_PORTS; do
lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true
done
sleep 1
# --- Step 1: Build proxy ---
echo "=== Step 1: Building proxy ==="
cd "$WORKSPACE"
make -s 2>&1
echo "Proxy built: $WORKSPACE/bin/srs-proxy"
# --- Step 2: Build SRS origin (if not already built) ---
if [[ ! -f "$SRS_BINARY" ]]; then
echo "=== Step 2: Building SRS origin ==="
cd "$WORKSPACE/trunk"
./configure && make 2>&1 | tail -3
echo "SRS origin built: $SRS_BINARY"
else
echo "=== Step 2: SRS origin already built ==="
fi
# --- Step 3: Build WebRTC regression tool (if enabled and needed) ---
if [[ "$PROXY_TRANSMUX_TEST_RTC" == "on" && ! -x "$SRS_TEST_BINARY" ]]; then
echo "=== Step 3: Building WebRTC regression tool ==="
cd "$WORKSPACE/trunk/3rdparty/srs-bench"
make ./objs/srs_test
echo "WebRTC regression tool built: $SRS_TEST_BINARY"
else
echo "=== Step 3: WebRTC regression tool build skipped ==="
fi
# --- Step 4: Start proxy ---
echo "=== Step 4: Starting proxy (memory LB) ==="
cd "$WORKSPACE"
env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \
PROXY_HTTP_API=$PROXY_HTTP_API_PORT \
PROXY_HTTP_SERVER=$PROXY_HTTP_SERVER_PORT \
PROXY_WEBRTC_SERVER=$PROXY_WEBRTC_PORT \
PROXY_SRT_SERVER=$PROXY_SRT_PORT \
PROXY_SYSTEM_API=$PROXY_SYSTEM_API_PORT \
PROXY_LOAD_BALANCER_TYPE=memory \
./bin/srs-proxy >/tmp/srs-proxy-transmux-e2e.log 2>&1 &
PROXY_PID=$!
echo "Proxy PID: $PROXY_PID"
sleep 1
if ! kill -0 "$PROXY_PID" 2>/dev/null; then
echo "Error: proxy failed to start. Logs:" >&2
cat /tmp/srs-proxy-transmux-e2e.log >&2
exit 1
fi
echo "Proxy started."
# --- Step 5: Start SRS origin ---
echo "=== Step 5: Starting SRS origin ==="
ulimit -n 10000 2>/dev/null || true
cd "$WORKSPACE/trunk"
./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-transmux-e2e.log 2>&1 &
ORIGIN_PID=$!
echo "SRS origin PID: $ORIGIN_PID"
# Wait for SRS to start and register with proxy (heartbeat interval is 9s).
echo "Waiting for SRS origin to register with proxy (up to 15s)..."
sleep 12
if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then
echo "Error: SRS origin failed to start. Logs:" >&2
cat /tmp/srs-origin-transmux-e2e.log >&2
exit 1
fi
echo "SRS origin started and registered."
# --- Step 6: Publish RTMP stream ---
echo "=== Step 6: Publishing RTMP stream to proxy ==="
ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \
"rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" >/tmp/srs-ffmpeg-transmux-e2e.log 2>&1 &
FFMPEG_PID=$!
echo "FFmpeg publisher PID: $FFMPEG_PID"
# Wait for stream to stabilize and for the origin to start muxing HTTP/HLS.
sleep 5
if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then
echo "Error: FFmpeg publisher failed. Logs:" >&2
cat /tmp/srs-ffmpeg-transmux-e2e.log >&2
exit 1
fi
echo "Stream publishing."
# --- Step 7: Verify RTMP playback ---
echo "=== Step 7: Verifying RTMP playback via proxy ==="
probe_has_audio_video "RTMP" "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL"
# --- Step 8: Verify HTTP-FLV playback ---
echo "=== Step 8: Verifying HTTP-FLV playback via proxy ==="
probe_has_audio_video "HTTP-FLV" "http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.flv"
# --- Step 9: Verify HLS playback ---
echo "=== Step 9: Verifying HLS playback via proxy ==="
HLS_URL="http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.m3u8"
wait_for_hls_playlist "$HLS_URL"
probe_has_audio_video "HLS" "$HLS_URL"
# --- Step 10: Verify WebRTC WHEP signaling via proxy ---
echo "=== Step 10: Verifying WebRTC WHEP signaling via proxy ==="
if [[ "$PROXY_TRANSMUX_TEST_RTC" == "on" ]]; then
if [[ ! -x "$SRS_TEST_BINARY" ]]; then
echo "FAIL: WebRTC regression tool not found: $SRS_TEST_BINARY" >&2
exit 1
fi
cd "$WORKSPACE/trunk/3rdparty/srs-bench"
"$SRS_TEST_BINARY" \
-test.run '^TestBugfix2371_RTMP2RTC_PlayWithNack$' \
-srs-server "127.0.0.1:$PROXY_HTTP_API_PORT" \
-srs-stream "/$STREAM_URL" \
-srs-timeout 10000
echo "PASS: WebRTC WHEP signaling succeeded."
else
echo "SKIP: WebRTC WHEP test disabled by PROXY_TRANSMUX_TEST_RTC=$PROXY_TRANSMUX_TEST_RTC."
fi
echo ""
echo "NOTE: RTSP is not tested here because the Go proxy currently has no RTSP listener."
echo "=== E2E RTMP Transmux Proxy Test PASSED ==="

View File

@ -12,7 +12,7 @@ import (
"srsx/internal/errors"
"srsx/internal/lb"
"srsx/internal/logger"
"srsx/internal/protocol"
"srsx/internal/server"
"srsx/internal/signal"
"srsx/internal/version"
)
@ -99,46 +99,46 @@ func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment
// startServers initializes and starts all protocol servers.
func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) error {
// Start the RTMP server.
srsRTMPServer := protocol.NewSRSRTMPServer(environment)
if err := srsRTMPServer.Run(ctx); err != nil {
rtmpServer := server.NewRTMPServer(environment)
if err := rtmpServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtmp server")
}
defer srsRTMPServer.Close()
defer rtmpServer.Close()
// Start the WebRTC server.
srsWebRTCServer := protocol.NewSRSWebRTCServer(environment)
if err := srsWebRTCServer.Run(ctx); err != nil {
webRTCServer := server.NewWebRTCServer(environment)
if err := webRTCServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtc server")
}
defer srsWebRTCServer.Close()
defer webRTCServer.Close()
// Start the HTTP API server.
srsHTTPAPIServer := protocol.NewSRSHTTPAPIServer(environment, gracefulQuitTimeout, srsWebRTCServer)
if err := srsHTTPAPIServer.Run(ctx); err != nil {
httpAPIServer := server.NewHTTPAPIServer(environment, gracefulQuitTimeout, webRTCServer)
if err := httpAPIServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http api server")
}
defer srsHTTPAPIServer.Close()
defer httpAPIServer.Close()
// Start the SRT server.
srsSRTServer := protocol.NewSRSSRTServer(environment)
srsSRTServer := server.NewSRSSRTServer(environment)
if err := srsSRTServer.Run(ctx); err != nil {
return errors.Wrapf(err, "srt server")
}
defer srsSRTServer.Close()
// Start the System API server.
systemAPI := protocol.NewSystemAPI(environment, gracefulQuitTimeout)
systemAPI := server.NewSystemAPI(environment, gracefulQuitTimeout)
if err := systemAPI.Run(ctx); err != nil {
return errors.Wrapf(err, "system api server")
}
defer systemAPI.Close()
// Start the HTTP web server.
srsHTTPStreamServer := protocol.NewSRSHTTPStreamServer(environment, gracefulQuitTimeout)
if err := srsHTTPStreamServer.Run(ctx); err != nil {
httpStreamServer := server.NewHTTPStreamServer(environment, gracefulQuitTimeout)
if err := httpStreamServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http server")
}
defer srsHTTPStreamServer.Close()
defer httpStreamServer.Close()
// Wait for the main loop to quit.
<-ctx.Done()

View File

@ -95,7 +95,7 @@ var createBuffer = func() amf0Buffer {
}
// All AMF0 things.
type amf0Any interface {
type Amf0Any interface {
// Binary marshaler and unmarshaler.
encoding.BinaryUnmarshaler
encoding.BinaryMarshaler
@ -106,59 +106,83 @@ type amf0Any interface {
amf0Marker() amf0Marker
}
type amf0Converter struct {
from amf0Any
type Amf0Converter interface {
ToNumber() Amf0Number
ToBoolean() Amf0Boolean
ToString() Amf0String
ToObject() Amf0Object
ToNull() Amf0Null
ToUndefined() Amf0Undefined
ToEcmaArray() Amf0EcmaArray
ToStrictArray() Amf0StrictArray
}
func NewAmf0Converter(from amf0Any) *amf0Converter {
type amf0Converter struct {
from Amf0Any
}
func NewAmf0Converter(from Amf0Any) Amf0Converter {
return &amf0Converter{from: from}
}
func (v *amf0Converter) ToNumber() *amf0Number {
return amf0AnyTo[*amf0Number](v.from)
}
func (v *amf0Converter) ToBoolean() *amf0Boolean {
return amf0AnyTo[*amf0Boolean](v.from)
}
func (v *amf0Converter) ToString() *amf0String {
return amf0AnyTo[*amf0String](v.from)
}
func (v *amf0Converter) ToObject() *amf0Object {
return amf0AnyTo[*amf0Object](v.from)
}
func (v *amf0Converter) ToNull() *amf0Null {
return amf0AnyTo[*amf0Null](v.from)
}
func (v *amf0Converter) ToUndefined() *amf0Undefined {
return amf0AnyTo[*amf0Undefined](v.from)
}
func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray {
return amf0AnyTo[*amf0EcmaArray](v.from)
}
func (v *amf0Converter) ToStrictArray() *amf0StrictArray {
return amf0AnyTo[*amf0StrictArray](v.from)
}
// Convert any to specified object.
func amf0AnyTo[T amf0Any](a amf0Any) T {
var to T
if a != nil {
if v, ok := a.(T); ok {
return v
func (v *amf0Converter) ToNumber() Amf0Number {
if r, ok := v.from.(Amf0Number); ok {
return r
}
return nil
}
func (v *amf0Converter) ToBoolean() Amf0Boolean {
if r, ok := v.from.(Amf0Boolean); ok {
return r
}
return to
return nil
}
func (v *amf0Converter) ToString() Amf0String {
if r, ok := v.from.(Amf0String); ok {
return r
}
return nil
}
func (v *amf0Converter) ToObject() Amf0Object {
if r, ok := v.from.(Amf0Object); ok {
return r
}
return nil
}
func (v *amf0Converter) ToNull() Amf0Null {
if r, ok := v.from.(Amf0Null); ok {
return r
}
return nil
}
func (v *amf0Converter) ToUndefined() Amf0Undefined {
if r, ok := v.from.(Amf0Undefined); ok {
return r
}
return nil
}
func (v *amf0Converter) ToEcmaArray() Amf0EcmaArray {
if r, ok := v.from.(Amf0EcmaArray); ok {
return r
}
return nil
}
func (v *amf0Converter) ToStrictArray() Amf0StrictArray {
if r, ok := v.from.(Amf0StrictArray); ok {
return r
}
return nil
}
// Discovery the amf0 object from the bytes b.
func Amf0Discovery(p []byte) (a amf0Any, err error) {
func Amf0Discovery(p []byte) (a Amf0Any, err error) {
if len(p) < 1 {
return nil, errors.Errorf("require 1 bytes only %v", len(p))
}
@ -228,14 +252,24 @@ func (v *amf0UTF8) MarshalBinary() (data []byte, err error) {
return
}
// Amf0Number is the AMF0 number type.
type Amf0Number interface {
Amf0Any
Float64() float64
}
// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type
type amf0Number float64
func NewAmf0Number(f float64) *amf0Number {
func NewAmf0Number(f float64) Amf0Number {
v := amf0Number(f)
return &v
}
func (v *amf0Number) Float64() float64 {
return float64(*v)
}
func (v *amf0Number) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
@ -266,14 +300,28 @@ func (v *amf0Number) MarshalBinary() (data []byte, err error) {
return
}
// Amf0String is the AMF0 string type.
type Amf0String interface {
Amf0Any
String() string
}
// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type
type amf0String string
func NewAmf0String(s string) *amf0String {
func NewAmf0String(s string) Amf0String {
return newAmf0String(s)
}
func newAmf0String(s string) *amf0String {
v := amf0String(s)
return &v
}
func (v *amf0String) String() string {
return string(*v)
}
func (v *amf0String) amf0Marker() amf0Marker {
return amf0MarkerString
}
@ -344,7 +392,7 @@ func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) {
// Use array for object and ecma array, to keep the original order.
type amf0Property struct {
key amf0UTF8
value amf0Any
value Amf0Any
}
// The object-like AMF0 structure, like object and ecma array and strict array.
@ -367,7 +415,7 @@ func (v *amf0ObjectBase) Size() int {
return size
}
func (v *amf0ObjectBase) Get(key string) amf0Any {
func (v *amf0ObjectBase) Get(key string) Amf0Any {
v.lock.Lock()
defer v.lock.Unlock()
@ -380,7 +428,7 @@ func (v *amf0ObjectBase) Get(key string) amf0Any {
return nil
}
func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase {
func (v *amf0ObjectBase) Set(key string, value Amf0Any) *amf0ObjectBase {
v.lock.Lock()
defer v.lock.Unlock()
@ -411,21 +459,21 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error)
return errors.Errorf("maxElems=%v with eof", maxElems)
}
readOne := func() (amf0UTF8, amf0Any, error) {
readOne := func() (amf0UTF8, Amf0Any, error) {
var u amf0UTF8
if err = u.UnmarshalBinary(p); err != nil {
return "", nil, errors.WithMessage(err, "prop name")
}
p = p[u.Size():]
var a amf0Any
var a Amf0Any
if a, err = Amf0Discovery(p); err != nil {
return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u)))
}
return u, a, nil
}
pushOne := func(u amf0UTF8, a amf0Any) error {
pushOne := func(u amf0UTF8, a Amf0Any) error {
// For object property, consume the whole bytes.
if err = a.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u)))
@ -494,13 +542,24 @@ func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) {
return
}
// Amf0Object is the AMF0 object type.
type Amf0Object interface {
Amf0Any
Get(key string) Amf0Any
Set(key string, value Amf0Any) Amf0Object
}
// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type
type amf0Object struct {
amf0ObjectBase
eof amf0ObjectEOF
}
func NewAmf0Object() *amf0Object {
func NewAmf0Object() Amf0Object {
return newAmf0Object()
}
func newAmf0Object() *amf0Object {
v := &amf0Object{}
v.properties = []*amf0Property{}
return v
@ -510,6 +569,15 @@ func (v *amf0Object) amf0Marker() amf0Marker {
return amf0MarkerObject
}
func (v *amf0Object) Get(key string) Amf0Any {
return v.amf0ObjectBase.Get(key)
}
func (v *amf0Object) Set(key string, value Amf0Any) Amf0Object {
v.amf0ObjectBase.Set(key, value)
return v
}
func (v *amf0Object) Size() int {
return int(1) + v.eof.Size() + v.amf0ObjectBase.Size()
}
@ -542,17 +610,22 @@ func (v *amf0Object) MarshalBinary() (data []byte, err error) {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
if pb, err := v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
} else if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// Amf0EcmaArray is the AMF0 ECMA array type.
type Amf0EcmaArray interface {
Amf0Any
Get(key string) Amf0Any
Set(key string, value Amf0Any) Amf0EcmaArray
}
// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type
type amf0EcmaArray struct {
amf0ObjectBase
@ -560,7 +633,11 @@ type amf0EcmaArray struct {
eof amf0ObjectEOF
}
func NewAmf0EcmaArray() *amf0EcmaArray {
func NewAmf0EcmaArray() Amf0EcmaArray {
return newAmf0EcmaArray()
}
func newAmf0EcmaArray() *amf0EcmaArray {
v := &amf0EcmaArray{}
v.properties = []*amf0Property{}
return v
@ -570,6 +647,15 @@ func (v *amf0EcmaArray) amf0Marker() amf0Marker {
return amf0MarkerEcmaArray
}
func (v *amf0EcmaArray) Get(key string) Amf0Any {
return v.amf0ObjectBase.Get(key)
}
func (v *amf0EcmaArray) Set(key string, value Amf0Any) Amf0EcmaArray {
v.amf0ObjectBase.Set(key, value)
return v
}
func (v *amf0EcmaArray) Size() int {
return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size()
}
@ -606,24 +692,29 @@ func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
if pb, err := v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
} else if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// Amf0StrictArray is the AMF0 strict array type.
type Amf0StrictArray interface {
Amf0Any
Get(key string) Amf0Any
Set(key string, value Amf0Any) Amf0StrictArray
}
// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type
type amf0StrictArray struct {
amf0ObjectBase
count uint32
}
func NewAmf0StrictArray() *amf0StrictArray {
func NewAmf0StrictArray() Amf0StrictArray {
v := &amf0StrictArray{}
v.properties = []*amf0Property{}
return v
@ -633,6 +724,15 @@ func (v *amf0StrictArray) amf0Marker() amf0Marker {
return amf0MarkerStrictArray
}
func (v *amf0StrictArray) Get(key string) Amf0Any {
return v.amf0ObjectBase.Get(key)
}
func (v *amf0StrictArray) Set(key string, value Amf0Any) Amf0StrictArray {
v.amf0ObjectBase.Set(key, value)
return v
}
func (v *amf0StrictArray) Size() int {
return int(1) + 4 + v.amf0ObjectBase.Size()
}
@ -708,36 +808,56 @@ func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) {
return []byte{byte(v.target)}, nil
}
// Amf0Null is the AMF0 null type.
type Amf0Null interface {
Amf0Any
}
// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type
type amf0Null struct {
amf0SingleMarkerObject
}
func NewAmf0Null() *amf0Null {
func NewAmf0Null() Amf0Null {
v := amf0Null{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull)
return &v
}
// Amf0Undefined is the AMF0 undefined type.
type Amf0Undefined interface {
Amf0Any
}
// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type
type amf0Undefined struct {
amf0SingleMarkerObject
}
func NewAmf0Undefined() amf0Any {
func NewAmf0Undefined() Amf0Undefined {
v := amf0Undefined{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined)
return &v
}
// Amf0Boolean is the public typed view of an AMF0 boolean.
type Amf0Boolean interface {
Amf0Any
Bool() bool
}
// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type
type amf0Boolean bool
func NewAmf0Boolean(b bool) amf0Any {
func NewAmf0Boolean(b bool) Amf0Boolean {
v := amf0Boolean(b)
return &v
}
func (v *amf0Boolean) Bool() bool {
return bool(*v)
}
func (v *amf0Boolean) amf0Marker() amf0Marker {
return amf0MarkerBoolean
}

509
internal/rtmp/amf0_test.go Normal file
View File

@ -0,0 +1,509 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
import (
"bytes"
"fmt"
"math"
"strings"
"testing"
)
func TestAmf0MarkerString(t *testing.T) {
for _, tt := range []struct {
marker amf0Marker
want string
}{
{amf0MarkerNumber, "Amf0Number"},
{amf0MarkerBoolean, "amf0Boolean"},
{amf0MarkerString, "Amf0String"},
{amf0MarkerObject, "Amf0Object"},
{amf0MarkerMovieClip, "MovieClip"},
{amf0MarkerNull, "Null"},
{amf0MarkerUndefined, "Undefined"},
{amf0MarkerReference, "Reference"},
{amf0MarkerEcmaArray, "EcmaArray"},
{amf0MarkerObjectEnd, "ObjectEnd"},
{amf0MarkerStrictArray, "StrictArray"},
{amf0MarkerDate, "Date"},
{amf0MarkerLongString, "LongString"},
{amf0MarkerUnsupported, "Unsupported"},
{amf0MarkerRecordSet, "RecordSet"},
{amf0MarkerXmlDocument, "XmlDocument"},
{amf0MarkerTypedObject, "TypedObject"},
{amf0MarkerAvmPlusObject, "AvmPlusObject"},
{amf0MarkerForbidden, "Forbidden"},
{amf0Marker(0xee), "Forbidden"},
} {
if got := tt.marker.String(); got != tt.want {
t.Fatalf("marker=%#x String()=%v, want %v", byte(tt.marker), got, tt.want)
}
}
}
func TestAmf0Discovery(t *testing.T) {
for _, tt := range []struct {
name string
data []byte
ok func(Amf0Any) bool
}{
{"number", []byte{byte(amf0MarkerNumber)}, func(v Amf0Any) bool { _, ok := v.(Amf0Number); return ok }},
{"boolean", []byte{byte(amf0MarkerBoolean)}, func(v Amf0Any) bool { _, ok := v.(Amf0Boolean); return ok }},
{"string", []byte{byte(amf0MarkerString)}, func(v Amf0Any) bool { _, ok := v.(Amf0String); return ok }},
{"object", []byte{byte(amf0MarkerObject)}, func(v Amf0Any) bool { _, ok := v.(Amf0Object); return ok }},
{"null", []byte{byte(amf0MarkerNull)}, func(v Amf0Any) bool { _, ok := v.(Amf0Null); return ok }},
{"undefined", []byte{byte(amf0MarkerUndefined)}, func(v Amf0Any) bool { _, ok := v.(Amf0Undefined); return ok }},
{"ecma-array", []byte{byte(amf0MarkerEcmaArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0EcmaArray); return ok }},
{"object-end", []byte{byte(amf0MarkerObjectEnd)}, func(v Amf0Any) bool { _, ok := v.(*amf0ObjectEOF); return ok }},
{"strict-array", []byte{byte(amf0MarkerStrictArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0StrictArray); return ok }},
} {
t.Run(tt.name, func(t *testing.T) {
value, err := Amf0Discovery(tt.data)
if err != nil {
t.Fatalf("Amf0Discovery() err=%v", err)
}
if !tt.ok(value) {
t.Fatalf("Amf0Discovery()=%T", value)
}
})
}
for _, data := range [][]byte{{}, {byte(amf0MarkerReference)}, {byte(amf0MarkerDate)}, {byte(amf0MarkerForbidden)}} {
if value, err := Amf0Discovery(data); err == nil || value != nil {
t.Fatalf("Amf0Discovery(%v) value=%T, err=%v, want error", data, value, err)
}
}
}
func TestAmf0Converter(t *testing.T) {
values := []struct {
name string
in Amf0Any
ok func(Amf0Converter) bool
}{
{"number", NewAmf0Number(1), func(c Amf0Converter) bool { return c.ToNumber() != nil }},
{"boolean", NewAmf0Boolean(true), func(c Amf0Converter) bool { return c.ToBoolean() != nil }},
{"string", NewAmf0String("v"), func(c Amf0Converter) bool { return c.ToString() != nil }},
{"object", NewAmf0Object(), func(c Amf0Converter) bool { return c.ToObject() != nil }},
{"null", NewAmf0Null(), func(c Amf0Converter) bool { return c.ToNull() != nil }},
{"undefined", NewAmf0Undefined(), func(c Amf0Converter) bool { return c.ToUndefined() != nil }},
{"ecma-array", NewAmf0EcmaArray(), func(c Amf0Converter) bool { return c.ToEcmaArray() != nil }},
{"strict-array", NewAmf0StrictArray(), func(c Amf0Converter) bool { return c.ToStrictArray() != nil }},
}
for _, tt := range values {
t.Run(tt.name, func(t *testing.T) {
converter := NewAmf0Converter(tt.in)
if !tt.ok(converter) {
t.Fatalf("expected successful conversion for %T", tt.in)
}
})
}
nilConverter := NewAmf0Converter(nil)
if nilConverter.ToNumber() != nil || nilConverter.ToBoolean() != nil || nilConverter.ToString() != nil ||
nilConverter.ToObject() != nil || nilConverter.ToNull() != nil || nilConverter.ToUndefined() != nil ||
nilConverter.ToEcmaArray() != nil || nilConverter.ToStrictArray() != nil {
t.Fatal("nil converter should not convert")
}
}
func TestAmf0UTF8(t *testing.T) {
var value amf0UTF8 = "hello"
b, err := value.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
if value.Size() != len(b) {
t.Fatalf("Size()=%v, len=%v", value.Size(), len(b))
}
var decoded amf0UTF8
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if decoded != value {
t.Fatalf("decoded=%v, want %v", decoded, value)
}
for _, data := range [][]byte{{0x00}, {0x00, 0x05, 'h'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0Number(t *testing.T) {
number := NewAmf0Number(math.Pi)
if number.Size() != 9 || number.(*amf0Number).amf0Marker() != amf0MarkerNumber {
t.Fatalf("unexpected number metadata")
}
b, err := number.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0Number(0)
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := decoded.Float64(); got != math.Pi {
t.Fatalf("Float64()=%v, want %v", got, math.Pi)
}
for _, data := range [][]byte{{byte(amf0MarkerNumber)}, append([]byte{byte(amf0MarkerString)}, b[1:]...)} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0Boolean(t *testing.T) {
for _, want := range []bool{false, true} {
boolean := NewAmf0Boolean(want)
if boolean.Size() != 2 || boolean.(*amf0Boolean).amf0Marker() != amf0MarkerBoolean {
t.Fatalf("unexpected boolean metadata")
}
b, err := boolean.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0Boolean(!want)
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := decoded.Bool(); got != want {
t.Fatalf("Bool()=%v, want %v", got, want)
}
}
decoded := NewAmf0Boolean(false)
for _, data := range [][]byte{{byte(amf0MarkerBoolean)}, {byte(amf0MarkerNumber), 1}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0String(t *testing.T) {
value := NewAmf0String("hello")
if value.Size() != 8 || value.(*amf0String).amf0Marker() != amf0MarkerString {
t.Fatalf("unexpected string metadata")
}
b, err := value.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0String("")
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := decoded.String(); got != "hello" {
t.Fatalf("String()=%v, want hello", got)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerNumber), 0, 0}, {byte(amf0MarkerString), 0, 5, 'h'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0ObjectEOF(t *testing.T) {
eof := &amf0ObjectEOF{}
if eof.Size() != 3 || eof.amf0Marker() != amf0MarkerObjectEnd {
t.Fatalf("unexpected eof metadata")
}
b, err := eof.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
if !bytes.Equal(b, []byte{0, 0, 9}) {
t.Fatalf("MarshalBinary()=%v", b)
}
for _, data := range [][]byte{b, {0, 0, 9, 1}} {
if err := eof.UnmarshalBinary(data); err != nil {
t.Fatalf("UnmarshalBinary(%v) err=%v", data, err)
}
}
for _, data := range [][]byte{{0, 0}, {0, 1, 9}} {
if err := eof.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0Object(t *testing.T) {
object := NewAmf0Object().
Set("name", NewAmf0String("stream")).
Set("code", NewAmf0Number(100)).
Set("ok", NewAmf0Boolean(true))
object.Set("code", NewAmf0Number(200))
if object.(*amf0Object).amf0Marker() != amf0MarkerObject || object.Size() == 0 {
t.Fatalf("unexpected object metadata")
}
if object.Get("missing") != nil {
t.Fatal("missing property should be nil")
}
b, err := object.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0Object()
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" {
t.Fatalf("name=%v", got)
}
if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 200 {
t.Fatalf("code=%v", got)
}
if got := NewAmf0Converter(decoded.Get("ok")).ToBoolean().Bool(); !got {
t.Fatalf("ok=%v", got)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerString)}, {byte(amf0MarkerObject), 0, 4, 'n'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
base := &amf0ObjectBase{}
if err := base.unmarshal(nil, false, -1); err == nil {
t.Fatal("unmarshal without eof and negative maxElems should fail")
}
if err := base.unmarshal(nil, true, 0); err == nil {
t.Fatal("unmarshal with eof and non-negative maxElems should fail")
}
}
func TestAmf0EcmaArray(t *testing.T) {
array := NewAmf0EcmaArray().
Set("name", NewAmf0String("stream")).
Set("code", NewAmf0Number(100))
if array.(*amf0EcmaArray).amf0Marker() != amf0MarkerEcmaArray || array.Size() == 0 {
t.Fatalf("unexpected ecma array metadata")
}
if array.Get("missing") != nil {
t.Fatal("missing property should be nil")
}
b, err := array.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0EcmaArray()
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" {
t.Fatalf("name=%v", got)
}
if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 {
t.Fatalf("code=%v", got)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerEcmaArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerEcmaArray), 0, 0, 0, 0, 0, 4, 'n'}} {
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0StrictArray(t *testing.T) {
array := NewAmf0StrictArray().
Set("name", NewAmf0String("stream")).
Set("code", NewAmf0Number(100))
array.(*amf0StrictArray).count = 2
if array.(*amf0StrictArray).amf0Marker() != amf0MarkerStrictArray || array.Size() == 0 {
t.Fatalf("unexpected strict array metadata")
}
if array.Get("missing") != nil {
t.Fatal("missing property should be nil")
}
b, err := array.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
decoded := NewAmf0StrictArray()
if err := decoded.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" {
t.Fatalf("name=%v", got)
}
if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 {
t.Fatalf("code=%v", got)
}
empty := append([]byte{byte(amf0MarkerStrictArray)}, 0, 0, 0, 0)
if err := decoded.UnmarshalBinary(empty); err != nil {
t.Fatalf("UnmarshalBinary(empty) err=%v", err)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerStrictArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerStrictArray), 0, 0, 0, 1, 0, 4, 'n'}} {
if err := NewAmf0StrictArray().UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
}
func TestAmf0SingleMarkerObjects(t *testing.T) {
for _, tt := range []struct {
name string
value Amf0Any
marker amf0Marker
}{
{"null", NewAmf0Null(), amf0MarkerNull},
{"undefined", NewAmf0Undefined(), amf0MarkerUndefined},
} {
t.Run(tt.name, func(t *testing.T) {
if tt.value.Size() != 1 || tt.value.amf0Marker() != tt.marker {
t.Fatalf("unexpected metadata")
}
b, err := tt.value.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary() err=%v", err)
}
if err := tt.value.UnmarshalBinary(b); err != nil {
t.Fatalf("UnmarshalBinary() err=%v", err)
}
for _, data := range [][]byte{{}, {byte(amf0MarkerString)}} {
if err := tt.value.UnmarshalBinary(data); err == nil {
t.Fatalf("UnmarshalBinary(%v) should fail", data)
}
}
})
}
}
type errorAmf0Buffer struct {
writeByteErr bool
writeErr bool
}
func (v *errorAmf0Buffer) Bytes() []byte {
return nil
}
func (v *errorAmf0Buffer) WriteByte(byte) error {
if v.writeByteErr {
return fmt.Errorf("write byte")
}
return nil
}
func (v *errorAmf0Buffer) Write([]byte) (int, error) {
if v.writeErr {
return 0, fmt.Errorf("write")
}
return 0, nil
}
type errorAmf0Any struct {
Amf0Any
}
func (v *errorAmf0Any) Size() int {
return 1
}
func (v *errorAmf0Any) MarshalBinary() ([]byte, error) {
return nil, fmt.Errorf("marshal")
}
func (v *errorAmf0Any) UnmarshalBinary([]byte) error {
return nil
}
func (v *errorAmf0Any) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
func TestAmf0MarshalErrors(t *testing.T) {
originalCreateBuffer := createBuffer
defer func() { createBuffer = originalCreateBuffer }()
for _, tt := range []struct {
name string
make func() Amf0Any
}{
{"object", func() Amf0Any { return NewAmf0Object() }},
{"ecma-array", func() Amf0Any { return NewAmf0EcmaArray() }},
{"strict-array", func() Amf0Any { return NewAmf0StrictArray() }},
} {
t.Run(tt.name+" write-byte", func(t *testing.T) {
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} }
if _, err := tt.make().MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
t.Run(tt.name+" write-prop", func(t *testing.T) {
createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} }
value := tt.make()
switch v := value.(type) {
case Amf0Object:
v.Set("name", NewAmf0String("stream"))
case Amf0EcmaArray:
v.Set("name", NewAmf0String("stream"))
case Amf0StrictArray:
v.Set("name", NewAmf0String("stream"))
v.(*amf0StrictArray).count = 1
}
if _, err := value.MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
}
createBuffer = originalCreateBuffer
for _, tt := range []struct {
name string
make func() Amf0Any
}{
{"object", func() Amf0Any { return NewAmf0Object().Set("bad", &errorAmf0Any{}) }},
{"ecma-array", func() Amf0Any { return NewAmf0EcmaArray().Set("bad", &errorAmf0Any{}) }},
{"strict-array", func() Amf0Any {
value := NewAmf0StrictArray().Set("bad", &errorAmf0Any{})
value.(*amf0StrictArray).count = 1
return value
}},
} {
t.Run(tt.name+" marshal-value", func(t *testing.T) {
if _, err := tt.make().MarshalBinary(); err == nil {
t.Fatal("MarshalBinary() should fail")
}
})
}
}
func TestAmf0UnmarshalNestedErrors(t *testing.T) {
// Object property with unsupported marker.
data := []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerDate)}
if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "discover prop bad") {
t.Fatalf("err=%v, want discover prop bad", err)
}
// Object property with invalid payload size.
data = []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerNumber), 0}
if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "unmarshal prop bad") {
t.Fatalf("err=%v, want unmarshal prop bad", err)
}
}

View File

@ -0,0 +1,313 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp_test
import (
"bytes"
"context"
"fmt"
"net"
"time"
"srsx/internal/rtmp"
)
func ExampleAmf0Number() {
number := rtmp.NewAmf0Number(3.14)
b, err := number.MarshalBinary()
if err != nil {
panic(err)
}
value, err := rtmp.Amf0Discovery(b)
if err != nil {
panic(err)
}
if err := value.UnmarshalBinary(b); err != nil {
panic(err)
}
converter := rtmp.NewAmf0Converter(value)
fmt.Println("number:", converter.ToNumber().Float64())
fmt.Println("is string:", converter.ToString() != nil)
// Output:
// number: 3.14
// is string: false
}
func ExampleAmf0Object() {
object := rtmp.NewAmf0Object().
Set("code", rtmp.NewAmf0Number(100)).
Set("level", rtmp.NewAmf0String("status"))
b, err := object.MarshalBinary()
if err != nil {
panic(err)
}
value, err := rtmp.Amf0Discovery(b)
if err != nil {
panic(err)
}
if err := value.UnmarshalBinary(b); err != nil {
panic(err)
}
converter := rtmp.NewAmf0Converter(value)
fmt.Println("code:", rtmp.NewAmf0Converter(converter.ToObject().Get("code")).ToNumber().Float64())
fmt.Println("level:", rtmp.NewAmf0Converter(converter.ToObject().Get("level")).ToString().String())
fmt.Println("is number:", converter.ToNumber() != nil)
// Output:
// code: 100
// level: status
// is number: false
}
func ExampleNewHandshake() {
client := rtmp.NewHandshake()
server := rtmp.NewHandshake()
var clientToServer bytes.Buffer
if err := client.WriteC0S0(&clientToServer); err != nil {
panic(err)
}
if err := client.WriteC1S1(&clientToServer); err != nil {
panic(err)
}
c0, err := server.ReadC0S0(&clientToServer)
if err != nil {
panic(err)
}
c1, err := server.ReadC1S1(&clientToServer)
if err != nil {
panic(err)
}
var serverToClient bytes.Buffer
if err := server.WriteC0S0(&serverToClient); err != nil {
panic(err)
}
if err := server.WriteC1S1(&serverToClient); err != nil {
panic(err)
}
if err := server.WriteC2S2(&serverToClient, c1); err != nil {
panic(err)
}
s0, err := client.ReadC0S0(&serverToClient)
if err != nil {
panic(err)
}
s1, err := client.ReadC1S1(&serverToClient)
if err != nil {
panic(err)
}
s2, err := client.ReadC2S2(&serverToClient)
if err != nil {
panic(err)
}
if err := client.WriteC2S2(&clientToServer, s1); err != nil {
panic(err)
}
c2, err := server.ReadC2S2(&clientToServer)
if err != nil {
panic(err)
}
fmt.Println("client version:", c0[0])
fmt.Println("server version:", s0[0])
fmt.Println("c1 bytes:", len(c1))
fmt.Println("s1 bytes:", len(s1))
fmt.Println("s2 echoes c1:", bytes.Equal(s2, c1))
fmt.Println("c2 echoes s1:", bytes.Equal(c2, s1))
fmt.Println("server cached c1:", bytes.Equal(server.C1S1(), c1))
// Output:
// client version: 3
// server version: 3
// c1 bytes: 1536
// s1 bytes: 1536
// s2 echoes c1: true
// c2 echoes s1: true
// server cached c1: true
}
func ExampleNewProtocol() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
client := rtmp.NewProtocol(clientConn)
server := rtmp.NewProtocol(serverConn)
backendConn, upstreamConn := net.Pipe()
defer backendConn.Close()
defer upstreamConn.Close()
backend := rtmp.NewProtocol(backendConn)
upstream := rtmp.NewProtocol(upstreamConn)
done := make(chan error, 1)
go func() {
err := func() error {
// The server can read a raw message first, then decode it explicitly.
m, err := server.ExpectMessage(ctx, rtmp.MessageTypeSetChunkSize)
if err != nil {
return err
}
pkt, err := server.DecodeMessage(m)
if err != nil {
return err
}
chunkSize := pkt.(*rtmp.SetChunkSize)
// ExpectPacket reads and decodes messages until it finds the requested packet type.
var connectReq *rtmp.ConnectAppPacket
if _, err := rtmp.ExpectPacket(ctx, server, &connectReq); err != nil {
return err
}
ack := rtmp.NewWindowAcknowledgementSize()
ack.AckSize = 2500000
if err := server.WritePacket(ctx, ack, 0); err != nil {
return err
}
serverChunk := rtmp.NewSetChunkSize()
serverChunk.ChunkSize = chunkSize.ChunkSize
if err := server.WritePacket(ctx, serverChunk, 0); err != nil {
return err
}
connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID)
connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888"))
connectRes.Args.Set("level", rtmp.NewAmf0String("status"))
connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success"))
if err := server.WritePacket(ctx, connectRes, 0); err != nil {
return err
}
var createStream *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, server, &createStream); err != nil {
return err
}
createStreamRes := rtmp.NewCreateStreamResPacket(createStream.TransactionID)
createStreamRes.SetStreamID(1)
if err := server.WritePacket(ctx, createStreamRes, 0); err != nil {
return err
}
// For media forwarding, the proxy reads a complete RTMP message and writes
// that same message to another protocol connection without decoding/repacking.
publishMessage, err := server.ReadMessage(ctx)
if err != nil {
return err
}
if err := backend.WriteMessage(ctx, publishMessage); err != nil {
return err
}
return nil
}()
select {
case done <- err:
case <-ctx.Done():
}
}()
// Client runs the normal RTMP command workflow: configure chunk size,
// connect to an app, create a stream, publish it, then verify the proxy
// forwarded the publish message to the upstream side.
if err := func() error {
clientChunk := rtmp.NewSetChunkSize()
clientChunk.ChunkSize = 128
if err := client.WritePacket(ctx, clientChunk, 0); err != nil {
return err
}
connectReq := rtmp.NewConnectAppPacket()
connectReq.CommandObject.Set("tcUrl", rtmp.NewAmf0String("rtmp://example.com/live"))
if err := client.WritePacket(ctx, connectReq, 0); err != nil {
return err
}
ackMessage, err := client.ExpectMessage(ctx, rtmp.MessageTypeWindowAcknowledgementSize)
if err != nil {
return err
}
ackPacket, err := client.DecodeMessage(ackMessage)
if err != nil {
return err
}
ack, ok := ackPacket.(*rtmp.WindowAcknowledgementSize)
if !ok {
return fmt.Errorf("unexpected ack packet %T", ackPacket)
}
var serverChunk *rtmp.SetChunkSize
if _, err := rtmp.ExpectPacket(ctx, client, &serverChunk); err != nil {
return err
}
var connectRes *rtmp.ConnectAppResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectRes); err != nil {
return err
}
createStream := rtmp.NewCreateStreamPacket()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return err
}
var createStreamRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &createStreamRes); err != nil {
return err
}
publish := rtmp.NewPublishPacket()
publish.TransactionID = 5
publish.StreamName = rtmp.NewAmf0String("livestream")
publish.StreamType = rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publish, int(createStreamRes.StreamID)); err != nil {
return err
}
var upstreamPublish *rtmp.PublishPacket
if _, err := rtmp.ExpectPacket(ctx, upstream, &upstreamPublish); err != nil {
return err
}
select {
case err := <-done:
if err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
fmt.Println("ack size:", ack.AckSize)
fmt.Println("chunk size:", serverChunk.ChunkSize)
fmt.Println("connect:", rtmp.NewAmf0Converter(connectRes.Args.Get("code")).ToString().String())
fmt.Println("stream id:", int(createStreamRes.StreamID))
fmt.Println("forward publish:", upstreamPublish.StreamName.String())
return nil
}(); err != nil {
panic(err)
}
// Output:
// ack size: 2500000
// chunk size: 128
// connect: NetConnection.Connect.Success
// stream id: 1
// forward publish: livestream
}

7
internal/rtmp/gen.go Normal file
View File

@ -0,0 +1,7 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
//go:generate go tool counterfeiter -o rtmpfakes/fake_handshake.go . Handshake
//go:generate go tool counterfeiter -o rtmpfakes/fake_protocol.go . Protocol

View File

@ -17,21 +17,31 @@ import (
"srsx/internal/errors"
)
// The handshake implements the RTMP handshake protocol.
type Handshake struct {
// Handshake implements the RTMP handshake protocol.
type Handshake interface {
C1S1() []byte
WriteC0S0(w io.Writer) error
ReadC0S0(r io.Reader) ([]byte, error)
WriteC1S1(w io.Writer) error
ReadC1S1(r io.Reader) ([]byte, error)
WriteC2S2(w io.Writer, s1c1 []byte) error
ReadC2S2(r io.Reader) ([]byte, error)
}
type handshake struct {
// The c1s1 cache.
c1s1 []byte
}
func NewHandshake() *Handshake {
return &Handshake{}
func NewHandshake() Handshake {
return &handshake{}
}
func (v *Handshake) C1S1() []byte {
func (v *handshake) C1S1() []byte {
return v.c1s1
}
func (v *Handshake) WriteC0S0(w io.Writer) (err error) {
func (v *handshake) WriteC0S0(w io.Writer) (err error) {
r := bytes.NewReader([]byte{0x03})
if _, err = io.Copy(w, r); err != nil {
return errors.Wrap(err, "write c0s0")
@ -40,7 +50,7 @@ func (v *Handshake) WriteC0S0(w io.Writer) (err error) {
return
}
func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) {
func (v *handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) {
b := &bytes.Buffer{}
if _, err = io.CopyN(b, r, 1); err != nil {
return nil, errors.Wrap(err, "read c0s0")
@ -51,7 +61,7 @@ func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) {
return
}
func (v *Handshake) WriteC1S1(w io.Writer) (err error) {
func (v *handshake) WriteC1S1(w io.Writer) (err error) {
p := make([]byte, 1536)
// Use crypto/rand for thread-safe random generation
@ -67,7 +77,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) {
return
}
func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) {
func (v *handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) {
b := &bytes.Buffer{}
if _, err = io.CopyN(b, r, 1536); err != nil {
return nil, errors.Wrap(err, "read c1s1")
@ -79,7 +89,7 @@ func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) {
return
}
func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) {
func (v *handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) {
r := bytes.NewReader(s1c1[:])
if _, err = io.Copy(w, r); err != nil {
return errors.Wrap(err, "write c2s2")
@ -88,7 +98,7 @@ func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) {
return
}
func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) {
func (v *handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) {
b := &bytes.Buffer{}
if _, err = io.CopyN(b, r, 1536); err != nil {
return nil, errors.Wrap(err, "read c2s2")
@ -129,7 +139,7 @@ type chunkStream struct {
format formatType
cid chunkID
header messageHeader
message *Message
message *message
count uint64
extendedTimestamp bool
}
@ -138,8 +148,18 @@ func newChunkStream() *chunkStream {
return &chunkStream{}
}
// The protocol implements the RTMP command and chunk stack.
type Protocol struct {
// Protocol implements the RTMP command and chunk stack.
type Protocol interface {
// Deprecated: Go does not support generic methods. Please use rtmp.ExpectPacket instead.
ExpectPacket(ctx context.Context, ppkt any) (Message, error)
ExpectMessage(ctx context.Context, types ...MessageType) (Message, error)
DecodeMessage(m Message) (Packet, error)
ReadMessage(ctx context.Context) (Message, error)
WritePacket(ctx context.Context, pkt Packet, streamID int) error
WriteMessage(ctx context.Context, m Message) error
}
type protocol struct {
r *bufio.Reader
w *bufio.Writer
input struct {
@ -154,8 +174,8 @@ type Protocol struct {
}
}
func NewProtocol(rw io.ReadWriter) *Protocol {
v := &Protocol{
func NewProtocol(rw io.ReadWriter) Protocol {
v := &protocol{
r: bufio.NewReader(rw),
w: bufio.NewWriter(rw),
}
@ -169,7 +189,11 @@ func NewProtocol(rw io.ReadWriter) *Protocol {
return v
}
func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) {
// ExpectPacket reads and decodes RTMP messages until it finds a packet of type T.
//
// Messages with other packet types are consumed and ignored. On success, ppkt is
// set to the decoded packet and the Message carrying that packet is returned.
func ExpectPacket[T Packet](ctx context.Context, v Protocol, ppkt *T) (m Message, err error) {
for {
if m, err = v.ReadMessage(ctx); err != nil {
return nil, errors.WithMessage(err, "read message")
@ -189,12 +213,12 @@ func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Messa
return
}
// Deprecated: Please use rtmp.ExpectPacket instead.
func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) {
panic("Please use rtmp.ExpectPacket instead")
// Deprecated: Go does not support generic methods. Please use rtmp.ExpectPacket instead.
func (v *protocol) ExpectPacket(ctx context.Context, ppkt any) (m Message, err error) {
panic("Go does not support generic methods; please use rtmp.ExpectPacket instead")
}
func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) {
func (v *protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m Message, err error) {
for {
if m, err = v.ReadMessage(ctx); err != nil {
return nil, errors.WithMessage(err, "read message")
@ -205,16 +229,14 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *
}
for _, t := range types {
if m.MessageType == t {
if m.MessageType() == t {
return
}
}
}
return
}
func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) {
func (v *protocol) parseAMFObject(p []byte) (pkt Packet, err error) {
var commandName amf0String
if err = commandName.UnmarshalBinary(p); err != nil {
return nil, errors.WithMessage(err, "unmarshal command name")
@ -266,18 +288,18 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) {
}
}
func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) {
p := m.Payload[:]
func (v *protocol) DecodeMessage(m Message) (pkt Packet, err error) {
p := m.Payload()[:]
if len(p) == 0 {
return nil, errors.New("Empty packet")
}
switch m.MessageType {
switch m.MessageType() {
case MessageTypeAMF3Command, MessageTypeAMF3Data:
p = p[1:]
}
switch m.MessageType {
switch m.MessageType() {
case MessageTypeSetChunkSize:
pkt = NewSetChunkSize()
case MessageTypeWindowAcknowledgementSize:
@ -286,22 +308,22 @@ func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) {
pkt = NewSetPeerBandwidth()
case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data:
if pkt, err = v.parseAMFObject(p); err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType))
return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType()))
}
case MessageTypeUserControl:
pkt = NewUserControl()
default:
return nil, errors.Errorf("Unknown message %v", m.MessageType)
return nil, errors.Errorf("Unknown message %v", m.MessageType())
}
if err = pkt.UnmarshalBinary(p); err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType))
return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType()))
}
return
}
func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) {
func (v *protocol) ReadMessage(ctx context.Context) (m Message, err error) {
for m == nil {
// TODO: We should convert buffered io to async io, because we will be stuck in block io here,
// TODO: but the risk is acceptable because we literally will set the underlay io timeout.
@ -331,15 +353,17 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) {
return nil, errors.WithMessage(err, "read message payload")
}
if err = v.onMessageArrivated(m); err != nil {
if m != nil {
if err = v.onMessageArrivated(m.asMessage()); err != nil {
return nil, errors.WithMessage(err, "on message")
}
}
}
return
}
func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) {
func (v *protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m Message, err error) {
// Empty payload message.
if chunk.message.payloadLength == 0 {
m = chunk.message
@ -348,7 +372,7 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (
}
// Calculate the chunk payload size.
chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload)
chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.payload)
if chunkedPayloadSize > int(v.input.opt.chunkSize) {
chunkedPayloadSize = int(v.input.opt.chunkSize)
}
@ -357,10 +381,10 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (
if _, err = io.ReadFull(v.r, b); err != nil {
return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize)
}
chunk.message.Payload = append(chunk.message.Payload, b...)
chunk.message.payload = append(chunk.message.payload, b...)
// Got entire RTMP message?
if int(chunk.message.payloadLength) == len(chunk.message.Payload) {
if int(chunk.message.payloadLength) == len(chunk.message.payload) {
m = chunk.message
chunk.message = nil
}
@ -426,7 +450,7 @@ var messageHeaderSizes = []int{11, 7, 3, 0}
// fmt=1, 0x4X
// fmt=2, 0x8X
// fmt=3, 0xCX
func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) {
func (v *protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) {
// We should not assert anything about fmt, for the first packet.
// (when first packet, the chunk.message is nil).
// the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet.
@ -480,7 +504,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo
// Create msg when new chunk stream start
if chunk.message == nil {
chunk.message = NewMessage()
chunk.message = newMessage()
}
// Read the message header.
@ -659,7 +683,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo
//
// Chunk stream IDs with values 64-319 could be represented by both 2-
// byte version and 3-byte version of this field.
func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) {
func (v *protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) {
// 2-63, 1B chunk header
var t uint8
if err = binary.Read(v.r, binary.BigEndian, &t); err != nil {
@ -689,14 +713,14 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid
return
}
func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) {
m := NewMessage()
func (v *protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) {
m := newMessage()
if m.Payload, err = pkt.MarshalBinary(); err != nil {
if m.payload, err = pkt.MarshalBinary(); err != nil {
return errors.WithMessage(err, "marshal payload")
}
m.MessageType = pkt.Type()
m.messageHeader.MessageType = pkt.Type()
m.streamID = uint32(streamID)
m.betterCid = pkt.BetterCid()
@ -711,7 +735,7 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e
return
}
func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) {
func (v *protocol) onPacketWriten(m *message, pkt Packet) (err error) {
var tid amf0Number
var name amf0String
@ -734,16 +758,16 @@ func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) {
return
}
func (v *Protocol) onMessageArrivated(m *Message) (err error) {
func (v *protocol) onMessageArrivated(m *message) (err error) {
if m == nil {
return
}
var pkt Packet
switch m.MessageType {
switch m.MessageType() {
case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize:
if pkt, err = v.DecodeMessage(m); err != nil {
return errors.Errorf("decode message %v", m.MessageType)
return errors.Errorf("decode message %v", m.MessageType())
}
}
@ -755,19 +779,20 @@ func (v *Protocol) onMessageArrivated(m *Message) (err error) {
return
}
func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) {
m.payloadLength = uint32(len(m.Payload))
func (v *protocol) WriteMessage(ctx context.Context, m Message) (err error) {
msg := m.asMessage()
msg.payloadLength = uint32(len(msg.payload))
var c0h, c3h []byte
if c0h, err = m.generateC0Header(); err != nil {
if c0h, err = msg.generateC0Header(); err != nil {
return errors.WithMessage(err, "generate c0 header")
}
if c3h, err = m.generateC3Header(); err != nil {
if c3h, err = msg.generateC3Header(); err != nil {
return errors.WithMessage(err, "generate c3 header")
}
var h []byte
p := m.Payload
p := msg.payload
for len(p) > 0 {
// TODO: We should convert buffered io to async io, because we will be stuck in block io here,
// TODO: but the risk is acceptable because we literally will set the underlay io timeout.
@ -899,29 +924,56 @@ type messageHeader struct {
Timestamp uint64
}
// The RTMP message, transport over chunk stream in RTMP.
// Message is an RTMP message transported over a chunk stream.
// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header
type Message struct {
type Message interface {
MessageType() MessageType
Timestamp() uint64
Payload() []byte
asMessage() *message
}
type message struct {
messageHeader
// The payload which carries the RTMP packet.
Payload []byte
payload []byte
}
func NewMessage() *Message {
return &Message{}
func NewMessage() Message {
return newMessage()
}
func NewStreamMessage(streamID int) *Message {
v := NewMessage()
func newMessage() *message {
return &message{}
}
func NewStreamMessage(streamID int) Message {
v := newMessage()
v.streamID = uint32(streamID)
v.betterCid = chunkIDOverStream
return v
}
func (v *Message) generateC3Header() ([]byte, error) {
func (v *message) MessageType() MessageType {
return v.messageHeader.MessageType
}
func (v *message) Timestamp() uint64 {
return v.messageHeader.Timestamp
}
func (v *message) Payload() []byte {
return v.payload
}
func (v *message) asMessage() *message {
return v
}
func (v *message) generateC3Header() ([]byte, error) {
var c3h []byte
if v.Timestamp < extendedTimestamp {
if v.messageHeader.Timestamp < extendedTimestamp {
c3h = make([]byte, 1)
} else {
c3h = make([]byte, 1+4)
@ -935,19 +987,19 @@ func (v *Message) generateC3Header() ([]byte, error) {
// but actually all products from adobe, such as FMS/AMS and Flash player and FMLE,
// always carry a extended timestamp in C3 header.
// @see: http://blog.csdn.net/win_lin/article/details/13363699
if v.Timestamp >= extendedTimestamp {
p[0] = byte(v.Timestamp >> 24)
p[1] = byte(v.Timestamp >> 16)
p[2] = byte(v.Timestamp >> 8)
p[3] = byte(v.Timestamp)
if v.messageHeader.Timestamp >= extendedTimestamp {
p[0] = byte(v.messageHeader.Timestamp >> 24)
p[1] = byte(v.messageHeader.Timestamp >> 16)
p[2] = byte(v.messageHeader.Timestamp >> 8)
p[3] = byte(v.messageHeader.Timestamp)
}
return c3h, nil
}
func (v *Message) generateC0Header() ([]byte, error) {
func (v *message) generateC0Header() ([]byte, error) {
var c0h []byte
if v.Timestamp < extendedTimestamp {
if v.messageHeader.Timestamp < extendedTimestamp {
c0h = make([]byte, 1+3+3+1+4)
} else {
c0h = make([]byte, 1+3+3+1+4+4)
@ -957,10 +1009,10 @@ func (v *Message) generateC0Header() ([]byte, error) {
p[0] = byte(v.betterCid) & 0x3f
p = p[1:]
if v.Timestamp < extendedTimestamp {
p[0] = byte(v.Timestamp >> 16)
p[1] = byte(v.Timestamp >> 8)
p[2] = byte(v.Timestamp)
if v.messageHeader.Timestamp < extendedTimestamp {
p[0] = byte(v.messageHeader.Timestamp >> 16)
p[1] = byte(v.messageHeader.Timestamp >> 8)
p[2] = byte(v.messageHeader.Timestamp)
} else {
p[0] = 0xff
p[1] = 0xff
@ -973,7 +1025,7 @@ func (v *Message) generateC0Header() ([]byte, error) {
p[2] = byte(v.payloadLength)
p = p[3:]
p[0] = byte(v.MessageType)
p[0] = byte(v.messageHeader.MessageType)
p = p[1:]
p[0] = byte(v.streamID)
@ -982,11 +1034,11 @@ func (v *Message) generateC0Header() ([]byte, error) {
p[3] = byte(v.streamID >> 24)
p = p[4:]
if v.Timestamp >= extendedTimestamp {
p[0] = byte(v.Timestamp >> 24)
p[1] = byte(v.Timestamp >> 16)
p[2] = byte(v.Timestamp >> 8)
p[3] = byte(v.Timestamp)
if v.messageHeader.Timestamp >= extendedTimestamp {
p[0] = byte(v.messageHeader.Timestamp >> 24)
p[1] = byte(v.messageHeader.Timestamp >> 16)
p[2] = byte(v.messageHeader.Timestamp >> 8)
p[3] = byte(v.messageHeader.Timestamp)
}
return c0h, nil
@ -1039,8 +1091,8 @@ type Packet interface {
type objectCallPacket struct {
CommandName amf0String
TransactionID amf0Number
CommandObject *amf0Object
Args *amf0Object
CommandObject Amf0Object
Args Amf0Object
}
func (v *objectCallPacket) BetterCid() chunkID {
@ -1081,7 +1133,7 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) {
return
}
v.Args = NewAmf0Object()
v.Args = newAmf0Object()
if err = v.Args.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal args")
}
@ -1149,8 +1201,8 @@ func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) {
func (v *ConnectAppPacket) TcUrl() string {
if v.CommandObject != nil {
if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok {
return string(*v)
if v, ok := v.CommandObject.Get("tcUrl").(Amf0String); ok {
return v.String()
}
}
return ""
@ -1172,9 +1224,9 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket {
func (v *ConnectAppResPacket) SrsID() string {
if v.Args != nil {
if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok {
if v, ok := v.Get("srs_id").(*amf0String); ok {
return string(*v)
if v, ok := v.Args.Get("data").(Amf0EcmaArray); ok {
if v, ok := v.Get("srs_id").(Amf0String); ok {
return v.String()
}
}
}
@ -1197,7 +1249,7 @@ func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) {
type variantCallPacket struct {
CommandName amf0String
TransactionID amf0Number
CommandObject amf0Any // object or null
CommandObject Amf0Any // object or null
}
func (v *variantCallPacket) BetterCid() chunkID {
@ -1273,7 +1325,7 @@ func (v *variantCallPacket) MarshalBinary() (data []byte, err error) {
// @remark onStatus packet is a call packet.
type CallPacket struct {
variantCallPacket
Args amf0Any // optional or object or null
Args Amf0Any // optional or object or null
}
func NewCallPacket() *CallPacket {
@ -1282,9 +1334,9 @@ func NewCallPacket() *CallPacket {
func (v *CallPacket) ArgsCode() string {
if v.Args != nil {
if v, ok := v.Args.(*amf0Object); ok {
if code, ok := v.Get("code").(*amf0String); ok {
return string(*code)
if v, ok := v.Args.(Amf0Object); ok {
if code, ok := v.Get("code").(Amf0String); ok {
return code.String()
}
}
}
@ -1370,6 +1422,10 @@ func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket {
return v
}
func (v *CreateStreamResPacket) SetStreamID(streamID int) {
v.StreamID = amf0Number(streamID)
}
func (v *CreateStreamResPacket) Size() int {
return v.variantCallPacket.Size() + v.StreamID.Size()
}
@ -1407,15 +1463,16 @@ func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) {
// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish
type PublishPacket struct {
variantCallPacket
StreamName amf0String
StreamType amf0String
StreamName Amf0String
StreamType Amf0String
}
func NewPublishPacket() *PublishPacket {
v := &PublishPacket{}
v.CommandName = commandPublish
v.CommandObject = NewAmf0Null()
v.StreamType = "live"
v.StreamName = NewAmf0String("")
v.StreamType = NewAmf0String("live")
return v
}
@ -1431,11 +1488,13 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) {
}
p = p[v.variantCallPacket.Size():]
v.StreamName = newAmf0String("")
if err = v.StreamName.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal stream name")
}
p = p[v.StreamName.Size():]
v.StreamType = newAmf0String("")
if err = v.StreamType.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal stream type")
}
@ -1466,13 +1525,14 @@ func (v *PublishPacket) MarshalBinary() (data []byte, err error) {
// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play
type PlayPacket struct {
variantCallPacket
StreamName amf0String
StreamName Amf0String
}
func NewPlayPacket() *PlayPacket {
v := &PlayPacket{}
v.CommandName = commandPlay
v.CommandObject = NewAmf0Null()
v.StreamName = NewAmf0String("")
return v
}
@ -1488,6 +1548,7 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) {
}
p = p[v.variantCallPacket.Size():]
v.StreamName = newAmf0String("")
if err = v.StreamName.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, "unmarshal stream name")
}

728
internal/rtmp/rtmp_test.go Normal file
View File

@ -0,0 +1,728 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
import (
"bytes"
"context"
"encoding/binary"
"io"
"reflect"
"strings"
"testing"
)
type errWriter struct{}
func (errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe }
func TestHandshakeSimpleAndErrors(t *testing.T) {
h := NewHandshake()
var b bytes.Buffer
if err := h.WriteC0S0(&b); err != nil {
t.Fatalf("WriteC0S0 err=%v", err)
}
c0, err := h.ReadC0S0(&b)
if err != nil || !bytes.Equal(c0, []byte{3}) {
t.Fatalf("ReadC0S0=%v, err=%v", c0, err)
}
if err := h.WriteC0S0(errWriter{}); err == nil {
t.Fatal("WriteC0S0 should fail")
}
if _, err := h.ReadC0S0(bytes.NewReader(nil)); err == nil {
t.Fatal("ReadC0S0 should fail")
}
b.Reset()
if err := h.WriteC1S1(&b); err != nil {
t.Fatalf("WriteC1S1 err=%v", err)
}
if b.Len() != 1536 {
t.Fatalf("C1S1 len=%v", b.Len())
}
c1, err := h.ReadC1S1(&b)
if err != nil || len(c1) != 1536 || !bytes.Equal(h.C1S1(), c1) {
t.Fatalf("ReadC1S1 len=%v, cached=%v, err=%v", len(c1), bytes.Equal(h.C1S1(), c1), err)
}
if err := h.WriteC1S1(errWriter{}); err == nil {
t.Fatal("WriteC1S1 should fail")
}
if _, err := h.ReadC1S1(bytes.NewReader(make([]byte, 1535))); err == nil {
t.Fatal("ReadC1S1 should fail")
}
b.Reset()
if err := h.WriteC2S2(&b, c1); err != nil {
t.Fatalf("WriteC2S2 err=%v", err)
}
c2, err := h.ReadC2S2(&b)
if err != nil || !bytes.Equal(c2, c1) {
t.Fatalf("ReadC2S2 match=%v, err=%v", bytes.Equal(c2, c1), err)
}
if err := h.WriteC2S2(errWriter{}, c1); err == nil {
t.Fatal("WriteC2S2 should fail")
}
if _, err := h.ReadC2S2(bytes.NewReader(make([]byte, 1535))); err == nil {
t.Fatal("ReadC2S2 should fail")
}
}
func TestSettingsChunkStreamAndMessageConstructors(t *testing.T) {
if s := newSettings(); s.chunkSize != defaultChunkSize {
t.Fatalf("chunk size=%v", s.chunkSize)
}
if c := newChunkStream(); c == nil || c.count != 0 {
t.Fatalf("chunk stream=%#v", c)
}
m := NewMessage().asMessage()
m.messageHeader.MessageType = MessageTypeAudio
m.messageHeader.Timestamp = 99
m.payload = []byte{1, 2, 3}
if m.MessageType() != MessageTypeAudio || m.Timestamp() != 99 || !bytes.Equal(m.Payload(), []byte{1, 2, 3}) || m.asMessage() != m {
t.Fatalf("bad message accessors")
}
sm := NewStreamMessage(7).asMessage()
if sm.streamID != 7 || sm.betterCid != chunkIDOverStream {
t.Fatalf("stream message=%#v", sm.messageHeader)
}
}
func TestBasicHeaderVariantsAndErrors(t *testing.T) {
ctx := context.Background()
cases := []struct {
name string
data []byte
fmt formatType
cid chunkID
}{
{"one-byte", []byte{0x85}, formatType2, 5},
{"two-byte", []byte{0x40, 0x0a}, formatType1, 74},
{"three-byte-code-path", []byte{0xc1, 0x01, 0x02}, formatType3, 65},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p := NewProtocol(bytes.NewBuffer(tt.data)).(*protocol)
fmt, cid, err := p.readBasicHeader(ctx)
if err != nil || fmt != tt.fmt || cid != tt.cid {
t.Fatalf("fmt=%v cid=%v err=%v", fmt, cid, err)
}
})
}
for _, data := range [][]byte{{}, {0x00}} {
p := NewProtocol(bytes.NewBuffer(data)).(*protocol)
if _, _, err := p.readBasicHeader(ctx); err == nil {
t.Fatalf("readBasicHeader(%x) should fail", data)
}
}
}
func TestReadMessageHeadersPayloadsAndChunks(t *testing.T) {
ctx := context.Background()
var in bytes.Buffer
// fmt0 cid=5, timestamp=10, len=3, type audio, stream=1, payload 010203.
in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3})
// fmt1 same cid, delta=5, len=2, type video, payload 0405.
in.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5})
// fmt2 same cid, delta=7, reuses len/type/stream, payload 0607.
in.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7})
// fmt3 same cid, reuses delta and advances timestamp, payload 0809.
in.Write([]byte{0xc5, 8, 9})
p := NewProtocol(&in).(*protocol)
for i, want := range []struct {
typ MessageType
ts uint64
pl []byte
}{
{MessageTypeAudio, 10, []byte{1, 2, 3}},
{MessageTypeVideo, 15, []byte{4, 5}},
{MessageTypeVideo, 22, []byte{6, 7}},
{MessageTypeVideo, 29, []byte{8, 9}},
} {
m, err := p.ReadMessage(ctx)
if err != nil {
t.Fatalf("ReadMessage #%v err=%v", i, err)
}
if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) {
t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload())
}
}
}
func TestReadMessageExtendedTimestampAndChunking(t *testing.T) {
ctx := context.Background()
var in bytes.Buffer
payload := []byte{1, 2, 3, 4, 5}
// fmt0 cid=5, normal timestamp=0xffffff, extended timestamp has high bit set and should be masked.
in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, byte(len(payload)), byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00})
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
in.Write(payload[:2])
// continuation chunk has fmt3 and extended timestamp too.
in.Write([]byte{0xc5})
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
in.Write(payload[2:4])
in.Write([]byte{0xc5})
binary.Write(&in, binary.BigEndian, uint32(0x8000002a))
in.Write(payload[4:])
p := NewProtocol(&in).(*protocol)
p.input.opt.chunkSize = 2
m, err := p.ReadMessage(ctx)
if err != nil {
t.Fatalf("ReadMessage err=%v", err)
}
if m.Timestamp() != 42 || !bytes.Equal(m.Payload(), payload) {
t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload())
}
}
func TestReadMessageHeaderErrors(t *testing.T) {
ctx := context.Background()
// Fresh non-zero chunk with fmt1 is rejected.
p := NewProtocol(bytes.NewBuffer([]byte{0x45})).(*protocol)
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "fresh chunk") {
t.Fatalf("fresh fmt1 err=%v", err)
}
// Existing partial message cannot restart with fmt0.
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x05})).(*protocol)
p.input.opt.chunkSize = 1
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "exists chunk") {
t.Fatalf("restart err=%v", err)
}
// Size change in a continuation header is rejected.
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x45, 0, 0, 1, 0, 0, 4, byte(MessageTypeAudio)})).(*protocol)
p.input.opt.chunkSize = 1
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "message size") {
t.Fatalf("size change err=%v", err)
}
// Short payload and short extended timestamp.
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 2, byte(MessageTypeAudio), 1, 0, 0, 0, 1})).(*protocol)
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "read chunk") {
t.Fatalf("payload err=%v", err)
}
p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0xff, 0xff, 0xff, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 2})).(*protocol)
if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "ext-ts") {
t.Fatalf("ext-ts err=%v", err)
}
}
func TestWriteMessageHeadersChunkingAndErrors(t *testing.T) {
ctx := context.Background()
var out bytes.Buffer
p := NewProtocol(&out).(*protocol)
p.output.opt.chunkSize = 2
m := NewStreamMessage(7).asMessage()
m.messageHeader.MessageType = MessageTypeVideo
m.messageHeader.Timestamp = extendedTimestamp + 9
m.payload = []byte{1, 2, 3, 4, 5}
if err := p.WriteMessage(ctx, m); err != nil {
t.Fatalf("WriteMessage err=%v", err)
}
want := []byte{0x05, 0xff, 0xff, 0xff, 0, 0, 5, byte(MessageTypeVideo), 7, 0, 0, 0, 0x01, 0x00, 0x00, 0x08, 1, 2, 0xc5, 0x01, 0x00, 0x00, 0x08, 3, 4, 0xc5, 0x01, 0x00, 0x00, 0x08, 5}
if !bytes.Equal(out.Bytes(), want) {
t.Fatalf("written=%x want=%x", out.Bytes(), want)
}
if err := p.WriteMessage(ctx, (&message{})); err != nil {
t.Fatalf("empty WriteMessage err=%v", err)
}
canceled, cancel := context.WithCancel(ctx)
cancel()
if err := p.WriteMessage(canceled, m); err != context.Canceled {
t.Fatalf("canceled WriteMessage err=%v", err)
}
p = NewProtocol(struct {
io.Reader
io.Writer
}{bytes.NewReader(nil), errWriter{}}).(*protocol)
if err := p.WriteMessage(ctx, m); err == nil {
t.Fatal("WriteMessage to bad writer should fail")
}
}
func TestProtocolDecodeMessageAndControls(t *testing.T) {
ctx := context.Background()
p := NewProtocol(&bytes.Buffer{}).(*protocol)
if _, err := p.DecodeMessage((&message{})); err == nil || !strings.Contains(err.Error(), "Empty packet") {
t.Fatalf("empty decode err=%v", err)
}
unknown := &message{}
unknown.messageHeader.MessageType = MessageTypeAudio
unknown.payload = []byte{1}
if _, err := p.DecodeMessage(unknown); err == nil || !strings.Contains(err.Error(), "Unknown message") {
t.Fatalf("unknown err=%v", err)
}
bad := &message{}
bad.messageHeader.MessageType = MessageTypeSetChunkSize
bad.payload = []byte{1, 2}
if _, err := p.DecodeMessage(bad); err == nil || !strings.Contains(err.Error(), "Unmarshal") {
t.Fatalf("bad control err=%v", err)
}
for _, pkt := range []Packet{
&SetChunkSize{ChunkSize: 4096},
&WindowAcknowledgementSize{AckSize: 2500000},
&SetPeerBandwidth{Bandwidth: 1000, LimitType: LimitTypeSoft},
&UserControl{EventType: EventTypePingRequest, EventData: 123},
} {
data, err := pkt.MarshalBinary()
if err != nil {
t.Fatalf("marshal %T err=%v", pkt, err)
}
m := &message{payload: data}
m.messageHeader.MessageType = pkt.Type()
got, err := p.DecodeMessage(m)
if err != nil {
t.Fatalf("DecodeMessage %T err=%v", pkt, err)
}
if reflect.TypeOf(got) != reflect.TypeOf(pkt) {
t.Fatalf("got %T want %T", got, pkt)
}
}
chunk := &SetChunkSize{ChunkSize: 3}
m := &message{}
m.messageHeader.MessageType = chunk.Type()
m.payload, _ = chunk.MarshalBinary()
if err := p.onMessageArrivated(m); err != nil || p.input.opt.chunkSize != 3 {
t.Fatalf("onMessageArrivated err=%v chunk=%v", err, p.input.opt.chunkSize)
}
if err := p.onMessageArrivated(nil); err != nil {
t.Fatalf("nil onMessageArrivated err=%v", err)
}
bad.Payload()[0] = 1
if err := p.onMessageArrivated(bad); err == nil {
t.Fatal("bad onMessageArrivated should fail")
}
if _, err := p.ExpectMessage(ctx); err == nil {
t.Fatal("ExpectMessage on empty reader should fail")
}
}
func TestProtocolPacketsAndTransactions(t *testing.T) {
ctx := context.Background()
var wire bytes.Buffer
writer := NewProtocol(&wire).(*protocol)
connect := NewConnectAppPacket()
connect.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live"))
if connect.Size() == 0 || connect.BetterCid() != chunkIDOverConnection || connect.Type() != MessageTypeAMF0Command || connect.TcUrl() == "" {
t.Fatalf("connect metadata invalid")
}
if err := writer.WritePacket(ctx, connect, 0); err != nil {
t.Fatalf("WritePacket connect err=%v", err)
}
if _, ok := writer.input.transactions[connect.TransactionID]; !ok {
t.Fatal("connect transaction not tracked")
}
create := NewCreateStreamPacket()
if err := writer.WritePacket(ctx, create, 0); err != nil {
t.Fatalf("WritePacket create err=%v", err)
}
call := NewCallPacket()
call.CommandName = commandReleaseStream
call.TransactionID = 3
call.CommandObject = NewAmf0Null()
if err := writer.WritePacket(ctx, call, 0); err != nil {
t.Fatalf("WritePacket call err=%v", err)
}
reader := NewProtocol(&wire)
var gotConnect *ConnectAppPacket
if _, err := ExpectPacket(ctx, reader, &gotConnect); err != nil || gotConnect.TcUrl() != "rtmp://host/live" {
t.Fatalf("gotConnect=%v err=%v", gotConnect, err)
}
var gotCreate *CallPacket
if _, err := ExpectPacket(ctx, reader, &gotCreate); err != nil || gotCreate.CommandName != commandCreateStream {
t.Fatalf("gotCreate=%v err=%v", gotCreate, err)
}
decoder := NewProtocol(&bytes.Buffer{}).(*protocol)
decoder.input.transactions[1] = commandConnect
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, NewConnectAppResPacket(1))); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(NewConnectAppResPacket(1)) {
t.Fatalf("connect res pkt=%T err=%v", pkt, err)
}
decoder.input.transactions[2] = commandCreateStream
csr := NewCreateStreamResPacket(2)
csr.SetStreamID(99)
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, csr)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(csr) {
t.Fatalf("create res pkt=%T err=%v", pkt, err)
}
decoder.input.transactions[3] = commandReleaseStream
res := NewCallPacket()
res.CommandName = commandResult
res.TransactionID = 3
res.CommandObject = NewAmf0Null()
if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, res)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(res) {
t.Fatalf("call res pkt=%T err=%v", pkt, err)
}
for _, name := range []amf0String{commandPublish, commandPlay, commandOnStatus} {
pkt := NewCallPacket()
pkt.CommandName = name
pkt.TransactionID = 0
pkt.CommandObject = NewAmf0Null()
if name == commandPublish {
pub := NewPublishPacket()
pub.TransactionID = 0
pub.StreamName = NewAmf0String("s")
pub.StreamType = NewAmf0String("live")
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pub)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pub) {
t.Fatalf("publish decoded=%T err=%v", decoded, err)
}
continue
}
if name == commandPlay {
play := NewPlayPacket()
play.TransactionID = 0
play.StreamName = NewAmf0String("s")
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, play)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(play) {
t.Fatalf("play decoded=%T err=%v", decoded, err)
}
continue
}
if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pkt)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pkt) {
t.Fatalf("call decoded=%T err=%v", decoded, err)
}
}
decoder.input.transactions[9] = commandPause
errPkt := NewCallPacket()
errPkt.CommandName = commandError
errPkt.TransactionID = 9
errPkt.CommandObject = NewAmf0Null()
if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No request") {
t.Fatalf("unknown request err=%v", err)
}
if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No matched request") {
t.Fatalf("missing transaction err=%v", err)
}
if _, err := decoder.parseAMFObject([]byte{byte(amf0MarkerString), 0, 8, 'c'}); err == nil {
t.Fatal("bad AMF parse should fail")
}
cctx, cancel := context.WithCancel(ctx)
cancel()
if err := writer.WritePacket(cctx, connect, 0); err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
t.Fatalf("WritePacket canceled err=%v", err)
}
}
func TestDeprecatedExpectPacketPanics(t *testing.T) {
defer func() {
if recover() == nil {
t.Fatal("Expected panic")
}
}()
NewProtocol(&bytes.Buffer{}).ExpectPacket(context.Background(), nil)
}
func TestPacketRoundTripsAndErrors(t *testing.T) {
packets := []Packet{
NewConnectAppPacket(),
NewConnectAppResPacket(7),
NewCallPacket(),
NewCreateStreamPacket(),
func() Packet { p := NewCreateStreamResPacket(2); p.SetStreamID(1); return p }(),
func() Packet {
p := NewPublishPacket()
p.TransactionID = 0
p.StreamName = NewAmf0String("s")
return p
}(),
func() Packet { p := NewPlayPacket(); p.TransactionID = 0; p.StreamName = NewAmf0String("s"); return p }(),
&SetChunkSize{ChunkSize: 1},
&WindowAcknowledgementSize{AckSize: 2},
&SetPeerBandwidth{Bandwidth: 3, LimitType: LimitTypeDynamic},
&UserControl{EventType: EventTypeFmsEvent0, EventData: 1},
&UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 2},
}
// Initialize the generic call packet so it is marshalable.
packets[2].(*CallPacket).CommandName = commandOnStatus
packets[2].(*CallPacket).TransactionID = 0
packets[2].(*CallPacket).CommandObject = NewAmf0Null()
packets[2].(*CallPacket).Args = NewAmf0Object().Set("code", NewAmf0String("ok"))
packets[1].(*ConnectAppResPacket).Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid")))
for _, pkt := range packets {
t.Run(reflect.TypeOf(pkt).String(), func(t *testing.T) {
data, err := pkt.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary err=%v", err)
}
if len(data) != pkt.Size() {
t.Fatalf("len=%v Size=%v", len(data), pkt.Size())
}
fresh := reflect.New(reflect.TypeOf(pkt).Elem()).Interface().(Packet)
switch v := fresh.(type) {
case *ConnectAppPacket:
*v = *NewConnectAppPacket()
case *ConnectAppResPacket:
*v = *NewConnectAppResPacket(0)
case *CreateStreamPacket:
*v = *NewCreateStreamPacket()
case *CreateStreamResPacket:
*v = *NewCreateStreamResPacket(0)
case *PublishPacket:
*v = *NewPublishPacket()
case *PlayPacket:
*v = *NewPlayPacket()
}
if err := fresh.UnmarshalBinary(data); err != nil {
t.Fatalf("UnmarshalBinary err=%v", err)
}
})
}
if packets[1].(*ConnectAppResPacket).SrsID() != "sid" || packets[2].(*CallPacket).ArgsCode() != "ok" {
t.Fatalf("packet helpers failed")
}
if NewConnectAppResPacket(1).SrsID() != "" || NewCallPacket().ArgsCode() != "" || NewConnectAppPacket().TcUrl() != "" {
t.Fatalf("empty helpers failed")
}
badConnect := NewConnectAppPacket()
badConnect.CommandName = commandPlay
if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil {
t.Fatal("bad connect name should fail")
}
badConnect = NewConnectAppPacket()
badConnect.TransactionID = 2
if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil {
t.Fatal("bad connect tid should fail")
}
badRes := NewConnectAppResPacket(1)
badRes.CommandName = commandPlay
if err := badRes.UnmarshalBinary(mustPacketBytes(t, badRes)); err == nil {
t.Fatal("bad connect response name should fail")
}
for _, pkt := range []Packet{NewConnectAppPacket(), NewCallPacket(), NewCreateStreamResPacket(1), NewPublishPacket(), NewPlayPacket()} {
if err := pkt.UnmarshalBinary([]byte{byte(amf0MarkerString)}); err == nil {
t.Fatalf("%T short unmarshal should fail", pkt)
}
}
for _, pkt := range []Packet{&SetChunkSize{}, &WindowAcknowledgementSize{}, &SetPeerBandwidth{}, &UserControl{}} {
if err := pkt.UnmarshalBinary([]byte{0, 1}); err == nil {
t.Fatalf("%T short unmarshal should fail", pkt)
}
}
uc := &UserControl{}
if err := uc.UnmarshalBinary([]byte{0, byte(EventTypeSetBufferLength), 1, 2, 3, 4, 5}); err == nil {
t.Fatal("short set-buffer-length should fail")
}
}
func mustPacketBytes(t *testing.T, pkt Packet) []byte {
t.Helper()
data, err := pkt.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary %T err=%v", pkt, err)
}
return data
}
type failPacket struct{}
func (failPacket) Size() int { return 0 }
func (failPacket) UnmarshalBinary([]byte) error { return io.ErrUnexpectedEOF }
func (failPacket) MarshalBinary() ([]byte, error) { return nil, io.ErrClosedPipe }
func (failPacket) BetterCid() chunkID { return chunkIDOverConnection }
func (failPacket) Type() MessageType { return MessageTypeAMF0Command }
type stepWriter struct {
writes int
failAt int
}
func (w *stepWriter) Write(p []byte) (int, error) {
w.writes++
if w.writes == w.failAt {
return 0, io.ErrClosedPipe
}
return len(p), nil
}
func TestProtocolAdditionalBranches(t *testing.T) {
ctx := context.Background()
p := NewProtocol(&bytes.Buffer{}).(*protocol)
if NewSetPeerBandwidth().BetterCid() != chunkIDProtocolControl || NewUserControl().BetterCid() != chunkIDProtocolControl {
t.Fatal("control better cid failed")
}
if err := p.WritePacket(ctx, failPacket{}, 0); err == nil || !strings.Contains(err.Error(), "marshal payload") {
t.Fatalf("WritePacket marshal err=%v", err)
}
payloadWriter := &stepWriter{failAt: 1}
p = NewProtocol(struct {
io.Reader
io.Writer
}{bytes.NewReader(nil), payloadWriter}).(*protocol)
m := NewStreamMessage(1).asMessage()
m.messageHeader.MessageType = MessageTypeVideo
m.payload = bytes.Repeat([]byte{1}, 5000)
if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "write chunk payload") {
t.Fatalf("WriteMessage payload err=%v", err)
}
flushWriter := &stepWriter{failAt: 1}
p = NewProtocol(struct {
io.Reader
io.Writer
}{bytes.NewReader(nil), flushWriter}).(*protocol)
m.payload = []byte{1}
if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "flush writer") {
t.Fatalf("WriteMessage flush err=%v writes=%v", err, flushWriter.writes)
}
// Zero-length payload returns a complete message without reading chunk bytes.
in := bytes.NewBuffer([]byte{0x05, 0, 0, 1, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0})
p = NewProtocol(in).(*protocol)
if msg, err := p.ReadMessage(ctx); err != nil || msg.MessageType() != MessageTypeAudio || len(msg.Payload()) != 0 {
t.Fatalf("zero payload msg=%v err=%v", msg, err)
}
// ExpectMessage skips unwanted message types before returning the desired one.
var wire bytes.Buffer
writer := NewProtocol(&wire).(*protocol)
am := NewStreamMessage(1).asMessage()
am.messageHeader.MessageType = MessageTypeAudio
am.payload = []byte{1}
vm := NewStreamMessage(1).asMessage()
vm.messageHeader.MessageType = MessageTypeVideo
vm.payload = []byte{2}
if err := writer.WriteMessage(ctx, am); err != nil {
t.Fatal(err)
}
if err := writer.WriteMessage(ctx, vm); err != nil {
t.Fatal(err)
}
reader := NewProtocol(&wire)
if got, err := reader.ExpectMessage(ctx, MessageTypeVideo); err != nil || got.MessageType() != MessageTypeVideo {
t.Fatalf("ExpectMessage got=%v err=%v", got, err)
}
// Generic ExpectPacket skips non-matching packets, then returns matching; it also reports decode/read errors.
wire.Reset()
writer = NewProtocol(&wire).(*protocol)
if err := writer.WritePacket(ctx, &WindowAcknowledgementSize{AckSize: 1}, 0); err != nil {
t.Fatal(err)
}
if err := writer.WritePacket(ctx, &SetChunkSize{ChunkSize: 2}, 0); err != nil {
t.Fatal(err)
}
reader = NewProtocol(&wire)
var chunk *SetChunkSize
if _, err := ExpectPacket(ctx, reader, &chunk); err != nil || chunk.ChunkSize != 2 {
t.Fatalf("ExpectPacket chunk=%v err=%v", chunk, err)
}
reader = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 1, byte(MessageTypeSetChunkSize), 1, 0, 0, 0, 0}))
if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "decode message") {
t.Fatalf("ExpectPacket decode err=%v", err)
}
reader = NewProtocol(bytes.NewBuffer(nil))
if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "read message") {
t.Fatalf("ExpectPacket read err=%v", err)
}
// AMF3 strips the leading byte before AMF0 decoding.
pub := NewPublishPacket()
pub.TransactionID = 0
pub.StreamName = NewAmf0String("stream")
data := append([]byte{0}, mustPacketBytes(t, pub)...)
msg := &message{payload: data}
msg.messageHeader.MessageType = MessageTypeAMF3Command
if pkt, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(msg); err != nil || pkt.(*PublishPacket).StreamName.String() != "stream" {
t.Fatalf("AMF3 decode pkt=%T err=%v", pkt, err)
}
}
func TestProtocolErrorBranchesForCoverage(t *testing.T) {
ctx := context.Background()
// ExpectMessage without requested types returns the first message.
var wire bytes.Buffer
w := NewProtocol(&wire).(*protocol)
msg := NewStreamMessage(1).asMessage()
msg.messageHeader.MessageType = MessageTypeAudio
msg.payload = []byte{1}
if err := w.WriteMessage(ctx, msg); err != nil {
t.Fatal(err)
}
if got, err := NewProtocol(&wire).ExpectMessage(ctx); err != nil || got.MessageType() != MessageTypeAudio {
t.Fatalf("ExpectMessage any got=%v err=%v", got, err)
}
cctx, cancel := context.WithCancel(ctx)
cancel()
if _, err := NewProtocol(bytes.NewBuffer(nil)).ReadMessage(cctx); err != context.Canceled {
t.Fatalf("ReadMessage canceled err=%v", err)
}
if err := w.WriteMessage(cctx, (&message{})); err != context.Canceled {
t.Fatalf("WriteMessage empty canceled err=%v", err)
}
badAMF := &message{payload: []byte{0xff}}
badAMF.messageHeader.MessageType = MessageTypeAMF0Command
if _, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(badAMF); err == nil || !strings.Contains(err.Error(), "Parse AMF") {
t.Fatalf("bad AMF decode err=%v", err)
}
rn := commandResult
resultName, _ := (&rn).MarshalBinary()
if _, err := NewProtocol(&bytes.Buffer{}).(*protocol).parseAMFObject(append(resultName, 0)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
t.Fatalf("bad result tid err=%v", err)
}
}
func TestPacketUnmarshalErrorBranchesForCoverage(t *testing.T) {
cn := commandConnect
name, _ := (&cn).MarshalBinary()
tn := amf0Number(1)
tid, _ := (&tn).MarshalBinary()
obj, _ := NewAmf0Object().MarshalBinary()
base := append(append([]byte{}, name...), tid...)
oc := &objectCallPacket{CommandObject: NewAmf0Object()}
if err := oc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
t.Fatalf("object tid err=%v", err)
}
if err := oc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal command") {
t.Fatalf("object command err=%v", err)
}
withObj := append(append([]byte{}, base...), obj...)
if err := oc.UnmarshalBinary(append(withObj, 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal args") {
t.Fatalf("object args err=%v", err)
}
vc := &variantCallPacket{}
if err := vc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") {
t.Fatalf("variant tid err=%v", err)
}
if err := vc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "discovery command object") {
t.Fatalf("variant discovery err=%v", err)
}
if err := vc.UnmarshalBinary(append(append([]byte{}, base...), byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal command object") {
t.Fatalf("variant command object err=%v", err)
}
call := NewCallPacket()
call.CommandName = commandOnStatus
call.TransactionID = 0
call.CommandObject = NewAmf0Null()
callBase := mustPacketBytes(t, call)
if err := NewCallPacket().UnmarshalBinary(append(callBase, 0xff)); err == nil || !strings.Contains(err.Error(), "discovery args") {
t.Fatalf("call discovery args err=%v", err)
}
if err := NewCallPacket().UnmarshalBinary(append(callBase, byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal args") {
t.Fatalf("call unmarshal args err=%v", err)
}
csr := NewCreateStreamResPacket(2)
if err := NewCreateStreamResPacket(0).UnmarshalBinary(mustPacketBytes(t, &csr.variantCallPacket)); err == nil || !strings.Contains(err.Error(), "unmarshal sid") {
t.Fatalf("create stream sid err=%v", err)
}
pub := NewPublishPacket()
pub.TransactionID = 0
pubPrefix, _ := pub.variantCallPacket.MarshalBinary()
if err := NewPublishPacket().UnmarshalBinary(append(pubPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") {
t.Fatalf("publish stream name err=%v", err)
}
streamName, _ := NewAmf0String("s").MarshalBinary()
if err := NewPublishPacket().UnmarshalBinary(append(append(pubPrefix, streamName...), 0xff)); err == nil || !strings.Contains(err.Error(), "stream type") {
t.Fatalf("publish stream type err=%v", err)
}
play := NewPlayPacket()
play.TransactionID = 0
playPrefix, _ := play.variantCallPacket.MarshalBinary()
if err := NewPlayPacket().UnmarshalBinary(append(playPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") {
t.Fatalf("play stream name err=%v", err)
}
}

View File

@ -0,0 +1,554 @@
// Code generated by counterfeiter. DO NOT EDIT.
package rtmpfakes
import (
"io"
"srsx/internal/rtmp"
"sync"
)
type FakeHandshake struct {
C1S1Stub func() []byte
c1S1Mutex sync.RWMutex
c1S1ArgsForCall []struct {
}
c1S1Returns struct {
result1 []byte
}
c1S1ReturnsOnCall map[int]struct {
result1 []byte
}
ReadC0S0Stub func(io.Reader) ([]byte, error)
readC0S0Mutex sync.RWMutex
readC0S0ArgsForCall []struct {
arg1 io.Reader
}
readC0S0Returns struct {
result1 []byte
result2 error
}
readC0S0ReturnsOnCall map[int]struct {
result1 []byte
result2 error
}
ReadC1S1Stub func(io.Reader) ([]byte, error)
readC1S1Mutex sync.RWMutex
readC1S1ArgsForCall []struct {
arg1 io.Reader
}
readC1S1Returns struct {
result1 []byte
result2 error
}
readC1S1ReturnsOnCall map[int]struct {
result1 []byte
result2 error
}
ReadC2S2Stub func(io.Reader) ([]byte, error)
readC2S2Mutex sync.RWMutex
readC2S2ArgsForCall []struct {
arg1 io.Reader
}
readC2S2Returns struct {
result1 []byte
result2 error
}
readC2S2ReturnsOnCall map[int]struct {
result1 []byte
result2 error
}
WriteC0S0Stub func(io.Writer) error
writeC0S0Mutex sync.RWMutex
writeC0S0ArgsForCall []struct {
arg1 io.Writer
}
writeC0S0Returns struct {
result1 error
}
writeC0S0ReturnsOnCall map[int]struct {
result1 error
}
WriteC1S1Stub func(io.Writer) error
writeC1S1Mutex sync.RWMutex
writeC1S1ArgsForCall []struct {
arg1 io.Writer
}
writeC1S1Returns struct {
result1 error
}
writeC1S1ReturnsOnCall map[int]struct {
result1 error
}
WriteC2S2Stub func(io.Writer, []byte) error
writeC2S2Mutex sync.RWMutex
writeC2S2ArgsForCall []struct {
arg1 io.Writer
arg2 []byte
}
writeC2S2Returns struct {
result1 error
}
writeC2S2ReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeHandshake) C1S1() []byte {
fake.c1S1Mutex.Lock()
ret, specificReturn := fake.c1S1ReturnsOnCall[len(fake.c1S1ArgsForCall)]
fake.c1S1ArgsForCall = append(fake.c1S1ArgsForCall, struct {
}{})
stub := fake.C1S1Stub
fakeReturns := fake.c1S1Returns
fake.recordInvocation("C1S1", []interface{}{})
fake.c1S1Mutex.Unlock()
if stub != nil {
return stub()
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeHandshake) C1S1CallCount() int {
fake.c1S1Mutex.RLock()
defer fake.c1S1Mutex.RUnlock()
return len(fake.c1S1ArgsForCall)
}
func (fake *FakeHandshake) C1S1Calls(stub func() []byte) {
fake.c1S1Mutex.Lock()
defer fake.c1S1Mutex.Unlock()
fake.C1S1Stub = stub
}
func (fake *FakeHandshake) C1S1Returns(result1 []byte) {
fake.c1S1Mutex.Lock()
defer fake.c1S1Mutex.Unlock()
fake.C1S1Stub = nil
fake.c1S1Returns = struct {
result1 []byte
}{result1}
}
func (fake *FakeHandshake) C1S1ReturnsOnCall(i int, result1 []byte) {
fake.c1S1Mutex.Lock()
defer fake.c1S1Mutex.Unlock()
fake.C1S1Stub = nil
if fake.c1S1ReturnsOnCall == nil {
fake.c1S1ReturnsOnCall = make(map[int]struct {
result1 []byte
})
}
fake.c1S1ReturnsOnCall[i] = struct {
result1 []byte
}{result1}
}
func (fake *FakeHandshake) ReadC0S0(arg1 io.Reader) ([]byte, error) {
fake.readC0S0Mutex.Lock()
ret, specificReturn := fake.readC0S0ReturnsOnCall[len(fake.readC0S0ArgsForCall)]
fake.readC0S0ArgsForCall = append(fake.readC0S0ArgsForCall, struct {
arg1 io.Reader
}{arg1})
stub := fake.ReadC0S0Stub
fakeReturns := fake.readC0S0Returns
fake.recordInvocation("ReadC0S0", []interface{}{arg1})
fake.readC0S0Mutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeHandshake) ReadC0S0CallCount() int {
fake.readC0S0Mutex.RLock()
defer fake.readC0S0Mutex.RUnlock()
return len(fake.readC0S0ArgsForCall)
}
func (fake *FakeHandshake) ReadC0S0Calls(stub func(io.Reader) ([]byte, error)) {
fake.readC0S0Mutex.Lock()
defer fake.readC0S0Mutex.Unlock()
fake.ReadC0S0Stub = stub
}
func (fake *FakeHandshake) ReadC0S0ArgsForCall(i int) io.Reader {
fake.readC0S0Mutex.RLock()
defer fake.readC0S0Mutex.RUnlock()
argsForCall := fake.readC0S0ArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeHandshake) ReadC0S0Returns(result1 []byte, result2 error) {
fake.readC0S0Mutex.Lock()
defer fake.readC0S0Mutex.Unlock()
fake.ReadC0S0Stub = nil
fake.readC0S0Returns = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeHandshake) ReadC0S0ReturnsOnCall(i int, result1 []byte, result2 error) {
fake.readC0S0Mutex.Lock()
defer fake.readC0S0Mutex.Unlock()
fake.ReadC0S0Stub = nil
if fake.readC0S0ReturnsOnCall == nil {
fake.readC0S0ReturnsOnCall = make(map[int]struct {
result1 []byte
result2 error
})
}
fake.readC0S0ReturnsOnCall[i] = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeHandshake) ReadC1S1(arg1 io.Reader) ([]byte, error) {
fake.readC1S1Mutex.Lock()
ret, specificReturn := fake.readC1S1ReturnsOnCall[len(fake.readC1S1ArgsForCall)]
fake.readC1S1ArgsForCall = append(fake.readC1S1ArgsForCall, struct {
arg1 io.Reader
}{arg1})
stub := fake.ReadC1S1Stub
fakeReturns := fake.readC1S1Returns
fake.recordInvocation("ReadC1S1", []interface{}{arg1})
fake.readC1S1Mutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeHandshake) ReadC1S1CallCount() int {
fake.readC1S1Mutex.RLock()
defer fake.readC1S1Mutex.RUnlock()
return len(fake.readC1S1ArgsForCall)
}
func (fake *FakeHandshake) ReadC1S1Calls(stub func(io.Reader) ([]byte, error)) {
fake.readC1S1Mutex.Lock()
defer fake.readC1S1Mutex.Unlock()
fake.ReadC1S1Stub = stub
}
func (fake *FakeHandshake) ReadC1S1ArgsForCall(i int) io.Reader {
fake.readC1S1Mutex.RLock()
defer fake.readC1S1Mutex.RUnlock()
argsForCall := fake.readC1S1ArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeHandshake) ReadC1S1Returns(result1 []byte, result2 error) {
fake.readC1S1Mutex.Lock()
defer fake.readC1S1Mutex.Unlock()
fake.ReadC1S1Stub = nil
fake.readC1S1Returns = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeHandshake) ReadC1S1ReturnsOnCall(i int, result1 []byte, result2 error) {
fake.readC1S1Mutex.Lock()
defer fake.readC1S1Mutex.Unlock()
fake.ReadC1S1Stub = nil
if fake.readC1S1ReturnsOnCall == nil {
fake.readC1S1ReturnsOnCall = make(map[int]struct {
result1 []byte
result2 error
})
}
fake.readC1S1ReturnsOnCall[i] = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeHandshake) ReadC2S2(arg1 io.Reader) ([]byte, error) {
fake.readC2S2Mutex.Lock()
ret, specificReturn := fake.readC2S2ReturnsOnCall[len(fake.readC2S2ArgsForCall)]
fake.readC2S2ArgsForCall = append(fake.readC2S2ArgsForCall, struct {
arg1 io.Reader
}{arg1})
stub := fake.ReadC2S2Stub
fakeReturns := fake.readC2S2Returns
fake.recordInvocation("ReadC2S2", []interface{}{arg1})
fake.readC2S2Mutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeHandshake) ReadC2S2CallCount() int {
fake.readC2S2Mutex.RLock()
defer fake.readC2S2Mutex.RUnlock()
return len(fake.readC2S2ArgsForCall)
}
func (fake *FakeHandshake) ReadC2S2Calls(stub func(io.Reader) ([]byte, error)) {
fake.readC2S2Mutex.Lock()
defer fake.readC2S2Mutex.Unlock()
fake.ReadC2S2Stub = stub
}
func (fake *FakeHandshake) ReadC2S2ArgsForCall(i int) io.Reader {
fake.readC2S2Mutex.RLock()
defer fake.readC2S2Mutex.RUnlock()
argsForCall := fake.readC2S2ArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeHandshake) ReadC2S2Returns(result1 []byte, result2 error) {
fake.readC2S2Mutex.Lock()
defer fake.readC2S2Mutex.Unlock()
fake.ReadC2S2Stub = nil
fake.readC2S2Returns = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeHandshake) ReadC2S2ReturnsOnCall(i int, result1 []byte, result2 error) {
fake.readC2S2Mutex.Lock()
defer fake.readC2S2Mutex.Unlock()
fake.ReadC2S2Stub = nil
if fake.readC2S2ReturnsOnCall == nil {
fake.readC2S2ReturnsOnCall = make(map[int]struct {
result1 []byte
result2 error
})
}
fake.readC2S2ReturnsOnCall[i] = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeHandshake) WriteC0S0(arg1 io.Writer) error {
fake.writeC0S0Mutex.Lock()
ret, specificReturn := fake.writeC0S0ReturnsOnCall[len(fake.writeC0S0ArgsForCall)]
fake.writeC0S0ArgsForCall = append(fake.writeC0S0ArgsForCall, struct {
arg1 io.Writer
}{arg1})
stub := fake.WriteC0S0Stub
fakeReturns := fake.writeC0S0Returns
fake.recordInvocation("WriteC0S0", []interface{}{arg1})
fake.writeC0S0Mutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeHandshake) WriteC0S0CallCount() int {
fake.writeC0S0Mutex.RLock()
defer fake.writeC0S0Mutex.RUnlock()
return len(fake.writeC0S0ArgsForCall)
}
func (fake *FakeHandshake) WriteC0S0Calls(stub func(io.Writer) error) {
fake.writeC0S0Mutex.Lock()
defer fake.writeC0S0Mutex.Unlock()
fake.WriteC0S0Stub = stub
}
func (fake *FakeHandshake) WriteC0S0ArgsForCall(i int) io.Writer {
fake.writeC0S0Mutex.RLock()
defer fake.writeC0S0Mutex.RUnlock()
argsForCall := fake.writeC0S0ArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeHandshake) WriteC0S0Returns(result1 error) {
fake.writeC0S0Mutex.Lock()
defer fake.writeC0S0Mutex.Unlock()
fake.WriteC0S0Stub = nil
fake.writeC0S0Returns = struct {
result1 error
}{result1}
}
func (fake *FakeHandshake) WriteC0S0ReturnsOnCall(i int, result1 error) {
fake.writeC0S0Mutex.Lock()
defer fake.writeC0S0Mutex.Unlock()
fake.WriteC0S0Stub = nil
if fake.writeC0S0ReturnsOnCall == nil {
fake.writeC0S0ReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.writeC0S0ReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeHandshake) WriteC1S1(arg1 io.Writer) error {
fake.writeC1S1Mutex.Lock()
ret, specificReturn := fake.writeC1S1ReturnsOnCall[len(fake.writeC1S1ArgsForCall)]
fake.writeC1S1ArgsForCall = append(fake.writeC1S1ArgsForCall, struct {
arg1 io.Writer
}{arg1})
stub := fake.WriteC1S1Stub
fakeReturns := fake.writeC1S1Returns
fake.recordInvocation("WriteC1S1", []interface{}{arg1})
fake.writeC1S1Mutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeHandshake) WriteC1S1CallCount() int {
fake.writeC1S1Mutex.RLock()
defer fake.writeC1S1Mutex.RUnlock()
return len(fake.writeC1S1ArgsForCall)
}
func (fake *FakeHandshake) WriteC1S1Calls(stub func(io.Writer) error) {
fake.writeC1S1Mutex.Lock()
defer fake.writeC1S1Mutex.Unlock()
fake.WriteC1S1Stub = stub
}
func (fake *FakeHandshake) WriteC1S1ArgsForCall(i int) io.Writer {
fake.writeC1S1Mutex.RLock()
defer fake.writeC1S1Mutex.RUnlock()
argsForCall := fake.writeC1S1ArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeHandshake) WriteC1S1Returns(result1 error) {
fake.writeC1S1Mutex.Lock()
defer fake.writeC1S1Mutex.Unlock()
fake.WriteC1S1Stub = nil
fake.writeC1S1Returns = struct {
result1 error
}{result1}
}
func (fake *FakeHandshake) WriteC1S1ReturnsOnCall(i int, result1 error) {
fake.writeC1S1Mutex.Lock()
defer fake.writeC1S1Mutex.Unlock()
fake.WriteC1S1Stub = nil
if fake.writeC1S1ReturnsOnCall == nil {
fake.writeC1S1ReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.writeC1S1ReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeHandshake) WriteC2S2(arg1 io.Writer, arg2 []byte) error {
var arg2Copy []byte
if arg2 != nil {
arg2Copy = make([]byte, len(arg2))
copy(arg2Copy, arg2)
}
fake.writeC2S2Mutex.Lock()
ret, specificReturn := fake.writeC2S2ReturnsOnCall[len(fake.writeC2S2ArgsForCall)]
fake.writeC2S2ArgsForCall = append(fake.writeC2S2ArgsForCall, struct {
arg1 io.Writer
arg2 []byte
}{arg1, arg2Copy})
stub := fake.WriteC2S2Stub
fakeReturns := fake.writeC2S2Returns
fake.recordInvocation("WriteC2S2", []interface{}{arg1, arg2Copy})
fake.writeC2S2Mutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeHandshake) WriteC2S2CallCount() int {
fake.writeC2S2Mutex.RLock()
defer fake.writeC2S2Mutex.RUnlock()
return len(fake.writeC2S2ArgsForCall)
}
func (fake *FakeHandshake) WriteC2S2Calls(stub func(io.Writer, []byte) error) {
fake.writeC2S2Mutex.Lock()
defer fake.writeC2S2Mutex.Unlock()
fake.WriteC2S2Stub = stub
}
func (fake *FakeHandshake) WriteC2S2ArgsForCall(i int) (io.Writer, []byte) {
fake.writeC2S2Mutex.RLock()
defer fake.writeC2S2Mutex.RUnlock()
argsForCall := fake.writeC2S2ArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeHandshake) WriteC2S2Returns(result1 error) {
fake.writeC2S2Mutex.Lock()
defer fake.writeC2S2Mutex.Unlock()
fake.WriteC2S2Stub = nil
fake.writeC2S2Returns = struct {
result1 error
}{result1}
}
func (fake *FakeHandshake) WriteC2S2ReturnsOnCall(i int, result1 error) {
fake.writeC2S2Mutex.Lock()
defer fake.writeC2S2Mutex.Unlock()
fake.WriteC2S2Stub = nil
if fake.writeC2S2ReturnsOnCall == nil {
fake.writeC2S2ReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.writeC2S2ReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeHandshake) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeHandshake) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ rtmp.Handshake = new(FakeHandshake)

View File

@ -0,0 +1,499 @@
// Code generated by counterfeiter. DO NOT EDIT.
package rtmpfakes
import (
"context"
"srsx/internal/rtmp"
"sync"
)
type FakeProtocol struct {
DecodeMessageStub func(rtmp.Message) (rtmp.Packet, error)
decodeMessageMutex sync.RWMutex
decodeMessageArgsForCall []struct {
arg1 rtmp.Message
}
decodeMessageReturns struct {
result1 rtmp.Packet
result2 error
}
decodeMessageReturnsOnCall map[int]struct {
result1 rtmp.Packet
result2 error
}
ExpectMessageStub func(context.Context, ...rtmp.MessageType) (rtmp.Message, error)
expectMessageMutex sync.RWMutex
expectMessageArgsForCall []struct {
arg1 context.Context
arg2 []rtmp.MessageType
}
expectMessageReturns struct {
result1 rtmp.Message
result2 error
}
expectMessageReturnsOnCall map[int]struct {
result1 rtmp.Message
result2 error
}
ExpectPacketStub func(context.Context, any) (rtmp.Message, error)
expectPacketMutex sync.RWMutex
expectPacketArgsForCall []struct {
arg1 context.Context
arg2 any
}
expectPacketReturns struct {
result1 rtmp.Message
result2 error
}
expectPacketReturnsOnCall map[int]struct {
result1 rtmp.Message
result2 error
}
ReadMessageStub func(context.Context) (rtmp.Message, error)
readMessageMutex sync.RWMutex
readMessageArgsForCall []struct {
arg1 context.Context
}
readMessageReturns struct {
result1 rtmp.Message
result2 error
}
readMessageReturnsOnCall map[int]struct {
result1 rtmp.Message
result2 error
}
WriteMessageStub func(context.Context, rtmp.Message) error
writeMessageMutex sync.RWMutex
writeMessageArgsForCall []struct {
arg1 context.Context
arg2 rtmp.Message
}
writeMessageReturns struct {
result1 error
}
writeMessageReturnsOnCall map[int]struct {
result1 error
}
WritePacketStub func(context.Context, rtmp.Packet, int) error
writePacketMutex sync.RWMutex
writePacketArgsForCall []struct {
arg1 context.Context
arg2 rtmp.Packet
arg3 int
}
writePacketReturns struct {
result1 error
}
writePacketReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeProtocol) DecodeMessage(arg1 rtmp.Message) (rtmp.Packet, error) {
fake.decodeMessageMutex.Lock()
ret, specificReturn := fake.decodeMessageReturnsOnCall[len(fake.decodeMessageArgsForCall)]
fake.decodeMessageArgsForCall = append(fake.decodeMessageArgsForCall, struct {
arg1 rtmp.Message
}{arg1})
stub := fake.DecodeMessageStub
fakeReturns := fake.decodeMessageReturns
fake.recordInvocation("DecodeMessage", []interface{}{arg1})
fake.decodeMessageMutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeProtocol) DecodeMessageCallCount() int {
fake.decodeMessageMutex.RLock()
defer fake.decodeMessageMutex.RUnlock()
return len(fake.decodeMessageArgsForCall)
}
func (fake *FakeProtocol) DecodeMessageCalls(stub func(rtmp.Message) (rtmp.Packet, error)) {
fake.decodeMessageMutex.Lock()
defer fake.decodeMessageMutex.Unlock()
fake.DecodeMessageStub = stub
}
func (fake *FakeProtocol) DecodeMessageArgsForCall(i int) rtmp.Message {
fake.decodeMessageMutex.RLock()
defer fake.decodeMessageMutex.RUnlock()
argsForCall := fake.decodeMessageArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeProtocol) DecodeMessageReturns(result1 rtmp.Packet, result2 error) {
fake.decodeMessageMutex.Lock()
defer fake.decodeMessageMutex.Unlock()
fake.DecodeMessageStub = nil
fake.decodeMessageReturns = struct {
result1 rtmp.Packet
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) DecodeMessageReturnsOnCall(i int, result1 rtmp.Packet, result2 error) {
fake.decodeMessageMutex.Lock()
defer fake.decodeMessageMutex.Unlock()
fake.DecodeMessageStub = nil
if fake.decodeMessageReturnsOnCall == nil {
fake.decodeMessageReturnsOnCall = make(map[int]struct {
result1 rtmp.Packet
result2 error
})
}
fake.decodeMessageReturnsOnCall[i] = struct {
result1 rtmp.Packet
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) ExpectMessage(arg1 context.Context, arg2 ...rtmp.MessageType) (rtmp.Message, error) {
fake.expectMessageMutex.Lock()
ret, specificReturn := fake.expectMessageReturnsOnCall[len(fake.expectMessageArgsForCall)]
fake.expectMessageArgsForCall = append(fake.expectMessageArgsForCall, struct {
arg1 context.Context
arg2 []rtmp.MessageType
}{arg1, arg2})
stub := fake.ExpectMessageStub
fakeReturns := fake.expectMessageReturns
fake.recordInvocation("ExpectMessage", []interface{}{arg1, arg2})
fake.expectMessageMutex.Unlock()
if stub != nil {
return stub(arg1, arg2...)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeProtocol) ExpectMessageCallCount() int {
fake.expectMessageMutex.RLock()
defer fake.expectMessageMutex.RUnlock()
return len(fake.expectMessageArgsForCall)
}
func (fake *FakeProtocol) ExpectMessageCalls(stub func(context.Context, ...rtmp.MessageType) (rtmp.Message, error)) {
fake.expectMessageMutex.Lock()
defer fake.expectMessageMutex.Unlock()
fake.ExpectMessageStub = stub
}
func (fake *FakeProtocol) ExpectMessageArgsForCall(i int) (context.Context, []rtmp.MessageType) {
fake.expectMessageMutex.RLock()
defer fake.expectMessageMutex.RUnlock()
argsForCall := fake.expectMessageArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeProtocol) ExpectMessageReturns(result1 rtmp.Message, result2 error) {
fake.expectMessageMutex.Lock()
defer fake.expectMessageMutex.Unlock()
fake.ExpectMessageStub = nil
fake.expectMessageReturns = struct {
result1 rtmp.Message
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) ExpectMessageReturnsOnCall(i int, result1 rtmp.Message, result2 error) {
fake.expectMessageMutex.Lock()
defer fake.expectMessageMutex.Unlock()
fake.ExpectMessageStub = nil
if fake.expectMessageReturnsOnCall == nil {
fake.expectMessageReturnsOnCall = make(map[int]struct {
result1 rtmp.Message
result2 error
})
}
fake.expectMessageReturnsOnCall[i] = struct {
result1 rtmp.Message
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) ExpectPacket(arg1 context.Context, arg2 any) (rtmp.Message, error) {
fake.expectPacketMutex.Lock()
ret, specificReturn := fake.expectPacketReturnsOnCall[len(fake.expectPacketArgsForCall)]
fake.expectPacketArgsForCall = append(fake.expectPacketArgsForCall, struct {
arg1 context.Context
arg2 any
}{arg1, arg2})
stub := fake.ExpectPacketStub
fakeReturns := fake.expectPacketReturns
fake.recordInvocation("ExpectPacket", []interface{}{arg1, arg2})
fake.expectPacketMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeProtocol) ExpectPacketCallCount() int {
fake.expectPacketMutex.RLock()
defer fake.expectPacketMutex.RUnlock()
return len(fake.expectPacketArgsForCall)
}
func (fake *FakeProtocol) ExpectPacketCalls(stub func(context.Context, any) (rtmp.Message, error)) {
fake.expectPacketMutex.Lock()
defer fake.expectPacketMutex.Unlock()
fake.ExpectPacketStub = stub
}
func (fake *FakeProtocol) ExpectPacketArgsForCall(i int) (context.Context, any) {
fake.expectPacketMutex.RLock()
defer fake.expectPacketMutex.RUnlock()
argsForCall := fake.expectPacketArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeProtocol) ExpectPacketReturns(result1 rtmp.Message, result2 error) {
fake.expectPacketMutex.Lock()
defer fake.expectPacketMutex.Unlock()
fake.ExpectPacketStub = nil
fake.expectPacketReturns = struct {
result1 rtmp.Message
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) ExpectPacketReturnsOnCall(i int, result1 rtmp.Message, result2 error) {
fake.expectPacketMutex.Lock()
defer fake.expectPacketMutex.Unlock()
fake.ExpectPacketStub = nil
if fake.expectPacketReturnsOnCall == nil {
fake.expectPacketReturnsOnCall = make(map[int]struct {
result1 rtmp.Message
result2 error
})
}
fake.expectPacketReturnsOnCall[i] = struct {
result1 rtmp.Message
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) ReadMessage(arg1 context.Context) (rtmp.Message, error) {
fake.readMessageMutex.Lock()
ret, specificReturn := fake.readMessageReturnsOnCall[len(fake.readMessageArgsForCall)]
fake.readMessageArgsForCall = append(fake.readMessageArgsForCall, struct {
arg1 context.Context
}{arg1})
stub := fake.ReadMessageStub
fakeReturns := fake.readMessageReturns
fake.recordInvocation("ReadMessage", []interface{}{arg1})
fake.readMessageMutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeProtocol) ReadMessageCallCount() int {
fake.readMessageMutex.RLock()
defer fake.readMessageMutex.RUnlock()
return len(fake.readMessageArgsForCall)
}
func (fake *FakeProtocol) ReadMessageCalls(stub func(context.Context) (rtmp.Message, error)) {
fake.readMessageMutex.Lock()
defer fake.readMessageMutex.Unlock()
fake.ReadMessageStub = stub
}
func (fake *FakeProtocol) ReadMessageArgsForCall(i int) context.Context {
fake.readMessageMutex.RLock()
defer fake.readMessageMutex.RUnlock()
argsForCall := fake.readMessageArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeProtocol) ReadMessageReturns(result1 rtmp.Message, result2 error) {
fake.readMessageMutex.Lock()
defer fake.readMessageMutex.Unlock()
fake.ReadMessageStub = nil
fake.readMessageReturns = struct {
result1 rtmp.Message
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) ReadMessageReturnsOnCall(i int, result1 rtmp.Message, result2 error) {
fake.readMessageMutex.Lock()
defer fake.readMessageMutex.Unlock()
fake.ReadMessageStub = nil
if fake.readMessageReturnsOnCall == nil {
fake.readMessageReturnsOnCall = make(map[int]struct {
result1 rtmp.Message
result2 error
})
}
fake.readMessageReturnsOnCall[i] = struct {
result1 rtmp.Message
result2 error
}{result1, result2}
}
func (fake *FakeProtocol) WriteMessage(arg1 context.Context, arg2 rtmp.Message) error {
fake.writeMessageMutex.Lock()
ret, specificReturn := fake.writeMessageReturnsOnCall[len(fake.writeMessageArgsForCall)]
fake.writeMessageArgsForCall = append(fake.writeMessageArgsForCall, struct {
arg1 context.Context
arg2 rtmp.Message
}{arg1, arg2})
stub := fake.WriteMessageStub
fakeReturns := fake.writeMessageReturns
fake.recordInvocation("WriteMessage", []interface{}{arg1, arg2})
fake.writeMessageMutex.Unlock()
if stub != nil {
return stub(arg1, arg2)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeProtocol) WriteMessageCallCount() int {
fake.writeMessageMutex.RLock()
defer fake.writeMessageMutex.RUnlock()
return len(fake.writeMessageArgsForCall)
}
func (fake *FakeProtocol) WriteMessageCalls(stub func(context.Context, rtmp.Message) error) {
fake.writeMessageMutex.Lock()
defer fake.writeMessageMutex.Unlock()
fake.WriteMessageStub = stub
}
func (fake *FakeProtocol) WriteMessageArgsForCall(i int) (context.Context, rtmp.Message) {
fake.writeMessageMutex.RLock()
defer fake.writeMessageMutex.RUnlock()
argsForCall := fake.writeMessageArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeProtocol) WriteMessageReturns(result1 error) {
fake.writeMessageMutex.Lock()
defer fake.writeMessageMutex.Unlock()
fake.WriteMessageStub = nil
fake.writeMessageReturns = struct {
result1 error
}{result1}
}
func (fake *FakeProtocol) WriteMessageReturnsOnCall(i int, result1 error) {
fake.writeMessageMutex.Lock()
defer fake.writeMessageMutex.Unlock()
fake.WriteMessageStub = nil
if fake.writeMessageReturnsOnCall == nil {
fake.writeMessageReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.writeMessageReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeProtocol) WritePacket(arg1 context.Context, arg2 rtmp.Packet, arg3 int) error {
fake.writePacketMutex.Lock()
ret, specificReturn := fake.writePacketReturnsOnCall[len(fake.writePacketArgsForCall)]
fake.writePacketArgsForCall = append(fake.writePacketArgsForCall, struct {
arg1 context.Context
arg2 rtmp.Packet
arg3 int
}{arg1, arg2, arg3})
stub := fake.WritePacketStub
fakeReturns := fake.writePacketReturns
fake.recordInvocation("WritePacket", []interface{}{arg1, arg2, arg3})
fake.writePacketMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeProtocol) WritePacketCallCount() int {
fake.writePacketMutex.RLock()
defer fake.writePacketMutex.RUnlock()
return len(fake.writePacketArgsForCall)
}
func (fake *FakeProtocol) WritePacketCalls(stub func(context.Context, rtmp.Packet, int) error) {
fake.writePacketMutex.Lock()
defer fake.writePacketMutex.Unlock()
fake.WritePacketStub = stub
}
func (fake *FakeProtocol) WritePacketArgsForCall(i int) (context.Context, rtmp.Packet, int) {
fake.writePacketMutex.RLock()
defer fake.writePacketMutex.RUnlock()
argsForCall := fake.writePacketArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
}
func (fake *FakeProtocol) WritePacketReturns(result1 error) {
fake.writePacketMutex.Lock()
defer fake.writePacketMutex.Unlock()
fake.WritePacketStub = nil
fake.writePacketReturns = struct {
result1 error
}{result1}
}
func (fake *FakeProtocol) WritePacketReturnsOnCall(i int, result1 error) {
fake.writePacketMutex.Lock()
defer fake.writePacketMutex.Unlock()
fake.WritePacketStub = nil
if fake.writePacketReturnsOnCall == nil {
fake.writePacketReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.writePacketReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeProtocol) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeProtocol) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ rtmp.Protocol = new(FakeProtocol)

View File

@ -1,7 +1,7 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package protocol
package server
import (
"context"
@ -20,23 +20,28 @@ import (
"srsx/internal/version"
)
// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP,
// HTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP,
// to proxy other HTTP API of SRS like the streams and clients, etc.
type srsHTTPAPIServer struct {
type HTTPAPIServer interface {
Run(ctx context.Context) error
Close() error
}
type httpAPIServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The underlayer HTTP server.
server *http.Server
// The WebRTC server.
rtc *srsWebRTCServer
rtc WebRTCServer
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc *srsWebRTCServer) *srsHTTPAPIServer {
v := &srsHTTPAPIServer{
func NewHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc WebRTCServer) HTTPAPIServer {
v := &httpAPIServer{
environment: environment,
gracefulQuitTimeout: gracefulQuitTimeout,
rtc: rtc,
@ -44,7 +49,7 @@ func NewSRSHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout t
return v
}
func (v *srsHTTPAPIServer) Close() error {
func (v *httpAPIServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
@ -53,7 +58,7 @@ func (v *srsHTTPAPIServer) Close() error {
return nil
}
func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
func (v *httpAPIServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := v.environment.HttpAPI()
if !strings.Contains(addr, ":") {
@ -92,6 +97,13 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
utils.ApiError(ctx, w, r, err)
}
})
// Keep compatibility with the legacy SRS WebRTC publish API used by srs-bench.
logger.Debug(ctx, "Handle /rtc/v1/publish/ by %v", addr)
mux.HandleFunc("/rtc/v1/publish/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil {
utils.ApiError(ctx, w, r, err)
}
})
// The WebRTC WHEP API handler.
logger.Debug(ctx, "Handle /rtc/v1/whep/ by %v", addr)
@ -100,6 +112,13 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
utils.ApiError(ctx, w, r, err)
}
})
// Keep compatibility with the legacy SRS WebRTC play API used by srs-bench.
logger.Debug(ctx, "Handle /rtc/v1/play/ by %v", addr)
mux.HandleFunc("/rtc/v1/play/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil {
utils.ApiError(ctx, w, r, err)
}
})
// Run HTTP API server.
v.wg.Add(1)

View File

@ -1,7 +1,7 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package protocol
package server
import (
"context"
@ -23,10 +23,15 @@ import (
"srsx/internal/version"
)
// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS,
// HTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS,
// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy
// the request to the origin server.
type srsHTTPStreamServer struct {
type HTTPStreamServer interface {
Run(ctx context.Context) error
Close() error
}
type httpStreamServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The underlayer HTTP server.
@ -37,15 +42,15 @@ type srsHTTPStreamServer struct {
wg stdSync.WaitGroup
}
func NewSRSHTTPStreamServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) *srsHTTPStreamServer {
v := &srsHTTPStreamServer{
func NewHTTPStreamServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) HTTPStreamServer {
v := &httpStreamServer{
environment: environment,
gracefulQuitTimeout: gracefulQuitTimeout,
}
return v
}
func (v *srsHTTPStreamServer) Close() error {
func (v *httpStreamServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
@ -54,7 +59,7 @@ func (v *srsHTTPStreamServer) Close() error {
return nil
}
func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
func (v *httpStreamServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := v.environment.HttpServer()
if !strings.Contains(addr, ":") {
@ -123,12 +128,12 @@ func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
return
}
stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) {
stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) {
s.SRSProxyBackendHLSID = logger.GenerateContextID()
s.StreamURL, s.FullURL = streamURL, fullURL
}))
stream.Initialize(ctx).(*HLSPlayStream).ServeHTTP(w, r)
stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r)
return
}
@ -140,13 +145,13 @@ func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
if stream, err := lb.SrsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest)
} else {
stream.Initialize(ctx).(*HLSPlayStream).ServeHTTP(w, r)
stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r)
}
return
}
// Use HTTP pseudo streaming to proxy the request.
NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) {
newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
c.ctx = ctx
}).ServeHTTP(w, r)
return
@ -182,26 +187,26 @@ func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
return nil
}
// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS
// httpFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS
// connection. There is no state need to be sync between proxy servers.
//
// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request,
// then proxy to the corresponding backend server. All state is in the HTTP request, so this
// connection is stateless.
type HTTPFlvTsConnection struct {
type httpFlvTsConnection struct {
// The context for HTTP streaming.
ctx context.Context
}
func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection {
v := &HTTPFlvTsConnection{}
func newHTTPFlvTsConnection(opts ...func(*httpFlvTsConnection)) *httpFlvTsConnection {
v := &httpFlvTsConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (v *httpFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := logger.WithContext(v.ctx)
@ -212,7 +217,7 @@ func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request)
}
}
func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
func (v *httpFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
// Always allow CORS for all requests.
if ok := utils.ApiCORS(ctx, w, r); ok {
return nil
@ -240,7 +245,7 @@ func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter,
return nil
}
func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error {
func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no http stream server")
@ -288,14 +293,14 @@ func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons
return nil
}
// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS
// hlsPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS
// clients will share this object, and they do not use the same ctx among proxy servers.
//
// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections.
// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create
// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert
// to the stream URL and then query the backend server to serve it.
type HLSPlayStream struct {
type hlsPlayStream struct {
// The context for HLS streaming.
ctx context.Context
@ -307,26 +312,26 @@ type HLSPlayStream struct {
FullURL string `json:"full_url"`
}
func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream {
v := &HLSPlayStream{}
func newHLSPlayStream(opts ...func(*hlsPlayStream)) *hlsPlayStream {
v := &hlsPlayStream{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HLSPlayStream) Initialize(ctx context.Context) lb.HLSPlayStream {
func (v *hlsPlayStream) Initialize(ctx context.Context) lb.HLSPlayStream {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
return v
}
func (v *HLSPlayStream) GetSPBHID() string {
func (v *hlsPlayStream) GetSPBHID() string {
return v.SRSProxyBackendHLSID
}
func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (v *hlsPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := v.serve(v.ctx, w, r); err != nil {
@ -337,7 +342,7 @@ func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL
// Always allow CORS for all requests.
@ -358,7 +363,7 @@ func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt
return nil
}
func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error {
func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no rtmp server")

View File

@ -1,7 +1,7 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package protocol
package server
import (
"context"
@ -23,10 +23,17 @@ import (
"srsx/internal/utils"
)
// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out
// WebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out
// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the
// SDP answer.
type srsWebRTCServer struct {
type WebRTCServer interface {
Run(ctx context.Context) error
Close() error
HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error
HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error
}
type webRTCServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The UDP listener for WebRTC server.
@ -34,21 +41,21 @@ type srsWebRTCServer struct {
// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
usernames sync.Map[string, *RTCConnection]
usernames sync.Map[string, *rtcConnection]
// Fast cache for the udp address to identify the connection.
// The key is UDP address, the value is the username.
// TODO: Support fast earch by uint64 address.
addresses sync.Map[string, *RTCConnection]
addresses sync.Map[string, *rtcConnection]
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSWebRTCServer(environment env.ProxyEnvironment, opts ...func(*srsWebRTCServer)) *srsWebRTCServer {
v := &srsWebRTCServer{
func NewWebRTCServer(environment env.ProxyEnvironment, opts ...func(*webRTCServer)) WebRTCServer {
v := &webRTCServer{
environment: environment,
usernames: sync.NewMap[string, *RTCConnection](),
addresses: sync.NewMap[string, *RTCConnection](),
usernames: sync.NewMap[string, *rtcConnection](),
addresses: sync.NewMap[string, *rtcConnection](),
}
for _, opt := range opts {
opt(v)
@ -56,7 +63,7 @@ func NewSRSWebRTCServer(environment env.ProxyEnvironment, opts ...func(*srsWebRT
return v
}
func (v *srsWebRTCServer) Close() error {
func (v *webRTCServer) Close() error {
if v.listener != nil {
_ = v.listener.Close()
}
@ -65,7 +72,7 @@ func (v *srsWebRTCServer) Close() error {
return nil
}
func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
func (v *webRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
@ -102,7 +109,7 @@ func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseW
return nil
}
func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
func (v *webRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
@ -139,7 +146,7 @@ func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseW
return nil
}
func (v *srsWebRTCServer) proxyApiToBackend(
func (v *webRTCServer) proxyApiToBackend(
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer,
remoteSDPOffer string, streamURL string,
) error {
@ -215,11 +222,11 @@ func (v *srsWebRTCServer) proxyApiToBackend(
}
// Save the new WebRTC connection to LB.
icePair := &RTCICEPair{
icePair := &rtcICEPair{
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) {
if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) {
c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag()
c.Initialize(ctx, v.listener)
@ -239,7 +246,7 @@ func (v *srsWebRTCServer) proxyApiToBackend(
return nil
}
func (v *srsWebRTCServer) Run(ctx context.Context) error {
func (v *webRTCServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := v.environment.WebRTCServer()
if !strings.Contains(endpoint, ":") {
@ -287,8 +294,8 @@ func (v *srsWebRTCServer) Run(ctx context.Context) error {
return nil
}
func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var connection *RTCConnection
func (v *webRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var connection *rtcConnection
// If STUN binding request, parse the ufrag and identify the connection.
if err := func() error {
@ -296,7 +303,7 @@ func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr
return nil
}
var pkt RTCStunPacket
var pkt rtcStunPacket
if err := pkt.UnmarshalBinary(data); err != nil {
return errors.Wrapf(err, "unmarshal stun packet")
}
@ -311,7 +318,7 @@ func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr
if s, err := lb.SrsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
connection = s.(*RTCConnection).Initialize(ctx, v.listener)
connection = s.(*rtcConnection).Initialize(ctx, v.listener)
logger.Debug(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL)
}
@ -346,17 +353,17 @@ func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr
return nil
}
// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC
// rtcConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC
// connection, identify by the ufrag in sdp offer/answer and ICE binding request.
//
// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is
// in the client request. The RTCConnection is stateful, and need to sync the ufrag between
// in the client request. The rtcConnection is stateful, and need to sync the ufrag between
// proxy servers.
//
// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch
// to another UDP address, it may connect to another WebRTC proxy, then we should discover the
// RTCConnection by the ufrag from the ICE binding request.
type RTCConnection struct {
// rtcConnection by the ufrag from the ICE binding request.
type rtcConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
@ -373,15 +380,15 @@ type RTCConnection struct {
listenerUDP *net.UDPConn
}
func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection {
v := &RTCConnection{}
func newRTCConnection(opts ...func(*rtcConnection)) *rtcConnection {
v := &rtcConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection {
func (v *rtcConnection) Initialize(ctx context.Context, listener *net.UDPConn) *rtcConnection {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
@ -391,11 +398,11 @@ func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *
return v
}
func (v *RTCConnection) GetUfrag() string {
func (v *rtcConnection) GetUfrag() string {
return v.Ufrag
}
func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
ctx := v.ctx
// Update the current UDP address.
@ -437,7 +444,7 @@ func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
return nil
}
func (v *RTCConnection) connectBackend(ctx context.Context) error {
func (v *rtcConnection) connectBackend(ctx context.Context) error {
if v.backendUDP != nil {
return nil
}
@ -470,7 +477,7 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error {
return nil
}
type RTCICEPair struct {
type rtcICEPair struct {
// The remote ufrag, used for ICE username and session id.
RemoteICEUfrag string `json:"remote_ufrag"`
// The remote pwd, used for ICE password.
@ -482,18 +489,18 @@ type RTCICEPair struct {
}
// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
func (v *RTCICEPair) Ufrag() string {
func (v *rtcICEPair) Ufrag() string {
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
}
type RTCStunPacket struct {
type rtcStunPacket struct {
// The stun message type.
MessageType uint16
// The stun username, or ufrag.
Username string
}
func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
func (v *rtcStunPacket) UnmarshalBinary(data []byte) error {
if len(data) < 20 {
return errors.Errorf("stun packet too short %v", len(data))
}

View File

@ -1,7 +1,7 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package protocol
package server
import (
"context"
@ -20,10 +20,15 @@ import (
"srsx/internal/version"
)
// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
// RTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
// server. It will figure out the backend server to proxy to. Unlike the edge server, it will
// not cache the stream, but just proxy the stream to backend.
type srsRTMPServer struct {
type RTMPServer interface {
Run(ctx context.Context) error
Close() error
}
type rtmpServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The TCP listener for RTMP server.
@ -32,15 +37,15 @@ type srsRTMPServer struct {
wg sync.WaitGroup
}
func NewSRSRTMPServer(environment env.ProxyEnvironment, opts ...func(*srsRTMPServer)) *srsRTMPServer {
v := &srsRTMPServer{environment: environment}
func NewRTMPServer(environment env.ProxyEnvironment, opts ...func(*rtmpServer)) RTMPServer {
v := &rtmpServer{environment: environment}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsRTMPServer) Close() error {
func (v *rtmpServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
@ -49,7 +54,7 @@ func (v *srsRTMPServer) Close() error {
return nil
}
func (v *srsRTMPServer) Run(ctx context.Context) error {
func (v *rtmpServer) Run(ctx context.Context) error {
endpoint := v.environment.RtmpServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
@ -97,7 +102,7 @@ func (v *srsRTMPServer) Run(ctx context.Context) error {
}
}
rc := NewRTMPConnection()
rc := newRTMPConnection()
if err := rc.serve(ctx, conn); err != nil {
handleErr(err)
} else {
@ -110,24 +115,24 @@ func (v *srsRTMPServer) Run(ctx context.Context) error {
return nil
}
// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between
// rtmpConnection is an RTMP streaming connection. There is no state need to be sync between
// proxy servers.
//
// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request,
// then proxy to the corresponding backend server. All state is in the RTMP request, so this
// connection is stateless.
type RTMPConnection struct {
type rtmpConnection struct {
}
func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection {
v := &RTMPConnection{}
func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection {
v := &rtmpConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
logger.Debug(ctx, "Got RTMP client from %v", conn.RemoteAddr())
// If any goroutine quit, cancel another one.
@ -135,7 +140,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var backend *RTMPClientToBackend
var backend *rtmpClientToBackend
if true {
go func() {
<-ctx.Done()
@ -229,7 +234,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
response = identifyRes
nextStreamID = 1
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
identifyRes.SetStreamID(nextStreamID)
} else if pkt.CommandName == "getStreamLength" {
// Ignore and do not reply these packets.
} else {
@ -243,7 +248,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
identifyRes.Args = rtmp.NewAmf0Undefined()
}
case *rtmp.PublishPacket:
streamName = string(pkt.StreamName)
streamName = pkt.StreamName.String()
clientType = RTMPClientTypePublisher
identifyRes := rtmp.NewCallPacket()
@ -257,7 +262,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
identifyRes.Args = data
case *rtmp.PlayPacket:
streamName = string(pkt.StreamName)
streamName = pkt.StreamName.String()
clientType = RTMPClientTypeViewer
identifyRes := rtmp.NewCallPacket()
@ -289,7 +294,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
tcUrl, streamName, currentStreamID, clientType)
// Find a backend SRS server to proxy the RTMP stream.
backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) {
backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) {
client.typ = clientType
})
defer backend.Close()
@ -352,7 +357,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Debug(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
//logger.Debug(ctx, "client<- %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload()))
// TODO: Update the stream ID if not the same.
if err := client.WriteMessage(ctx, m); err != nil {
@ -375,7 +380,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Debug(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
//logger.Debug(ctx, "client-> %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload()))
// TODO: Update the stream ID if not the same.
if err := backend.client.WriteMessage(ctx, m); err != nil {
@ -416,32 +421,32 @@ const (
RTMPClientTypeViewer RTMPClientType = "viewer"
)
// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend.
type RTMPClientToBackend struct {
// rtmpClientToBackend is an RTMP client to proxy the RTMP stream to backend.
type rtmpClientToBackend struct {
// The underlayer tcp client.
tcpConn *net.TCPConn
// The RTMP protocol client.
client *rtmp.Protocol
client rtmp.Protocol
// The stream type.
typ RTMPClientType
}
func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend {
v := &RTMPClientToBackend{}
func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend {
v := &rtmpClientToBackend{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPClientToBackend) Close() error {
func (v *rtmpClientToBackend) Close() error {
if v.tcpConn != nil {
v.tcpConn.Close()
}
return nil
}
func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
// Build the stream URL in vhost/app/stream schema.
streamURL, err := utils.BuildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName))
if err != nil {
@ -527,7 +532,7 @@ func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
return v.publish(ctx, client, streamName)
}
func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
func (v *rtmpClientToBackend) publish(ctx context.Context, client rtmp.Protocol, streamName string) error {
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "releaseStream"
@ -592,8 +597,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol
publishStream := rtmp.NewPublishPacket()
publishStream.TransactionID = 5
publishStream.CommandObject = rtmp.NewAmf0Null()
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
publishStream.StreamType = *rtmp.NewAmf0String("live")
publishStream.StreamName = rtmp.NewAmf0String(streamName)
publishStream.StreamType = rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
return errors.Wrapf(err, "publish")
}
@ -609,8 +614,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol
return errors.Errorf("onStatus args not object")
} else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
return errors.Errorf("onStatus code not string")
} else if *code != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
} else if code.String() != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", code.String())
}
break
}
@ -620,7 +625,7 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol
return nil
}
func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
func (v *rtmpClientToBackend) play(ctx context.Context, client rtmp.Protocol, streamName string) error {
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
@ -642,7 +647,7 @@ func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, s
}
playStream := rtmp.NewPlayPacket()
playStream.StreamName = *rtmp.NewAmf0String(streamName)
playStream.StreamName = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
return errors.Wrapf(err, "play")
}

View File

@ -1,7 +1,7 @@
// Copyright (c) 2026 Winlin
//
// SPDX-License-Identifier: MIT
package protocol
package server
import (
"bytes"

View File

@ -15,7 +15,7 @@ func VersionMinor() int {
}
func VersionRevision() int {
return 146
return 147
}
func Version() string {

View File

@ -7,6 +7,7 @@ The changelog for SRS.
<a name="v7-changes"></a>
## SRS 7.0 Changelog
* v7.0, 2026-05-02, Merge [#4672](https://github.com/ossrs/srs/pull/4672): Proxy: Refactor server APIs and expand RTMP test coverage. v7.0.147 (#4672)
* v7.0, 2026-04-28, Merge [#4670](https://github.com/ossrs/srs/pull/4670): Proxy: Refine logger and environment APIs. v7.0.146 (#4670)
* v7.0, 2026-04-23, Merge [#4667](https://github.com/ossrs/srs/pull/4667): Proxy: Refactor internal/errors and internal/sync, and add unit tests across internal/*. v7.0.145 (#4667)
* v7.0, 2026-04-18, Merge [#4665](https://github.com/ossrs/srs/pull/4665): Proxy: Harden internal/env tests and add counterfeiter fake generation. v7.0.144 (#4665)

View File

@ -9,6 +9,6 @@
#define VERSION_MAJOR 7
#define VERSION_MINOR 0
#define VERSION_REVISION 146
#define VERSION_REVISION 147
#endif