1. Fix test expectation: change && to & in m3u8 rewrite test 2. Add TestHLSPlayStream_ServeByBackend_HeadersCopiedFromBackend to verify backend headers reach the client for m3u8 responses 3. Add TestHLSPlayStream_ServeByBackend_TSHeadersCopiedFromBackend to verify header copy for .ts file responses 4. Add TestHTTPFlvTsConn_ServeByBackend_HeadersCopiedFromBackend to verify header copy for FLV/TS streaming responses These tests protect against regression where calling WriteHeader() before Header.Add() causes headers to be silently discarded by Go's http.ResponseWriter.
1375 lines
47 KiB
Go
1375 lines
47 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
stdSync "sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"srsx/internal/env/envfakes"
|
|
"srsx/internal/lb"
|
|
"srsx/internal/lb/lbfakes"
|
|
)
|
|
|
|
// httptestHostPort splits an httptest.Server URL into host and port strings.
|
|
func httptestHostPort(t *testing.T, ts *httptest.Server) (string, string) {
|
|
t.Helper()
|
|
u, err := url.Parse(ts.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse httptest URL %q: %v", ts.URL, err)
|
|
}
|
|
return u.Hostname(), u.Port()
|
|
}
|
|
|
|
// reservedClosedPort binds and immediately closes a TCP port, returning an
|
|
// address that is reliably refused for the lifetime of the test.
|
|
func reservedClosedPort(t *testing.T) (string, string) {
|
|
t.Helper()
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("reserve port: %v", err)
|
|
}
|
|
addr := l.Addr().(*net.TCPAddr)
|
|
if err := l.Close(); err != nil {
|
|
t.Fatalf("close listener: %v", err)
|
|
}
|
|
return addr.IP.String(), strconv.Itoa(addr.Port)
|
|
}
|
|
|
|
// =============================================================================
|
|
// newHLSPlayStream
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_New_DefaultsBuildBackendURL(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
if v.buildBackendURL == nil {
|
|
t.Fatal("buildBackendURL should default to non-nil")
|
|
}
|
|
if got := v.buildBackendURL("1.2.3.4", 8080, "/live.ts"); got != "http://1.2.3.4:8080/live.ts" {
|
|
t.Fatalf("default buildBackendURL produced %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_New_AppliesOpts(t *testing.T) {
|
|
ctx := context.Background()
|
|
lbStub := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = ctx
|
|
s.loadBalancer = lbStub
|
|
s.SRSProxyBackendHLSID = "spb-id"
|
|
s.StreamURL = "vhost/app/stream"
|
|
s.FullURL = "http://example.com/live.m3u8"
|
|
})
|
|
if v.ctx != ctx {
|
|
t.Error("ctx not applied")
|
|
}
|
|
if v.loadBalancer != lbStub {
|
|
t.Error("loadBalancer not applied")
|
|
}
|
|
if v.SRSProxyBackendHLSID != "spb-id" {
|
|
t.Errorf("SRSProxyBackendHLSID = %q", v.SRSProxyBackendHLSID)
|
|
}
|
|
if v.StreamURL != "vhost/app/stream" {
|
|
t.Errorf("StreamURL = %q", v.StreamURL)
|
|
}
|
|
if v.FullURL != "http://example.com/live.m3u8" {
|
|
t.Errorf("FullURL = %q", v.FullURL)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_New_OptCanOverrideBuildBackendURL(t *testing.T) {
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.buildBackendURL = func(string, int, string) string { return "custom" }
|
|
})
|
|
if got := v.buildBackendURL("", 0, ""); got != "custom" {
|
|
t.Fatalf("override not applied: got %q", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Initialize
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_Initialize_SetsCtxWhenNil(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
ret := v.Initialize(context.Background())
|
|
if v.ctx == nil {
|
|
t.Fatal("Initialize should set v.ctx when nil")
|
|
}
|
|
if ret != lb.HLSPlayStream(v) {
|
|
t.Fatal("Initialize should return v")
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Initialize_PreservesExistingCtx(t *testing.T) {
|
|
type ctxKey struct{}
|
|
existing := context.WithValue(context.Background(), ctxKey{}, "sentinel")
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.ctx = existing })
|
|
v.Initialize(context.Background())
|
|
if got, _ := v.ctx.Value(ctxKey{}).(string); got != "sentinel" {
|
|
t.Fatalf("Initialize should not replace existing ctx, value=%q", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// GetSPBHID
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_GetSPBHID(t *testing.T) {
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "spb-xyz" })
|
|
if v.GetSPBHID() != "spb-xyz" {
|
|
t.Fatalf("GetSPBHID = %q", v.GetSPBHID())
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// ServeHTTP / serve / CORS
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
if lbFake.PickCallCount() != 0 {
|
|
t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount())
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
// ApiError writes a JSON error response. Verify the body is non-empty
|
|
// and the status is not the default 200 (or that some response was made).
|
|
if rec.Body.Len() == 0 {
|
|
t.Fatal("ServeHTTP error branch should produce a response body")
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Serve_PickError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(v.ctx, rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "pick backend for vhost/app/stream") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Serve_WrapsServeByBackendError(t *testing.T) {
|
|
// Backend with empty HTTP slice triggers serveByBackend's "no http server"
|
|
// error, which serve() then wraps with "serve %v with %v by backend %+v".
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil)
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
s.FullURL = "http://example.com/live.m3u8"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(v.ctx, rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "serve http://example.com/live.m3u8 with vhost/app/stream by backend") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Serve_HappyPathRewritesM3U8(t *testing.T) {
|
|
m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\n"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, m3u8)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil)
|
|
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
s.FullURL = "http://example.com/live.m3u8"
|
|
s.SRSProxyBackendHLSID = "spb-1"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serve(v.ctx, rec, req); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "live-0.ts?spbhid=spb-1") {
|
|
t.Fatalf("body missing spbhid rewrite: %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — error paths (no HTTP round-trip needed)
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_ServeByBackend_NoHTTPEndpoint(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"})
|
|
if err == nil || !strings.Contains(err.Error(), "no http server") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_BadPort(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}})
|
|
if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_RequestBuildError(t *testing.T) {
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.buildBackendURL = func(string, int, string) string { return "://invalid-url" }
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}})
|
|
if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_DialError(t *testing.T) {
|
|
host, port := reservedClosedPort(t)
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.HasSuffix(err.Error(), "EOF") {
|
|
t.Fatalf("expected error suffixed with 'EOF', got: %v", err)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — HTTP round-trip via httptest.Server
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_ServeByBackend_NonOKStatus(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.Contains(err.Error(), "status=404") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_TSPassthrough(t *testing.T) {
|
|
payload := []byte{0x47, 0x00, 0x01, 0x02, 0x03, 0x04}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "video/mp2t")
|
|
_, _ = w.Write(payload)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if got := rec.Body.Bytes(); !bytes.Equal(got, payload) {
|
|
t.Fatalf("body mismatch: got=%v want=%v", got, payload)
|
|
}
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithoutQuery(t *testing.T) {
|
|
m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\nlive-1.ts\n"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, m3u8)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" })
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
body := rec.Body.String()
|
|
for _, want := range []string{"live-0.ts?spbhid=ABC", "live-1.ts?spbhid=ABC"} {
|
|
if !strings.Contains(body, want) {
|
|
t.Fatalf("missing %q in body: %q", want, body)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) {
|
|
m3u8 := "#EXTM3U\nlive-0.ts?token=foo\n"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, m3u8)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" })
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if want := "live-0.ts?spbhid=ABC&token=foo"; !strings.Contains(rec.Body.String(), want) {
|
|
t.Fatalf("missing %q in body: %q", want, rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) {
|
|
var seenURL string
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seenURL = r.URL.String()
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts?token=foo", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !strings.Contains(seenURL, "token=foo") {
|
|
t.Fatalf("backend should see raw query, got %q", seenURL)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/vnd.apple.mpegurl")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("X-Custom-Header", "custom-value")
|
|
_, _ = io.WriteString(w, "#EXTM3U\nlive-0.ts\n")
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
// Verify headers are properly copied (not lost due to WriteHeader order)
|
|
if got := rec.Header().Get("Content-Type"); got != "application/vnd.apple.mpegurl" {
|
|
t.Errorf("Content-Type = %q, want application/vnd.apple.mpegurl", got)
|
|
}
|
|
if got := rec.Header().Get("Cache-Control"); got != "no-cache" {
|
|
t.Errorf("Cache-Control = %q, want no-cache", got)
|
|
}
|
|
if got := rec.Header().Get("X-Custom-Header"); got != "custom-value" {
|
|
t.Errorf("X-Custom-Header = %q, want custom-value", got)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_TSHeadersCopiedFromBackend(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "video/mp2t")
|
|
w.Header().Set("Cache-Control", "max-age=3600")
|
|
_, _ = w.Write([]byte{0x47, 0x00, 0x01})
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
if got := rec.Header().Get("Content-Type"); got != "video/mp2t" {
|
|
t.Errorf("Content-Type = %q, want video/mp2t", got)
|
|
}
|
|
if got := rec.Header().Get("Cache-Control"); got != "max-age=3600" {
|
|
t.Errorf("Cache-Control = %q, want max-age=3600", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// httpFlvTsConnection
|
|
// =============================================================================
|
|
|
|
// =============================================================================
|
|
// newHTTPFlvTsConnection
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_New_DefaultsBuildBackendURL(t *testing.T) {
|
|
v := newHTTPFlvTsConnection()
|
|
if v.buildBackendURL == nil {
|
|
t.Fatal("buildBackendURL should default to non-nil")
|
|
}
|
|
if got := v.buildBackendURL("1.2.3.4", 8080, "/live.flv"); got != "http://1.2.3.4:8080/live.flv" {
|
|
t.Fatalf("default buildBackendURL produced %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_New_AppliesOpts(t *testing.T) {
|
|
ctx := context.Background()
|
|
lbStub := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = ctx
|
|
c.loadBalancer = lbStub
|
|
})
|
|
if v.ctx != ctx {
|
|
t.Error("ctx not applied")
|
|
}
|
|
if v.loadBalancer != lbStub {
|
|
t.Error("loadBalancer not applied")
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_New_OptCanOverrideBuildBackendURL(t *testing.T) {
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.buildBackendURL = func(string, int, string) string { return "custom" }
|
|
})
|
|
if got := v.buildBackendURL("", 0, ""); got != "custom" {
|
|
t.Fatalf("override not applied: got %q", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// ServeHTTP / serve / CORS
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
if lbFake.PickCallCount() != 0 {
|
|
t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount())
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
if rec.Body.Len() == 0 {
|
|
t.Fatal("ServeHTTP error branch should produce a response body")
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_Serve_PickError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "pick backend for") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_Serve_WrapsServeByBackendError(t *testing.T) {
|
|
// Empty HTTP slice on backend triggers serveByBackend's "no http stream
|
|
// server" error, which serve() wraps with "serve <fullURL> with <streamURL>".
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil)
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "serve ") || !strings.Contains(err.Error(), " by backend ") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_Serve_HappyPath(t *testing.T) {
|
|
payload := []byte{0x46, 0x4c, 0x56, 0x01} // "FLV\x01"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write(payload)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil)
|
|
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serve(context.Background(), rec, req); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !bytes.Equal(rec.Body.Bytes(), payload) {
|
|
t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — error paths
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_NoHTTPEndpoint(t *testing.T) {
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"})
|
|
if err == nil || !strings.Contains(err.Error(), "no http stream server") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_BadPort(t *testing.T) {
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}})
|
|
if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_RequestBuildError(t *testing.T) {
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.buildBackendURL = func(string, int, string) string { return "://invalid-url" }
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}})
|
|
if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_DialError(t *testing.T) {
|
|
host, port := reservedClosedPort(t)
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.Contains(err.Error(), "do request to") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — HTTP round-trip
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_NonOKStatus(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.Contains(err.Error(), "status=404") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_BodyPassthrough(t *testing.T) {
|
|
payload := []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write(payload)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !bytes.Equal(rec.Body.Bytes(), payload) {
|
|
t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload)
|
|
}
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_DropsRawQuery(t *testing.T) {
|
|
// Unlike hlsPlayStream.serveByBackend, the FLV/TS path forwards only
|
|
// r.URL.Path — it does NOT append RawQuery to the backend request.
|
|
var seenRawQuery string
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seenRawQuery = r.URL.RawQuery
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv?token=foo", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if seenRawQuery != "" {
|
|
t.Fatalf("backend should NOT see raw query, got %q", seenRawQuery)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) {
|
|
var seenMethod string
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seenMethod = r.Method
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodHead, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if seenMethod != http.MethodHead {
|
|
t.Fatalf("backend method = %q, want HEAD", seenMethod)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "video/x-flv")
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.Header().Set("X-Custom-Header", "flv-value")
|
|
_, _ = w.Write([]byte("FLV\x01\x05\x00\x00\x00\x09"))
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
|
|
// Verify headers are properly copied (not lost due to WriteHeader order)
|
|
if got := rec.Header().Get("Content-Type"); got != "video/x-flv" {
|
|
t.Errorf("Content-Type = %q, want video/x-flv", got)
|
|
}
|
|
if got := rec.Header().Get("Cache-Control"); got != "no-store" {
|
|
t.Errorf("Cache-Control = %q, want no-store", got)
|
|
}
|
|
if got := rec.Header().Get("X-Custom-Header"); got != "flv-value" {
|
|
t.Errorf("X-Custom-Header = %q, want flv-value", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// httpStreamProxyServer
|
|
// =============================================================================
|
|
|
|
// fakeHTTPProxyServer is an httpServer that blocks in ListenAndServe until
|
|
// Shutdown is called. Used to drive Run()'s lifecycle without binding a port.
|
|
type fakeHTTPProxyServer struct {
|
|
listenCalls atomic.Int32
|
|
shutdownCalls atomic.Int32
|
|
listenReturn error
|
|
shutdownReturn error
|
|
block chan struct{}
|
|
once stdSync.Once
|
|
}
|
|
|
|
func newFakeHTTPProxyServer() *fakeHTTPProxyServer {
|
|
return &fakeHTTPProxyServer{
|
|
listenReturn: http.ErrServerClosed,
|
|
block: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (f *fakeHTTPProxyServer) ListenAndServe() error {
|
|
f.listenCalls.Add(1)
|
|
<-f.block
|
|
return f.listenReturn
|
|
}
|
|
|
|
func (f *fakeHTTPProxyServer) Shutdown(ctx context.Context) error {
|
|
f.shutdownCalls.Add(1)
|
|
f.once.Do(func() { close(f.block) })
|
|
return f.shutdownReturn
|
|
}
|
|
|
|
// captureMuxFromRun calls Run with a fake server that captures the registered
|
|
// mux. Returns the mux and the fake server for further assertions. Caller is
|
|
// responsible for cancelling ctx to trigger shutdown.
|
|
func captureMuxFromRun(t *testing.T, env *envfakes.FakeProxyEnvironment,
|
|
lbFake *lbfakes.FakeOriginLoadBalancer, ctx context.Context,
|
|
opts ...func(*httpStreamProxyServer)) (*http.ServeMux, *fakeHTTPProxyServer, *httpStreamProxyServer) {
|
|
t.Helper()
|
|
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
var capturedMux *http.ServeMux
|
|
|
|
baseOpts := []func(*httpStreamProxyServer){
|
|
func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
mux := http.NewServeMux()
|
|
capturedMux = mux
|
|
return fakeSrv, mux
|
|
}
|
|
},
|
|
}
|
|
srvIface := NewHTTPStreamProxyServer(env, lbFake, 50*time.Millisecond, append(baseOpts, opts...)...)
|
|
srv := srvIface.(*httpStreamProxyServer)
|
|
|
|
if err := srv.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if capturedMux == nil {
|
|
t.Fatal("newServer was not called by Run")
|
|
}
|
|
return capturedMux, fakeSrv, srv
|
|
}
|
|
|
|
// =============================================================================
|
|
// NewHTTPStreamProxyServer
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_New_StoresFieldsAndDefaultsSeams(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
timeout := 2 * time.Second
|
|
srv := NewHTTPStreamProxyServer(env, lbFake, timeout).(*httpStreamProxyServer)
|
|
|
|
if srv.environment != env {
|
|
t.Error("environment not stored")
|
|
}
|
|
if srv.loadBalancer != lbFake {
|
|
t.Error("loadBalancer not stored")
|
|
}
|
|
if srv.gracefulQuitTimeout != timeout {
|
|
t.Errorf("gracefulQuitTimeout = %v, want %v", srv.gracefulQuitTimeout, timeout)
|
|
}
|
|
if srv.shutdown == nil {
|
|
t.Error("shutdown seam should default to non-nil")
|
|
}
|
|
if srv.newServer == nil {
|
|
t.Error("newServer seam should default to non-nil")
|
|
}
|
|
if srv.newHLSStream == nil {
|
|
t.Error("newHLSStream seam should default to non-nil")
|
|
}
|
|
if srv.newFlvTsConn == nil {
|
|
t.Error("newFlvTsConn seam should default to non-nil")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_New_AppliesOpts(t *testing.T) {
|
|
var optCalled bool
|
|
srv := NewHTTPStreamProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
time.Second,
|
|
func(s *httpStreamProxyServer) { optCalled = true },
|
|
).(*httpStreamProxyServer)
|
|
if !optCalled {
|
|
t.Fatal("opt was not invoked")
|
|
}
|
|
if srv.shutdown == nil {
|
|
t.Error("default seams should still be set when opts don't override them")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_New_OptCanOverrideAllSeams(t *testing.T) {
|
|
customShutdown := func(context.Context) error { return errors.New("custom") }
|
|
customNewServer := func(string) (httpServer, *http.ServeMux) { return nil, nil }
|
|
customNewHLS := func(string, string) *hlsPlayStream { return nil }
|
|
customNewFlv := func(context.Context) *httpFlvTsConnection { return nil }
|
|
|
|
srv := NewHTTPStreamProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
time.Second,
|
|
func(s *httpStreamProxyServer) {
|
|
s.shutdown = customShutdown
|
|
s.newServer = customNewServer
|
|
s.newHLSStream = customNewHLS
|
|
s.newFlvTsConn = customNewFlv
|
|
},
|
|
).(*httpStreamProxyServer)
|
|
|
|
if err := srv.shutdown(context.Background()); err == nil || err.Error() != "custom" {
|
|
t.Errorf("custom shutdown not applied: %v", err)
|
|
}
|
|
// Pointer comparison on func values isn't supported by ==; call them and
|
|
// check the override took effect via observable behavior.
|
|
if got, _ := srv.newServer(""); got != nil {
|
|
t.Error("custom newServer not applied")
|
|
}
|
|
if srv.newHLSStream("", "") != nil {
|
|
t.Error("custom newHLSStream not applied")
|
|
}
|
|
if srv.newFlvTsConn(context.Background()) != nil {
|
|
t.Error("custom newFlvTsConn not applied")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Default factory behavior
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_DefaultNewServer_BuildsRealServerAndMux(t *testing.T) {
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer)
|
|
|
|
got, mux := srv.newServer(":12345")
|
|
if mux == nil {
|
|
t.Fatal("mux is nil")
|
|
}
|
|
real, ok := got.(*http.Server)
|
|
if !ok {
|
|
t.Fatalf("expected *http.Server, got %T", got)
|
|
}
|
|
if real.Addr != ":12345" {
|
|
t.Errorf("Addr = %q, want :12345", real.Addr)
|
|
}
|
|
if real.Handler != mux {
|
|
t.Error("Handler should be the returned mux")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_DefaultNewHLSStream_WiresFields(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake,
|
|
time.Second).(*httpStreamProxyServer)
|
|
|
|
got := srv.newHLSStream("vhost/app/stream", "http://example.com/live.m3u8")
|
|
if got.loadBalancer != lbFake {
|
|
t.Error("loadBalancer not wired")
|
|
}
|
|
if got.StreamURL != "vhost/app/stream" {
|
|
t.Errorf("StreamURL = %q", got.StreamURL)
|
|
}
|
|
if got.FullURL != "http://example.com/live.m3u8" {
|
|
t.Errorf("FullURL = %q", got.FullURL)
|
|
}
|
|
if got.SRSProxyBackendHLSID == "" {
|
|
t.Error("SRSProxyBackendHLSID should be auto-generated")
|
|
}
|
|
if got.buildBackendURL == nil {
|
|
t.Error("buildBackendURL default should be propagated")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_DefaultNewFlvTsConn_WiresFields(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake,
|
|
time.Second).(*httpStreamProxyServer)
|
|
|
|
type ctxKey struct{}
|
|
ctx := context.WithValue(context.Background(), ctxKey{}, "sentinel")
|
|
got := srv.newFlvTsConn(ctx)
|
|
if got.ctx != ctx {
|
|
t.Error("ctx not wired")
|
|
}
|
|
if got.loadBalancer != lbFake {
|
|
t.Error("loadBalancer not wired")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_DefaultShutdown_DelegatesToServer(t *testing.T) {
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer)
|
|
srv.server = fakeSrv // simulate what Run() would assign
|
|
if err := srv.shutdown(context.Background()); err != nil {
|
|
t.Fatalf("shutdown: %v", err)
|
|
}
|
|
if fakeSrv.shutdownCalls.Load() != 1 {
|
|
t.Fatalf("shutdown was not delegated to server, calls=%d", fakeSrv.shutdownCalls.Load())
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Close
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_Close_InvokesShutdownWithDeadline(t *testing.T) {
|
|
var gotCtx context.Context
|
|
var calls int
|
|
srv := NewHTTPStreamProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond,
|
|
func(s *httpStreamProxyServer) {
|
|
s.shutdown = func(ctx context.Context) error {
|
|
gotCtx = ctx
|
|
calls++
|
|
return nil
|
|
}
|
|
},
|
|
).(*httpStreamProxyServer)
|
|
|
|
if err := srv.Close(); err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("shutdown calls = %d, want 1", calls)
|
|
}
|
|
if _, ok := gotCtx.Deadline(); !ok {
|
|
t.Error("Close should pass a deadline-bearing ctx to shutdown")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Run — lifecycle
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_Run_AddrWithoutColonPrependsIt(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns("8080")
|
|
|
|
var capturedAddr string
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond, func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
capturedAddr = addr
|
|
return fakeSrv, http.NewServeMux()
|
|
}
|
|
})
|
|
defer srvIface.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
if err := srvIface.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if capturedAddr != ":8080" {
|
|
t.Fatalf("newServer addr = %q, want :8080", capturedAddr)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_AddrWithColonUnchanged(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns("127.0.0.1:9999")
|
|
|
|
var capturedAddr string
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond, func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
capturedAddr = addr
|
|
return fakeSrv, http.NewServeMux()
|
|
}
|
|
})
|
|
defer srvIface.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
if err := srvIface.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if capturedAddr != "127.0.0.1:9999" {
|
|
t.Fatalf("newServer addr = %q", capturedAddr)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_StaticFilesInvalidPath(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
env.StaticFilesReturns("/no/such/path/exists/__srsbot_test__")
|
|
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond, func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
return fakeSrv, http.NewServeMux()
|
|
}
|
|
})
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
err := srvIface.Run(ctx)
|
|
if err == nil || !strings.Contains(err.Error(), "invalid static files") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_CtxCancelTriggersShutdown(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
_, fakeSrv, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
// Wait briefly for ListenAndServe goroutine to be running.
|
|
deadline := time.Now().Add(time.Second)
|
|
for fakeSrv.listenCalls.Load() == 0 && time.Now().Before(deadline) {
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
if fakeSrv.listenCalls.Load() == 0 {
|
|
t.Fatal("ListenAndServe goroutine did not start")
|
|
}
|
|
|
|
cancel()
|
|
|
|
// Wait for Shutdown to be invoked by the watcher goroutine.
|
|
deadline = time.Now().Add(time.Second)
|
|
for fakeSrv.shutdownCalls.Load() == 0 && time.Now().Before(deadline) {
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
if fakeSrv.shutdownCalls.Load() == 0 {
|
|
t.Fatal("Shutdown was not invoked after ctx cancel")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Run — handler dispatch
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerVersionsReturnsJSON(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/versions", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
var body struct {
|
|
Code int `json:"code"`
|
|
PID string `json:"pid"`
|
|
Data struct {
|
|
Major int `json:"major"`
|
|
Minor int `json:"minor"`
|
|
Revision int `json:"revision"`
|
|
Version string `json:"version"`
|
|
} `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("json: %v\nbody=%s", err, rec.Body.String())
|
|
}
|
|
if body.Code != 0 {
|
|
t.Errorf("Code = %d, want 0", body.Code)
|
|
}
|
|
if body.PID == "" {
|
|
t.Error("PID should be populated")
|
|
}
|
|
if body.Data.Version == "" {
|
|
t.Error("Version should be populated")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerM3U8InvokesNewHLSStream(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
// Make LoadOrStoreHLS return whatever was passed in (the stream from newHLSStream).
|
|
lbFake.LoadOrStoreHLSStub = func(_ context.Context, _ string, s lb.HLSPlayStream) (lb.HLSPlayStream, error) {
|
|
return s, nil
|
|
}
|
|
|
|
var capturedStreamURL, capturedFullURL string
|
|
var newHLSCalls int
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) {
|
|
// Wrap default newHLSStream to capture args, but return a real
|
|
// hlsPlayStream so the .(*hlsPlayStream) cast inside Run's handler works.
|
|
// The returned stream has a fake loadBalancer; ServeHTTP will short-circuit
|
|
// on the OPTIONS preflight we send below.
|
|
s.newHLSStream = func(streamURL, fullURL string) *hlsPlayStream {
|
|
newHLSCalls++
|
|
capturedStreamURL, capturedFullURL = streamURL, fullURL
|
|
return newHLSPlayStream(func(h *hlsPlayStream) {
|
|
h.loadBalancer = lbFake
|
|
h.StreamURL, h.FullURL = streamURL, fullURL
|
|
})
|
|
}
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if newHLSCalls != 1 {
|
|
t.Fatalf("newHLSStream calls = %d, want 1", newHLSCalls)
|
|
}
|
|
if !strings.HasSuffix(capturedStreamURL, "/live") {
|
|
t.Errorf("captured streamURL %q should end with /live", capturedStreamURL)
|
|
}
|
|
if !strings.Contains(capturedFullURL, "live.m3u8") {
|
|
t.Errorf("captured fullURL %q should contain live.m3u8", capturedFullURL)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerM3U8LoadOrStoreErrorReturns400(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.LoadOrStoreHLSReturns(nil, errors.New("redis down"))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String())
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "load or store hls") {
|
|
t.Errorf("body should mention 'load or store hls', got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerFlvInvokesNewFlvTsConn(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
var newFlvCalls int
|
|
var capturedCtx context.Context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) {
|
|
s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection {
|
|
newFlvCalls++
|
|
capturedCtx = reqCtx
|
|
return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = reqCtx
|
|
c.loadBalancer = lbFake
|
|
})
|
|
}
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if newFlvCalls != 1 {
|
|
t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls)
|
|
}
|
|
if capturedCtx == nil {
|
|
t.Error("captured ctx should be non-nil")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerTsInvokesNewFlvTsConn(t *testing.T) {
|
|
// Same dispatch as .flv but for .ts (without spbhid).
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
var newFlvCalls int
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) {
|
|
s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection {
|
|
newFlvCalls++
|
|
return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = reqCtx
|
|
c.loadBalancer = lbFake
|
|
})
|
|
}
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.ts", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if newFlvCalls != 1 {
|
|
t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDLoadsByID(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
stub := newHLSPlayStream(func(h *hlsPlayStream) {
|
|
h.loadBalancer = lbFake
|
|
h.SRSProxyBackendHLSID = "ABC"
|
|
})
|
|
lbFake.LoadHLSBySPBHIDReturns(stub, nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live-0.ts?spbhid=ABC", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if lbFake.LoadHLSBySPBHIDCallCount() != 1 {
|
|
t.Fatalf("LoadHLSBySPBHID calls = %d, want 1", lbFake.LoadHLSBySPBHIDCallCount())
|
|
}
|
|
_, gotID := lbFake.LoadHLSBySPBHIDArgsForCall(0)
|
|
if gotID != "ABC" {
|
|
t.Errorf("LoadHLSBySPBHID id = %q, want ABC", gotID)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDErrorReturns400(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.LoadHLSBySPBHIDReturns(nil, errors.New("not found"))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live-0.ts?spbhid=missing", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerUnmatchedReturns404(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
// StaticFiles unset, no .m3u8/.flv/.ts suffix → 404.
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/random/path", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Fatalf("status = %d, want 404", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerServesStaticFiles(t *testing.T) {
|
|
dir := t.TempDir()
|
|
if err := os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hi"), 0644); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
env.StaticFilesReturns(dir)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/hello.txt", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
if rec.Body.String() != "hi" {
|
|
t.Errorf("body = %q, want hi", rec.Body.String())
|
|
}
|
|
}
|