1290 lines
43 KiB
Go
1290 lines
43 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
stdSync "sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"srsx/internal/env/envfakes"
|
|
"srsx/internal/lb"
|
|
"srsx/internal/lb/lbfakes"
|
|
)
|
|
|
|
// httptestHostPort splits an httptest.Server URL into host and port strings.
|
|
func httptestHostPort(t *testing.T, ts *httptest.Server) (string, string) {
|
|
t.Helper()
|
|
u, err := url.Parse(ts.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse httptest URL %q: %v", ts.URL, err)
|
|
}
|
|
return u.Hostname(), u.Port()
|
|
}
|
|
|
|
// reservedClosedPort binds and immediately closes a TCP port, returning an
|
|
// address that is reliably refused for the lifetime of the test.
|
|
func reservedClosedPort(t *testing.T) (string, string) {
|
|
t.Helper()
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("reserve port: %v", err)
|
|
}
|
|
addr := l.Addr().(*net.TCPAddr)
|
|
if err := l.Close(); err != nil {
|
|
t.Fatalf("close listener: %v", err)
|
|
}
|
|
return addr.IP.String(), strconv.Itoa(addr.Port)
|
|
}
|
|
|
|
// =============================================================================
|
|
// newHLSPlayStream
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_New_DefaultsBuildBackendURL(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
if v.buildBackendURL == nil {
|
|
t.Fatal("buildBackendURL should default to non-nil")
|
|
}
|
|
if got := v.buildBackendURL("1.2.3.4", 8080, "/live.ts"); got != "http://1.2.3.4:8080/live.ts" {
|
|
t.Fatalf("default buildBackendURL produced %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_New_AppliesOpts(t *testing.T) {
|
|
ctx := context.Background()
|
|
lbStub := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = ctx
|
|
s.loadBalancer = lbStub
|
|
s.SRSProxyBackendHLSID = "spb-id"
|
|
s.StreamURL = "vhost/app/stream"
|
|
s.FullURL = "http://example.com/live.m3u8"
|
|
})
|
|
if v.ctx != ctx {
|
|
t.Error("ctx not applied")
|
|
}
|
|
if v.loadBalancer != lbStub {
|
|
t.Error("loadBalancer not applied")
|
|
}
|
|
if v.SRSProxyBackendHLSID != "spb-id" {
|
|
t.Errorf("SRSProxyBackendHLSID = %q", v.SRSProxyBackendHLSID)
|
|
}
|
|
if v.StreamURL != "vhost/app/stream" {
|
|
t.Errorf("StreamURL = %q", v.StreamURL)
|
|
}
|
|
if v.FullURL != "http://example.com/live.m3u8" {
|
|
t.Errorf("FullURL = %q", v.FullURL)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_New_OptCanOverrideBuildBackendURL(t *testing.T) {
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.buildBackendURL = func(string, int, string) string { return "custom" }
|
|
})
|
|
if got := v.buildBackendURL("", 0, ""); got != "custom" {
|
|
t.Fatalf("override not applied: got %q", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Initialize
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_Initialize_SetsCtxWhenNil(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
ret := v.Initialize(context.Background())
|
|
if v.ctx == nil {
|
|
t.Fatal("Initialize should set v.ctx when nil")
|
|
}
|
|
if ret != lb.HLSPlayStream(v) {
|
|
t.Fatal("Initialize should return v")
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Initialize_PreservesExistingCtx(t *testing.T) {
|
|
type ctxKey struct{}
|
|
existing := context.WithValue(context.Background(), ctxKey{}, "sentinel")
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.ctx = existing })
|
|
v.Initialize(context.Background())
|
|
if got, _ := v.ctx.Value(ctxKey{}).(string); got != "sentinel" {
|
|
t.Fatalf("Initialize should not replace existing ctx, value=%q", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// GetSPBHID
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_GetSPBHID(t *testing.T) {
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "spb-xyz" })
|
|
if v.GetSPBHID() != "spb-xyz" {
|
|
t.Fatalf("GetSPBHID = %q", v.GetSPBHID())
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// ServeHTTP / serve / CORS
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
if lbFake.PickCallCount() != 0 {
|
|
t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount())
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
// ApiError writes a JSON error response. Verify the body is non-empty
|
|
// and the status is not the default 200 (or that some response was made).
|
|
if rec.Body.Len() == 0 {
|
|
t.Fatal("ServeHTTP error branch should produce a response body")
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Serve_PickError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(v.ctx, rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "pick backend for vhost/app/stream") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Serve_WrapsServeByBackendError(t *testing.T) {
|
|
// Backend with empty HTTP slice triggers serveByBackend's "no http server"
|
|
// error, which serve() then wraps with "serve %v with %v by backend %+v".
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil)
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
s.FullURL = "http://example.com/live.m3u8"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(v.ctx, rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "serve http://example.com/live.m3u8 with vhost/app/stream by backend") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_Serve_HappyPathRewritesM3U8(t *testing.T) {
|
|
m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\n"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, m3u8)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil)
|
|
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.ctx = context.Background()
|
|
s.loadBalancer = lbFake
|
|
s.StreamURL = "vhost/app/stream"
|
|
s.FullURL = "http://example.com/live.m3u8"
|
|
s.SRSProxyBackendHLSID = "spb-1"
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serve(v.ctx, rec, req); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "live-0.ts?spbhid=spb-1") {
|
|
t.Fatalf("body missing spbhid rewrite: %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — error paths (no HTTP round-trip needed)
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_ServeByBackend_NoHTTPEndpoint(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"})
|
|
if err == nil || !strings.Contains(err.Error(), "no http server") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_BadPort(t *testing.T) {
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}})
|
|
if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_RequestBuildError(t *testing.T) {
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) {
|
|
s.buildBackendURL = func(string, int, string) string { return "://invalid-url" }
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}})
|
|
if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_DialError(t *testing.T) {
|
|
host, port := reservedClosedPort(t)
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.HasSuffix(err.Error(), "EOF") {
|
|
t.Fatalf("expected error suffixed with 'EOF', got: %v", err)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — HTTP round-trip via httptest.Server
|
|
// =============================================================================
|
|
|
|
func TestHLSPlayStream_ServeByBackend_NonOKStatus(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.Contains(err.Error(), "status=404") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_TSPassthrough(t *testing.T) {
|
|
payload := []byte{0x47, 0x00, 0x01, 0x02, 0x03, 0x04}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "video/mp2t")
|
|
_, _ = w.Write(payload)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if got := rec.Body.Bytes(); !bytes.Equal(got, payload) {
|
|
t.Fatalf("body mismatch: got=%v want=%v", got, payload)
|
|
}
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithoutQuery(t *testing.T) {
|
|
m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\nlive-1.ts\n"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, m3u8)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" })
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
body := rec.Body.String()
|
|
for _, want := range []string{"live-0.ts?spbhid=ABC", "live-1.ts?spbhid=ABC"} {
|
|
if !strings.Contains(body, want) {
|
|
t.Fatalf("missing %q in body: %q", want, body)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) {
|
|
m3u8 := "#EXTM3U\nlive-0.ts?token=foo\n"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, m3u8)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" })
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if want := "live-0.ts?spbhid=ABC&&token=foo"; !strings.Contains(rec.Body.String(), want) {
|
|
t.Fatalf("missing %q in body: %q", want, rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) {
|
|
var seenURL string
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seenURL = r.URL.String()
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHLSPlayStream()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts?token=foo", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !strings.Contains(seenURL, "token=foo") {
|
|
t.Fatalf("backend should see raw query, got %q", seenURL)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// httpFlvTsConnection
|
|
// =============================================================================
|
|
|
|
// =============================================================================
|
|
// newHTTPFlvTsConnection
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_New_DefaultsBuildBackendURL(t *testing.T) {
|
|
v := newHTTPFlvTsConnection()
|
|
if v.buildBackendURL == nil {
|
|
t.Fatal("buildBackendURL should default to non-nil")
|
|
}
|
|
if got := v.buildBackendURL("1.2.3.4", 8080, "/live.flv"); got != "http://1.2.3.4:8080/live.flv" {
|
|
t.Fatalf("default buildBackendURL produced %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_New_AppliesOpts(t *testing.T) {
|
|
ctx := context.Background()
|
|
lbStub := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = ctx
|
|
c.loadBalancer = lbStub
|
|
})
|
|
if v.ctx != ctx {
|
|
t.Error("ctx not applied")
|
|
}
|
|
if v.loadBalancer != lbStub {
|
|
t.Error("loadBalancer not applied")
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_New_OptCanOverrideBuildBackendURL(t *testing.T) {
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.buildBackendURL = func(string, int, string) string { return "custom" }
|
|
})
|
|
if got := v.buildBackendURL("", 0, ""); got != "custom" {
|
|
t.Fatalf("override not applied: got %q", got)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// ServeHTTP / serve / CORS
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
if lbFake.PickCallCount() != 0 {
|
|
t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount())
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
v.ServeHTTP(rec, req)
|
|
if rec.Body.Len() == 0 {
|
|
t.Fatal("ServeHTTP error branch should produce a response body")
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_Serve_PickError(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(nil, errors.New("pick-fail"))
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "pick backend for") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_Serve_WrapsServeByBackendError(t *testing.T) {
|
|
// Empty HTTP slice on backend triggers serveByBackend's "no http stream
|
|
// server" error, which serve() wraps with "serve <fullURL> with <streamURL>".
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil)
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serve(context.Background(), rec, req)
|
|
if err == nil || !strings.Contains(err.Error(), "serve ") || !strings.Contains(err.Error(), " by backend ") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_Serve_HappyPath(t *testing.T) {
|
|
payload := []byte{0x46, 0x4c, 0x56, 0x01} // "FLV\x01"
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write(payload)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil)
|
|
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = context.Background()
|
|
c.loadBalancer = lbFake
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serve(context.Background(), rec, req); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !bytes.Equal(rec.Body.Bytes(), payload) {
|
|
t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — error paths
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_NoHTTPEndpoint(t *testing.T) {
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"})
|
|
if err == nil || !strings.Contains(err.Error(), "no http stream server") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_BadPort(t *testing.T) {
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}})
|
|
if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_RequestBuildError(t *testing.T) {
|
|
v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.buildBackendURL = func(string, int, string) string { return "://invalid-url" }
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}})
|
|
if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_DialError(t *testing.T) {
|
|
host, port := reservedClosedPort(t)
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.Contains(err.Error(), "do request to") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// serveByBackend — HTTP round-trip
|
|
// =============================================================================
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_NonOKStatus(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}})
|
|
if err == nil || !strings.Contains(err.Error(), "status=404") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_BodyPassthrough(t *testing.T) {
|
|
payload := []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write(payload)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if !bytes.Equal(rec.Body.Bytes(), payload) {
|
|
t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload)
|
|
}
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_DropsRawQuery(t *testing.T) {
|
|
// Unlike hlsPlayStream.serveByBackend, the FLV/TS path forwards only
|
|
// r.URL.Path — it does NOT append RawQuery to the backend request.
|
|
var seenRawQuery string
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seenRawQuery = r.URL.RawQuery
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv?token=foo", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if seenRawQuery != "" {
|
|
t.Fatalf("backend should NOT see raw query, got %q", seenRawQuery)
|
|
}
|
|
}
|
|
|
|
func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) {
|
|
var seenMethod string
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seenMethod = r.Method
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
host, port := httptestHostPort(t, ts)
|
|
|
|
v := newHTTPFlvTsConnection()
|
|
req := httptest.NewRequest(http.MethodHead, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
if err := v.serveByBackend(context.Background(), rec, req,
|
|
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
if seenMethod != http.MethodHead {
|
|
t.Fatalf("backend method = %q, want HEAD", seenMethod)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// httpStreamProxyServer
|
|
// =============================================================================
|
|
|
|
// fakeHTTPProxyServer is an httpServer that blocks in ListenAndServe until
|
|
// Shutdown is called. Used to drive Run()'s lifecycle without binding a port.
|
|
type fakeHTTPProxyServer struct {
|
|
listenCalls atomic.Int32
|
|
shutdownCalls atomic.Int32
|
|
listenReturn error
|
|
shutdownReturn error
|
|
block chan struct{}
|
|
once stdSync.Once
|
|
}
|
|
|
|
func newFakeHTTPProxyServer() *fakeHTTPProxyServer {
|
|
return &fakeHTTPProxyServer{
|
|
listenReturn: http.ErrServerClosed,
|
|
block: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (f *fakeHTTPProxyServer) ListenAndServe() error {
|
|
f.listenCalls.Add(1)
|
|
<-f.block
|
|
return f.listenReturn
|
|
}
|
|
|
|
func (f *fakeHTTPProxyServer) Shutdown(ctx context.Context) error {
|
|
f.shutdownCalls.Add(1)
|
|
f.once.Do(func() { close(f.block) })
|
|
return f.shutdownReturn
|
|
}
|
|
|
|
// captureMuxFromRun calls Run with a fake server that captures the registered
|
|
// mux. Returns the mux and the fake server for further assertions. Caller is
|
|
// responsible for cancelling ctx to trigger shutdown.
|
|
func captureMuxFromRun(t *testing.T, env *envfakes.FakeProxyEnvironment,
|
|
lbFake *lbfakes.FakeOriginLoadBalancer, ctx context.Context,
|
|
opts ...func(*httpStreamProxyServer)) (*http.ServeMux, *fakeHTTPProxyServer, *httpStreamProxyServer) {
|
|
t.Helper()
|
|
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
var capturedMux *http.ServeMux
|
|
|
|
baseOpts := []func(*httpStreamProxyServer){
|
|
func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
mux := http.NewServeMux()
|
|
capturedMux = mux
|
|
return fakeSrv, mux
|
|
}
|
|
},
|
|
}
|
|
srvIface := NewHTTPStreamProxyServer(env, lbFake, 50*time.Millisecond, append(baseOpts, opts...)...)
|
|
srv := srvIface.(*httpStreamProxyServer)
|
|
|
|
if err := srv.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if capturedMux == nil {
|
|
t.Fatal("newServer was not called by Run")
|
|
}
|
|
return capturedMux, fakeSrv, srv
|
|
}
|
|
|
|
// =============================================================================
|
|
// NewHTTPStreamProxyServer
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_New_StoresFieldsAndDefaultsSeams(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
timeout := 2 * time.Second
|
|
srv := NewHTTPStreamProxyServer(env, lbFake, timeout).(*httpStreamProxyServer)
|
|
|
|
if srv.environment != env {
|
|
t.Error("environment not stored")
|
|
}
|
|
if srv.loadBalancer != lbFake {
|
|
t.Error("loadBalancer not stored")
|
|
}
|
|
if srv.gracefulQuitTimeout != timeout {
|
|
t.Errorf("gracefulQuitTimeout = %v, want %v", srv.gracefulQuitTimeout, timeout)
|
|
}
|
|
if srv.shutdown == nil {
|
|
t.Error("shutdown seam should default to non-nil")
|
|
}
|
|
if srv.newServer == nil {
|
|
t.Error("newServer seam should default to non-nil")
|
|
}
|
|
if srv.newHLSStream == nil {
|
|
t.Error("newHLSStream seam should default to non-nil")
|
|
}
|
|
if srv.newFlvTsConn == nil {
|
|
t.Error("newFlvTsConn seam should default to non-nil")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_New_AppliesOpts(t *testing.T) {
|
|
var optCalled bool
|
|
srv := NewHTTPStreamProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
time.Second,
|
|
func(s *httpStreamProxyServer) { optCalled = true },
|
|
).(*httpStreamProxyServer)
|
|
if !optCalled {
|
|
t.Fatal("opt was not invoked")
|
|
}
|
|
if srv.shutdown == nil {
|
|
t.Error("default seams should still be set when opts don't override them")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_New_OptCanOverrideAllSeams(t *testing.T) {
|
|
customShutdown := func(context.Context) error { return errors.New("custom") }
|
|
customNewServer := func(string) (httpServer, *http.ServeMux) { return nil, nil }
|
|
customNewHLS := func(string, string) *hlsPlayStream { return nil }
|
|
customNewFlv := func(context.Context) *httpFlvTsConnection { return nil }
|
|
|
|
srv := NewHTTPStreamProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
time.Second,
|
|
func(s *httpStreamProxyServer) {
|
|
s.shutdown = customShutdown
|
|
s.newServer = customNewServer
|
|
s.newHLSStream = customNewHLS
|
|
s.newFlvTsConn = customNewFlv
|
|
},
|
|
).(*httpStreamProxyServer)
|
|
|
|
if err := srv.shutdown(context.Background()); err == nil || err.Error() != "custom" {
|
|
t.Errorf("custom shutdown not applied: %v", err)
|
|
}
|
|
// Pointer comparison on func values isn't supported by ==; call them and
|
|
// check the override took effect via observable behavior.
|
|
if got, _ := srv.newServer(""); got != nil {
|
|
t.Error("custom newServer not applied")
|
|
}
|
|
if srv.newHLSStream("", "") != nil {
|
|
t.Error("custom newHLSStream not applied")
|
|
}
|
|
if srv.newFlvTsConn(context.Background()) != nil {
|
|
t.Error("custom newFlvTsConn not applied")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Default factory behavior
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_DefaultNewServer_BuildsRealServerAndMux(t *testing.T) {
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer)
|
|
|
|
got, mux := srv.newServer(":12345")
|
|
if mux == nil {
|
|
t.Fatal("mux is nil")
|
|
}
|
|
real, ok := got.(*http.Server)
|
|
if !ok {
|
|
t.Fatalf("expected *http.Server, got %T", got)
|
|
}
|
|
if real.Addr != ":12345" {
|
|
t.Errorf("Addr = %q, want :12345", real.Addr)
|
|
}
|
|
if real.Handler != mux {
|
|
t.Error("Handler should be the returned mux")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_DefaultNewHLSStream_WiresFields(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake,
|
|
time.Second).(*httpStreamProxyServer)
|
|
|
|
got := srv.newHLSStream("vhost/app/stream", "http://example.com/live.m3u8")
|
|
if got.loadBalancer != lbFake {
|
|
t.Error("loadBalancer not wired")
|
|
}
|
|
if got.StreamURL != "vhost/app/stream" {
|
|
t.Errorf("StreamURL = %q", got.StreamURL)
|
|
}
|
|
if got.FullURL != "http://example.com/live.m3u8" {
|
|
t.Errorf("FullURL = %q", got.FullURL)
|
|
}
|
|
if got.SRSProxyBackendHLSID == "" {
|
|
t.Error("SRSProxyBackendHLSID should be auto-generated")
|
|
}
|
|
if got.buildBackendURL == nil {
|
|
t.Error("buildBackendURL default should be propagated")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_DefaultNewFlvTsConn_WiresFields(t *testing.T) {
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake,
|
|
time.Second).(*httpStreamProxyServer)
|
|
|
|
type ctxKey struct{}
|
|
ctx := context.WithValue(context.Background(), ctxKey{}, "sentinel")
|
|
got := srv.newFlvTsConn(ctx)
|
|
if got.ctx != ctx {
|
|
t.Error("ctx not wired")
|
|
}
|
|
if got.loadBalancer != lbFake {
|
|
t.Error("loadBalancer not wired")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_DefaultShutdown_DelegatesToServer(t *testing.T) {
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer)
|
|
srv.server = fakeSrv // simulate what Run() would assign
|
|
if err := srv.shutdown(context.Background()); err != nil {
|
|
t.Fatalf("shutdown: %v", err)
|
|
}
|
|
if fakeSrv.shutdownCalls.Load() != 1 {
|
|
t.Fatalf("shutdown was not delegated to server, calls=%d", fakeSrv.shutdownCalls.Load())
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Close
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_Close_InvokesShutdownWithDeadline(t *testing.T) {
|
|
var gotCtx context.Context
|
|
var calls int
|
|
srv := NewHTTPStreamProxyServer(
|
|
&envfakes.FakeProxyEnvironment{},
|
|
&lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond,
|
|
func(s *httpStreamProxyServer) {
|
|
s.shutdown = func(ctx context.Context) error {
|
|
gotCtx = ctx
|
|
calls++
|
|
return nil
|
|
}
|
|
},
|
|
).(*httpStreamProxyServer)
|
|
|
|
if err := srv.Close(); err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("shutdown calls = %d, want 1", calls)
|
|
}
|
|
if _, ok := gotCtx.Deadline(); !ok {
|
|
t.Error("Close should pass a deadline-bearing ctx to shutdown")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Run — lifecycle
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_Run_AddrWithoutColonPrependsIt(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns("8080")
|
|
|
|
var capturedAddr string
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond, func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
capturedAddr = addr
|
|
return fakeSrv, http.NewServeMux()
|
|
}
|
|
})
|
|
defer srvIface.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
if err := srvIface.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if capturedAddr != ":8080" {
|
|
t.Fatalf("newServer addr = %q, want :8080", capturedAddr)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_AddrWithColonUnchanged(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns("127.0.0.1:9999")
|
|
|
|
var capturedAddr string
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond, func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
capturedAddr = addr
|
|
return fakeSrv, http.NewServeMux()
|
|
}
|
|
})
|
|
defer srvIface.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
if err := srvIface.Run(ctx); err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if capturedAddr != "127.0.0.1:9999" {
|
|
t.Fatalf("newServer addr = %q", capturedAddr)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_StaticFilesInvalidPath(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
env.StaticFilesReturns("/no/such/path/exists/__srsbot_test__")
|
|
|
|
fakeSrv := newFakeHTTPProxyServer()
|
|
srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{},
|
|
50*time.Millisecond, func(s *httpStreamProxyServer) {
|
|
s.newServer = func(addr string) (httpServer, *http.ServeMux) {
|
|
return fakeSrv, http.NewServeMux()
|
|
}
|
|
})
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
err := srvIface.Run(ctx)
|
|
if err == nil || !strings.Contains(err.Error(), "invalid static files") {
|
|
t.Fatalf("unexpected err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_CtxCancelTriggersShutdown(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
_, fakeSrv, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
// Wait briefly for ListenAndServe goroutine to be running.
|
|
deadline := time.Now().Add(time.Second)
|
|
for fakeSrv.listenCalls.Load() == 0 && time.Now().Before(deadline) {
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
if fakeSrv.listenCalls.Load() == 0 {
|
|
t.Fatal("ListenAndServe goroutine did not start")
|
|
}
|
|
|
|
cancel()
|
|
|
|
// Wait for Shutdown to be invoked by the watcher goroutine.
|
|
deadline = time.Now().Add(time.Second)
|
|
for fakeSrv.shutdownCalls.Load() == 0 && time.Now().Before(deadline) {
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
if fakeSrv.shutdownCalls.Load() == 0 {
|
|
t.Fatal("Shutdown was not invoked after ctx cancel")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Run — handler dispatch
|
|
// =============================================================================
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerVersionsReturnsJSON(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/versions", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
var body struct {
|
|
Code int `json:"code"`
|
|
PID string `json:"pid"`
|
|
Data struct {
|
|
Major int `json:"major"`
|
|
Minor int `json:"minor"`
|
|
Revision int `json:"revision"`
|
|
Version string `json:"version"`
|
|
} `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("json: %v\nbody=%s", err, rec.Body.String())
|
|
}
|
|
if body.Code != 0 {
|
|
t.Errorf("Code = %d, want 0", body.Code)
|
|
}
|
|
if body.PID == "" {
|
|
t.Error("PID should be populated")
|
|
}
|
|
if body.Data.Version == "" {
|
|
t.Error("Version should be populated")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerM3U8InvokesNewHLSStream(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
// Make LoadOrStoreHLS return whatever was passed in (the stream from newHLSStream).
|
|
lbFake.LoadOrStoreHLSStub = func(_ context.Context, _ string, s lb.HLSPlayStream) (lb.HLSPlayStream, error) {
|
|
return s, nil
|
|
}
|
|
|
|
var capturedStreamURL, capturedFullURL string
|
|
var newHLSCalls int
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) {
|
|
// Wrap default newHLSStream to capture args, but return a real
|
|
// hlsPlayStream so the .(*hlsPlayStream) cast inside Run's handler works.
|
|
// The returned stream has a fake loadBalancer; ServeHTTP will short-circuit
|
|
// on the OPTIONS preflight we send below.
|
|
s.newHLSStream = func(streamURL, fullURL string) *hlsPlayStream {
|
|
newHLSCalls++
|
|
capturedStreamURL, capturedFullURL = streamURL, fullURL
|
|
return newHLSPlayStream(func(h *hlsPlayStream) {
|
|
h.loadBalancer = lbFake
|
|
h.StreamURL, h.FullURL = streamURL, fullURL
|
|
})
|
|
}
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if newHLSCalls != 1 {
|
|
t.Fatalf("newHLSStream calls = %d, want 1", newHLSCalls)
|
|
}
|
|
if !strings.HasSuffix(capturedStreamURL, "/live") {
|
|
t.Errorf("captured streamURL %q should end with /live", capturedStreamURL)
|
|
}
|
|
if !strings.Contains(capturedFullURL, "live.m3u8") {
|
|
t.Errorf("captured fullURL %q should contain live.m3u8", capturedFullURL)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerM3U8LoadOrStoreErrorReturns400(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.LoadOrStoreHLSReturns(nil, errors.New("redis down"))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String())
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "load or store hls") {
|
|
t.Errorf("body should mention 'load or store hls', got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerFlvInvokesNewFlvTsConn(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
var newFlvCalls int
|
|
var capturedCtx context.Context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) {
|
|
s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection {
|
|
newFlvCalls++
|
|
capturedCtx = reqCtx
|
|
return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = reqCtx
|
|
c.loadBalancer = lbFake
|
|
})
|
|
}
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if newFlvCalls != 1 {
|
|
t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls)
|
|
}
|
|
if capturedCtx == nil {
|
|
t.Error("captured ctx should be non-nil")
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerTsInvokesNewFlvTsConn(t *testing.T) {
|
|
// Same dispatch as .flv but for .ts (without spbhid).
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
var newFlvCalls int
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) {
|
|
s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection {
|
|
newFlvCalls++
|
|
return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
|
|
c.ctx = reqCtx
|
|
c.loadBalancer = lbFake
|
|
})
|
|
}
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.ts", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if newFlvCalls != 1 {
|
|
t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDLoadsByID(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
|
|
stub := newHLSPlayStream(func(h *hlsPlayStream) {
|
|
h.loadBalancer = lbFake
|
|
h.SRSProxyBackendHLSID = "ABC"
|
|
})
|
|
lbFake.LoadHLSBySPBHIDReturns(stub, nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://example.com/live-0.ts?spbhid=ABC", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if lbFake.LoadHLSBySPBHIDCallCount() != 1 {
|
|
t.Fatalf("LoadHLSBySPBHID calls = %d, want 1", lbFake.LoadHLSBySPBHIDCallCount())
|
|
}
|
|
_, gotID := lbFake.LoadHLSBySPBHIDArgsForCall(0)
|
|
if gotID != "ABC" {
|
|
t.Errorf("LoadHLSBySPBHID id = %q, want ABC", gotID)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDErrorReturns400(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
lbFake := &lbfakes.FakeOriginLoadBalancer{}
|
|
lbFake.LoadHLSBySPBHIDReturns(nil, errors.New("not found"))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/live-0.ts?spbhid=missing", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerUnmatchedReturns404(t *testing.T) {
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
// StaticFiles unset, no .m3u8/.flv/.ts suffix → 404.
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/random/path", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Fatalf("status = %d, want 404", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHTTPStreamProxyServer_Run_HandlerServesStaticFiles(t *testing.T) {
|
|
dir := t.TempDir()
|
|
if err := os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hi"), 0644); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
|
|
env := &envfakes.FakeProxyEnvironment{}
|
|
env.HttpServerReturns(":0")
|
|
env.StaticFilesReturns(dir)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/hello.txt", nil)
|
|
rec := httptest.NewRecorder()
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want 200", rec.Code)
|
|
}
|
|
if rec.Body.String() != "hi" {
|
|
t.Errorf("body = %q, want hi", rec.Body.String())
|
|
}
|
|
}
|