diff --git a/.openclaw/memory/srs-codebase-map.md b/.openclaw/memory/srs-codebase-map.md index 3f100c956..0beae9cae 100644 --- a/.openclaw/memory/srs-codebase-map.md +++ b/.openclaw/memory/srs-codebase-map.md @@ -215,7 +215,7 @@ The next-generation server (`cmd/` + `internal/`) is written in Go and maintaine `internal/bootstrap` — Server startup and lifecycle orchestration. Sets up logging context, signal handlers, loads environment, installs force-quit timer, optionally starts pprof, initializes the load balancer (memory or Redis based on `PROXY_LOAD_BALANCER_TYPE`), then starts all six servers sequentially (RTMP, WebRTC, HTTP API, SRT, System API, HTTP Stream) and blocks until context is cancelled. Deferred `Close()` on each server ensures graceful shutdown. -`internal/protocol` — Protocol proxy servers. Each server accepts client connections, parses just enough of the protocol to extract the stream URL, picks a backend via the load balancer, and proxies traffic bidirectionally. Contains five proxy servers: (1) **RTMP proxy** (`rtmp.go`) — TCP listener, simple handshake, parses connect/publish/play to get stream URL, bidirectional RTMP message copying, stateless. (2) **HTTP stream proxy** (`http.go`) — serves static files, proxies HTTP-FLV/TS via reverse-proxy, proxies HLS m3u8 with `spbhid` rewriting so TS segment requests route to the same backend. (3) **WebRTC proxy** (`rtc.go`) — two-phase: WHIP/WHEP signaling (SDP rewrite to replace backend UDP port with proxy's) + UDP media transport (identifies connections by STUN ufrag, supports address migration), stateful. (4) **SRT proxy** (`srt.go`) — intercepts SRT 4-step handshake locally, parses stream ID on handshake 2, replays full handshake with backend, then proxies UDP bidirectionally, stateful per-connection. (5) **HTTP API + System API** (`api.go`) — HTTP API delegates WHIP/WHEP to WebRTC server; System API provides `/api/v1/srs/register` where backend SRS C++ servers register themselves so the load balancer knows about them. +`internal/server` — Proxy server implementations. Each server accepts client connections, parses just enough of the protocol to extract the stream URL, picks a backend via the load balancer, and proxies traffic bidirectionally. Contains five proxy servers: (1) **RTMP proxy** (`rtmp.go`) — TCP listener, simple handshake, parses connect/publish/play to get stream URL, bidirectional RTMP message copying, stateless. (2) **HTTP stream proxy** (`http.go`) — serves static files, proxies HTTP-FLV/TS via reverse-proxy, proxies HLS m3u8 with `spbhid` rewriting so TS segment requests route to the same backend. (3) **WebRTC proxy** (`rtc.go`) — two-phase: WHIP/WHEP signaling (SDP rewrite to replace backend UDP port with proxy's) + UDP media transport (identifies connections by STUN ufrag, supports address migration), stateful. (4) **SRT proxy** (`srt.go`) — intercepts SRT 4-step handshake locally, parses stream ID on handshake 2, replays full handshake with backend, then proxies UDP bidirectionally, stateful per-connection. (5) **HTTP API + System API** (`api.go`) — HTTP API delegates WHIP/WHEP to WebRTC server; System API provides `/api/v1/srs/register` where backend SRS C++ servers register themselves so the load balancer knows about them. `internal/rtmp` — RTMP protocol implementation (parsing, not proxying). Full RTMP chunk stream and message protocol: simple handshake (C0/C1/C2), chunk stream reader/writer with all four format types, extended timestamp, message reassembly from chunks. Defines all RTMP message types, chunk stream IDs, and command names. Packet types include ConnectApp, CreateStream, Publish, Play, Call, SetChunkSize, WindowAcknowledgementSize, SetPeerBandwidth, UserControl. Uses Go generics (`ExpectPacket[T]`) to read until a specific packet type arrives. Also includes full AMF0 encoder/decoder supporting Number, Boolean, String, Object, Null, Undefined, EcmaArray, StrictArray, Date, LongString — with ordered key-value maps, auto-type-discovery, and safe type converters. @@ -301,6 +301,9 @@ The knowledge base (`memory/srs-*.md`) captures William's knowledge about SRS - `proxy-load-balancer.md` — Load balancer design: memory vs Redis implementations, stream-to-server mapping, server health via heartbeats, protocol-specific state - `proxy-origin-cluster.md` — Origin cluster tutorial: build proxy + SRS, configure multi-origin with proxy, stream publishing and playback verification +**Next-Generation Server API Examples** — Executable API documentation: +- `internal/rtmp/example_test.go` — RTMP API examples: AMF0, handshake, and protocol workflow + ## Testing and Verification Structure How to verify SRS works correctly. @@ -343,6 +346,13 @@ How to verify SRS works correctly. - Reconnecting Load Test - Janus +`.openclaw/skills/srs-develop/scripts/` — Go proxy verification scripts: +- `proxy-utest.sh` — Runs Go proxy unit tests with optional coverage. +- `proxy-e2e-test.sh` — Single-origin RTMP proxy E2E test. +- `proxy-e2e-cluster-test.sh` — Multi-origin memory load-balancer E2E test. +- `proxy-e2e-redis-test.sh` — Multi-proxy Redis load-balancer E2E test. +- `proxy-e2e-transmux-test.sh` — RTMP publish through proxy, then verify RTMP, HTTP-FLV, HLS, and WebRTC playback. + **Summary: The Key Differences** | | Unit Tests | Black-box | E2E | Benchmark | diff --git a/.openclaw/memory/srs-overview.md b/.openclaw/memory/srs-overview.md index 18a30596e..4d19fe304 100644 --- a/.openclaw/memory/srs-overview.md +++ b/.openclaw/memory/srs-overview.md @@ -127,6 +127,12 @@ Which underlying transport (TCP or UDP) each protocol uses in SRS: - **RTSP** — TCP. SRS only supports TCP transport (no UDP/RTP interleaved). - **GB28181** — TCP. PS stream over TCP. +Related transport protocols that are important in the media industry but **not supported by SRS or Oryx**: + +- **RIST** — UDP. Reliable Internet Stream Transport. Similar to SRT: a reliable, low-latency media transport over UDP, with retransmission and encryption options. +- **MoQ** — UDP/QUIC. Media over QUIC. An IETF effort for low-latency media ingest and delivery over QUIC, usually over UDP and optionally through WebTransport. +- **WebTransport** — UDP/QUIC, HTTP/3-based. A browser and network transport API/protocol that can carry media data, and one possible substrate for MoQ. + ## Most Common Usage The simplest way to use SRS: publish an RTMP stream and play it. @@ -292,4 +298,3 @@ Config files are in the `conf/` folder. Key files: - Other files exist for specific features like clustering, DVR, or different protocols. SRS also supports configuration via environment variables. This is especially useful for Docker and cloud-native deployments — you can set environment variables in YAML files or other platforms without needing a separate config file. It's convenient to copy and paste, making documentation clearer. In the SRS docs, environment variables are often used to show how to run SRS with different configurations. - diff --git a/.openclaw/skills/srs-develop/SKILL.md b/.openclaw/skills/srs-develop/SKILL.md index 9733c6740..4b444521a 100644 --- a/.openclaw/skills/srs-develop/SKILL.md +++ b/.openclaw/skills/srs-develop/SKILL.md @@ -1,6 +1,6 @@ --- name: srs-develop -description: Develop, modify, debug, and maintain the next-generation SRS media server written in Go — including the proxy, origin, and edge servers. This is the AI-maintained successor to the first-generation C++ SRS server. Use for all development tasks, for example, adding features, fixing bugs, refactoring code, understanding code architecture, reviewing changes, and writing tests for the Go codebase. NOT for end-user support, usage questions, configuration help, or learning how to use SRS — use the srs-support skill for those. Only activate when the task is explicitly about developing or modifying the Go SRS codebase. +description: Develop, modify, debug, and maintain the next-generation SRS media server written in Go. This is the AI-maintained successor to the first-generation C++ SRS server. Currently, planned changes are supported for the Go proxy server only; the next-generation Go origin and edge server workflows are not yet supported. Use for all development tasks, for example, adding features, fixing bugs, refactoring code, understanding code architecture, reviewing changes, and writing tests for the Go codebase. NOT for end-user support, usage questions, configuration help, or learning how to use SRS — use the srs-support skill for those. Only activate when the task is explicitly about developing or modifying the Go SRS codebase. --- # SRS Development @@ -100,7 +100,7 @@ Do NOT attempt unsupported tasks. | Service | Route To | Status | |---|---|---| | **Proxy server** | → [Proxy Server](#proxy-server) | ✅ Supported | -| **Origin server** | → [Origin Server](#origin-server) | ✅ Supported | +| **Origin server** | → [Origin Server](#origin-server) | ❌ Not yet supported | | **Edge server** | → [Edge Server](#edge-server) | ❌ Not yet supported | **If the routed service is not yet supported**, stop and tell the user: @@ -143,17 +143,30 @@ Only after the user confirms the routing do you proceed to Step 2. ``` bash scripts/proxy-utest.sh --coverage ``` -4. Run the proxy E2E test (starts proxy + SRS origin, publishes RTMP, verifies playback): +4. Run the proxy E2E tests: + - Single-origin RTMP proxy test (starts proxy + one SRS origin, publishes RTMP, verifies playback): ``` bash scripts/proxy-e2e-test.sh ``` + - Multi-origin cluster routing test (starts proxy + two SRS origins, publishes multiple streams, verifies streams are assigned to different origins): + ``` + bash scripts/proxy-e2e-cluster-test.sh + ``` + - Redis multi-proxy routing test (requires local Redis; starts two proxy instances with Redis LB, publishes through one proxy, verifies playback through the other): + ``` + bash scripts/proxy-e2e-redis-test.sh + ``` + - RTMP transmuxing test (starts proxy + one SRS origin, publishes RTMP, verifies RTMP/HTTP-FLV/HLS playback, and verifies WebRTC WHEP playback when `PROXY_TRANSMUX_TEST_RTC=on`): + ``` + bash scripts/proxy-e2e-transmux-test.sh + ``` 5. If any tests fail, fix the issues and re-run until all tests pass. All script paths are relative to this skill's directory. ### Origin Server -*(workflow steps to be defined)* +**Not yet supported.** This refers to the next-generation Go origin server workflow. The first-generation C++ origin server still exists, but it is in maintenance mode and only bug fixes are accepted there. ### Edge Server diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh b/.openclaw/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh new file mode 100755 index 000000000..eba2f5077 --- /dev/null +++ b/.openclaw/skills/srs-develop/scripts/proxy-e2e-cluster-test.sh @@ -0,0 +1,341 @@ +#!/bin/bash +# E2E test for RTMP proxy origin-cluster routing: starts one proxy + two SRS +# origins, publishes multiple RTMP streams, verifies playback via proxy, and +# verifies that different streams are assigned to different origin servers. +set -e + +SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" +# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs +WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" + +if [[ ! -f "$WORKSPACE/go.mod" ]]; then + echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + exit 1 +fi + +# Ports — use high ports to avoid conflicts with running services. +# The proxy starts ALL servers, so we must assign unique ports for each. +PROXY_RTMP_PORT=11935 +PROXY_HTTP_API_PORT=11985 +PROXY_HTTP_SERVER_PORT=18080 +PROXY_WEBRTC_PORT=18000 +PROXY_SRT_PORT=20080 +PROXY_SYSTEM_API_PORT=12025 + +SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" +SRS_BINARY="$WORKSPACE/trunk/objs/srs" + +# Origin ports from origin1-for-proxy.conf and origin2-for-proxy.conf. +ORIGIN1_RTMP_PORT=19351 +ORIGIN1_HTTP_PORT=8081 +ORIGIN1_API_PORT=19851 +ORIGIN1_RTC_PORT=8001 +ORIGIN1_SRT_PORT=10081 + +ORIGIN2_RTMP_PORT=19352 +ORIGIN2_HTTP_PORT=8082 +ORIGIN2_API_PORT=19853 +ORIGIN2_RTC_PORT=8002 +ORIGIN2_SRT_PORT=10082 + +# PIDs to clean up on exit. +PROXY_PID="" +ORIGIN_PIDS=() +FFMPEG_PIDS=() + +cleanup() { + echo "" + echo "=== Cleaning up ===" + for pid in $PROXY_PID "${ORIGIN_PIDS[@]}" "${FFMPEG_PIDS[@]}"; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in $PROXY_PID "${ORIGIN_PIDS[@]}" "${FFMPEG_PIDS[@]}"; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + done + echo "Cleanup done." +} +trap cleanup EXIT + +wait_for_http() { + local url=$1 + local name=$2 + local i + + for i in $(seq 1 30); do + if curl -fsS --max-time 2 "$url" >/dev/null 2>&1; then + echo "$name is ready." + return 0 + fi + sleep 1 + done + + echo "Error: $name is not ready after 30s: $url" >&2 + return 1 +} + +origin_has_stream() { + local api_port=$1 + local stream=$2 + + curl -fsS --max-time 3 "http://127.0.0.1:$api_port/api/v1/streams/" 2>/dev/null | grep -q "$stream" +} + +detect_origin_for_stream() { + local stream=$1 + local i + + for i in $(seq 1 10); do + local on_origin1=0 + local on_origin2=0 + + if origin_has_stream "$ORIGIN1_API_PORT" "$stream"; then + on_origin1=1 + fi + if origin_has_stream "$ORIGIN2_API_PORT" "$stream"; then + on_origin2=1 + fi + + if [[ $on_origin1 -eq 1 && $on_origin2 -eq 0 ]]; then + echo "origin1" + return 0 + fi + if [[ $on_origin1 -eq 0 && $on_origin2 -eq 1 ]]; then + echo "origin2" + return 0 + fi + if [[ $on_origin1 -eq 1 && $on_origin2 -eq 1 ]]; then + echo "Error: stream $stream exists on both origins; expected exactly one owner" >&2 + return 1 + fi + + sleep 1 + done + + echo "Error: stream $stream was not found on either origin" >&2 + return 1 +} + +verify_probe_has_av() { + local url=$1 + local label=$2 + local probe_output + + probe_output=$(ffprobe -v error -rw_timeout 5000000 -show_streams "$url" 2>&1 || true) + + if ! echo "$probe_output" | grep -q "codec_type=video"; then + echo "FAIL: No video stream detected for $label." >&2 + echo "ffprobe output:" >&2 + echo "$probe_output" >&2 + exit 1 + fi + + if ! echo "$probe_output" | grep -q "codec_type=audio"; then + echo "FAIL: No audio stream detected for $label." >&2 + echo "ffprobe output:" >&2 + echo "$probe_output" >&2 + exit 1 + fi + + echo "PASS: Audio/video detected for $label." +} + +echo "=== E2E RTMP Proxy Origin Cluster Test ===" +echo "Workspace: $WORKSPACE" +echo "" + +# --- Pre-checks --- +if [[ ! -f "$SOURCE_FLV" ]]; then + echo "Error: test source not found: $SOURCE_FLV" >&2 + exit 1 +fi +if ! command -v ffmpeg &>/dev/null; then + echo "Error: ffmpeg not found in PATH" >&2 + exit 1 +fi +if ! command -v ffprobe &>/dev/null; then + echo "Error: ffprobe not found in PATH" >&2 + exit 1 +fi +if ! command -v curl &>/dev/null; then + echo "Error: curl not found in PATH" >&2 + exit 1 +fi + +# --- Step 0: Clean up stale state --- +# Remove stale SRS PID files that prevent restart. +rm -f "$WORKSPACE/trunk/objs/origin1.pid" "$WORKSPACE/trunk/objs/origin2.pid" +# Kill any leftover processes on our ports (proxy + origins). +ALL_PORTS="$PROXY_RTMP_PORT $PROXY_HTTP_API_PORT $PROXY_HTTP_SERVER_PORT $PROXY_WEBRTC_PORT $PROXY_SRT_PORT $PROXY_SYSTEM_API_PORT" +ALL_PORTS="$ALL_PORTS $ORIGIN1_RTMP_PORT $ORIGIN1_HTTP_PORT $ORIGIN1_API_PORT $ORIGIN1_RTC_PORT $ORIGIN1_SRT_PORT" +ALL_PORTS="$ALL_PORTS $ORIGIN2_RTMP_PORT $ORIGIN2_HTTP_PORT $ORIGIN2_API_PORT $ORIGIN2_RTC_PORT $ORIGIN2_SRT_PORT" +for port in $ALL_PORTS; do + lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true +done +sleep 1 + +# --- Step 1: Build proxy --- +echo "=== Step 1: Building proxy ===" +cd "$WORKSPACE" +make -s 2>&1 +echo "Proxy built: $WORKSPACE/bin/srs-proxy" + +# --- Step 2: Build SRS origins (if not already built) --- +if [[ ! -f "$SRS_BINARY" ]]; then + echo "=== Step 2: Building SRS origins ===" + cd "$WORKSPACE/trunk" + ./configure && make 2>&1 | tail -3 + echo "SRS origins built: $SRS_BINARY" +else + echo "=== Step 2: SRS origins already built ===" +fi + +# --- Step 3: Start proxy --- +echo "=== Step 3: Starting proxy (RTMP :$PROXY_RTMP_PORT, System API :$PROXY_SYSTEM_API_PORT) ===" +cd "$WORKSPACE" +env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \ + PROXY_HTTP_API=$PROXY_HTTP_API_PORT \ + PROXY_HTTP_SERVER=$PROXY_HTTP_SERVER_PORT \ + PROXY_WEBRTC_SERVER=$PROXY_WEBRTC_PORT \ + PROXY_SRT_SERVER=$PROXY_SRT_PORT \ + PROXY_SYSTEM_API=$PROXY_SYSTEM_API_PORT \ + PROXY_LOAD_BALANCER_TYPE=memory \ + ./bin/srs-proxy >/tmp/srs-proxy-cluster-e2e.log 2>&1 & +PROXY_PID=$! +echo "Proxy PID: $PROXY_PID" + +wait_for_http "http://127.0.0.1:$PROXY_SYSTEM_API_PORT/api/v1/versions" "Proxy System API" + +if ! kill -0 "$PROXY_PID" 2>/dev/null; then + echo "Error: proxy failed to start. Logs:" >&2 + cat /tmp/srs-proxy-cluster-e2e.log >&2 + exit 1 +fi +echo "Proxy started." + +# --- Step 4: Start two SRS origins --- +echo "=== Step 4: Starting two SRS origins ===" +ulimit -n 10000 2>/dev/null || true +cd "$WORKSPACE/trunk" + +./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin1-cluster-e2e.log 2>&1 & +ORIGIN_PIDS+=($!) +echo "SRS origin1 PID: ${ORIGIN_PIDS[0]}" + +./objs/srs -c conf/origin2-for-proxy.conf >/tmp/srs-origin2-cluster-e2e.log 2>&1 & +ORIGIN_PIDS+=($!) +echo "SRS origin2 PID: ${ORIGIN_PIDS[1]}" + +wait_for_http "http://127.0.0.1:$ORIGIN1_API_PORT/api/v1/versions" "SRS origin1 HTTP API" +wait_for_http "http://127.0.0.1:$ORIGIN2_API_PORT/api/v1/versions" "SRS origin2 HTTP API" + +# Wait for both SRS origins to register with proxy (heartbeat interval is 9s). +echo "Waiting for both SRS origins to register with proxy (up to 20s)..." +for i in $(seq 1 20); do + registered=$(grep -c "Register SRS media server" /tmp/srs-proxy-cluster-e2e.log 2>/dev/null || true) + if [[ $registered -ge 2 ]]; then + echo "Both origins registered." + break + fi + sleep 1 +done + +registered=$(grep -c "Register SRS media server" /tmp/srs-proxy-cluster-e2e.log 2>/dev/null || true) +if [[ $registered -lt 2 ]]; then + echo "Error: expected two origin registrations, got $registered. Proxy logs:" >&2 + cat /tmp/srs-proxy-cluster-e2e.log >&2 + exit 1 +fi + +for pid in "${ORIGIN_PIDS[@]}"; do + if ! kill -0 "$pid" 2>/dev/null; then + echo "Error: SRS origin failed to start. Origin1 logs:" >&2 + cat /tmp/srs-origin1-cluster-e2e.log >&2 + echo "Origin2 logs:" >&2 + cat /tmp/srs-origin2-cluster-e2e.log >&2 + exit 1 + fi +done +echo "Two SRS origins started and registered." + +# --- Step 5: Publish RTMP streams until both origins own at least one stream --- +echo "=== Step 5: Publishing multiple RTMP streams to proxy ===" +STREAM_PREFIX="cluster$(date +%s)" +STREAMS=() +STREAM_ORIGINS=() +origin1_count=0 +origin2_count=0 + +# The memory load balancer picks a random healthy origin for each new stream and +# keeps that stream sticky. Publish several unique streams to verify distribution +# while keeping the test extremely unlikely to fail from random selection alone. +for i in $(seq 1 20); do + stream="${STREAM_PREFIX}_$i" + STREAMS+=("$stream") + + ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \ + "rtmp://localhost:$PROXY_RTMP_PORT/live/$stream" >/tmp/srs-ffmpeg-cluster-e2e-$i.log 2>&1 & + FFMPEG_PIDS+=($!) + echo "Started publisher for live/$stream, PID: ${FFMPEG_PIDS[$((${#FFMPEG_PIDS[@]} - 1))]}" + + sleep 3 + + if ! kill -0 "${FFMPEG_PIDS[$((${#FFMPEG_PIDS[@]} - 1))]}" 2>/dev/null; then + echo "Error: FFmpeg publisher failed for $stream. Logs:" >&2 + cat "/tmp/srs-ffmpeg-cluster-e2e-$i.log" >&2 + exit 1 + fi + + owner=$(detect_origin_for_stream "$stream") + STREAM_ORIGINS+=("$owner") + echo "Stream live/$stream is owned by $owner." + + if [[ "$owner" == "origin1" ]]; then + origin1_count=$((origin1_count + 1)) + else + origin2_count=$((origin2_count + 1)) + fi + + if [[ $origin1_count -gt 0 && $origin2_count -gt 0 ]]; then + break + fi +done + +if [[ $origin1_count -eq 0 || $origin2_count -eq 0 ]]; then + echo "FAIL: streams were not distributed to both origins." >&2 + echo "origin1_count=$origin1_count, origin2_count=$origin2_count" >&2 + exit 1 +fi + +echo "PASS: stream distribution detected: origin1=$origin1_count, origin2=$origin2_count." + +# --- Step 6: Verify playback via proxy and direct owning origins --- +echo "=== Step 6: Verifying playback and sticky origin ownership ===" +for i in "${!STREAMS[@]}"; do + stream="${STREAMS[$i]}" + owner="${STREAM_ORIGINS[$i]}" + + # Verify proxy playback works for every published stream. + verify_probe_has_av "rtmp://localhost:$PROXY_RTMP_PORT/live/$stream" "proxy live/$stream" + + # Verify the stream is on exactly one origin, and verify direct playback from + # the owning origin to prove the proxy really published to that backend. + current_owner=$(detect_origin_for_stream "$stream") + if [[ "$current_owner" != "$owner" ]]; then + echo "FAIL: stream live/$stream moved from $owner to $current_owner; expected sticky routing." >&2 + exit 1 + fi + + if [[ "$owner" == "origin1" ]]; then + verify_probe_has_av "rtmp://localhost:$ORIGIN1_RTMP_PORT/live/$stream" "origin1 live/$stream" + else + verify_probe_has_av "rtmp://localhost:$ORIGIN2_RTMP_PORT/live/$stream" "origin2 live/$stream" + fi +done + +echo "" +echo "=== E2E RTMP Proxy Origin Cluster Test PASSED ===" diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-redis-test.sh b/.openclaw/skills/srs-develop/scripts/proxy-e2e-redis-test.sh new file mode 100755 index 000000000..148a81d18 --- /dev/null +++ b/.openclaw/skills/srs-develop/scripts/proxy-e2e-redis-test.sh @@ -0,0 +1,324 @@ +#!/bin/bash +# E2E test for RTMP proxy Redis load balancer: starts two proxy instances with +# Redis-backed shared state + one SRS origin. The origin registers through proxy A, +# publishing goes through proxy A, and playback goes through proxy B. Playback must +# succeed because proxy B resolves the stream-to-origin mapping from Redis. +set -e + +SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" +# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs +WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" + +if [[ ! -f "$WORKSPACE/go.mod" ]]; then + echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + exit 1 +fi + +# Ports — use high ports to avoid conflicts with running services. +# Each proxy starts ALL servers, so each proxy needs a unique full port set. +PROXY_A_RTMP_PORT=11935 +PROXY_A_HTTP_API_PORT=11985 +PROXY_A_HTTP_SERVER_PORT=18080 +PROXY_A_WEBRTC_PORT=18000 +PROXY_A_SRT_PORT=20080 +PROXY_A_SYSTEM_API_PORT=12025 + +PROXY_B_RTMP_PORT=11936 +PROXY_B_HTTP_API_PORT=11986 +PROXY_B_HTTP_SERVER_PORT=18081 +PROXY_B_WEBRTC_PORT=18001 +PROXY_B_SRT_PORT=20081 +PROXY_B_SYSTEM_API_PORT=12026 + +REDIS_HOST="${PROXY_REDIS_HOST:-127.0.0.1}" +REDIS_PORT="${PROXY_REDIS_PORT:-6379}" +REDIS_PASSWORD="${PROXY_REDIS_PASSWORD:-}" +REDIS_DB="${PROXY_REDIS_DB:-0}" +PYTHON_BIN="${PYTHON_BIN:-python3}" + +SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" +SRS_BINARY="$WORKSPACE/trunk/objs/srs" +TEST_STREAM_URL="__defaultVhost__/live/livestream" + +# PIDs to clean up on exit. +PROXY_A_PID="" +PROXY_B_PID="" +ORIGIN_PID="" +FFMPEG_PID="" + +redis_cli() { + if [[ -n "$REDIS_PASSWORD" ]]; then + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" -a "$REDIS_PASSWORD" -n "$REDIS_DB" "$@" + else + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" -n "$REDIS_DB" "$@" + fi +} + +cleanup_redis_state() { + # Remove only the Redis records created by this E2E test. Never flush the DB, + # and never delete every srs-proxy-* key because the same Redis DB may be used + # by another proxy/origin-cluster test or by a developer's local proxy. + if ! command -v redis-cli &>/dev/null; then + return + fi + if ! command -v "$PYTHON_BIN" &>/dev/null; then + echo "Skip Redis cleanup: $PYTHON_BIN is not available" + return + fi + if ! redis_cli ping 2>/dev/null | grep -q "PONG"; then + echo "Skip Redis cleanup: Redis is not available at $REDIS_HOST:$REDIS_PORT db=$REDIS_DB" + return + fi + + local count=0 + local stream_key="srs-proxy-url:$TEST_STREAM_URL" + if [[ "$(redis_cli exists "$stream_key" 2>/dev/null || echo 0)" != "0" ]]; then + redis_cli del "$stream_key" >/dev/null 2>&1 || true + count=$((count + 1)) + fi + + # The origin server generates its server key from runtime IDs, so discover only + # server records that match this test origin's identity and configured ports. + local server_keys=() + local key value + while IFS= read -r key; do + [[ -z "$key" ]] && continue + value="$(redis_cli get "$key" 2>/dev/null || true)" + if [[ "$value" == *'"device_id":"origin1"'* && \ + "$value" == *'"rtmp":["19351"]'* && \ + "$value" == *'"http":["8081"]'* && \ + "$value" == *'"api":["19851"]'* ]]; then + server_keys+=("$key") + redis_cli del "$key" >/dev/null 2>&1 || true + count=$((count + 1)) + fi + done < <(redis_cli --scan --pattern 'srs-proxy-server:*' 2>/dev/null || true) + + # Keep the shared server index, but remove only the test origin server keys. + if [[ ${#server_keys[@]} -gt 0 ]]; then + local servers_json updated_json + servers_json="$(redis_cli get srs-proxy-all-servers 2>/dev/null || true)" + if [[ -n "$servers_json" ]]; then + updated_json="$($PYTHON_BIN - "$servers_json" "${server_keys[@]}" <<'PY' +import json, sys +servers = json.loads(sys.argv[1]) if sys.argv[1] else [] +remove = set(sys.argv[2:]) +servers = [server for server in servers if server not in remove] +print(json.dumps(servers, separators=(",", ":"))) +PY +)" + if [[ "$updated_json" == "[]" ]]; then + redis_cli del srs-proxy-all-servers >/dev/null 2>&1 || true + else + redis_cli set srs-proxy-all-servers "$updated_json" >/dev/null 2>&1 || true + fi + fi + fi + + echo "Cleaned $count Redis proxy test key(s)." +} +cleanup() { + echo "" + echo "=== Cleaning up ===" + for pid in $PROXY_A_PID $PROXY_B_PID $ORIGIN_PID $FFMPEG_PID; do + if kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in $PROXY_A_PID $PROXY_B_PID $ORIGIN_PID $FFMPEG_PID; do + if kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + done + cleanup_redis_state + echo "Cleanup done." +} +trap cleanup EXIT + +echo "=== E2E RTMP Proxy Redis Load Balancer Test ===" +echo "Workspace: $WORKSPACE" +echo "Redis: $REDIS_HOST:$REDIS_PORT db=$REDIS_DB" +echo "" + +# --- Pre-checks --- +if [[ ! -f "$SOURCE_FLV" ]]; then + echo "Error: test source not found: $SOURCE_FLV" >&2 + exit 1 +fi +if ! command -v ffmpeg &>/dev/null; then + echo "Error: ffmpeg not found in PATH" >&2 + exit 1 +fi +if ! command -v ffprobe &>/dev/null; then + echo "Error: ffprobe not found in PATH" >&2 + exit 1 +fi +if ! command -v redis-cli &>/dev/null; then + echo "Error: redis-cli not found in PATH" >&2 + echo "Install Redis on macOS with: brew install redis" >&2 + exit 1 +fi +if ! redis_cli ping 2>/dev/null | grep -q "PONG"; then + echo "Error: Redis is not available at $REDIS_HOST:$REDIS_PORT db=$REDIS_DB" >&2 + echo "Start Redis on macOS with: brew services start redis" >&2 + echo "Or run a foreground Redis with: redis-server" >&2 + exit 1 +fi + +# Origin ports (from origin1-for-proxy.conf). +ORIGIN_RTMP_PORT=19351 +ORIGIN_HTTP_PORT=8081 +ORIGIN_API_PORT=19851 +ORIGIN_RTC_PORT=8001 +ORIGIN_SRT_PORT=10081 + +# --- Step 0: Clean up stale state --- +# Remove stale SRS PID file that prevents restart. +rm -f "$WORKSPACE/trunk/objs/origin1.pid" +cleanup_redis_state +# Kill any leftover processes on our ports (proxy A + proxy B + origin). +ALL_PORTS="$PROXY_A_RTMP_PORT $PROXY_A_HTTP_API_PORT $PROXY_A_HTTP_SERVER_PORT $PROXY_A_WEBRTC_PORT $PROXY_A_SRT_PORT $PROXY_A_SYSTEM_API_PORT $PROXY_B_RTMP_PORT $PROXY_B_HTTP_API_PORT $PROXY_B_HTTP_SERVER_PORT $PROXY_B_WEBRTC_PORT $PROXY_B_SRT_PORT $PROXY_B_SYSTEM_API_PORT $ORIGIN_RTMP_PORT $ORIGIN_HTTP_PORT $ORIGIN_API_PORT $ORIGIN_RTC_PORT $ORIGIN_SRT_PORT" +for port in $ALL_PORTS; do + lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true +done +sleep 1 + +# --- Step 1: Build proxy --- +echo "=== Step 1: Building proxy ===" +cd "$WORKSPACE" +make -s 2>&1 +echo "Proxy built: $WORKSPACE/bin/srs-proxy" + +# --- Step 2: Build SRS origin (if not already built) --- +if [[ ! -f "$SRS_BINARY" ]]; then + echo "=== Step 2: Building SRS origin ===" + cd "$WORKSPACE/trunk" + ./configure && make 2>&1 | tail -3 + echo "SRS origin built: $SRS_BINARY" +else + echo "=== Step 2: SRS origin already built ===" +fi + +# --- Step 3: Start proxy A --- +echo "=== Step 3: Starting proxy A (RTMP :$PROXY_A_RTMP_PORT, System API :$PROXY_A_SYSTEM_API_PORT) ===" +cd "$WORKSPACE" +env PROXY_RTMP_SERVER=$PROXY_A_RTMP_PORT \ + PROXY_HTTP_API=$PROXY_A_HTTP_API_PORT \ + PROXY_HTTP_SERVER=$PROXY_A_HTTP_SERVER_PORT \ + PROXY_WEBRTC_SERVER=$PROXY_A_WEBRTC_PORT \ + PROXY_SRT_SERVER=$PROXY_A_SRT_PORT \ + PROXY_SYSTEM_API=$PROXY_A_SYSTEM_API_PORT \ + PROXY_LOAD_BALANCER_TYPE=redis \ + PROXY_REDIS_HOST="$REDIS_HOST" \ + PROXY_REDIS_PORT="$REDIS_PORT" \ + PROXY_REDIS_PASSWORD="$REDIS_PASSWORD" \ + PROXY_REDIS_DB="$REDIS_DB" \ + ./bin/srs-proxy >/tmp/srs-proxy-redis-a-e2e.log 2>&1 & +PROXY_A_PID=$! +echo "Proxy A PID: $PROXY_A_PID" +sleep 1 + +if ! kill -0 "$PROXY_A_PID" 2>/dev/null; then + echo "Error: proxy A failed to start. Logs:" >&2 + cat /tmp/srs-proxy-redis-a-e2e.log >&2 + exit 1 +fi +echo "Proxy A started." + +# --- Step 4: Start proxy B --- +echo "=== Step 4: Starting proxy B (RTMP :$PROXY_B_RTMP_PORT, System API :$PROXY_B_SYSTEM_API_PORT) ===" +cd "$WORKSPACE" +env PROXY_RTMP_SERVER=$PROXY_B_RTMP_PORT \ + PROXY_HTTP_API=$PROXY_B_HTTP_API_PORT \ + PROXY_HTTP_SERVER=$PROXY_B_HTTP_SERVER_PORT \ + PROXY_WEBRTC_SERVER=$PROXY_B_WEBRTC_PORT \ + PROXY_SRT_SERVER=$PROXY_B_SRT_PORT \ + PROXY_SYSTEM_API=$PROXY_B_SYSTEM_API_PORT \ + PROXY_LOAD_BALANCER_TYPE=redis \ + PROXY_REDIS_HOST="$REDIS_HOST" \ + PROXY_REDIS_PORT="$REDIS_PORT" \ + PROXY_REDIS_PASSWORD="$REDIS_PASSWORD" \ + PROXY_REDIS_DB="$REDIS_DB" \ + ./bin/srs-proxy >/tmp/srs-proxy-redis-b-e2e.log 2>&1 & +PROXY_B_PID=$! +echo "Proxy B PID: $PROXY_B_PID" +sleep 1 + +if ! kill -0 "$PROXY_B_PID" 2>/dev/null; then + echo "Error: proxy B failed to start. Logs:" >&2 + cat /tmp/srs-proxy-redis-b-e2e.log >&2 + exit 1 +fi +echo "Proxy B started." + +# --- Step 5: Start SRS origin --- +echo "=== Step 5: Starting SRS origin ===" +ulimit -n 10000 2>/dev/null || true +cd "$WORKSPACE/trunk" +./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-redis-e2e.log 2>&1 & +ORIGIN_PID=$! +echo "SRS origin PID: $ORIGIN_PID" + +# Wait for SRS to start and register with proxy A (heartbeat interval is 9s). +echo "Waiting for SRS origin to register with proxy A and Redis (up to 15s)..." +sleep 12 + +if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then + echo "Error: SRS origin failed to start. Logs:" >&2 + cat /tmp/srs-origin-redis-e2e.log >&2 + exit 1 +fi +if ! redis_cli --scan --pattern 'srs-proxy-server:*' | grep -q 'srs-proxy-server:'; then + echo "Error: SRS origin did not register in Redis. Proxy A logs:" >&2 + cat /tmp/srs-proxy-redis-a-e2e.log >&2 + exit 1 +fi +echo "SRS origin started and registered in Redis." + +# --- Step 6: Publish RTMP stream to proxy A --- +echo "=== Step 6: Publishing RTMP stream to proxy A ===" +ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \ + "rtmp://localhost:$PROXY_A_RTMP_PORT/live/livestream" >/tmp/srs-ffmpeg-redis-e2e.log 2>&1 & +FFMPEG_PID=$! +echo "FFmpeg publisher PID: $FFMPEG_PID" + +# Wait for stream to stabilize. +sleep 5 + +if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then + echo "Error: FFmpeg publisher failed. Logs:" >&2 + cat /tmp/srs-ffmpeg-redis-e2e.log >&2 + exit 1 +fi +echo "Stream publishing through proxy A." + +# --- Step 7: Verify RTMP playback through proxy B --- +echo "=== Step 7: Verifying RTMP playback through proxy B ===" +PROBE_OUTPUT=$(ffprobe -v error -show_streams \ + "rtmp://localhost:$PROXY_B_RTMP_PORT/live/livestream" 2>&1 || true) + +if echo "$PROBE_OUTPUT" | grep -q "codec_type=video"; then + echo "PASS: Video stream detected through proxy B." +else + echo "FAIL: No video stream detected through proxy B." >&2 + echo "ffprobe output:" >&2 + echo "$PROBE_OUTPUT" >&2 + echo "Proxy B logs:" >&2 + cat /tmp/srs-proxy-redis-b-e2e.log >&2 + exit 1 +fi + +if echo "$PROBE_OUTPUT" | grep -q "codec_type=audio"; then + echo "PASS: Audio stream detected through proxy B." +else + echo "FAIL: No audio stream detected through proxy B." >&2 + echo "ffprobe output:" >&2 + echo "$PROBE_OUTPUT" >&2 + echo "Proxy B logs:" >&2 + cat /tmp/srs-proxy-redis-b-e2e.log >&2 + exit 1 +fi + +echo "" +echo "=== E2E RTMP Proxy Redis Load Balancer Test PASSED ===" diff --git a/.openclaw/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh b/.openclaw/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh new file mode 100755 index 000000000..523284e98 --- /dev/null +++ b/.openclaw/skills/srs-develop/scripts/proxy-e2e-transmux-test.sh @@ -0,0 +1,261 @@ +#!/bin/bash +# E2E test for RTMP-to-multiple-protocol transmuxing through the proxy: +# starts one proxy with memory load balancer + one SRS origin, publishes one +# RTMP stream, then verifies RTMP, HTTP-FLV, HLS, and optional WebRTC WHEP +# playback through the proxy. +set -e + +SCRIPT_DIR="$(cd -P "$(dirname "$0")" && pwd)" +# Navigate: scripts/ -> srs-develop/ -> skills/ -> .openclaw/ -> srs +WORKSPACE="$(cd -P "$SCRIPT_DIR/../../../.." && pwd)" + +if [[ ! -f "$WORKSPACE/go.mod" ]]; then + echo "Error: go.mod not found in WORKSPACE: $WORKSPACE" >&2 + exit 1 +fi + +# Ports — use the same high ports as proxy-e2e-test.sh. +# The proxy starts ALL servers, so we must assign unique ports for each. +PROXY_RTMP_PORT=11935 +PROXY_HTTP_API_PORT=11985 +PROXY_HTTP_SERVER_PORT=18080 +PROXY_WEBRTC_PORT=18000 +PROXY_SRT_PORT=20080 +PROXY_SYSTEM_API_PORT=12025 + +# Origin ports (from origin1-for-proxy.conf). +ORIGIN_RTMP_PORT=19351 +ORIGIN_HTTP_PORT=8081 +ORIGIN_API_PORT=19851 +ORIGIN_RTC_PORT=8001 +ORIGIN_SRT_PORT=10081 + +SOURCE_FLV="$WORKSPACE/trunk/doc/source.flv" +SRS_BINARY="$WORKSPACE/trunk/objs/srs" +SRS_TEST_BINARY="$WORKSPACE/trunk/3rdparty/srs-bench/objs/srs_test" +STREAM_URL="live/livestream" + +# WebRTC requires the srs-bench regression test binary. Keep it enabled by +# default because it exercises the proxy WHEP API path; set to "off" to skip. +PROXY_TRANSMUX_TEST_RTC="${PROXY_TRANSMUX_TEST_RTC:-on}" + +# PIDs to clean up on exit. +PROXY_PID="" +ORIGIN_PID="" +FFMPEG_PID="" + +cleanup() { + echo "" + echo "=== Cleaning up ===" + for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in $PROXY_PID $ORIGIN_PID $FFMPEG_PID; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + done + echo "Cleanup done." +} +trap cleanup EXIT + +probe_has_audio_video() { + local name="$1" + local url="$2" + + echo "Verifying $name playback: $url" + local output + output=$(ffprobe -v error -show_streams "$url" 2>&1 || true) + + if echo "$output" | grep -q "codec_type=video"; then + echo "PASS: $name video stream detected." + else + echo "FAIL: $name no video stream detected." >&2 + echo "ffprobe output:" >&2 + echo "$output" >&2 + exit 1 + fi + + if echo "$output" | grep -q "codec_type=audio"; then + echo "PASS: $name audio stream detected." + else + echo "FAIL: $name no audio stream detected." >&2 + echo "ffprobe output:" >&2 + echo "$output" >&2 + exit 1 + fi +} + +wait_for_hls_playlist() { + local url="$1" + local deadline=45 + + echo "Waiting for HLS playlist to be generated (up to ${deadline}s): $url" + for ((i = 1; i <= deadline; i++)); do + if curl -fsS "$url" 2>/dev/null | grep -q "#EXTM3U"; then + echo "HLS playlist is ready." + return + fi + sleep 1 + done + + echo "FAIL: HLS playlist was not generated in ${deadline}s." >&2 + echo "Last HLS response:" >&2 + curl -v "$url" 2>&1 || true + exit 1 +} + +echo "=== E2E RTMP Transmux Proxy Test ===" +echo "Workspace: $WORKSPACE" +echo "Stream: $STREAM_URL" +echo "" + +# --- Pre-checks --- +if [[ ! -f "$SOURCE_FLV" ]]; then + echo "Error: test source not found: $SOURCE_FLV" >&2 + exit 1 +fi +if ! command -v ffmpeg &>/dev/null; then + echo "Error: ffmpeg not found in PATH" >&2 + exit 1 +fi +if ! command -v ffprobe &>/dev/null; then + echo "Error: ffprobe not found in PATH" >&2 + exit 1 +fi +if ! command -v curl &>/dev/null; then + echo "Error: curl not found in PATH" >&2 + exit 1 +fi + +# --- Step 0: Clean up stale state --- +rm -f "$WORKSPACE/trunk/objs/origin1.pid" +ALL_PORTS="$PROXY_RTMP_PORT $PROXY_HTTP_API_PORT $PROXY_HTTP_SERVER_PORT $PROXY_WEBRTC_PORT $PROXY_SRT_PORT $PROXY_SYSTEM_API_PORT $ORIGIN_RTMP_PORT $ORIGIN_HTTP_PORT $ORIGIN_API_PORT $ORIGIN_RTC_PORT $ORIGIN_SRT_PORT" +for port in $ALL_PORTS; do + lsof -ti :"$port" 2>/dev/null | xargs kill 2>/dev/null || true +done +sleep 1 + +# --- Step 1: Build proxy --- +echo "=== Step 1: Building proxy ===" +cd "$WORKSPACE" +make -s 2>&1 +echo "Proxy built: $WORKSPACE/bin/srs-proxy" + +# --- Step 2: Build SRS origin (if not already built) --- +if [[ ! -f "$SRS_BINARY" ]]; then + echo "=== Step 2: Building SRS origin ===" + cd "$WORKSPACE/trunk" + ./configure && make 2>&1 | tail -3 + echo "SRS origin built: $SRS_BINARY" +else + echo "=== Step 2: SRS origin already built ===" +fi + +# --- Step 3: Build WebRTC regression tool (if enabled and needed) --- +if [[ "$PROXY_TRANSMUX_TEST_RTC" == "on" && ! -x "$SRS_TEST_BINARY" ]]; then + echo "=== Step 3: Building WebRTC regression tool ===" + cd "$WORKSPACE/trunk/3rdparty/srs-bench" + make ./objs/srs_test + echo "WebRTC regression tool built: $SRS_TEST_BINARY" +else + echo "=== Step 3: WebRTC regression tool build skipped ===" +fi + +# --- Step 4: Start proxy --- +echo "=== Step 4: Starting proxy (memory LB) ===" +cd "$WORKSPACE" +env PROXY_RTMP_SERVER=$PROXY_RTMP_PORT \ + PROXY_HTTP_API=$PROXY_HTTP_API_PORT \ + PROXY_HTTP_SERVER=$PROXY_HTTP_SERVER_PORT \ + PROXY_WEBRTC_SERVER=$PROXY_WEBRTC_PORT \ + PROXY_SRT_SERVER=$PROXY_SRT_PORT \ + PROXY_SYSTEM_API=$PROXY_SYSTEM_API_PORT \ + PROXY_LOAD_BALANCER_TYPE=memory \ + ./bin/srs-proxy >/tmp/srs-proxy-transmux-e2e.log 2>&1 & +PROXY_PID=$! +echo "Proxy PID: $PROXY_PID" +sleep 1 + +if ! kill -0 "$PROXY_PID" 2>/dev/null; then + echo "Error: proxy failed to start. Logs:" >&2 + cat /tmp/srs-proxy-transmux-e2e.log >&2 + exit 1 +fi +echo "Proxy started." + +# --- Step 5: Start SRS origin --- +echo "=== Step 5: Starting SRS origin ===" +ulimit -n 10000 2>/dev/null || true +cd "$WORKSPACE/trunk" +./objs/srs -c conf/origin1-for-proxy.conf >/tmp/srs-origin-transmux-e2e.log 2>&1 & +ORIGIN_PID=$! +echo "SRS origin PID: $ORIGIN_PID" + +# Wait for SRS to start and register with proxy (heartbeat interval is 9s). +echo "Waiting for SRS origin to register with proxy (up to 15s)..." +sleep 12 + +if ! kill -0 "$ORIGIN_PID" 2>/dev/null; then + echo "Error: SRS origin failed to start. Logs:" >&2 + cat /tmp/srs-origin-transmux-e2e.log >&2 + exit 1 +fi +echo "SRS origin started and registered." + +# --- Step 6: Publish RTMP stream --- +echo "=== Step 6: Publishing RTMP stream to proxy ===" +ffmpeg -stream_loop -1 -re -i "$SOURCE_FLV" -c copy -f flv \ + "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" >/tmp/srs-ffmpeg-transmux-e2e.log 2>&1 & +FFMPEG_PID=$! +echo "FFmpeg publisher PID: $FFMPEG_PID" + +# Wait for stream to stabilize and for the origin to start muxing HTTP/HLS. +sleep 5 + +if ! kill -0 "$FFMPEG_PID" 2>/dev/null; then + echo "Error: FFmpeg publisher failed. Logs:" >&2 + cat /tmp/srs-ffmpeg-transmux-e2e.log >&2 + exit 1 +fi +echo "Stream publishing." + +# --- Step 7: Verify RTMP playback --- +echo "=== Step 7: Verifying RTMP playback via proxy ===" +probe_has_audio_video "RTMP" "rtmp://localhost:$PROXY_RTMP_PORT/$STREAM_URL" + +# --- Step 8: Verify HTTP-FLV playback --- +echo "=== Step 8: Verifying HTTP-FLV playback via proxy ===" +probe_has_audio_video "HTTP-FLV" "http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.flv" + +# --- Step 9: Verify HLS playback --- +echo "=== Step 9: Verifying HLS playback via proxy ===" +HLS_URL="http://localhost:$PROXY_HTTP_SERVER_PORT/$STREAM_URL.m3u8" +wait_for_hls_playlist "$HLS_URL" +probe_has_audio_video "HLS" "$HLS_URL" + +# --- Step 10: Verify WebRTC WHEP signaling via proxy --- +echo "=== Step 10: Verifying WebRTC WHEP signaling via proxy ===" +if [[ "$PROXY_TRANSMUX_TEST_RTC" == "on" ]]; then + if [[ ! -x "$SRS_TEST_BINARY" ]]; then + echo "FAIL: WebRTC regression tool not found: $SRS_TEST_BINARY" >&2 + exit 1 + fi + + cd "$WORKSPACE/trunk/3rdparty/srs-bench" + "$SRS_TEST_BINARY" \ + -test.run '^TestBugfix2371_RTMP2RTC_PlayWithNack$' \ + -srs-server "127.0.0.1:$PROXY_HTTP_API_PORT" \ + -srs-stream "/$STREAM_URL" \ + -srs-timeout 10000 + echo "PASS: WebRTC WHEP signaling succeeded." +else + echo "SKIP: WebRTC WHEP test disabled by PROXY_TRANSMUX_TEST_RTC=$PROXY_TRANSMUX_TEST_RTC." +fi + +echo "" +echo "NOTE: RTSP is not tested here because the Go proxy currently has no RTSP listener." +echo "=== E2E RTMP Transmux Proxy Test PASSED ===" diff --git a/internal/bootstrap/proxy.go b/internal/bootstrap/proxy.go index 5bd8fd7e1..f59522b6d 100644 --- a/internal/bootstrap/proxy.go +++ b/internal/bootstrap/proxy.go @@ -12,7 +12,7 @@ import ( "srsx/internal/errors" "srsx/internal/lb" "srsx/internal/logger" - "srsx/internal/protocol" + "srsx/internal/server" "srsx/internal/signal" "srsx/internal/version" ) @@ -99,46 +99,46 @@ func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment // startServers initializes and starts all protocol servers. func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) error { // Start the RTMP server. - srsRTMPServer := protocol.NewSRSRTMPServer(environment) - if err := srsRTMPServer.Run(ctx); err != nil { + rtmpServer := server.NewRTMPServer(environment) + if err := rtmpServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtmp server") } - defer srsRTMPServer.Close() + defer rtmpServer.Close() // Start the WebRTC server. - srsWebRTCServer := protocol.NewSRSWebRTCServer(environment) - if err := srsWebRTCServer.Run(ctx); err != nil { + webRTCServer := server.NewWebRTCServer(environment) + if err := webRTCServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtc server") } - defer srsWebRTCServer.Close() + defer webRTCServer.Close() // Start the HTTP API server. - srsHTTPAPIServer := protocol.NewSRSHTTPAPIServer(environment, gracefulQuitTimeout, srsWebRTCServer) - if err := srsHTTPAPIServer.Run(ctx); err != nil { + httpAPIServer := server.NewHTTPAPIServer(environment, gracefulQuitTimeout, webRTCServer) + if err := httpAPIServer.Run(ctx); err != nil { return errors.Wrapf(err, "http api server") } - defer srsHTTPAPIServer.Close() + defer httpAPIServer.Close() // Start the SRT server. - srsSRTServer := protocol.NewSRSSRTServer(environment) + srsSRTServer := server.NewSRSSRTServer(environment) if err := srsSRTServer.Run(ctx); err != nil { return errors.Wrapf(err, "srt server") } defer srsSRTServer.Close() // Start the System API server. - systemAPI := protocol.NewSystemAPI(environment, gracefulQuitTimeout) + systemAPI := server.NewSystemAPI(environment, gracefulQuitTimeout) if err := systemAPI.Run(ctx); err != nil { return errors.Wrapf(err, "system api server") } defer systemAPI.Close() // Start the HTTP web server. - srsHTTPStreamServer := protocol.NewSRSHTTPStreamServer(environment, gracefulQuitTimeout) - if err := srsHTTPStreamServer.Run(ctx); err != nil { + httpStreamServer := server.NewHTTPStreamServer(environment, gracefulQuitTimeout) + if err := httpStreamServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } - defer srsHTTPStreamServer.Close() + defer httpStreamServer.Close() // Wait for the main loop to quit. <-ctx.Done() diff --git a/internal/rtmp/amf0.go b/internal/rtmp/amf0.go index 86a476308..7fd2c7a3d 100644 --- a/internal/rtmp/amf0.go +++ b/internal/rtmp/amf0.go @@ -95,7 +95,7 @@ var createBuffer = func() amf0Buffer { } // All AMF0 things. -type amf0Any interface { +type Amf0Any interface { // Binary marshaler and unmarshaler. encoding.BinaryUnmarshaler encoding.BinaryMarshaler @@ -106,59 +106,83 @@ type amf0Any interface { amf0Marker() amf0Marker } -type amf0Converter struct { - from amf0Any +type Amf0Converter interface { + ToNumber() Amf0Number + ToBoolean() Amf0Boolean + ToString() Amf0String + ToObject() Amf0Object + ToNull() Amf0Null + ToUndefined() Amf0Undefined + ToEcmaArray() Amf0EcmaArray + ToStrictArray() Amf0StrictArray } -func NewAmf0Converter(from amf0Any) *amf0Converter { +type amf0Converter struct { + from Amf0Any +} + +func NewAmf0Converter(from Amf0Any) Amf0Converter { return &amf0Converter{from: from} } -func (v *amf0Converter) ToNumber() *amf0Number { - return amf0AnyTo[*amf0Number](v.from) -} - -func (v *amf0Converter) ToBoolean() *amf0Boolean { - return amf0AnyTo[*amf0Boolean](v.from) -} - -func (v *amf0Converter) ToString() *amf0String { - return amf0AnyTo[*amf0String](v.from) -} - -func (v *amf0Converter) ToObject() *amf0Object { - return amf0AnyTo[*amf0Object](v.from) -} - -func (v *amf0Converter) ToNull() *amf0Null { - return amf0AnyTo[*amf0Null](v.from) -} - -func (v *amf0Converter) ToUndefined() *amf0Undefined { - return amf0AnyTo[*amf0Undefined](v.from) -} - -func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { - return amf0AnyTo[*amf0EcmaArray](v.from) -} - -func (v *amf0Converter) ToStrictArray() *amf0StrictArray { - return amf0AnyTo[*amf0StrictArray](v.from) -} - -// Convert any to specified object. -func amf0AnyTo[T amf0Any](a amf0Any) T { - var to T - if a != nil { - if v, ok := a.(T); ok { - return v - } +func (v *amf0Converter) ToNumber() Amf0Number { + if r, ok := v.from.(Amf0Number); ok { + return r } - return to + return nil +} + +func (v *amf0Converter) ToBoolean() Amf0Boolean { + if r, ok := v.from.(Amf0Boolean); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToString() Amf0String { + if r, ok := v.from.(Amf0String); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToObject() Amf0Object { + if r, ok := v.from.(Amf0Object); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToNull() Amf0Null { + if r, ok := v.from.(Amf0Null); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToUndefined() Amf0Undefined { + if r, ok := v.from.(Amf0Undefined); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToEcmaArray() Amf0EcmaArray { + if r, ok := v.from.(Amf0EcmaArray); ok { + return r + } + return nil +} + +func (v *amf0Converter) ToStrictArray() Amf0StrictArray { + if r, ok := v.from.(Amf0StrictArray); ok { + return r + } + return nil } // Discovery the amf0 object from the bytes b. -func Amf0Discovery(p []byte) (a amf0Any, err error) { +func Amf0Discovery(p []byte) (a Amf0Any, err error) { if len(p) < 1 { return nil, errors.Errorf("require 1 bytes only %v", len(p)) } @@ -228,14 +252,24 @@ func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { return } +// Amf0Number is the AMF0 number type. +type Amf0Number interface { + Amf0Any + Float64() float64 +} + // The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type type amf0Number float64 -func NewAmf0Number(f float64) *amf0Number { +func NewAmf0Number(f float64) Amf0Number { v := amf0Number(f) return &v } +func (v *amf0Number) Float64() float64 { + return float64(*v) +} + func (v *amf0Number) amf0Marker() amf0Marker { return amf0MarkerNumber } @@ -266,14 +300,28 @@ func (v *amf0Number) MarshalBinary() (data []byte, err error) { return } +// Amf0String is the AMF0 string type. +type Amf0String interface { + Amf0Any + String() string +} + // The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type type amf0String string -func NewAmf0String(s string) *amf0String { +func NewAmf0String(s string) Amf0String { + return newAmf0String(s) +} + +func newAmf0String(s string) *amf0String { v := amf0String(s) return &v } +func (v *amf0String) String() string { + return string(*v) +} + func (v *amf0String) amf0Marker() amf0Marker { return amf0MarkerString } @@ -344,7 +392,7 @@ func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { // Use array for object and ecma array, to keep the original order. type amf0Property struct { key amf0UTF8 - value amf0Any + value Amf0Any } // The object-like AMF0 structure, like object and ecma array and strict array. @@ -367,7 +415,7 @@ func (v *amf0ObjectBase) Size() int { return size } -func (v *amf0ObjectBase) Get(key string) amf0Any { +func (v *amf0ObjectBase) Get(key string) Amf0Any { v.lock.Lock() defer v.lock.Unlock() @@ -380,7 +428,7 @@ func (v *amf0ObjectBase) Get(key string) amf0Any { return nil } -func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { +func (v *amf0ObjectBase) Set(key string, value Amf0Any) *amf0ObjectBase { v.lock.Lock() defer v.lock.Unlock() @@ -411,21 +459,21 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) return errors.Errorf("maxElems=%v with eof", maxElems) } - readOne := func() (amf0UTF8, amf0Any, error) { + readOne := func() (amf0UTF8, Amf0Any, error) { var u amf0UTF8 if err = u.UnmarshalBinary(p); err != nil { return "", nil, errors.WithMessage(err, "prop name") } p = p[u.Size():] - var a amf0Any + var a Amf0Any if a, err = Amf0Discovery(p); err != nil { return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) } return u, a, nil } - pushOne := func(u amf0UTF8, a amf0Any) error { + pushOne := func(u amf0UTF8, a Amf0Any) error { // For object property, consume the whole bytes. if err = a.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) @@ -494,13 +542,24 @@ func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { return } +// Amf0Object is the AMF0 object type. +type Amf0Object interface { + Amf0Any + Get(key string) Amf0Any + Set(key string, value Amf0Any) Amf0Object +} + // The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type type amf0Object struct { amf0ObjectBase eof amf0ObjectEOF } -func NewAmf0Object() *amf0Object { +func NewAmf0Object() Amf0Object { + return newAmf0Object() +} + +func newAmf0Object() *amf0Object { v := &amf0Object{} v.properties = []*amf0Property{} return v @@ -510,6 +569,15 @@ func (v *amf0Object) amf0Marker() amf0Marker { return amf0MarkerObject } +func (v *amf0Object) Get(key string) Amf0Any { + return v.amf0ObjectBase.Get(key) +} + +func (v *amf0Object) Set(key string, value Amf0Any) Amf0Object { + v.amf0ObjectBase.Set(key, value) + return v +} + func (v *amf0Object) Size() int { return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() } @@ -542,17 +610,22 @@ func (v *amf0Object) MarshalBinary() (data []byte, err error) { return nil, errors.WithMessage(err, "marshal") } - var pb []byte - if pb, err = v.eof.MarshalBinary(); err != nil { + if pb, err := v.eof.MarshalBinary(); err != nil { return nil, errors.WithMessage(err, "marshal") - } - if _, err = b.Write(pb); err != nil { + } else if _, err = b.Write(pb); err != nil { return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil } +// Amf0EcmaArray is the AMF0 ECMA array type. +type Amf0EcmaArray interface { + Amf0Any + Get(key string) Amf0Any + Set(key string, value Amf0Any) Amf0EcmaArray +} + // The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type type amf0EcmaArray struct { amf0ObjectBase @@ -560,7 +633,11 @@ type amf0EcmaArray struct { eof amf0ObjectEOF } -func NewAmf0EcmaArray() *amf0EcmaArray { +func NewAmf0EcmaArray() Amf0EcmaArray { + return newAmf0EcmaArray() +} + +func newAmf0EcmaArray() *amf0EcmaArray { v := &amf0EcmaArray{} v.properties = []*amf0Property{} return v @@ -570,6 +647,15 @@ func (v *amf0EcmaArray) amf0Marker() amf0Marker { return amf0MarkerEcmaArray } +func (v *amf0EcmaArray) Get(key string) Amf0Any { + return v.amf0ObjectBase.Get(key) +} + +func (v *amf0EcmaArray) Set(key string, value Amf0Any) Amf0EcmaArray { + v.amf0ObjectBase.Set(key, value) + return v +} + func (v *amf0EcmaArray) Size() int { return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() } @@ -606,24 +692,29 @@ func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { return nil, errors.WithMessage(err, "marshal") } - var pb []byte - if pb, err = v.eof.MarshalBinary(); err != nil { + if pb, err := v.eof.MarshalBinary(); err != nil { return nil, errors.WithMessage(err, "marshal") - } - if _, err = b.Write(pb); err != nil { + } else if _, err = b.Write(pb); err != nil { return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil } +// Amf0StrictArray is the AMF0 strict array type. +type Amf0StrictArray interface { + Amf0Any + Get(key string) Amf0Any + Set(key string, value Amf0Any) Amf0StrictArray +} + // The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type type amf0StrictArray struct { amf0ObjectBase count uint32 } -func NewAmf0StrictArray() *amf0StrictArray { +func NewAmf0StrictArray() Amf0StrictArray { v := &amf0StrictArray{} v.properties = []*amf0Property{} return v @@ -633,6 +724,15 @@ func (v *amf0StrictArray) amf0Marker() amf0Marker { return amf0MarkerStrictArray } +func (v *amf0StrictArray) Get(key string) Amf0Any { + return v.amf0ObjectBase.Get(key) +} + +func (v *amf0StrictArray) Set(key string, value Amf0Any) Amf0StrictArray { + v.amf0ObjectBase.Set(key, value) + return v +} + func (v *amf0StrictArray) Size() int { return int(1) + 4 + v.amf0ObjectBase.Size() } @@ -708,36 +808,56 @@ func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { return []byte{byte(v.target)}, nil } +// Amf0Null is the AMF0 null type. +type Amf0Null interface { + Amf0Any +} + // The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type type amf0Null struct { amf0SingleMarkerObject } -func NewAmf0Null() *amf0Null { +func NewAmf0Null() Amf0Null { v := amf0Null{} v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) return &v } +// Amf0Undefined is the AMF0 undefined type. +type Amf0Undefined interface { + Amf0Any +} + // The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type type amf0Undefined struct { amf0SingleMarkerObject } -func NewAmf0Undefined() amf0Any { +func NewAmf0Undefined() Amf0Undefined { v := amf0Undefined{} v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) return &v } +// Amf0Boolean is the public typed view of an AMF0 boolean. +type Amf0Boolean interface { + Amf0Any + Bool() bool +} + // The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type type amf0Boolean bool -func NewAmf0Boolean(b bool) amf0Any { +func NewAmf0Boolean(b bool) Amf0Boolean { v := amf0Boolean(b) return &v } +func (v *amf0Boolean) Bool() bool { + return bool(*v) +} + func (v *amf0Boolean) amf0Marker() amf0Marker { return amf0MarkerBoolean } diff --git a/internal/rtmp/amf0_test.go b/internal/rtmp/amf0_test.go new file mode 100644 index 000000000..a2c240360 --- /dev/null +++ b/internal/rtmp/amf0_test.go @@ -0,0 +1,509 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "fmt" + "math" + "strings" + "testing" +) + +func TestAmf0MarkerString(t *testing.T) { + for _, tt := range []struct { + marker amf0Marker + want string + }{ + {amf0MarkerNumber, "Amf0Number"}, + {amf0MarkerBoolean, "amf0Boolean"}, + {amf0MarkerString, "Amf0String"}, + {amf0MarkerObject, "Amf0Object"}, + {amf0MarkerMovieClip, "MovieClip"}, + {amf0MarkerNull, "Null"}, + {amf0MarkerUndefined, "Undefined"}, + {amf0MarkerReference, "Reference"}, + {amf0MarkerEcmaArray, "EcmaArray"}, + {amf0MarkerObjectEnd, "ObjectEnd"}, + {amf0MarkerStrictArray, "StrictArray"}, + {amf0MarkerDate, "Date"}, + {amf0MarkerLongString, "LongString"}, + {amf0MarkerUnsupported, "Unsupported"}, + {amf0MarkerRecordSet, "RecordSet"}, + {amf0MarkerXmlDocument, "XmlDocument"}, + {amf0MarkerTypedObject, "TypedObject"}, + {amf0MarkerAvmPlusObject, "AvmPlusObject"}, + {amf0MarkerForbidden, "Forbidden"}, + {amf0Marker(0xee), "Forbidden"}, + } { + if got := tt.marker.String(); got != tt.want { + t.Fatalf("marker=%#x String()=%v, want %v", byte(tt.marker), got, tt.want) + } + } +} + +func TestAmf0Discovery(t *testing.T) { + for _, tt := range []struct { + name string + data []byte + ok func(Amf0Any) bool + }{ + {"number", []byte{byte(amf0MarkerNumber)}, func(v Amf0Any) bool { _, ok := v.(Amf0Number); return ok }}, + {"boolean", []byte{byte(amf0MarkerBoolean)}, func(v Amf0Any) bool { _, ok := v.(Amf0Boolean); return ok }}, + {"string", []byte{byte(amf0MarkerString)}, func(v Amf0Any) bool { _, ok := v.(Amf0String); return ok }}, + {"object", []byte{byte(amf0MarkerObject)}, func(v Amf0Any) bool { _, ok := v.(Amf0Object); return ok }}, + {"null", []byte{byte(amf0MarkerNull)}, func(v Amf0Any) bool { _, ok := v.(Amf0Null); return ok }}, + {"undefined", []byte{byte(amf0MarkerUndefined)}, func(v Amf0Any) bool { _, ok := v.(Amf0Undefined); return ok }}, + {"ecma-array", []byte{byte(amf0MarkerEcmaArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0EcmaArray); return ok }}, + {"object-end", []byte{byte(amf0MarkerObjectEnd)}, func(v Amf0Any) bool { _, ok := v.(*amf0ObjectEOF); return ok }}, + {"strict-array", []byte{byte(amf0MarkerStrictArray)}, func(v Amf0Any) bool { _, ok := v.(Amf0StrictArray); return ok }}, + } { + t.Run(tt.name, func(t *testing.T) { + value, err := Amf0Discovery(tt.data) + if err != nil { + t.Fatalf("Amf0Discovery() err=%v", err) + } + if !tt.ok(value) { + t.Fatalf("Amf0Discovery()=%T", value) + } + }) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerReference)}, {byte(amf0MarkerDate)}, {byte(amf0MarkerForbidden)}} { + if value, err := Amf0Discovery(data); err == nil || value != nil { + t.Fatalf("Amf0Discovery(%v) value=%T, err=%v, want error", data, value, err) + } + } +} + +func TestAmf0Converter(t *testing.T) { + values := []struct { + name string + in Amf0Any + ok func(Amf0Converter) bool + }{ + {"number", NewAmf0Number(1), func(c Amf0Converter) bool { return c.ToNumber() != nil }}, + {"boolean", NewAmf0Boolean(true), func(c Amf0Converter) bool { return c.ToBoolean() != nil }}, + {"string", NewAmf0String("v"), func(c Amf0Converter) bool { return c.ToString() != nil }}, + {"object", NewAmf0Object(), func(c Amf0Converter) bool { return c.ToObject() != nil }}, + {"null", NewAmf0Null(), func(c Amf0Converter) bool { return c.ToNull() != nil }}, + {"undefined", NewAmf0Undefined(), func(c Amf0Converter) bool { return c.ToUndefined() != nil }}, + {"ecma-array", NewAmf0EcmaArray(), func(c Amf0Converter) bool { return c.ToEcmaArray() != nil }}, + {"strict-array", NewAmf0StrictArray(), func(c Amf0Converter) bool { return c.ToStrictArray() != nil }}, + } + + for _, tt := range values { + t.Run(tt.name, func(t *testing.T) { + converter := NewAmf0Converter(tt.in) + if !tt.ok(converter) { + t.Fatalf("expected successful conversion for %T", tt.in) + } + }) + } + + nilConverter := NewAmf0Converter(nil) + if nilConverter.ToNumber() != nil || nilConverter.ToBoolean() != nil || nilConverter.ToString() != nil || + nilConverter.ToObject() != nil || nilConverter.ToNull() != nil || nilConverter.ToUndefined() != nil || + nilConverter.ToEcmaArray() != nil || nilConverter.ToStrictArray() != nil { + t.Fatal("nil converter should not convert") + } +} + +func TestAmf0UTF8(t *testing.T) { + var value amf0UTF8 = "hello" + b, err := value.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + if value.Size() != len(b) { + t.Fatalf("Size()=%v, len=%v", value.Size(), len(b)) + } + + var decoded amf0UTF8 + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if decoded != value { + t.Fatalf("decoded=%v, want %v", decoded, value) + } + + for _, data := range [][]byte{{0x00}, {0x00, 0x05, 'h'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0Number(t *testing.T) { + number := NewAmf0Number(math.Pi) + if number.Size() != 9 || number.(*amf0Number).amf0Marker() != amf0MarkerNumber { + t.Fatalf("unexpected number metadata") + } + + b, err := number.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0Number(0) + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := decoded.Float64(); got != math.Pi { + t.Fatalf("Float64()=%v, want %v", got, math.Pi) + } + + for _, data := range [][]byte{{byte(amf0MarkerNumber)}, append([]byte{byte(amf0MarkerString)}, b[1:]...)} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0Boolean(t *testing.T) { + for _, want := range []bool{false, true} { + boolean := NewAmf0Boolean(want) + if boolean.Size() != 2 || boolean.(*amf0Boolean).amf0Marker() != amf0MarkerBoolean { + t.Fatalf("unexpected boolean metadata") + } + + b, err := boolean.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0Boolean(!want) + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := decoded.Bool(); got != want { + t.Fatalf("Bool()=%v, want %v", got, want) + } + } + + decoded := NewAmf0Boolean(false) + for _, data := range [][]byte{{byte(amf0MarkerBoolean)}, {byte(amf0MarkerNumber), 1}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0String(t *testing.T) { + value := NewAmf0String("hello") + if value.Size() != 8 || value.(*amf0String).amf0Marker() != amf0MarkerString { + t.Fatalf("unexpected string metadata") + } + + b, err := value.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0String("") + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := decoded.String(); got != "hello" { + t.Fatalf("String()=%v, want hello", got) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerNumber), 0, 0}, {byte(amf0MarkerString), 0, 5, 'h'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0ObjectEOF(t *testing.T) { + eof := &amf0ObjectEOF{} + if eof.Size() != 3 || eof.amf0Marker() != amf0MarkerObjectEnd { + t.Fatalf("unexpected eof metadata") + } + + b, err := eof.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + if !bytes.Equal(b, []byte{0, 0, 9}) { + t.Fatalf("MarshalBinary()=%v", b) + } + for _, data := range [][]byte{b, {0, 0, 9, 1}} { + if err := eof.UnmarshalBinary(data); err != nil { + t.Fatalf("UnmarshalBinary(%v) err=%v", data, err) + } + } + for _, data := range [][]byte{{0, 0}, {0, 1, 9}} { + if err := eof.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0Object(t *testing.T) { + object := NewAmf0Object(). + Set("name", NewAmf0String("stream")). + Set("code", NewAmf0Number(100)). + Set("ok", NewAmf0Boolean(true)) + object.Set("code", NewAmf0Number(200)) + + if object.(*amf0Object).amf0Marker() != amf0MarkerObject || object.Size() == 0 { + t.Fatalf("unexpected object metadata") + } + if object.Get("missing") != nil { + t.Fatal("missing property should be nil") + } + + b, err := object.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0Object() + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" { + t.Fatalf("name=%v", got) + } + if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 200 { + t.Fatalf("code=%v", got) + } + if got := NewAmf0Converter(decoded.Get("ok")).ToBoolean().Bool(); !got { + t.Fatalf("ok=%v", got) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerString)}, {byte(amf0MarkerObject), 0, 4, 'n'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } + + base := &amf0ObjectBase{} + if err := base.unmarshal(nil, false, -1); err == nil { + t.Fatal("unmarshal without eof and negative maxElems should fail") + } + if err := base.unmarshal(nil, true, 0); err == nil { + t.Fatal("unmarshal with eof and non-negative maxElems should fail") + } +} + +func TestAmf0EcmaArray(t *testing.T) { + array := NewAmf0EcmaArray(). + Set("name", NewAmf0String("stream")). + Set("code", NewAmf0Number(100)) + + if array.(*amf0EcmaArray).amf0Marker() != amf0MarkerEcmaArray || array.Size() == 0 { + t.Fatalf("unexpected ecma array metadata") + } + if array.Get("missing") != nil { + t.Fatal("missing property should be nil") + } + + b, err := array.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0EcmaArray() + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" { + t.Fatalf("name=%v", got) + } + if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 { + t.Fatalf("code=%v", got) + } + + for _, data := range [][]byte{{}, {byte(amf0MarkerEcmaArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerEcmaArray), 0, 0, 0, 0, 0, 4, 'n'}} { + if err := decoded.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0StrictArray(t *testing.T) { + array := NewAmf0StrictArray(). + Set("name", NewAmf0String("stream")). + Set("code", NewAmf0Number(100)) + array.(*amf0StrictArray).count = 2 + + if array.(*amf0StrictArray).amf0Marker() != amf0MarkerStrictArray || array.Size() == 0 { + t.Fatalf("unexpected strict array metadata") + } + if array.Get("missing") != nil { + t.Fatal("missing property should be nil") + } + + b, err := array.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + + decoded := NewAmf0StrictArray() + if err := decoded.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + if got := NewAmf0Converter(decoded.Get("name")).ToString().String(); got != "stream" { + t.Fatalf("name=%v", got) + } + if got := NewAmf0Converter(decoded.Get("code")).ToNumber().Float64(); got != 100 { + t.Fatalf("code=%v", got) + } + + empty := append([]byte{byte(amf0MarkerStrictArray)}, 0, 0, 0, 0) + if err := decoded.UnmarshalBinary(empty); err != nil { + t.Fatalf("UnmarshalBinary(empty) err=%v", err) + } + for _, data := range [][]byte{{}, {byte(amf0MarkerStrictArray), 0}, {byte(amf0MarkerString), 0, 0, 0, 0}, {byte(amf0MarkerStrictArray), 0, 0, 0, 1, 0, 4, 'n'}} { + if err := NewAmf0StrictArray().UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } +} + +func TestAmf0SingleMarkerObjects(t *testing.T) { + for _, tt := range []struct { + name string + value Amf0Any + marker amf0Marker + }{ + {"null", NewAmf0Null(), amf0MarkerNull}, + {"undefined", NewAmf0Undefined(), amf0MarkerUndefined}, + } { + t.Run(tt.name, func(t *testing.T) { + if tt.value.Size() != 1 || tt.value.amf0Marker() != tt.marker { + t.Fatalf("unexpected metadata") + } + b, err := tt.value.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary() err=%v", err) + } + if err := tt.value.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary() err=%v", err) + } + for _, data := range [][]byte{{}, {byte(amf0MarkerString)}} { + if err := tt.value.UnmarshalBinary(data); err == nil { + t.Fatalf("UnmarshalBinary(%v) should fail", data) + } + } + }) + } +} + +type errorAmf0Buffer struct { + writeByteErr bool + writeErr bool +} + +func (v *errorAmf0Buffer) Bytes() []byte { + return nil +} + +func (v *errorAmf0Buffer) WriteByte(byte) error { + if v.writeByteErr { + return fmt.Errorf("write byte") + } + return nil +} + +func (v *errorAmf0Buffer) Write([]byte) (int, error) { + if v.writeErr { + return 0, fmt.Errorf("write") + } + return 0, nil +} + +type errorAmf0Any struct { + Amf0Any +} + +func (v *errorAmf0Any) Size() int { + return 1 +} + +func (v *errorAmf0Any) MarshalBinary() ([]byte, error) { + return nil, fmt.Errorf("marshal") +} + +func (v *errorAmf0Any) UnmarshalBinary([]byte) error { + return nil +} + +func (v *errorAmf0Any) amf0Marker() amf0Marker { + return amf0MarkerNumber +} + +func TestAmf0MarshalErrors(t *testing.T) { + originalCreateBuffer := createBuffer + defer func() { createBuffer = originalCreateBuffer }() + + for _, tt := range []struct { + name string + make func() Amf0Any + }{ + {"object", func() Amf0Any { return NewAmf0Object() }}, + {"ecma-array", func() Amf0Any { return NewAmf0EcmaArray() }}, + {"strict-array", func() Amf0Any { return NewAmf0StrictArray() }}, + } { + t.Run(tt.name+" write-byte", func(t *testing.T) { + createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeByteErr: true} } + if _, err := tt.make().MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() should fail") + } + }) + + t.Run(tt.name+" write-prop", func(t *testing.T) { + createBuffer = func() amf0Buffer { return &errorAmf0Buffer{writeErr: true} } + value := tt.make() + switch v := value.(type) { + case Amf0Object: + v.Set("name", NewAmf0String("stream")) + case Amf0EcmaArray: + v.Set("name", NewAmf0String("stream")) + case Amf0StrictArray: + v.Set("name", NewAmf0String("stream")) + v.(*amf0StrictArray).count = 1 + } + if _, err := value.MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() should fail") + } + }) + } + + createBuffer = originalCreateBuffer + for _, tt := range []struct { + name string + make func() Amf0Any + }{ + {"object", func() Amf0Any { return NewAmf0Object().Set("bad", &errorAmf0Any{}) }}, + {"ecma-array", func() Amf0Any { return NewAmf0EcmaArray().Set("bad", &errorAmf0Any{}) }}, + {"strict-array", func() Amf0Any { + value := NewAmf0StrictArray().Set("bad", &errorAmf0Any{}) + value.(*amf0StrictArray).count = 1 + return value + }}, + } { + t.Run(tt.name+" marshal-value", func(t *testing.T) { + if _, err := tt.make().MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() should fail") + } + }) + } +} + +func TestAmf0UnmarshalNestedErrors(t *testing.T) { + // Object property with unsupported marker. + data := []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerDate)} + if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "discover prop bad") { + t.Fatalf("err=%v, want discover prop bad", err) + } + + // Object property with invalid payload size. + data = []byte{byte(amf0MarkerObject), 0, 3, 'b', 'a', 'd', byte(amf0MarkerNumber), 0} + if err := NewAmf0Object().UnmarshalBinary(data); err == nil || !strings.Contains(err.Error(), "unmarshal prop bad") { + t.Fatalf("err=%v, want unmarshal prop bad", err) + } +} diff --git a/internal/rtmp/example_test.go b/internal/rtmp/example_test.go new file mode 100644 index 000000000..4cc299d24 --- /dev/null +++ b/internal/rtmp/example_test.go @@ -0,0 +1,313 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp_test + +import ( + "bytes" + "context" + "fmt" + "net" + "time" + + "srsx/internal/rtmp" +) + +func ExampleAmf0Number() { + number := rtmp.NewAmf0Number(3.14) + b, err := number.MarshalBinary() + if err != nil { + panic(err) + } + + value, err := rtmp.Amf0Discovery(b) + if err != nil { + panic(err) + } + if err := value.UnmarshalBinary(b); err != nil { + panic(err) + } + + converter := rtmp.NewAmf0Converter(value) + fmt.Println("number:", converter.ToNumber().Float64()) + fmt.Println("is string:", converter.ToString() != nil) + + // Output: + // number: 3.14 + // is string: false +} + +func ExampleAmf0Object() { + object := rtmp.NewAmf0Object(). + Set("code", rtmp.NewAmf0Number(100)). + Set("level", rtmp.NewAmf0String("status")) + b, err := object.MarshalBinary() + if err != nil { + panic(err) + } + + value, err := rtmp.Amf0Discovery(b) + if err != nil { + panic(err) + } + if err := value.UnmarshalBinary(b); err != nil { + panic(err) + } + + converter := rtmp.NewAmf0Converter(value) + fmt.Println("code:", rtmp.NewAmf0Converter(converter.ToObject().Get("code")).ToNumber().Float64()) + fmt.Println("level:", rtmp.NewAmf0Converter(converter.ToObject().Get("level")).ToString().String()) + fmt.Println("is number:", converter.ToNumber() != nil) + + // Output: + // code: 100 + // level: status + // is number: false +} + +func ExampleNewHandshake() { + client := rtmp.NewHandshake() + server := rtmp.NewHandshake() + + var clientToServer bytes.Buffer + if err := client.WriteC0S0(&clientToServer); err != nil { + panic(err) + } + if err := client.WriteC1S1(&clientToServer); err != nil { + panic(err) + } + + c0, err := server.ReadC0S0(&clientToServer) + if err != nil { + panic(err) + } + c1, err := server.ReadC1S1(&clientToServer) + if err != nil { + panic(err) + } + + var serverToClient bytes.Buffer + if err := server.WriteC0S0(&serverToClient); err != nil { + panic(err) + } + if err := server.WriteC1S1(&serverToClient); err != nil { + panic(err) + } + if err := server.WriteC2S2(&serverToClient, c1); err != nil { + panic(err) + } + + s0, err := client.ReadC0S0(&serverToClient) + if err != nil { + panic(err) + } + s1, err := client.ReadC1S1(&serverToClient) + if err != nil { + panic(err) + } + s2, err := client.ReadC2S2(&serverToClient) + if err != nil { + panic(err) + } + + if err := client.WriteC2S2(&clientToServer, s1); err != nil { + panic(err) + } + c2, err := server.ReadC2S2(&clientToServer) + if err != nil { + panic(err) + } + + fmt.Println("client version:", c0[0]) + fmt.Println("server version:", s0[0]) + fmt.Println("c1 bytes:", len(c1)) + fmt.Println("s1 bytes:", len(s1)) + fmt.Println("s2 echoes c1:", bytes.Equal(s2, c1)) + fmt.Println("c2 echoes s1:", bytes.Equal(c2, s1)) + fmt.Println("server cached c1:", bytes.Equal(server.C1S1(), c1)) + + // Output: + // client version: 3 + // server version: 3 + // c1 bytes: 1536 + // s1 bytes: 1536 + // s2 echoes c1: true + // c2 echoes s1: true + // server cached c1: true +} + +func ExampleNewProtocol() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + client := rtmp.NewProtocol(clientConn) + server := rtmp.NewProtocol(serverConn) + + backendConn, upstreamConn := net.Pipe() + defer backendConn.Close() + defer upstreamConn.Close() + + backend := rtmp.NewProtocol(backendConn) + upstream := rtmp.NewProtocol(upstreamConn) + + done := make(chan error, 1) + go func() { + err := func() error { + // The server can read a raw message first, then decode it explicitly. + m, err := server.ExpectMessage(ctx, rtmp.MessageTypeSetChunkSize) + if err != nil { + return err + } + pkt, err := server.DecodeMessage(m) + if err != nil { + return err + } + chunkSize := pkt.(*rtmp.SetChunkSize) + + // ExpectPacket reads and decodes messages until it finds the requested packet type. + var connectReq *rtmp.ConnectAppPacket + if _, err := rtmp.ExpectPacket(ctx, server, &connectReq); err != nil { + return err + } + + ack := rtmp.NewWindowAcknowledgementSize() + ack.AckSize = 2500000 + if err := server.WritePacket(ctx, ack, 0); err != nil { + return err + } + + serverChunk := rtmp.NewSetChunkSize() + serverChunk.ChunkSize = chunkSize.ChunkSize + if err := server.WritePacket(ctx, serverChunk, 0); err != nil { + return err + } + + connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) + connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) + connectRes.Args.Set("level", rtmp.NewAmf0String("status")) + connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success")) + if err := server.WritePacket(ctx, connectRes, 0); err != nil { + return err + } + + var createStream *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, server, &createStream); err != nil { + return err + } + createStreamRes := rtmp.NewCreateStreamResPacket(createStream.TransactionID) + createStreamRes.SetStreamID(1) + if err := server.WritePacket(ctx, createStreamRes, 0); err != nil { + return err + } + + // For media forwarding, the proxy reads a complete RTMP message and writes + // that same message to another protocol connection without decoding/repacking. + publishMessage, err := server.ReadMessage(ctx) + if err != nil { + return err + } + if err := backend.WriteMessage(ctx, publishMessage); err != nil { + return err + } + + return nil + }() + + select { + case done <- err: + case <-ctx.Done(): + } + }() + + // Client runs the normal RTMP command workflow: configure chunk size, + // connect to an app, create a stream, publish it, then verify the proxy + // forwarded the publish message to the upstream side. + if err := func() error { + clientChunk := rtmp.NewSetChunkSize() + clientChunk.ChunkSize = 128 + if err := client.WritePacket(ctx, clientChunk, 0); err != nil { + return err + } + + connectReq := rtmp.NewConnectAppPacket() + connectReq.CommandObject.Set("tcUrl", rtmp.NewAmf0String("rtmp://example.com/live")) + if err := client.WritePacket(ctx, connectReq, 0); err != nil { + return err + } + + ackMessage, err := client.ExpectMessage(ctx, rtmp.MessageTypeWindowAcknowledgementSize) + if err != nil { + return err + } + ackPacket, err := client.DecodeMessage(ackMessage) + if err != nil { + return err + } + ack, ok := ackPacket.(*rtmp.WindowAcknowledgementSize) + if !ok { + return fmt.Errorf("unexpected ack packet %T", ackPacket) + } + + var serverChunk *rtmp.SetChunkSize + if _, err := rtmp.ExpectPacket(ctx, client, &serverChunk); err != nil { + return err + } + + var connectRes *rtmp.ConnectAppResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectRes); err != nil { + return err + } + + createStream := rtmp.NewCreateStreamPacket() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return err + } + var createStreamRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &createStreamRes); err != nil { + return err + } + + publish := rtmp.NewPublishPacket() + publish.TransactionID = 5 + publish.StreamName = rtmp.NewAmf0String("livestream") + publish.StreamType = rtmp.NewAmf0String("live") + if err := client.WritePacket(ctx, publish, int(createStreamRes.StreamID)); err != nil { + return err + } + + var upstreamPublish *rtmp.PublishPacket + if _, err := rtmp.ExpectPacket(ctx, upstream, &upstreamPublish); err != nil { + return err + } + + select { + case err := <-done: + if err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + + fmt.Println("ack size:", ack.AckSize) + fmt.Println("chunk size:", serverChunk.ChunkSize) + fmt.Println("connect:", rtmp.NewAmf0Converter(connectRes.Args.Get("code")).ToString().String()) + fmt.Println("stream id:", int(createStreamRes.StreamID)) + fmt.Println("forward publish:", upstreamPublish.StreamName.String()) + + return nil + }(); err != nil { + panic(err) + } + + // Output: + // ack size: 2500000 + // chunk size: 128 + // connect: NetConnection.Connect.Success + // stream id: 1 + // forward publish: livestream +} diff --git a/internal/rtmp/gen.go b/internal/rtmp/gen.go new file mode 100644 index 000000000..cb997e5c7 --- /dev/null +++ b/internal/rtmp/gen.go @@ -0,0 +1,7 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +//go:generate go tool counterfeiter -o rtmpfakes/fake_handshake.go . Handshake +//go:generate go tool counterfeiter -o rtmpfakes/fake_protocol.go . Protocol diff --git a/internal/rtmp/rtmp.go b/internal/rtmp/rtmp.go index b24a12de5..988804b3e 100644 --- a/internal/rtmp/rtmp.go +++ b/internal/rtmp/rtmp.go @@ -17,21 +17,31 @@ import ( "srsx/internal/errors" ) -// The handshake implements the RTMP handshake protocol. -type Handshake struct { +// Handshake implements the RTMP handshake protocol. +type Handshake interface { + C1S1() []byte + WriteC0S0(w io.Writer) error + ReadC0S0(r io.Reader) ([]byte, error) + WriteC1S1(w io.Writer) error + ReadC1S1(r io.Reader) ([]byte, error) + WriteC2S2(w io.Writer, s1c1 []byte) error + ReadC2S2(r io.Reader) ([]byte, error) +} + +type handshake struct { // The c1s1 cache. c1s1 []byte } -func NewHandshake() *Handshake { - return &Handshake{} +func NewHandshake() Handshake { + return &handshake{} } -func (v *Handshake) C1S1() []byte { +func (v *handshake) C1S1() []byte { return v.c1s1 } -func (v *Handshake) WriteC0S0(w io.Writer) (err error) { +func (v *handshake) WriteC0S0(w io.Writer) (err error) { r := bytes.NewReader([]byte{0x03}) if _, err = io.Copy(w, r); err != nil { return errors.Wrap(err, "write c0s0") @@ -40,7 +50,7 @@ func (v *Handshake) WriteC0S0(w io.Writer) (err error) { return } -func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { +func (v *handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1); err != nil { return nil, errors.Wrap(err, "read c0s0") @@ -51,7 +61,7 @@ func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { return } -func (v *Handshake) WriteC1S1(w io.Writer) (err error) { +func (v *handshake) WriteC1S1(w io.Writer) (err error) { p := make([]byte, 1536) // Use crypto/rand for thread-safe random generation @@ -67,7 +77,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { return } -func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { +func (v *handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { return nil, errors.Wrap(err, "read c1s1") @@ -79,7 +89,7 @@ func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { return } -func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { +func (v *handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { r := bytes.NewReader(s1c1[:]) if _, err = io.Copy(w, r); err != nil { return errors.Wrap(err, "write c2s2") @@ -88,7 +98,7 @@ func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { return } -func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { +func (v *handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { return nil, errors.Wrap(err, "read c2s2") @@ -129,7 +139,7 @@ type chunkStream struct { format formatType cid chunkID header messageHeader - message *Message + message *message count uint64 extendedTimestamp bool } @@ -138,8 +148,18 @@ func newChunkStream() *chunkStream { return &chunkStream{} } -// The protocol implements the RTMP command and chunk stack. -type Protocol struct { +// Protocol implements the RTMP command and chunk stack. +type Protocol interface { + // Deprecated: Go does not support generic methods. Please use rtmp.ExpectPacket instead. + ExpectPacket(ctx context.Context, ppkt any) (Message, error) + ExpectMessage(ctx context.Context, types ...MessageType) (Message, error) + DecodeMessage(m Message) (Packet, error) + ReadMessage(ctx context.Context) (Message, error) + WritePacket(ctx context.Context, pkt Packet, streamID int) error + WriteMessage(ctx context.Context, m Message) error +} + +type protocol struct { r *bufio.Reader w *bufio.Writer input struct { @@ -154,8 +174,8 @@ type Protocol struct { } } -func NewProtocol(rw io.ReadWriter) *Protocol { - v := &Protocol{ +func NewProtocol(rw io.ReadWriter) Protocol { + v := &protocol{ r: bufio.NewReader(rw), w: bufio.NewWriter(rw), } @@ -169,7 +189,11 @@ func NewProtocol(rw io.ReadWriter) *Protocol { return v } -func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { +// ExpectPacket reads and decodes RTMP messages until it finds a packet of type T. +// +// Messages with other packet types are consumed and ignored. On success, ppkt is +// set to the decoded packet and the Message carrying that packet is returned. +func ExpectPacket[T Packet](ctx context.Context, v Protocol, ppkt *T) (m Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { return nil, errors.WithMessage(err, "read message") @@ -189,12 +213,12 @@ func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Messa return } -// Deprecated: Please use rtmp.ExpectPacket instead. -func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { - panic("Please use rtmp.ExpectPacket instead") +// Deprecated: Go does not support generic methods. Please use rtmp.ExpectPacket instead. +func (v *protocol) ExpectPacket(ctx context.Context, ppkt any) (m Message, err error) { + panic("Go does not support generic methods; please use rtmp.ExpectPacket instead") } -func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { +func (v *protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { return nil, errors.WithMessage(err, "read message") @@ -205,16 +229,14 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m * } for _, t := range types { - if m.MessageType == t { + if m.MessageType() == t { return } } } - - return } -func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { +func (v *protocol) parseAMFObject(p []byte) (pkt Packet, err error) { var commandName amf0String if err = commandName.UnmarshalBinary(p); err != nil { return nil, errors.WithMessage(err, "unmarshal command name") @@ -266,18 +288,18 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { } } -func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { - p := m.Payload[:] +func (v *protocol) DecodeMessage(m Message) (pkt Packet, err error) { + p := m.Payload()[:] if len(p) == 0 { return nil, errors.New("Empty packet") } - switch m.MessageType { + switch m.MessageType() { case MessageTypeAMF3Command, MessageTypeAMF3Data: p = p[1:] } - switch m.MessageType { + switch m.MessageType() { case MessageTypeSetChunkSize: pkt = NewSetChunkSize() case MessageTypeWindowAcknowledgementSize: @@ -286,22 +308,22 @@ func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { pkt = NewSetPeerBandwidth() case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: if pkt, err = v.parseAMFObject(p); err != nil { - return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType())) } case MessageTypeUserControl: pkt = NewUserControl() default: - return nil, errors.Errorf("Unknown message %v", m.MessageType) + return nil, errors.Errorf("Unknown message %v", m.MessageType()) } if err = pkt.UnmarshalBinary(p); err != nil { - return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType())) } return } -func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { +func (v *protocol) ReadMessage(ctx context.Context) (m Message, err error) { for m == nil { // TODO: We should convert buffered io to async io, because we will be stuck in block io here, // TODO: but the risk is acceptable because we literally will set the underlay io timeout. @@ -331,15 +353,17 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { return nil, errors.WithMessage(err, "read message payload") } - if err = v.onMessageArrivated(m); err != nil { - return nil, errors.WithMessage(err, "on message") + if m != nil { + if err = v.onMessageArrivated(m.asMessage()); err != nil { + return nil, errors.WithMessage(err, "on message") + } } } return } -func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { +func (v *protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m Message, err error) { // Empty payload message. if chunk.message.payloadLength == 0 { m = chunk.message @@ -348,7 +372,7 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) ( } // Calculate the chunk payload size. - chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) + chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.payload) if chunkedPayloadSize > int(v.input.opt.chunkSize) { chunkedPayloadSize = int(v.input.opt.chunkSize) } @@ -357,10 +381,10 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) ( if _, err = io.ReadFull(v.r, b); err != nil { return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) } - chunk.message.Payload = append(chunk.message.Payload, b...) + chunk.message.payload = append(chunk.message.payload, b...) // Got entire RTMP message? - if int(chunk.message.payloadLength) == len(chunk.message.Payload) { + if int(chunk.message.payloadLength) == len(chunk.message.payload) { m = chunk.message chunk.message = nil } @@ -426,7 +450,7 @@ var messageHeaderSizes = []int{11, 7, 3, 0} // fmt=1, 0x4X // fmt=2, 0x8X // fmt=3, 0xCX -func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { +func (v *protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { // We should not assert anything about fmt, for the first packet. // (when first packet, the chunk.message is nil). // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. @@ -480,7 +504,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // Create msg when new chunk stream start if chunk.message == nil { - chunk.message = NewMessage() + chunk.message = newMessage() } // Read the message header. @@ -659,7 +683,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // // Chunk stream IDs with values 64-319 could be represented by both 2- // byte version and 3-byte version of this field. -func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { +func (v *protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { // 2-63, 1B chunk header var t uint8 if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { @@ -689,14 +713,14 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid return } -func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { - m := NewMessage() +func (v *protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { + m := newMessage() - if m.Payload, err = pkt.MarshalBinary(); err != nil { + if m.payload, err = pkt.MarshalBinary(); err != nil { return errors.WithMessage(err, "marshal payload") } - m.MessageType = pkt.Type() + m.messageHeader.MessageType = pkt.Type() m.streamID = uint32(streamID) m.betterCid = pkt.BetterCid() @@ -711,7 +735,7 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e return } -func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { +func (v *protocol) onPacketWriten(m *message, pkt Packet) (err error) { var tid amf0Number var name amf0String @@ -734,16 +758,16 @@ func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { return } -func (v *Protocol) onMessageArrivated(m *Message) (err error) { +func (v *protocol) onMessageArrivated(m *message) (err error) { if m == nil { return } var pkt Packet - switch m.MessageType { + switch m.MessageType() { case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: if pkt, err = v.DecodeMessage(m); err != nil { - return errors.Errorf("decode message %v", m.MessageType) + return errors.Errorf("decode message %v", m.MessageType()) } } @@ -755,19 +779,20 @@ func (v *Protocol) onMessageArrivated(m *Message) (err error) { return } -func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { - m.payloadLength = uint32(len(m.Payload)) +func (v *protocol) WriteMessage(ctx context.Context, m Message) (err error) { + msg := m.asMessage() + msg.payloadLength = uint32(len(msg.payload)) var c0h, c3h []byte - if c0h, err = m.generateC0Header(); err != nil { + if c0h, err = msg.generateC0Header(); err != nil { return errors.WithMessage(err, "generate c0 header") } - if c3h, err = m.generateC3Header(); err != nil { + if c3h, err = msg.generateC3Header(); err != nil { return errors.WithMessage(err, "generate c3 header") } var h []byte - p := m.Payload + p := msg.payload for len(p) > 0 { // TODO: We should convert buffered io to async io, because we will be stuck in block io here, // TODO: but the risk is acceptable because we literally will set the underlay io timeout. @@ -899,29 +924,56 @@ type messageHeader struct { Timestamp uint64 } -// The RTMP message, transport over chunk stream in RTMP. +// Message is an RTMP message transported over a chunk stream. // Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header -type Message struct { +type Message interface { + MessageType() MessageType + Timestamp() uint64 + Payload() []byte + asMessage() *message +} + +type message struct { messageHeader // The payload which carries the RTMP packet. - Payload []byte + payload []byte } -func NewMessage() *Message { - return &Message{} +func NewMessage() Message { + return newMessage() } -func NewStreamMessage(streamID int) *Message { - v := NewMessage() +func newMessage() *message { + return &message{} +} + +func NewStreamMessage(streamID int) Message { + v := newMessage() v.streamID = uint32(streamID) v.betterCid = chunkIDOverStream return v } -func (v *Message) generateC3Header() ([]byte, error) { +func (v *message) MessageType() MessageType { + return v.messageHeader.MessageType +} + +func (v *message) Timestamp() uint64 { + return v.messageHeader.Timestamp +} + +func (v *message) Payload() []byte { + return v.payload +} + +func (v *message) asMessage() *message { + return v +} + +func (v *message) generateC3Header() ([]byte, error) { var c3h []byte - if v.Timestamp < extendedTimestamp { + if v.messageHeader.Timestamp < extendedTimestamp { c3h = make([]byte, 1) } else { c3h = make([]byte, 1+4) @@ -935,19 +987,19 @@ func (v *Message) generateC3Header() ([]byte, error) { // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, // always carry a extended timestamp in C3 header. // @see: http://blog.csdn.net/win_lin/article/details/13363699 - if v.Timestamp >= extendedTimestamp { - p[0] = byte(v.Timestamp >> 24) - p[1] = byte(v.Timestamp >> 16) - p[2] = byte(v.Timestamp >> 8) - p[3] = byte(v.Timestamp) + if v.messageHeader.Timestamp >= extendedTimestamp { + p[0] = byte(v.messageHeader.Timestamp >> 24) + p[1] = byte(v.messageHeader.Timestamp >> 16) + p[2] = byte(v.messageHeader.Timestamp >> 8) + p[3] = byte(v.messageHeader.Timestamp) } return c3h, nil } -func (v *Message) generateC0Header() ([]byte, error) { +func (v *message) generateC0Header() ([]byte, error) { var c0h []byte - if v.Timestamp < extendedTimestamp { + if v.messageHeader.Timestamp < extendedTimestamp { c0h = make([]byte, 1+3+3+1+4) } else { c0h = make([]byte, 1+3+3+1+4+4) @@ -957,10 +1009,10 @@ func (v *Message) generateC0Header() ([]byte, error) { p[0] = byte(v.betterCid) & 0x3f p = p[1:] - if v.Timestamp < extendedTimestamp { - p[0] = byte(v.Timestamp >> 16) - p[1] = byte(v.Timestamp >> 8) - p[2] = byte(v.Timestamp) + if v.messageHeader.Timestamp < extendedTimestamp { + p[0] = byte(v.messageHeader.Timestamp >> 16) + p[1] = byte(v.messageHeader.Timestamp >> 8) + p[2] = byte(v.messageHeader.Timestamp) } else { p[0] = 0xff p[1] = 0xff @@ -973,7 +1025,7 @@ func (v *Message) generateC0Header() ([]byte, error) { p[2] = byte(v.payloadLength) p = p[3:] - p[0] = byte(v.MessageType) + p[0] = byte(v.messageHeader.MessageType) p = p[1:] p[0] = byte(v.streamID) @@ -982,11 +1034,11 @@ func (v *Message) generateC0Header() ([]byte, error) { p[3] = byte(v.streamID >> 24) p = p[4:] - if v.Timestamp >= extendedTimestamp { - p[0] = byte(v.Timestamp >> 24) - p[1] = byte(v.Timestamp >> 16) - p[2] = byte(v.Timestamp >> 8) - p[3] = byte(v.Timestamp) + if v.messageHeader.Timestamp >= extendedTimestamp { + p[0] = byte(v.messageHeader.Timestamp >> 24) + p[1] = byte(v.messageHeader.Timestamp >> 16) + p[2] = byte(v.messageHeader.Timestamp >> 8) + p[3] = byte(v.messageHeader.Timestamp) } return c0h, nil @@ -1039,8 +1091,8 @@ type Packet interface { type objectCallPacket struct { CommandName amf0String TransactionID amf0Number - CommandObject *amf0Object - Args *amf0Object + CommandObject Amf0Object + Args Amf0Object } func (v *objectCallPacket) BetterCid() chunkID { @@ -1081,7 +1133,7 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { return } - v.Args = NewAmf0Object() + v.Args = newAmf0Object() if err = v.Args.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal args") } @@ -1149,8 +1201,8 @@ func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { func (v *ConnectAppPacket) TcUrl() string { if v.CommandObject != nil { - if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { - return string(*v) + if v, ok := v.CommandObject.Get("tcUrl").(Amf0String); ok { + return v.String() } } return "" @@ -1172,9 +1224,9 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { func (v *ConnectAppResPacket) SrsID() string { if v.Args != nil { - if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { - if v, ok := v.Get("srs_id").(*amf0String); ok { - return string(*v) + if v, ok := v.Args.Get("data").(Amf0EcmaArray); ok { + if v, ok := v.Get("srs_id").(Amf0String); ok { + return v.String() } } } @@ -1197,7 +1249,7 @@ func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { type variantCallPacket struct { CommandName amf0String TransactionID amf0Number - CommandObject amf0Any // object or null + CommandObject Amf0Any // object or null } func (v *variantCallPacket) BetterCid() chunkID { @@ -1273,7 +1325,7 @@ func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { // @remark onStatus packet is a call packet. type CallPacket struct { variantCallPacket - Args amf0Any // optional or object or null + Args Amf0Any // optional or object or null } func NewCallPacket() *CallPacket { @@ -1282,9 +1334,9 @@ func NewCallPacket() *CallPacket { func (v *CallPacket) ArgsCode() string { if v.Args != nil { - if v, ok := v.Args.(*amf0Object); ok { - if code, ok := v.Get("code").(*amf0String); ok { - return string(*code) + if v, ok := v.Args.(Amf0Object); ok { + if code, ok := v.Get("code").(Amf0String); ok { + return code.String() } } } @@ -1370,6 +1422,10 @@ func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { return v } +func (v *CreateStreamResPacket) SetStreamID(streamID int) { + v.StreamID = amf0Number(streamID) +} + func (v *CreateStreamResPacket) Size() int { return v.variantCallPacket.Size() + v.StreamID.Size() } @@ -1407,15 +1463,16 @@ func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { // Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish type PublishPacket struct { variantCallPacket - StreamName amf0String - StreamType amf0String + StreamName Amf0String + StreamType Amf0String } func NewPublishPacket() *PublishPacket { v := &PublishPacket{} v.CommandName = commandPublish v.CommandObject = NewAmf0Null() - v.StreamType = "live" + v.StreamName = NewAmf0String("") + v.StreamType = NewAmf0String("live") return v } @@ -1431,11 +1488,13 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { } p = p[v.variantCallPacket.Size():] + v.StreamName = newAmf0String("") if err = v.StreamName.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal stream name") } p = p[v.StreamName.Size():] + v.StreamType = newAmf0String("") if err = v.StreamType.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal stream type") } @@ -1466,13 +1525,14 @@ func (v *PublishPacket) MarshalBinary() (data []byte, err error) { // Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play type PlayPacket struct { variantCallPacket - StreamName amf0String + StreamName Amf0String } func NewPlayPacket() *PlayPacket { v := &PlayPacket{} v.CommandName = commandPlay v.CommandObject = NewAmf0Null() + v.StreamName = NewAmf0String("") return v } @@ -1488,6 +1548,7 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { } p = p[v.variantCallPacket.Size():] + v.StreamName = newAmf0String("") if err = v.StreamName.UnmarshalBinary(p); err != nil { return errors.WithMessage(err, "unmarshal stream name") } diff --git a/internal/rtmp/rtmp_test.go b/internal/rtmp/rtmp_test.go new file mode 100644 index 000000000..9dc0013ca --- /dev/null +++ b/internal/rtmp/rtmp_test.go @@ -0,0 +1,728 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "reflect" + "strings" + "testing" +) + +type errWriter struct{} + +func (errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe } + +func TestHandshakeSimpleAndErrors(t *testing.T) { + h := NewHandshake() + var b bytes.Buffer + if err := h.WriteC0S0(&b); err != nil { + t.Fatalf("WriteC0S0 err=%v", err) + } + c0, err := h.ReadC0S0(&b) + if err != nil || !bytes.Equal(c0, []byte{3}) { + t.Fatalf("ReadC0S0=%v, err=%v", c0, err) + } + if err := h.WriteC0S0(errWriter{}); err == nil { + t.Fatal("WriteC0S0 should fail") + } + if _, err := h.ReadC0S0(bytes.NewReader(nil)); err == nil { + t.Fatal("ReadC0S0 should fail") + } + + b.Reset() + if err := h.WriteC1S1(&b); err != nil { + t.Fatalf("WriteC1S1 err=%v", err) + } + if b.Len() != 1536 { + t.Fatalf("C1S1 len=%v", b.Len()) + } + c1, err := h.ReadC1S1(&b) + if err != nil || len(c1) != 1536 || !bytes.Equal(h.C1S1(), c1) { + t.Fatalf("ReadC1S1 len=%v, cached=%v, err=%v", len(c1), bytes.Equal(h.C1S1(), c1), err) + } + if err := h.WriteC1S1(errWriter{}); err == nil { + t.Fatal("WriteC1S1 should fail") + } + if _, err := h.ReadC1S1(bytes.NewReader(make([]byte, 1535))); err == nil { + t.Fatal("ReadC1S1 should fail") + } + + b.Reset() + if err := h.WriteC2S2(&b, c1); err != nil { + t.Fatalf("WriteC2S2 err=%v", err) + } + c2, err := h.ReadC2S2(&b) + if err != nil || !bytes.Equal(c2, c1) { + t.Fatalf("ReadC2S2 match=%v, err=%v", bytes.Equal(c2, c1), err) + } + if err := h.WriteC2S2(errWriter{}, c1); err == nil { + t.Fatal("WriteC2S2 should fail") + } + if _, err := h.ReadC2S2(bytes.NewReader(make([]byte, 1535))); err == nil { + t.Fatal("ReadC2S2 should fail") + } +} + +func TestSettingsChunkStreamAndMessageConstructors(t *testing.T) { + if s := newSettings(); s.chunkSize != defaultChunkSize { + t.Fatalf("chunk size=%v", s.chunkSize) + } + if c := newChunkStream(); c == nil || c.count != 0 { + t.Fatalf("chunk stream=%#v", c) + } + m := NewMessage().asMessage() + m.messageHeader.MessageType = MessageTypeAudio + m.messageHeader.Timestamp = 99 + m.payload = []byte{1, 2, 3} + if m.MessageType() != MessageTypeAudio || m.Timestamp() != 99 || !bytes.Equal(m.Payload(), []byte{1, 2, 3}) || m.asMessage() != m { + t.Fatalf("bad message accessors") + } + sm := NewStreamMessage(7).asMessage() + if sm.streamID != 7 || sm.betterCid != chunkIDOverStream { + t.Fatalf("stream message=%#v", sm.messageHeader) + } +} + +func TestBasicHeaderVariantsAndErrors(t *testing.T) { + ctx := context.Background() + cases := []struct { + name string + data []byte + fmt formatType + cid chunkID + }{ + {"one-byte", []byte{0x85}, formatType2, 5}, + {"two-byte", []byte{0x40, 0x0a}, formatType1, 74}, + {"three-byte-code-path", []byte{0xc1, 0x01, 0x02}, formatType3, 65}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + p := NewProtocol(bytes.NewBuffer(tt.data)).(*protocol) + fmt, cid, err := p.readBasicHeader(ctx) + if err != nil || fmt != tt.fmt || cid != tt.cid { + t.Fatalf("fmt=%v cid=%v err=%v", fmt, cid, err) + } + }) + } + for _, data := range [][]byte{{}, {0x00}} { + p := NewProtocol(bytes.NewBuffer(data)).(*protocol) + if _, _, err := p.readBasicHeader(ctx); err == nil { + t.Fatalf("readBasicHeader(%x) should fail", data) + } + } +} + +func TestReadMessageHeadersPayloadsAndChunks(t *testing.T) { + ctx := context.Background() + var in bytes.Buffer + // fmt0 cid=5, timestamp=10, len=3, type audio, stream=1, payload 010203. + in.Write([]byte{0x05, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x03, byte(MessageTypeAudio), 0x01, 0x00, 0x00, 0x00, 1, 2, 3}) + // fmt1 same cid, delta=5, len=2, type video, payload 0405. + in.Write([]byte{0x45, 0x00, 0x00, 0x05, 0x00, 0x00, 0x02, byte(MessageTypeVideo), 4, 5}) + // fmt2 same cid, delta=7, reuses len/type/stream, payload 0607. + in.Write([]byte{0x85, 0x00, 0x00, 0x07, 6, 7}) + // fmt3 same cid, reuses delta and advances timestamp, payload 0809. + in.Write([]byte{0xc5, 8, 9}) + + p := NewProtocol(&in).(*protocol) + for i, want := range []struct { + typ MessageType + ts uint64 + pl []byte + }{ + {MessageTypeAudio, 10, []byte{1, 2, 3}}, + {MessageTypeVideo, 15, []byte{4, 5}}, + {MessageTypeVideo, 22, []byte{6, 7}}, + {MessageTypeVideo, 29, []byte{8, 9}}, + } { + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage #%v err=%v", i, err) + } + if m.MessageType() != want.typ || m.Timestamp() != want.ts || !bytes.Equal(m.Payload(), want.pl) { + t.Fatalf("message #%v type=%v ts=%v payload=%x", i, m.MessageType(), m.Timestamp(), m.Payload()) + } + } +} + +func TestReadMessageExtendedTimestampAndChunking(t *testing.T) { + ctx := context.Background() + var in bytes.Buffer + payload := []byte{1, 2, 3, 4, 5} + // fmt0 cid=5, normal timestamp=0xffffff, extended timestamp has high bit set and should be masked. + in.Write([]byte{0x05, 0xff, 0xff, 0xff, 0x00, 0x00, byte(len(payload)), byte(MessageTypeVideo), 0x01, 0x00, 0x00, 0x00}) + binary.Write(&in, binary.BigEndian, uint32(0x8000002a)) + in.Write(payload[:2]) + // continuation chunk has fmt3 and extended timestamp too. + in.Write([]byte{0xc5}) + binary.Write(&in, binary.BigEndian, uint32(0x8000002a)) + in.Write(payload[2:4]) + in.Write([]byte{0xc5}) + binary.Write(&in, binary.BigEndian, uint32(0x8000002a)) + in.Write(payload[4:]) + + p := NewProtocol(&in).(*protocol) + p.input.opt.chunkSize = 2 + m, err := p.ReadMessage(ctx) + if err != nil { + t.Fatalf("ReadMessage err=%v", err) + } + if m.Timestamp() != 42 || !bytes.Equal(m.Payload(), payload) { + t.Fatalf("ts=%v payload=%x", m.Timestamp(), m.Payload()) + } +} + +func TestReadMessageHeaderErrors(t *testing.T) { + ctx := context.Background() + // Fresh non-zero chunk with fmt1 is rejected. + p := NewProtocol(bytes.NewBuffer([]byte{0x45})).(*protocol) + if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "fresh chunk") { + t.Fatalf("fresh fmt1 err=%v", err) + } + // Existing partial message cannot restart with fmt0. + p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x05})).(*protocol) + p.input.opt.chunkSize = 1 + if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "exists chunk") { + t.Fatalf("restart err=%v", err) + } + // Size change in a continuation header is rejected. + p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 3, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 0x45, 0, 0, 1, 0, 0, 4, byte(MessageTypeAudio)})).(*protocol) + p.input.opt.chunkSize = 1 + if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "message size") { + t.Fatalf("size change err=%v", err) + } + // Short payload and short extended timestamp. + p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 2, byte(MessageTypeAudio), 1, 0, 0, 0, 1})).(*protocol) + if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "read chunk") { + t.Fatalf("payload err=%v", err) + } + p = NewProtocol(bytes.NewBuffer([]byte{0x05, 0xff, 0xff, 0xff, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0, 1, 2})).(*protocol) + if _, err := p.ReadMessage(ctx); err == nil || !strings.Contains(err.Error(), "ext-ts") { + t.Fatalf("ext-ts err=%v", err) + } +} + +func TestWriteMessageHeadersChunkingAndErrors(t *testing.T) { + ctx := context.Background() + var out bytes.Buffer + p := NewProtocol(&out).(*protocol) + p.output.opt.chunkSize = 2 + m := NewStreamMessage(7).asMessage() + m.messageHeader.MessageType = MessageTypeVideo + m.messageHeader.Timestamp = extendedTimestamp + 9 + m.payload = []byte{1, 2, 3, 4, 5} + if err := p.WriteMessage(ctx, m); err != nil { + t.Fatalf("WriteMessage err=%v", err) + } + want := []byte{0x05, 0xff, 0xff, 0xff, 0, 0, 5, byte(MessageTypeVideo), 7, 0, 0, 0, 0x01, 0x00, 0x00, 0x08, 1, 2, 0xc5, 0x01, 0x00, 0x00, 0x08, 3, 4, 0xc5, 0x01, 0x00, 0x00, 0x08, 5} + if !bytes.Equal(out.Bytes(), want) { + t.Fatalf("written=%x want=%x", out.Bytes(), want) + } + if err := p.WriteMessage(ctx, (&message{})); err != nil { + t.Fatalf("empty WriteMessage err=%v", err) + } + canceled, cancel := context.WithCancel(ctx) + cancel() + if err := p.WriteMessage(canceled, m); err != context.Canceled { + t.Fatalf("canceled WriteMessage err=%v", err) + } + p = NewProtocol(struct { + io.Reader + io.Writer + }{bytes.NewReader(nil), errWriter{}}).(*protocol) + if err := p.WriteMessage(ctx, m); err == nil { + t.Fatal("WriteMessage to bad writer should fail") + } +} + +func TestProtocolDecodeMessageAndControls(t *testing.T) { + ctx := context.Background() + p := NewProtocol(&bytes.Buffer{}).(*protocol) + if _, err := p.DecodeMessage((&message{})); err == nil || !strings.Contains(err.Error(), "Empty packet") { + t.Fatalf("empty decode err=%v", err) + } + unknown := &message{} + unknown.messageHeader.MessageType = MessageTypeAudio + unknown.payload = []byte{1} + if _, err := p.DecodeMessage(unknown); err == nil || !strings.Contains(err.Error(), "Unknown message") { + t.Fatalf("unknown err=%v", err) + } + bad := &message{} + bad.messageHeader.MessageType = MessageTypeSetChunkSize + bad.payload = []byte{1, 2} + if _, err := p.DecodeMessage(bad); err == nil || !strings.Contains(err.Error(), "Unmarshal") { + t.Fatalf("bad control err=%v", err) + } + + for _, pkt := range []Packet{ + &SetChunkSize{ChunkSize: 4096}, + &WindowAcknowledgementSize{AckSize: 2500000}, + &SetPeerBandwidth{Bandwidth: 1000, LimitType: LimitTypeSoft}, + &UserControl{EventType: EventTypePingRequest, EventData: 123}, + } { + data, err := pkt.MarshalBinary() + if err != nil { + t.Fatalf("marshal %T err=%v", pkt, err) + } + m := &message{payload: data} + m.messageHeader.MessageType = pkt.Type() + got, err := p.DecodeMessage(m) + if err != nil { + t.Fatalf("DecodeMessage %T err=%v", pkt, err) + } + if reflect.TypeOf(got) != reflect.TypeOf(pkt) { + t.Fatalf("got %T want %T", got, pkt) + } + } + + chunk := &SetChunkSize{ChunkSize: 3} + m := &message{} + m.messageHeader.MessageType = chunk.Type() + m.payload, _ = chunk.MarshalBinary() + if err := p.onMessageArrivated(m); err != nil || p.input.opt.chunkSize != 3 { + t.Fatalf("onMessageArrivated err=%v chunk=%v", err, p.input.opt.chunkSize) + } + if err := p.onMessageArrivated(nil); err != nil { + t.Fatalf("nil onMessageArrivated err=%v", err) + } + bad.Payload()[0] = 1 + if err := p.onMessageArrivated(bad); err == nil { + t.Fatal("bad onMessageArrivated should fail") + } + if _, err := p.ExpectMessage(ctx); err == nil { + t.Fatal("ExpectMessage on empty reader should fail") + } +} + +func TestProtocolPacketsAndTransactions(t *testing.T) { + ctx := context.Background() + var wire bytes.Buffer + writer := NewProtocol(&wire).(*protocol) + connect := NewConnectAppPacket() + connect.CommandObject.Set("tcUrl", NewAmf0String("rtmp://host/live")) + if connect.Size() == 0 || connect.BetterCid() != chunkIDOverConnection || connect.Type() != MessageTypeAMF0Command || connect.TcUrl() == "" { + t.Fatalf("connect metadata invalid") + } + if err := writer.WritePacket(ctx, connect, 0); err != nil { + t.Fatalf("WritePacket connect err=%v", err) + } + if _, ok := writer.input.transactions[connect.TransactionID]; !ok { + t.Fatal("connect transaction not tracked") + } + create := NewCreateStreamPacket() + if err := writer.WritePacket(ctx, create, 0); err != nil { + t.Fatalf("WritePacket create err=%v", err) + } + call := NewCallPacket() + call.CommandName = commandReleaseStream + call.TransactionID = 3 + call.CommandObject = NewAmf0Null() + if err := writer.WritePacket(ctx, call, 0); err != nil { + t.Fatalf("WritePacket call err=%v", err) + } + reader := NewProtocol(&wire) + var gotConnect *ConnectAppPacket + if _, err := ExpectPacket(ctx, reader, &gotConnect); err != nil || gotConnect.TcUrl() != "rtmp://host/live" { + t.Fatalf("gotConnect=%v err=%v", gotConnect, err) + } + var gotCreate *CallPacket + if _, err := ExpectPacket(ctx, reader, &gotCreate); err != nil || gotCreate.CommandName != commandCreateStream { + t.Fatalf("gotCreate=%v err=%v", gotCreate, err) + } + + decoder := NewProtocol(&bytes.Buffer{}).(*protocol) + decoder.input.transactions[1] = commandConnect + if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, NewConnectAppResPacket(1))); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(NewConnectAppResPacket(1)) { + t.Fatalf("connect res pkt=%T err=%v", pkt, err) + } + decoder.input.transactions[2] = commandCreateStream + csr := NewCreateStreamResPacket(2) + csr.SetStreamID(99) + if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, csr)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(csr) { + t.Fatalf("create res pkt=%T err=%v", pkt, err) + } + decoder.input.transactions[3] = commandReleaseStream + res := NewCallPacket() + res.CommandName = commandResult + res.TransactionID = 3 + res.CommandObject = NewAmf0Null() + if pkt, err := decoder.parseAMFObject(mustPacketBytes(t, res)); err != nil || reflect.TypeOf(pkt) != reflect.TypeOf(res) { + t.Fatalf("call res pkt=%T err=%v", pkt, err) + } + for _, name := range []amf0String{commandPublish, commandPlay, commandOnStatus} { + pkt := NewCallPacket() + pkt.CommandName = name + pkt.TransactionID = 0 + pkt.CommandObject = NewAmf0Null() + if name == commandPublish { + pub := NewPublishPacket() + pub.TransactionID = 0 + pub.StreamName = NewAmf0String("s") + pub.StreamType = NewAmf0String("live") + if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pub)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pub) { + t.Fatalf("publish decoded=%T err=%v", decoded, err) + } + continue + } + if name == commandPlay { + play := NewPlayPacket() + play.TransactionID = 0 + play.StreamName = NewAmf0String("s") + if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, play)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(play) { + t.Fatalf("play decoded=%T err=%v", decoded, err) + } + continue + } + if decoded, err := decoder.parseAMFObject(mustPacketBytes(t, pkt)); err != nil || reflect.TypeOf(decoded) != reflect.TypeOf(pkt) { + t.Fatalf("call decoded=%T err=%v", decoded, err) + } + } + decoder.input.transactions[9] = commandPause + errPkt := NewCallPacket() + errPkt.CommandName = commandError + errPkt.TransactionID = 9 + errPkt.CommandObject = NewAmf0Null() + if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No request") { + t.Fatalf("unknown request err=%v", err) + } + if _, err := decoder.parseAMFObject(mustPacketBytes(t, errPkt)); err == nil || !strings.Contains(err.Error(), "No matched request") { + t.Fatalf("missing transaction err=%v", err) + } + if _, err := decoder.parseAMFObject([]byte{byte(amf0MarkerString), 0, 8, 'c'}); err == nil { + t.Fatal("bad AMF parse should fail") + } + + cctx, cancel := context.WithCancel(ctx) + cancel() + if err := writer.WritePacket(cctx, connect, 0); err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { + t.Fatalf("WritePacket canceled err=%v", err) + } +} + +func TestDeprecatedExpectPacketPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("Expected panic") + } + }() + NewProtocol(&bytes.Buffer{}).ExpectPacket(context.Background(), nil) +} + +func TestPacketRoundTripsAndErrors(t *testing.T) { + packets := []Packet{ + NewConnectAppPacket(), + NewConnectAppResPacket(7), + NewCallPacket(), + NewCreateStreamPacket(), + func() Packet { p := NewCreateStreamResPacket(2); p.SetStreamID(1); return p }(), + func() Packet { + p := NewPublishPacket() + p.TransactionID = 0 + p.StreamName = NewAmf0String("s") + return p + }(), + func() Packet { p := NewPlayPacket(); p.TransactionID = 0; p.StreamName = NewAmf0String("s"); return p }(), + &SetChunkSize{ChunkSize: 1}, + &WindowAcknowledgementSize{AckSize: 2}, + &SetPeerBandwidth{Bandwidth: 3, LimitType: LimitTypeDynamic}, + &UserControl{EventType: EventTypeFmsEvent0, EventData: 1}, + &UserControl{EventType: EventTypeSetBufferLength, EventData: 1, ExtraData: 2}, + } + // Initialize the generic call packet so it is marshalable. + packets[2].(*CallPacket).CommandName = commandOnStatus + packets[2].(*CallPacket).TransactionID = 0 + packets[2].(*CallPacket).CommandObject = NewAmf0Null() + packets[2].(*CallPacket).Args = NewAmf0Object().Set("code", NewAmf0String("ok")) + packets[1].(*ConnectAppResPacket).Args.Set("data", NewAmf0EcmaArray().Set("srs_id", NewAmf0String("sid"))) + + for _, pkt := range packets { + t.Run(reflect.TypeOf(pkt).String(), func(t *testing.T) { + data, err := pkt.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary err=%v", err) + } + if len(data) != pkt.Size() { + t.Fatalf("len=%v Size=%v", len(data), pkt.Size()) + } + fresh := reflect.New(reflect.TypeOf(pkt).Elem()).Interface().(Packet) + switch v := fresh.(type) { + case *ConnectAppPacket: + *v = *NewConnectAppPacket() + case *ConnectAppResPacket: + *v = *NewConnectAppResPacket(0) + case *CreateStreamPacket: + *v = *NewCreateStreamPacket() + case *CreateStreamResPacket: + *v = *NewCreateStreamResPacket(0) + case *PublishPacket: + *v = *NewPublishPacket() + case *PlayPacket: + *v = *NewPlayPacket() + } + if err := fresh.UnmarshalBinary(data); err != nil { + t.Fatalf("UnmarshalBinary err=%v", err) + } + }) + } + if packets[1].(*ConnectAppResPacket).SrsID() != "sid" || packets[2].(*CallPacket).ArgsCode() != "ok" { + t.Fatalf("packet helpers failed") + } + if NewConnectAppResPacket(1).SrsID() != "" || NewCallPacket().ArgsCode() != "" || NewConnectAppPacket().TcUrl() != "" { + t.Fatalf("empty helpers failed") + } + + badConnect := NewConnectAppPacket() + badConnect.CommandName = commandPlay + if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil { + t.Fatal("bad connect name should fail") + } + badConnect = NewConnectAppPacket() + badConnect.TransactionID = 2 + if err := badConnect.UnmarshalBinary(mustPacketBytes(t, badConnect)); err == nil { + t.Fatal("bad connect tid should fail") + } + badRes := NewConnectAppResPacket(1) + badRes.CommandName = commandPlay + if err := badRes.UnmarshalBinary(mustPacketBytes(t, badRes)); err == nil { + t.Fatal("bad connect response name should fail") + } + for _, pkt := range []Packet{NewConnectAppPacket(), NewCallPacket(), NewCreateStreamResPacket(1), NewPublishPacket(), NewPlayPacket()} { + if err := pkt.UnmarshalBinary([]byte{byte(amf0MarkerString)}); err == nil { + t.Fatalf("%T short unmarshal should fail", pkt) + } + } + for _, pkt := range []Packet{&SetChunkSize{}, &WindowAcknowledgementSize{}, &SetPeerBandwidth{}, &UserControl{}} { + if err := pkt.UnmarshalBinary([]byte{0, 1}); err == nil { + t.Fatalf("%T short unmarshal should fail", pkt) + } + } + uc := &UserControl{} + if err := uc.UnmarshalBinary([]byte{0, byte(EventTypeSetBufferLength), 1, 2, 3, 4, 5}); err == nil { + t.Fatal("short set-buffer-length should fail") + } +} + +func mustPacketBytes(t *testing.T, pkt Packet) []byte { + t.Helper() + data, err := pkt.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary %T err=%v", pkt, err) + } + return data +} + +type failPacket struct{} + +func (failPacket) Size() int { return 0 } +func (failPacket) UnmarshalBinary([]byte) error { return io.ErrUnexpectedEOF } +func (failPacket) MarshalBinary() ([]byte, error) { return nil, io.ErrClosedPipe } +func (failPacket) BetterCid() chunkID { return chunkIDOverConnection } +func (failPacket) Type() MessageType { return MessageTypeAMF0Command } + +type stepWriter struct { + writes int + failAt int +} + +func (w *stepWriter) Write(p []byte) (int, error) { + w.writes++ + if w.writes == w.failAt { + return 0, io.ErrClosedPipe + } + return len(p), nil +} + +func TestProtocolAdditionalBranches(t *testing.T) { + ctx := context.Background() + p := NewProtocol(&bytes.Buffer{}).(*protocol) + if NewSetPeerBandwidth().BetterCid() != chunkIDProtocolControl || NewUserControl().BetterCid() != chunkIDProtocolControl { + t.Fatal("control better cid failed") + } + if err := p.WritePacket(ctx, failPacket{}, 0); err == nil || !strings.Contains(err.Error(), "marshal payload") { + t.Fatalf("WritePacket marshal err=%v", err) + } + + payloadWriter := &stepWriter{failAt: 1} + p = NewProtocol(struct { + io.Reader + io.Writer + }{bytes.NewReader(nil), payloadWriter}).(*protocol) + m := NewStreamMessage(1).asMessage() + m.messageHeader.MessageType = MessageTypeVideo + m.payload = bytes.Repeat([]byte{1}, 5000) + if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "write chunk payload") { + t.Fatalf("WriteMessage payload err=%v", err) + } + flushWriter := &stepWriter{failAt: 1} + p = NewProtocol(struct { + io.Reader + io.Writer + }{bytes.NewReader(nil), flushWriter}).(*protocol) + m.payload = []byte{1} + if err := p.WriteMessage(ctx, m); err == nil || !strings.Contains(err.Error(), "flush writer") { + t.Fatalf("WriteMessage flush err=%v writes=%v", err, flushWriter.writes) + } + + // Zero-length payload returns a complete message without reading chunk bytes. + in := bytes.NewBuffer([]byte{0x05, 0, 0, 1, 0, 0, 0, byte(MessageTypeAudio), 1, 0, 0, 0}) + p = NewProtocol(in).(*protocol) + if msg, err := p.ReadMessage(ctx); err != nil || msg.MessageType() != MessageTypeAudio || len(msg.Payload()) != 0 { + t.Fatalf("zero payload msg=%v err=%v", msg, err) + } + + // ExpectMessage skips unwanted message types before returning the desired one. + var wire bytes.Buffer + writer := NewProtocol(&wire).(*protocol) + am := NewStreamMessage(1).asMessage() + am.messageHeader.MessageType = MessageTypeAudio + am.payload = []byte{1} + vm := NewStreamMessage(1).asMessage() + vm.messageHeader.MessageType = MessageTypeVideo + vm.payload = []byte{2} + if err := writer.WriteMessage(ctx, am); err != nil { + t.Fatal(err) + } + if err := writer.WriteMessage(ctx, vm); err != nil { + t.Fatal(err) + } + reader := NewProtocol(&wire) + if got, err := reader.ExpectMessage(ctx, MessageTypeVideo); err != nil || got.MessageType() != MessageTypeVideo { + t.Fatalf("ExpectMessage got=%v err=%v", got, err) + } + + // Generic ExpectPacket skips non-matching packets, then returns matching; it also reports decode/read errors. + wire.Reset() + writer = NewProtocol(&wire).(*protocol) + if err := writer.WritePacket(ctx, &WindowAcknowledgementSize{AckSize: 1}, 0); err != nil { + t.Fatal(err) + } + if err := writer.WritePacket(ctx, &SetChunkSize{ChunkSize: 2}, 0); err != nil { + t.Fatal(err) + } + reader = NewProtocol(&wire) + var chunk *SetChunkSize + if _, err := ExpectPacket(ctx, reader, &chunk); err != nil || chunk.ChunkSize != 2 { + t.Fatalf("ExpectPacket chunk=%v err=%v", chunk, err) + } + reader = NewProtocol(bytes.NewBuffer([]byte{0x05, 0, 0, 0, 0, 0, 1, byte(MessageTypeSetChunkSize), 1, 0, 0, 0, 0})) + if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "decode message") { + t.Fatalf("ExpectPacket decode err=%v", err) + } + reader = NewProtocol(bytes.NewBuffer(nil)) + if _, err := ExpectPacket(ctx, reader, &chunk); err == nil || !strings.Contains(err.Error(), "read message") { + t.Fatalf("ExpectPacket read err=%v", err) + } + + // AMF3 strips the leading byte before AMF0 decoding. + pub := NewPublishPacket() + pub.TransactionID = 0 + pub.StreamName = NewAmf0String("stream") + data := append([]byte{0}, mustPacketBytes(t, pub)...) + msg := &message{payload: data} + msg.messageHeader.MessageType = MessageTypeAMF3Command + if pkt, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(msg); err != nil || pkt.(*PublishPacket).StreamName.String() != "stream" { + t.Fatalf("AMF3 decode pkt=%T err=%v", pkt, err) + } +} + +func TestProtocolErrorBranchesForCoverage(t *testing.T) { + ctx := context.Background() + // ExpectMessage without requested types returns the first message. + var wire bytes.Buffer + w := NewProtocol(&wire).(*protocol) + msg := NewStreamMessage(1).asMessage() + msg.messageHeader.MessageType = MessageTypeAudio + msg.payload = []byte{1} + if err := w.WriteMessage(ctx, msg); err != nil { + t.Fatal(err) + } + if got, err := NewProtocol(&wire).ExpectMessage(ctx); err != nil || got.MessageType() != MessageTypeAudio { + t.Fatalf("ExpectMessage any got=%v err=%v", got, err) + } + cctx, cancel := context.WithCancel(ctx) + cancel() + if _, err := NewProtocol(bytes.NewBuffer(nil)).ReadMessage(cctx); err != context.Canceled { + t.Fatalf("ReadMessage canceled err=%v", err) + } + if err := w.WriteMessage(cctx, (&message{})); err != context.Canceled { + t.Fatalf("WriteMessage empty canceled err=%v", err) + } + badAMF := &message{payload: []byte{0xff}} + badAMF.messageHeader.MessageType = MessageTypeAMF0Command + if _, err := NewProtocol(&bytes.Buffer{}).DecodeMessage(badAMF); err == nil || !strings.Contains(err.Error(), "Parse AMF") { + t.Fatalf("bad AMF decode err=%v", err) + } + rn := commandResult + resultName, _ := (&rn).MarshalBinary() + if _, err := NewProtocol(&bytes.Buffer{}).(*protocol).parseAMFObject(append(resultName, 0)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") { + t.Fatalf("bad result tid err=%v", err) + } +} + +func TestPacketUnmarshalErrorBranchesForCoverage(t *testing.T) { + cn := commandConnect + name, _ := (&cn).MarshalBinary() + tn := amf0Number(1) + tid, _ := (&tn).MarshalBinary() + obj, _ := NewAmf0Object().MarshalBinary() + base := append(append([]byte{}, name...), tid...) + oc := &objectCallPacket{CommandObject: NewAmf0Object()} + if err := oc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") { + t.Fatalf("object tid err=%v", err) + } + if err := oc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal command") { + t.Fatalf("object command err=%v", err) + } + withObj := append(append([]byte{}, base...), obj...) + if err := oc.UnmarshalBinary(append(withObj, 0xff)); err == nil || !strings.Contains(err.Error(), "unmarshal args") { + t.Fatalf("object args err=%v", err) + } + + vc := &variantCallPacket{} + if err := vc.UnmarshalBinary(append([]byte{}, name...)); err == nil || !strings.Contains(err.Error(), "unmarshal tid") { + t.Fatalf("variant tid err=%v", err) + } + if err := vc.UnmarshalBinary(append(append([]byte{}, base...), 0xff)); err == nil || !strings.Contains(err.Error(), "discovery command object") { + t.Fatalf("variant discovery err=%v", err) + } + if err := vc.UnmarshalBinary(append(append([]byte{}, base...), byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal command object") { + t.Fatalf("variant command object err=%v", err) + } + + call := NewCallPacket() + call.CommandName = commandOnStatus + call.TransactionID = 0 + call.CommandObject = NewAmf0Null() + callBase := mustPacketBytes(t, call) + if err := NewCallPacket().UnmarshalBinary(append(callBase, 0xff)); err == nil || !strings.Contains(err.Error(), "discovery args") { + t.Fatalf("call discovery args err=%v", err) + } + if err := NewCallPacket().UnmarshalBinary(append(callBase, byte(amf0MarkerString), 0, 3, 'a')); err == nil || !strings.Contains(err.Error(), "unmarshal args") { + t.Fatalf("call unmarshal args err=%v", err) + } + csr := NewCreateStreamResPacket(2) + if err := NewCreateStreamResPacket(0).UnmarshalBinary(mustPacketBytes(t, &csr.variantCallPacket)); err == nil || !strings.Contains(err.Error(), "unmarshal sid") { + t.Fatalf("create stream sid err=%v", err) + } + pub := NewPublishPacket() + pub.TransactionID = 0 + pubPrefix, _ := pub.variantCallPacket.MarshalBinary() + if err := NewPublishPacket().UnmarshalBinary(append(pubPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") { + t.Fatalf("publish stream name err=%v", err) + } + streamName, _ := NewAmf0String("s").MarshalBinary() + if err := NewPublishPacket().UnmarshalBinary(append(append(pubPrefix, streamName...), 0xff)); err == nil || !strings.Contains(err.Error(), "stream type") { + t.Fatalf("publish stream type err=%v", err) + } + play := NewPlayPacket() + play.TransactionID = 0 + playPrefix, _ := play.variantCallPacket.MarshalBinary() + if err := NewPlayPacket().UnmarshalBinary(append(playPrefix, 0xff)); err == nil || !strings.Contains(err.Error(), "stream name") { + t.Fatalf("play stream name err=%v", err) + } +} diff --git a/internal/rtmp/rtmpfakes/fake_handshake.go b/internal/rtmp/rtmpfakes/fake_handshake.go new file mode 100644 index 000000000..521800811 --- /dev/null +++ b/internal/rtmp/rtmpfakes/fake_handshake.go @@ -0,0 +1,554 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package rtmpfakes + +import ( + "io" + "srsx/internal/rtmp" + "sync" +) + +type FakeHandshake struct { + C1S1Stub func() []byte + c1S1Mutex sync.RWMutex + c1S1ArgsForCall []struct { + } + c1S1Returns struct { + result1 []byte + } + c1S1ReturnsOnCall map[int]struct { + result1 []byte + } + ReadC0S0Stub func(io.Reader) ([]byte, error) + readC0S0Mutex sync.RWMutex + readC0S0ArgsForCall []struct { + arg1 io.Reader + } + readC0S0Returns struct { + result1 []byte + result2 error + } + readC0S0ReturnsOnCall map[int]struct { + result1 []byte + result2 error + } + ReadC1S1Stub func(io.Reader) ([]byte, error) + readC1S1Mutex sync.RWMutex + readC1S1ArgsForCall []struct { + arg1 io.Reader + } + readC1S1Returns struct { + result1 []byte + result2 error + } + readC1S1ReturnsOnCall map[int]struct { + result1 []byte + result2 error + } + ReadC2S2Stub func(io.Reader) ([]byte, error) + readC2S2Mutex sync.RWMutex + readC2S2ArgsForCall []struct { + arg1 io.Reader + } + readC2S2Returns struct { + result1 []byte + result2 error + } + readC2S2ReturnsOnCall map[int]struct { + result1 []byte + result2 error + } + WriteC0S0Stub func(io.Writer) error + writeC0S0Mutex sync.RWMutex + writeC0S0ArgsForCall []struct { + arg1 io.Writer + } + writeC0S0Returns struct { + result1 error + } + writeC0S0ReturnsOnCall map[int]struct { + result1 error + } + WriteC1S1Stub func(io.Writer) error + writeC1S1Mutex sync.RWMutex + writeC1S1ArgsForCall []struct { + arg1 io.Writer + } + writeC1S1Returns struct { + result1 error + } + writeC1S1ReturnsOnCall map[int]struct { + result1 error + } + WriteC2S2Stub func(io.Writer, []byte) error + writeC2S2Mutex sync.RWMutex + writeC2S2ArgsForCall []struct { + arg1 io.Writer + arg2 []byte + } + writeC2S2Returns struct { + result1 error + } + writeC2S2ReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHandshake) C1S1() []byte { + fake.c1S1Mutex.Lock() + ret, specificReturn := fake.c1S1ReturnsOnCall[len(fake.c1S1ArgsForCall)] + fake.c1S1ArgsForCall = append(fake.c1S1ArgsForCall, struct { + }{}) + stub := fake.C1S1Stub + fakeReturns := fake.c1S1Returns + fake.recordInvocation("C1S1", []interface{}{}) + fake.c1S1Mutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandshake) C1S1CallCount() int { + fake.c1S1Mutex.RLock() + defer fake.c1S1Mutex.RUnlock() + return len(fake.c1S1ArgsForCall) +} + +func (fake *FakeHandshake) C1S1Calls(stub func() []byte) { + fake.c1S1Mutex.Lock() + defer fake.c1S1Mutex.Unlock() + fake.C1S1Stub = stub +} + +func (fake *FakeHandshake) C1S1Returns(result1 []byte) { + fake.c1S1Mutex.Lock() + defer fake.c1S1Mutex.Unlock() + fake.C1S1Stub = nil + fake.c1S1Returns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeHandshake) C1S1ReturnsOnCall(i int, result1 []byte) { + fake.c1S1Mutex.Lock() + defer fake.c1S1Mutex.Unlock() + fake.C1S1Stub = nil + if fake.c1S1ReturnsOnCall == nil { + fake.c1S1ReturnsOnCall = make(map[int]struct { + result1 []byte + }) + } + fake.c1S1ReturnsOnCall[i] = struct { + result1 []byte + }{result1} +} + +func (fake *FakeHandshake) ReadC0S0(arg1 io.Reader) ([]byte, error) { + fake.readC0S0Mutex.Lock() + ret, specificReturn := fake.readC0S0ReturnsOnCall[len(fake.readC0S0ArgsForCall)] + fake.readC0S0ArgsForCall = append(fake.readC0S0ArgsForCall, struct { + arg1 io.Reader + }{arg1}) + stub := fake.ReadC0S0Stub + fakeReturns := fake.readC0S0Returns + fake.recordInvocation("ReadC0S0", []interface{}{arg1}) + fake.readC0S0Mutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHandshake) ReadC0S0CallCount() int { + fake.readC0S0Mutex.RLock() + defer fake.readC0S0Mutex.RUnlock() + return len(fake.readC0S0ArgsForCall) +} + +func (fake *FakeHandshake) ReadC0S0Calls(stub func(io.Reader) ([]byte, error)) { + fake.readC0S0Mutex.Lock() + defer fake.readC0S0Mutex.Unlock() + fake.ReadC0S0Stub = stub +} + +func (fake *FakeHandshake) ReadC0S0ArgsForCall(i int) io.Reader { + fake.readC0S0Mutex.RLock() + defer fake.readC0S0Mutex.RUnlock() + argsForCall := fake.readC0S0ArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandshake) ReadC0S0Returns(result1 []byte, result2 error) { + fake.readC0S0Mutex.Lock() + defer fake.readC0S0Mutex.Unlock() + fake.ReadC0S0Stub = nil + fake.readC0S0Returns = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeHandshake) ReadC0S0ReturnsOnCall(i int, result1 []byte, result2 error) { + fake.readC0S0Mutex.Lock() + defer fake.readC0S0Mutex.Unlock() + fake.ReadC0S0Stub = nil + if fake.readC0S0ReturnsOnCall == nil { + fake.readC0S0ReturnsOnCall = make(map[int]struct { + result1 []byte + result2 error + }) + } + fake.readC0S0ReturnsOnCall[i] = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeHandshake) ReadC1S1(arg1 io.Reader) ([]byte, error) { + fake.readC1S1Mutex.Lock() + ret, specificReturn := fake.readC1S1ReturnsOnCall[len(fake.readC1S1ArgsForCall)] + fake.readC1S1ArgsForCall = append(fake.readC1S1ArgsForCall, struct { + arg1 io.Reader + }{arg1}) + stub := fake.ReadC1S1Stub + fakeReturns := fake.readC1S1Returns + fake.recordInvocation("ReadC1S1", []interface{}{arg1}) + fake.readC1S1Mutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHandshake) ReadC1S1CallCount() int { + fake.readC1S1Mutex.RLock() + defer fake.readC1S1Mutex.RUnlock() + return len(fake.readC1S1ArgsForCall) +} + +func (fake *FakeHandshake) ReadC1S1Calls(stub func(io.Reader) ([]byte, error)) { + fake.readC1S1Mutex.Lock() + defer fake.readC1S1Mutex.Unlock() + fake.ReadC1S1Stub = stub +} + +func (fake *FakeHandshake) ReadC1S1ArgsForCall(i int) io.Reader { + fake.readC1S1Mutex.RLock() + defer fake.readC1S1Mutex.RUnlock() + argsForCall := fake.readC1S1ArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandshake) ReadC1S1Returns(result1 []byte, result2 error) { + fake.readC1S1Mutex.Lock() + defer fake.readC1S1Mutex.Unlock() + fake.ReadC1S1Stub = nil + fake.readC1S1Returns = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeHandshake) ReadC1S1ReturnsOnCall(i int, result1 []byte, result2 error) { + fake.readC1S1Mutex.Lock() + defer fake.readC1S1Mutex.Unlock() + fake.ReadC1S1Stub = nil + if fake.readC1S1ReturnsOnCall == nil { + fake.readC1S1ReturnsOnCall = make(map[int]struct { + result1 []byte + result2 error + }) + } + fake.readC1S1ReturnsOnCall[i] = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeHandshake) ReadC2S2(arg1 io.Reader) ([]byte, error) { + fake.readC2S2Mutex.Lock() + ret, specificReturn := fake.readC2S2ReturnsOnCall[len(fake.readC2S2ArgsForCall)] + fake.readC2S2ArgsForCall = append(fake.readC2S2ArgsForCall, struct { + arg1 io.Reader + }{arg1}) + stub := fake.ReadC2S2Stub + fakeReturns := fake.readC2S2Returns + fake.recordInvocation("ReadC2S2", []interface{}{arg1}) + fake.readC2S2Mutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeHandshake) ReadC2S2CallCount() int { + fake.readC2S2Mutex.RLock() + defer fake.readC2S2Mutex.RUnlock() + return len(fake.readC2S2ArgsForCall) +} + +func (fake *FakeHandshake) ReadC2S2Calls(stub func(io.Reader) ([]byte, error)) { + fake.readC2S2Mutex.Lock() + defer fake.readC2S2Mutex.Unlock() + fake.ReadC2S2Stub = stub +} + +func (fake *FakeHandshake) ReadC2S2ArgsForCall(i int) io.Reader { + fake.readC2S2Mutex.RLock() + defer fake.readC2S2Mutex.RUnlock() + argsForCall := fake.readC2S2ArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandshake) ReadC2S2Returns(result1 []byte, result2 error) { + fake.readC2S2Mutex.Lock() + defer fake.readC2S2Mutex.Unlock() + fake.ReadC2S2Stub = nil + fake.readC2S2Returns = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeHandshake) ReadC2S2ReturnsOnCall(i int, result1 []byte, result2 error) { + fake.readC2S2Mutex.Lock() + defer fake.readC2S2Mutex.Unlock() + fake.ReadC2S2Stub = nil + if fake.readC2S2ReturnsOnCall == nil { + fake.readC2S2ReturnsOnCall = make(map[int]struct { + result1 []byte + result2 error + }) + } + fake.readC2S2ReturnsOnCall[i] = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeHandshake) WriteC0S0(arg1 io.Writer) error { + fake.writeC0S0Mutex.Lock() + ret, specificReturn := fake.writeC0S0ReturnsOnCall[len(fake.writeC0S0ArgsForCall)] + fake.writeC0S0ArgsForCall = append(fake.writeC0S0ArgsForCall, struct { + arg1 io.Writer + }{arg1}) + stub := fake.WriteC0S0Stub + fakeReturns := fake.writeC0S0Returns + fake.recordInvocation("WriteC0S0", []interface{}{arg1}) + fake.writeC0S0Mutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandshake) WriteC0S0CallCount() int { + fake.writeC0S0Mutex.RLock() + defer fake.writeC0S0Mutex.RUnlock() + return len(fake.writeC0S0ArgsForCall) +} + +func (fake *FakeHandshake) WriteC0S0Calls(stub func(io.Writer) error) { + fake.writeC0S0Mutex.Lock() + defer fake.writeC0S0Mutex.Unlock() + fake.WriteC0S0Stub = stub +} + +func (fake *FakeHandshake) WriteC0S0ArgsForCall(i int) io.Writer { + fake.writeC0S0Mutex.RLock() + defer fake.writeC0S0Mutex.RUnlock() + argsForCall := fake.writeC0S0ArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandshake) WriteC0S0Returns(result1 error) { + fake.writeC0S0Mutex.Lock() + defer fake.writeC0S0Mutex.Unlock() + fake.WriteC0S0Stub = nil + fake.writeC0S0Returns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandshake) WriteC0S0ReturnsOnCall(i int, result1 error) { + fake.writeC0S0Mutex.Lock() + defer fake.writeC0S0Mutex.Unlock() + fake.WriteC0S0Stub = nil + if fake.writeC0S0ReturnsOnCall == nil { + fake.writeC0S0ReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeC0S0ReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandshake) WriteC1S1(arg1 io.Writer) error { + fake.writeC1S1Mutex.Lock() + ret, specificReturn := fake.writeC1S1ReturnsOnCall[len(fake.writeC1S1ArgsForCall)] + fake.writeC1S1ArgsForCall = append(fake.writeC1S1ArgsForCall, struct { + arg1 io.Writer + }{arg1}) + stub := fake.WriteC1S1Stub + fakeReturns := fake.writeC1S1Returns + fake.recordInvocation("WriteC1S1", []interface{}{arg1}) + fake.writeC1S1Mutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandshake) WriteC1S1CallCount() int { + fake.writeC1S1Mutex.RLock() + defer fake.writeC1S1Mutex.RUnlock() + return len(fake.writeC1S1ArgsForCall) +} + +func (fake *FakeHandshake) WriteC1S1Calls(stub func(io.Writer) error) { + fake.writeC1S1Mutex.Lock() + defer fake.writeC1S1Mutex.Unlock() + fake.WriteC1S1Stub = stub +} + +func (fake *FakeHandshake) WriteC1S1ArgsForCall(i int) io.Writer { + fake.writeC1S1Mutex.RLock() + defer fake.writeC1S1Mutex.RUnlock() + argsForCall := fake.writeC1S1ArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandshake) WriteC1S1Returns(result1 error) { + fake.writeC1S1Mutex.Lock() + defer fake.writeC1S1Mutex.Unlock() + fake.WriteC1S1Stub = nil + fake.writeC1S1Returns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandshake) WriteC1S1ReturnsOnCall(i int, result1 error) { + fake.writeC1S1Mutex.Lock() + defer fake.writeC1S1Mutex.Unlock() + fake.WriteC1S1Stub = nil + if fake.writeC1S1ReturnsOnCall == nil { + fake.writeC1S1ReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeC1S1ReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandshake) WriteC2S2(arg1 io.Writer, arg2 []byte) error { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.writeC2S2Mutex.Lock() + ret, specificReturn := fake.writeC2S2ReturnsOnCall[len(fake.writeC2S2ArgsForCall)] + fake.writeC2S2ArgsForCall = append(fake.writeC2S2ArgsForCall, struct { + arg1 io.Writer + arg2 []byte + }{arg1, arg2Copy}) + stub := fake.WriteC2S2Stub + fakeReturns := fake.writeC2S2Returns + fake.recordInvocation("WriteC2S2", []interface{}{arg1, arg2Copy}) + fake.writeC2S2Mutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandshake) WriteC2S2CallCount() int { + fake.writeC2S2Mutex.RLock() + defer fake.writeC2S2Mutex.RUnlock() + return len(fake.writeC2S2ArgsForCall) +} + +func (fake *FakeHandshake) WriteC2S2Calls(stub func(io.Writer, []byte) error) { + fake.writeC2S2Mutex.Lock() + defer fake.writeC2S2Mutex.Unlock() + fake.WriteC2S2Stub = stub +} + +func (fake *FakeHandshake) WriteC2S2ArgsForCall(i int) (io.Writer, []byte) { + fake.writeC2S2Mutex.RLock() + defer fake.writeC2S2Mutex.RUnlock() + argsForCall := fake.writeC2S2ArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandshake) WriteC2S2Returns(result1 error) { + fake.writeC2S2Mutex.Lock() + defer fake.writeC2S2Mutex.Unlock() + fake.WriteC2S2Stub = nil + fake.writeC2S2Returns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandshake) WriteC2S2ReturnsOnCall(i int, result1 error) { + fake.writeC2S2Mutex.Lock() + defer fake.writeC2S2Mutex.Unlock() + fake.WriteC2S2Stub = nil + if fake.writeC2S2ReturnsOnCall == nil { + fake.writeC2S2ReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeC2S2ReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandshake) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHandshake) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ rtmp.Handshake = new(FakeHandshake) diff --git a/internal/rtmp/rtmpfakes/fake_protocol.go b/internal/rtmp/rtmpfakes/fake_protocol.go new file mode 100644 index 000000000..abbb5d3d3 --- /dev/null +++ b/internal/rtmp/rtmpfakes/fake_protocol.go @@ -0,0 +1,499 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package rtmpfakes + +import ( + "context" + "srsx/internal/rtmp" + "sync" +) + +type FakeProtocol struct { + DecodeMessageStub func(rtmp.Message) (rtmp.Packet, error) + decodeMessageMutex sync.RWMutex + decodeMessageArgsForCall []struct { + arg1 rtmp.Message + } + decodeMessageReturns struct { + result1 rtmp.Packet + result2 error + } + decodeMessageReturnsOnCall map[int]struct { + result1 rtmp.Packet + result2 error + } + ExpectMessageStub func(context.Context, ...rtmp.MessageType) (rtmp.Message, error) + expectMessageMutex sync.RWMutex + expectMessageArgsForCall []struct { + arg1 context.Context + arg2 []rtmp.MessageType + } + expectMessageReturns struct { + result1 rtmp.Message + result2 error + } + expectMessageReturnsOnCall map[int]struct { + result1 rtmp.Message + result2 error + } + ExpectPacketStub func(context.Context, any) (rtmp.Message, error) + expectPacketMutex sync.RWMutex + expectPacketArgsForCall []struct { + arg1 context.Context + arg2 any + } + expectPacketReturns struct { + result1 rtmp.Message + result2 error + } + expectPacketReturnsOnCall map[int]struct { + result1 rtmp.Message + result2 error + } + ReadMessageStub func(context.Context) (rtmp.Message, error) + readMessageMutex sync.RWMutex + readMessageArgsForCall []struct { + arg1 context.Context + } + readMessageReturns struct { + result1 rtmp.Message + result2 error + } + readMessageReturnsOnCall map[int]struct { + result1 rtmp.Message + result2 error + } + WriteMessageStub func(context.Context, rtmp.Message) error + writeMessageMutex sync.RWMutex + writeMessageArgsForCall []struct { + arg1 context.Context + arg2 rtmp.Message + } + writeMessageReturns struct { + result1 error + } + writeMessageReturnsOnCall map[int]struct { + result1 error + } + WritePacketStub func(context.Context, rtmp.Packet, int) error + writePacketMutex sync.RWMutex + writePacketArgsForCall []struct { + arg1 context.Context + arg2 rtmp.Packet + arg3 int + } + writePacketReturns struct { + result1 error + } + writePacketReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeProtocol) DecodeMessage(arg1 rtmp.Message) (rtmp.Packet, error) { + fake.decodeMessageMutex.Lock() + ret, specificReturn := fake.decodeMessageReturnsOnCall[len(fake.decodeMessageArgsForCall)] + fake.decodeMessageArgsForCall = append(fake.decodeMessageArgsForCall, struct { + arg1 rtmp.Message + }{arg1}) + stub := fake.DecodeMessageStub + fakeReturns := fake.decodeMessageReturns + fake.recordInvocation("DecodeMessage", []interface{}{arg1}) + fake.decodeMessageMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeProtocol) DecodeMessageCallCount() int { + fake.decodeMessageMutex.RLock() + defer fake.decodeMessageMutex.RUnlock() + return len(fake.decodeMessageArgsForCall) +} + +func (fake *FakeProtocol) DecodeMessageCalls(stub func(rtmp.Message) (rtmp.Packet, error)) { + fake.decodeMessageMutex.Lock() + defer fake.decodeMessageMutex.Unlock() + fake.DecodeMessageStub = stub +} + +func (fake *FakeProtocol) DecodeMessageArgsForCall(i int) rtmp.Message { + fake.decodeMessageMutex.RLock() + defer fake.decodeMessageMutex.RUnlock() + argsForCall := fake.decodeMessageArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeProtocol) DecodeMessageReturns(result1 rtmp.Packet, result2 error) { + fake.decodeMessageMutex.Lock() + defer fake.decodeMessageMutex.Unlock() + fake.DecodeMessageStub = nil + fake.decodeMessageReturns = struct { + result1 rtmp.Packet + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) DecodeMessageReturnsOnCall(i int, result1 rtmp.Packet, result2 error) { + fake.decodeMessageMutex.Lock() + defer fake.decodeMessageMutex.Unlock() + fake.DecodeMessageStub = nil + if fake.decodeMessageReturnsOnCall == nil { + fake.decodeMessageReturnsOnCall = make(map[int]struct { + result1 rtmp.Packet + result2 error + }) + } + fake.decodeMessageReturnsOnCall[i] = struct { + result1 rtmp.Packet + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) ExpectMessage(arg1 context.Context, arg2 ...rtmp.MessageType) (rtmp.Message, error) { + fake.expectMessageMutex.Lock() + ret, specificReturn := fake.expectMessageReturnsOnCall[len(fake.expectMessageArgsForCall)] + fake.expectMessageArgsForCall = append(fake.expectMessageArgsForCall, struct { + arg1 context.Context + arg2 []rtmp.MessageType + }{arg1, arg2}) + stub := fake.ExpectMessageStub + fakeReturns := fake.expectMessageReturns + fake.recordInvocation("ExpectMessage", []interface{}{arg1, arg2}) + fake.expectMessageMutex.Unlock() + if stub != nil { + return stub(arg1, arg2...) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeProtocol) ExpectMessageCallCount() int { + fake.expectMessageMutex.RLock() + defer fake.expectMessageMutex.RUnlock() + return len(fake.expectMessageArgsForCall) +} + +func (fake *FakeProtocol) ExpectMessageCalls(stub func(context.Context, ...rtmp.MessageType) (rtmp.Message, error)) { + fake.expectMessageMutex.Lock() + defer fake.expectMessageMutex.Unlock() + fake.ExpectMessageStub = stub +} + +func (fake *FakeProtocol) ExpectMessageArgsForCall(i int) (context.Context, []rtmp.MessageType) { + fake.expectMessageMutex.RLock() + defer fake.expectMessageMutex.RUnlock() + argsForCall := fake.expectMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeProtocol) ExpectMessageReturns(result1 rtmp.Message, result2 error) { + fake.expectMessageMutex.Lock() + defer fake.expectMessageMutex.Unlock() + fake.ExpectMessageStub = nil + fake.expectMessageReturns = struct { + result1 rtmp.Message + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) ExpectMessageReturnsOnCall(i int, result1 rtmp.Message, result2 error) { + fake.expectMessageMutex.Lock() + defer fake.expectMessageMutex.Unlock() + fake.ExpectMessageStub = nil + if fake.expectMessageReturnsOnCall == nil { + fake.expectMessageReturnsOnCall = make(map[int]struct { + result1 rtmp.Message + result2 error + }) + } + fake.expectMessageReturnsOnCall[i] = struct { + result1 rtmp.Message + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) ExpectPacket(arg1 context.Context, arg2 any) (rtmp.Message, error) { + fake.expectPacketMutex.Lock() + ret, specificReturn := fake.expectPacketReturnsOnCall[len(fake.expectPacketArgsForCall)] + fake.expectPacketArgsForCall = append(fake.expectPacketArgsForCall, struct { + arg1 context.Context + arg2 any + }{arg1, arg2}) + stub := fake.ExpectPacketStub + fakeReturns := fake.expectPacketReturns + fake.recordInvocation("ExpectPacket", []interface{}{arg1, arg2}) + fake.expectPacketMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeProtocol) ExpectPacketCallCount() int { + fake.expectPacketMutex.RLock() + defer fake.expectPacketMutex.RUnlock() + return len(fake.expectPacketArgsForCall) +} + +func (fake *FakeProtocol) ExpectPacketCalls(stub func(context.Context, any) (rtmp.Message, error)) { + fake.expectPacketMutex.Lock() + defer fake.expectPacketMutex.Unlock() + fake.ExpectPacketStub = stub +} + +func (fake *FakeProtocol) ExpectPacketArgsForCall(i int) (context.Context, any) { + fake.expectPacketMutex.RLock() + defer fake.expectPacketMutex.RUnlock() + argsForCall := fake.expectPacketArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeProtocol) ExpectPacketReturns(result1 rtmp.Message, result2 error) { + fake.expectPacketMutex.Lock() + defer fake.expectPacketMutex.Unlock() + fake.ExpectPacketStub = nil + fake.expectPacketReturns = struct { + result1 rtmp.Message + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) ExpectPacketReturnsOnCall(i int, result1 rtmp.Message, result2 error) { + fake.expectPacketMutex.Lock() + defer fake.expectPacketMutex.Unlock() + fake.ExpectPacketStub = nil + if fake.expectPacketReturnsOnCall == nil { + fake.expectPacketReturnsOnCall = make(map[int]struct { + result1 rtmp.Message + result2 error + }) + } + fake.expectPacketReturnsOnCall[i] = struct { + result1 rtmp.Message + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) ReadMessage(arg1 context.Context) (rtmp.Message, error) { + fake.readMessageMutex.Lock() + ret, specificReturn := fake.readMessageReturnsOnCall[len(fake.readMessageArgsForCall)] + fake.readMessageArgsForCall = append(fake.readMessageArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.ReadMessageStub + fakeReturns := fake.readMessageReturns + fake.recordInvocation("ReadMessage", []interface{}{arg1}) + fake.readMessageMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeProtocol) ReadMessageCallCount() int { + fake.readMessageMutex.RLock() + defer fake.readMessageMutex.RUnlock() + return len(fake.readMessageArgsForCall) +} + +func (fake *FakeProtocol) ReadMessageCalls(stub func(context.Context) (rtmp.Message, error)) { + fake.readMessageMutex.Lock() + defer fake.readMessageMutex.Unlock() + fake.ReadMessageStub = stub +} + +func (fake *FakeProtocol) ReadMessageArgsForCall(i int) context.Context { + fake.readMessageMutex.RLock() + defer fake.readMessageMutex.RUnlock() + argsForCall := fake.readMessageArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeProtocol) ReadMessageReturns(result1 rtmp.Message, result2 error) { + fake.readMessageMutex.Lock() + defer fake.readMessageMutex.Unlock() + fake.ReadMessageStub = nil + fake.readMessageReturns = struct { + result1 rtmp.Message + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) ReadMessageReturnsOnCall(i int, result1 rtmp.Message, result2 error) { + fake.readMessageMutex.Lock() + defer fake.readMessageMutex.Unlock() + fake.ReadMessageStub = nil + if fake.readMessageReturnsOnCall == nil { + fake.readMessageReturnsOnCall = make(map[int]struct { + result1 rtmp.Message + result2 error + }) + } + fake.readMessageReturnsOnCall[i] = struct { + result1 rtmp.Message + result2 error + }{result1, result2} +} + +func (fake *FakeProtocol) WriteMessage(arg1 context.Context, arg2 rtmp.Message) error { + fake.writeMessageMutex.Lock() + ret, specificReturn := fake.writeMessageReturnsOnCall[len(fake.writeMessageArgsForCall)] + fake.writeMessageArgsForCall = append(fake.writeMessageArgsForCall, struct { + arg1 context.Context + arg2 rtmp.Message + }{arg1, arg2}) + stub := fake.WriteMessageStub + fakeReturns := fake.writeMessageReturns + fake.recordInvocation("WriteMessage", []interface{}{arg1, arg2}) + fake.writeMessageMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeProtocol) WriteMessageCallCount() int { + fake.writeMessageMutex.RLock() + defer fake.writeMessageMutex.RUnlock() + return len(fake.writeMessageArgsForCall) +} + +func (fake *FakeProtocol) WriteMessageCalls(stub func(context.Context, rtmp.Message) error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = stub +} + +func (fake *FakeProtocol) WriteMessageArgsForCall(i int) (context.Context, rtmp.Message) { + fake.writeMessageMutex.RLock() + defer fake.writeMessageMutex.RUnlock() + argsForCall := fake.writeMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeProtocol) WriteMessageReturns(result1 error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = nil + fake.writeMessageReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeProtocol) WriteMessageReturnsOnCall(i int, result1 error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = nil + if fake.writeMessageReturnsOnCall == nil { + fake.writeMessageReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeProtocol) WritePacket(arg1 context.Context, arg2 rtmp.Packet, arg3 int) error { + fake.writePacketMutex.Lock() + ret, specificReturn := fake.writePacketReturnsOnCall[len(fake.writePacketArgsForCall)] + fake.writePacketArgsForCall = append(fake.writePacketArgsForCall, struct { + arg1 context.Context + arg2 rtmp.Packet + arg3 int + }{arg1, arg2, arg3}) + stub := fake.WritePacketStub + fakeReturns := fake.writePacketReturns + fake.recordInvocation("WritePacket", []interface{}{arg1, arg2, arg3}) + fake.writePacketMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeProtocol) WritePacketCallCount() int { + fake.writePacketMutex.RLock() + defer fake.writePacketMutex.RUnlock() + return len(fake.writePacketArgsForCall) +} + +func (fake *FakeProtocol) WritePacketCalls(stub func(context.Context, rtmp.Packet, int) error) { + fake.writePacketMutex.Lock() + defer fake.writePacketMutex.Unlock() + fake.WritePacketStub = stub +} + +func (fake *FakeProtocol) WritePacketArgsForCall(i int) (context.Context, rtmp.Packet, int) { + fake.writePacketMutex.RLock() + defer fake.writePacketMutex.RUnlock() + argsForCall := fake.writePacketArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeProtocol) WritePacketReturns(result1 error) { + fake.writePacketMutex.Lock() + defer fake.writePacketMutex.Unlock() + fake.WritePacketStub = nil + fake.writePacketReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeProtocol) WritePacketReturnsOnCall(i int, result1 error) { + fake.writePacketMutex.Lock() + defer fake.writePacketMutex.Unlock() + fake.WritePacketStub = nil + if fake.writePacketReturnsOnCall == nil { + fake.writePacketReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writePacketReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeProtocol) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeProtocol) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ rtmp.Protocol = new(FakeProtocol) diff --git a/internal/protocol/api.go b/internal/server/api.go similarity index 87% rename from internal/protocol/api.go rename to internal/server/api.go index d1f3ef26d..a69353ee5 100644 --- a/internal/protocol/api.go +++ b/internal/server/api.go @@ -1,7 +1,7 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package protocol +package server import ( "context" @@ -20,23 +20,28 @@ import ( "srsx/internal/version" ) -// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, +// HTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, // to proxy other HTTP API of SRS like the streams and clients, etc. -type srsHTTPAPIServer struct { +type HTTPAPIServer interface { + Run(ctx context.Context) error + Close() error +} + +type httpAPIServer struct { // The environment interface. environment env.ProxyEnvironment // The underlayer HTTP server. server *http.Server // The WebRTC server. - rtc *srsWebRTCServer + rtc WebRTCServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. wg sync.WaitGroup } -func NewSRSHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc *srsWebRTCServer) *srsHTTPAPIServer { - v := &srsHTTPAPIServer{ +func NewHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration, rtc WebRTCServer) HTTPAPIServer { + v := &httpAPIServer{ environment: environment, gracefulQuitTimeout: gracefulQuitTimeout, rtc: rtc, @@ -44,7 +49,7 @@ func NewSRSHTTPAPIServer(environment env.ProxyEnvironment, gracefulQuitTimeout t return v } -func (v *srsHTTPAPIServer) Close() error { +func (v *httpAPIServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() v.server.Shutdown(ctx) @@ -53,7 +58,7 @@ func (v *srsHTTPAPIServer) Close() error { return nil } -func (v *srsHTTPAPIServer) Run(ctx context.Context) error { +func (v *httpAPIServer) Run(ctx context.Context) error { // Parse address to listen. addr := v.environment.HttpAPI() if !strings.Contains(addr, ":") { @@ -92,6 +97,13 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error { utils.ApiError(ctx, w, r, err) } }) + // Keep compatibility with the legacy SRS WebRTC publish API used by srs-bench. + logger.Debug(ctx, "Handle /rtc/v1/publish/ by %v", addr) + mux.HandleFunc("/rtc/v1/publish/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { + utils.ApiError(ctx, w, r, err) + } + }) // The WebRTC WHEP API handler. logger.Debug(ctx, "Handle /rtc/v1/whep/ by %v", addr) @@ -100,6 +112,13 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error { utils.ApiError(ctx, w, r, err) } }) + // Keep compatibility with the legacy SRS WebRTC play API used by srs-bench. + logger.Debug(ctx, "Handle /rtc/v1/play/ by %v", addr) + mux.HandleFunc("/rtc/v1/play/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { + utils.ApiError(ctx, w, r, err) + } + }) // Run HTTP API server. v.wg.Add(1) diff --git a/internal/protocol/http.go b/internal/server/http.go similarity index 87% rename from internal/protocol/http.go rename to internal/server/http.go index a145c551e..21db47741 100644 --- a/internal/protocol/http.go +++ b/internal/server/http.go @@ -1,7 +1,7 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package protocol +package server import ( "context" @@ -23,10 +23,15 @@ import ( "srsx/internal/version" ) -// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, +// HTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, // HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy // the request to the origin server. -type srsHTTPStreamServer struct { +type HTTPStreamServer interface { + Run(ctx context.Context) error + Close() error +} + +type httpStreamServer struct { // The environment interface. environment env.ProxyEnvironment // The underlayer HTTP server. @@ -37,15 +42,15 @@ type srsHTTPStreamServer struct { wg stdSync.WaitGroup } -func NewSRSHTTPStreamServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) *srsHTTPStreamServer { - v := &srsHTTPStreamServer{ +func NewHTTPStreamServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) HTTPStreamServer { + v := &httpStreamServer{ environment: environment, gracefulQuitTimeout: gracefulQuitTimeout, } return v } -func (v *srsHTTPStreamServer) Close() error { +func (v *httpStreamServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() v.server.Shutdown(ctx) @@ -54,7 +59,7 @@ func (v *srsHTTPStreamServer) Close() error { return nil } -func (v *srsHTTPStreamServer) Run(ctx context.Context) error { +func (v *httpStreamServer) Run(ctx context.Context) error { // Parse address to listen. addr := v.environment.HttpServer() if !strings.Contains(addr, ":") { @@ -123,12 +128,12 @@ func (v *srsHTTPStreamServer) Run(ctx context.Context) error { return } - stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { + stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = logger.GenerateContextID() s.StreamURL, s.FullURL = streamURL, fullURL })) - stream.Initialize(ctx).(*HLSPlayStream).ServeHTTP(w, r) + stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r) return } @@ -140,13 +145,13 @@ func (v *srsHTTPStreamServer) Run(ctx context.Context) error { if stream, err := lb.SrsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) } else { - stream.Initialize(ctx).(*HLSPlayStream).ServeHTTP(w, r) + stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r) } return } // Use HTTP pseudo streaming to proxy the request. - NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) { + newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { c.ctx = ctx }).ServeHTTP(w, r) return @@ -182,26 +187,26 @@ func (v *srsHTTPStreamServer) Run(ctx context.Context) error { return nil } -// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS +// httpFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS // connection. There is no state need to be sync between proxy servers. // // When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request, // then proxy to the corresponding backend server. All state is in the HTTP request, so this // connection is stateless. -type HTTPFlvTsConnection struct { +type httpFlvTsConnection struct { // The context for HTTP streaming. ctx context.Context } -func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection { - v := &HTTPFlvTsConnection{} +func newHTTPFlvTsConnection(opts ...func(*httpFlvTsConnection)) *httpFlvTsConnection { + v := &httpFlvTsConnection{} for _, opt := range opts { opt(v) } return v } -func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (v *httpFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() ctx := logger.WithContext(v.ctx) @@ -212,7 +217,7 @@ func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) } } -func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *httpFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { // Always allow CORS for all requests. if ok := utils.ApiCORS(ctx, w, r); ok { return nil @@ -240,7 +245,7 @@ func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, return nil } -func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error { +func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { return errors.Errorf("no http stream server") @@ -288,14 +293,14 @@ func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons return nil } -// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS +// hlsPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS // clients will share this object, and they do not use the same ctx among proxy servers. // // Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. // Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create // the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert // to the stream URL and then query the backend server to serve it. -type HLSPlayStream struct { +type hlsPlayStream struct { // The context for HLS streaming. ctx context.Context @@ -307,26 +312,26 @@ type HLSPlayStream struct { FullURL string `json:"full_url"` } -func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { - v := &HLSPlayStream{} +func newHLSPlayStream(opts ...func(*hlsPlayStream)) *hlsPlayStream { + v := &hlsPlayStream{} for _, opt := range opts { opt(v) } return v } -func (v *HLSPlayStream) Initialize(ctx context.Context) lb.HLSPlayStream { +func (v *hlsPlayStream) Initialize(ctx context.Context) lb.HLSPlayStream { if v.ctx == nil { v.ctx = logger.WithContext(ctx) } return v } -func (v *HLSPlayStream) GetSPBHID() string { +func (v *hlsPlayStream) GetSPBHID() string { return v.SRSProxyBackendHLSID } -func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (v *hlsPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := v.serve(v.ctx, w, r); err != nil { @@ -337,7 +342,7 @@ func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL // Always allow CORS for all requests. @@ -358,7 +363,7 @@ func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt return nil } -func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error { +func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { return errors.Errorf("no rtmp server") diff --git a/internal/protocol/rtc.go b/internal/server/rtc.go similarity index 86% rename from internal/protocol/rtc.go rename to internal/server/rtc.go index 51792f9ca..7a85e0bbb 100644 --- a/internal/protocol/rtc.go +++ b/internal/server/rtc.go @@ -1,7 +1,7 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package protocol +package server import ( "context" @@ -23,10 +23,17 @@ import ( "srsx/internal/utils" ) -// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out +// WebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out // which backend server to proxy to. It will also replace the UDP port to the proxy server's in the // SDP answer. -type srsWebRTCServer struct { +type WebRTCServer interface { + Run(ctx context.Context) error + Close() error + HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error + HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error +} + +type webRTCServer struct { // The environment interface. environment env.ProxyEnvironment // The UDP listener for WebRTC server. @@ -34,21 +41,21 @@ type srsWebRTCServer struct { // Fast cache for the username to identify the connection. // The key is username, the value is the UDP address. - usernames sync.Map[string, *RTCConnection] + usernames sync.Map[string, *rtcConnection] // Fast cache for the udp address to identify the connection. // The key is UDP address, the value is the username. // TODO: Support fast earch by uint64 address. - addresses sync.Map[string, *RTCConnection] + addresses sync.Map[string, *rtcConnection] // The wait group for server. wg stdSync.WaitGroup } -func NewSRSWebRTCServer(environment env.ProxyEnvironment, opts ...func(*srsWebRTCServer)) *srsWebRTCServer { - v := &srsWebRTCServer{ +func NewWebRTCServer(environment env.ProxyEnvironment, opts ...func(*webRTCServer)) WebRTCServer { + v := &webRTCServer{ environment: environment, - usernames: sync.NewMap[string, *RTCConnection](), - addresses: sync.NewMap[string, *RTCConnection](), + usernames: sync.NewMap[string, *rtcConnection](), + addresses: sync.NewMap[string, *rtcConnection](), } for _, opt := range opts { opt(v) @@ -56,7 +63,7 @@ func NewSRSWebRTCServer(environment env.ProxyEnvironment, opts ...func(*srsWebRT return v } -func (v *srsWebRTCServer) Close() error { +func (v *webRTCServer) Close() error { if v.listener != nil { _ = v.listener.Close() } @@ -65,7 +72,7 @@ func (v *srsWebRTCServer) Close() error { return nil } -func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *webRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -102,7 +109,7 @@ func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseW return nil } -func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *webRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -139,7 +146,7 @@ func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseW return nil } -func (v *srsWebRTCServer) proxyApiToBackend( +func (v *webRTCServer) proxyApiToBackend( ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.SRSServer, remoteSDPOffer string, streamURL string, ) error { @@ -215,11 +222,11 @@ func (v *srsWebRTCServer) proxyApiToBackend( } // Save the new WebRTC connection to LB. - icePair := &RTCICEPair{ + icePair := &rtcICEPair{ RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, } - if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) { + if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) { c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() c.Initialize(ctx, v.listener) @@ -239,7 +246,7 @@ func (v *srsWebRTCServer) proxyApiToBackend( return nil } -func (v *srsWebRTCServer) Run(ctx context.Context) error { +func (v *webRTCServer) Run(ctx context.Context) error { // Parse address to listen. endpoint := v.environment.WebRTCServer() if !strings.Contains(endpoint, ":") { @@ -287,8 +294,8 @@ func (v *srsWebRTCServer) Run(ctx context.Context) error { return nil } -func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { - var connection *RTCConnection +func (v *webRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + var connection *rtcConnection // If STUN binding request, parse the ufrag and identify the connection. if err := func() error { @@ -296,7 +303,7 @@ func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr return nil } - var pkt RTCStunPacket + var pkt rtcStunPacket if err := pkt.UnmarshalBinary(data); err != nil { return errors.Wrapf(err, "unmarshal stun packet") } @@ -311,7 +318,7 @@ func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr if s, err := lb.SrsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) } else { - connection = s.(*RTCConnection).Initialize(ctx, v.listener) + connection = s.(*rtcConnection).Initialize(ctx, v.listener) logger.Debug(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) } @@ -346,17 +353,17 @@ func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr return nil } -// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC +// rtcConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC // connection, identify by the ufrag in sdp offer/answer and ICE binding request. // // It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is -// in the client request. The RTCConnection is stateful, and need to sync the ufrag between +// in the client request. The rtcConnection is stateful, and need to sync the ufrag between // proxy servers. // // The media transport is UDP, which is also a special thing for WebRTC. So if the client switch // to another UDP address, it may connect to another WebRTC proxy, then we should discover the -// RTCConnection by the ufrag from the ICE binding request. -type RTCConnection struct { +// rtcConnection by the ufrag from the ICE binding request. +type rtcConnection struct { // The stream context for WebRTC streaming. ctx context.Context @@ -373,15 +380,15 @@ type RTCConnection struct { listenerUDP *net.UDPConn } -func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { - v := &RTCConnection{} +func newRTCConnection(opts ...func(*rtcConnection)) *rtcConnection { + v := &rtcConnection{} for _, opt := range opts { opt(v) } return v } -func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { +func (v *rtcConnection) Initialize(ctx context.Context, listener *net.UDPConn) *rtcConnection { if v.ctx == nil { v.ctx = logger.WithContext(ctx) } @@ -391,11 +398,11 @@ func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) * return v } -func (v *RTCConnection) GetUfrag() string { +func (v *rtcConnection) GetUfrag() string { return v.Ufrag } -func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { +func (v *rtcConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { ctx := v.ctx // Update the current UDP address. @@ -437,7 +444,7 @@ func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { return nil } -func (v *RTCConnection) connectBackend(ctx context.Context) error { +func (v *rtcConnection) connectBackend(ctx context.Context) error { if v.backendUDP != nil { return nil } @@ -470,7 +477,7 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { return nil } -type RTCICEPair struct { +type rtcICEPair struct { // The remote ufrag, used for ICE username and session id. RemoteICEUfrag string `json:"remote_ufrag"` // The remote pwd, used for ICE password. @@ -482,18 +489,18 @@ type RTCICEPair struct { } // Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. -func (v *RTCICEPair) Ufrag() string { +func (v *rtcICEPair) Ufrag() string { return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) } -type RTCStunPacket struct { +type rtcStunPacket struct { // The stun message type. MessageType uint16 // The stun username, or ufrag. Username string } -func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { +func (v *rtcStunPacket) UnmarshalBinary(data []byte) error { if len(data) < 20 { return errors.Errorf("stun packet too short %v", len(data)) } diff --git a/internal/protocol/rtmp.go b/internal/server/rtmp.go similarity index 90% rename from internal/protocol/rtmp.go rename to internal/server/rtmp.go index d5c554b7f..b787e99c8 100644 --- a/internal/protocol/rtmp.go +++ b/internal/server/rtmp.go @@ -1,7 +1,7 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package protocol +package server import ( "context" @@ -20,10 +20,15 @@ import ( "srsx/internal/version" ) -// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS +// RTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS // server. It will figure out the backend server to proxy to. Unlike the edge server, it will // not cache the stream, but just proxy the stream to backend. -type srsRTMPServer struct { +type RTMPServer interface { + Run(ctx context.Context) error + Close() error +} + +type rtmpServer struct { // The environment interface. environment env.ProxyEnvironment // The TCP listener for RTMP server. @@ -32,15 +37,15 @@ type srsRTMPServer struct { wg sync.WaitGroup } -func NewSRSRTMPServer(environment env.ProxyEnvironment, opts ...func(*srsRTMPServer)) *srsRTMPServer { - v := &srsRTMPServer{environment: environment} +func NewRTMPServer(environment env.ProxyEnvironment, opts ...func(*rtmpServer)) RTMPServer { + v := &rtmpServer{environment: environment} for _, opt := range opts { opt(v) } return v } -func (v *srsRTMPServer) Close() error { +func (v *rtmpServer) Close() error { if v.listener != nil { v.listener.Close() } @@ -49,7 +54,7 @@ func (v *srsRTMPServer) Close() error { return nil } -func (v *srsRTMPServer) Run(ctx context.Context) error { +func (v *rtmpServer) Run(ctx context.Context) error { endpoint := v.environment.RtmpServer() if !strings.Contains(endpoint, ":") { endpoint = ":" + endpoint @@ -97,7 +102,7 @@ func (v *srsRTMPServer) Run(ctx context.Context) error { } } - rc := NewRTMPConnection() + rc := newRTMPConnection() if err := rc.serve(ctx, conn); err != nil { handleErr(err) } else { @@ -110,24 +115,24 @@ func (v *srsRTMPServer) Run(ctx context.Context) error { return nil } -// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between +// rtmpConnection is an RTMP streaming connection. There is no state need to be sync between // proxy servers. // // When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request, // then proxy to the corresponding backend server. All state is in the RTMP request, so this // connection is stateless. -type RTMPConnection struct { +type rtmpConnection struct { } -func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { - v := &RTMPConnection{} +func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection { + v := &rtmpConnection{} for _, opt := range opts { opt(v) } return v } -func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { +func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { logger.Debug(ctx, "Got RTMP client from %v", conn.RemoteAddr()) // If any goroutine quit, cancel another one. @@ -135,7 +140,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - var backend *RTMPClientToBackend + var backend *rtmpClientToBackend if true { go func() { <-ctx.Done() @@ -229,7 +234,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { response = identifyRes nextStreamID = 1 - identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) + identifyRes.SetStreamID(nextStreamID) } else if pkt.CommandName == "getStreamLength" { // Ignore and do not reply these packets. } else { @@ -243,7 +248,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { identifyRes.Args = rtmp.NewAmf0Undefined() } case *rtmp.PublishPacket: - streamName = string(pkt.StreamName) + streamName = pkt.StreamName.String() clientType = RTMPClientTypePublisher identifyRes := rtmp.NewCallPacket() @@ -257,7 +262,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) identifyRes.Args = data case *rtmp.PlayPacket: - streamName = string(pkt.StreamName) + streamName = pkt.StreamName.String() clientType = RTMPClientTypeViewer identifyRes := rtmp.NewCallPacket() @@ -289,7 +294,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { tcUrl, streamName, currentStreamID, clientType) // Find a backend SRS server to proxy the RTMP stream. - backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { + backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) { client.typ = clientType }) defer backend.Close() @@ -352,7 +357,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { if err != nil { return errors.Wrapf(err, "read message") } - //logger.Debug(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + //logger.Debug(ctx, "client<- %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload())) // TODO: Update the stream ID if not the same. if err := client.WriteMessage(ctx, m); err != nil { @@ -375,7 +380,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { if err != nil { return errors.Wrapf(err, "read message") } - //logger.Debug(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + //logger.Debug(ctx, "client-> %v %v %vB", m.MessageType(), m.Timestamp(), len(m.Payload())) // TODO: Update the stream ID if not the same. if err := backend.client.WriteMessage(ctx, m); err != nil { @@ -416,32 +421,32 @@ const ( RTMPClientTypeViewer RTMPClientType = "viewer" ) -// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend. -type RTMPClientToBackend struct { +// rtmpClientToBackend is an RTMP client to proxy the RTMP stream to backend. +type rtmpClientToBackend struct { // The underlayer tcp client. tcpConn *net.TCPConn // The RTMP protocol client. - client *rtmp.Protocol + client rtmp.Protocol // The stream type. typ RTMPClientType } -func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend { - v := &RTMPClientToBackend{} +func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend { + v := &rtmpClientToBackend{} for _, opt := range opts { opt(v) } return v } -func (v *RTMPClientToBackend) Close() error { +func (v *rtmpClientToBackend) Close() error { if v.tcpConn != nil { v.tcpConn.Close() } return nil } -func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { +func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { // Build the stream URL in vhost/app/stream schema. streamURL, err := utils.BuildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) if err != nil { @@ -527,7 +532,7 @@ func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName str return v.publish(ctx, client, streamName) } -func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { +func (v *rtmpClientToBackend) publish(ctx context.Context, client rtmp.Protocol, streamName string) error { if true { identifyReq := rtmp.NewCallPacket() identifyReq.CommandName = "releaseStream" @@ -592,8 +597,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol publishStream := rtmp.NewPublishPacket() publishStream.TransactionID = 5 publishStream.CommandObject = rtmp.NewAmf0Null() - publishStream.StreamName = *rtmp.NewAmf0String(streamName) - publishStream.StreamType = *rtmp.NewAmf0String("live") + publishStream.StreamName = rtmp.NewAmf0String(streamName) + publishStream.StreamType = rtmp.NewAmf0String("live") if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { return errors.Wrapf(err, "publish") } @@ -609,8 +614,8 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol return errors.Errorf("onStatus args not object") } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { return errors.Errorf("onStatus code not string") - } else if *code != "NetStream.Publish.Start" { - return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) + } else if code.String() != "NetStream.Publish.Start" { + return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", code.String()) } break } @@ -620,7 +625,7 @@ func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol return nil } -func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { +func (v *rtmpClientToBackend) play(ctx context.Context, client rtmp.Protocol, streamName string) error { var currentStreamID int if true { createStream := rtmp.NewCreateStreamPacket() @@ -642,7 +647,7 @@ func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, s } playStream := rtmp.NewPlayPacket() - playStream.StreamName = *rtmp.NewAmf0String(streamName) + playStream.StreamName = rtmp.NewAmf0String(streamName) if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { return errors.Wrapf(err, "play") } diff --git a/internal/protocol/srt.go b/internal/server/srt.go similarity index 99% rename from internal/protocol/srt.go rename to internal/server/srt.go index cc9324f69..0da23f51f 100644 --- a/internal/protocol/srt.go +++ b/internal/server/srt.go @@ -1,7 +1,7 @@ // Copyright (c) 2026 Winlin // // SPDX-License-Identifier: MIT -package protocol +package server import ( "bytes" diff --git a/internal/version/version.go b/internal/version/version.go index a4e649d1f..f71511505 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -15,7 +15,7 @@ func VersionMinor() int { } func VersionRevision() int { - return 146 + return 147 } func Version() string { diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 9c10f1e7c..ae3c9bc0c 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2026-05-02, Merge [#4672](https://github.com/ossrs/srs/pull/4672): Proxy: Refactor server APIs and expand RTMP test coverage. v7.0.147 (#4672) * v7.0, 2026-04-28, Merge [#4670](https://github.com/ossrs/srs/pull/4670): Proxy: Refine logger and environment APIs. v7.0.146 (#4670) * v7.0, 2026-04-23, Merge [#4667](https://github.com/ossrs/srs/pull/4667): Proxy: Refactor internal/errors and internal/sync, and add unit tests across internal/*. v7.0.145 (#4667) * v7.0, 2026-04-18, Merge [#4665](https://github.com/ossrs/srs/pull/4665): Proxy: Harden internal/env tests and add counterfeiter fake generation. v7.0.144 (#4665) diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index cc5ef866e..55b2d2e04 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 146 +#define VERSION_REVISION 147 #endif