From 0a18a4a13b073bf4b40a06de3b15db177d9ae659 Mon Sep 17 00:00:00 2001 From: winlin Date: Sat, 16 May 2026 21:27:24 -0400 Subject: [PATCH] Claude: Add HTTP proxy seams and unit tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/proxy/http.go | 114 +++- internal/proxy/http_test.go | 1289 +++++++++++++++++++++++++++++++++++ 2 files changed, 1381 insertions(+), 22 deletions(-) create mode 100644 internal/proxy/http_test.go diff --git a/internal/proxy/http.go b/internal/proxy/http.go index c37090eef..2bf052460 100644 --- a/internal/proxy/http.go +++ b/internal/proxy/http.go @@ -7,7 +7,7 @@ import ( "context" "fmt" "io" - "io/ioutil" + "net/http" "os" "strconv" @@ -31,32 +31,95 @@ type HTTPStreamProxyServer interface { Close() error } +// httpServer is the minimal contract of an HTTP server that httpStreamProxyServer drives. +// *http.Server satisfies it. Tests may supply a fake that does not bind a real port. +type httpServer interface { + ListenAndServe() error + Shutdown(ctx context.Context) error +} + +// buildBackendHTTPURL composes the backend HTTP URL for a request path, targeting +// the given backend IP and port. Callers append query strings separately when needed. +func buildBackendHTTPURL(ip string, port int, path string) string { + return fmt.Sprintf("http://%v:%v%s", ip, port, path) +} + type httpStreamProxyServer struct { // The environment interface. environment env.ProxyEnvironment // The load balancer for origin servers. loadBalancer lb.OriginLoadBalancer // The underlayer HTTP server. - server *http.Server + server httpServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. wg stdSync.WaitGroup + // shutdown gracefully shuts down the underlying HTTP server. Defaults to + // v.server.Shutdown; tests may override via a functional option to verify + // the shutdown contract without binding a real socket. + shutdown func(ctx context.Context) error + // newServer constructs the underlying HTTP server bound to addr and the + // ServeMux that handlers are registered on. Defaults to a real http.Server + // and ServeMux; tests may override via a functional option to supply a fake + // server that does not bind a real port. + newServer func(addr string) (httpServer, *http.ServeMux) + // newHLSStream constructs a per-stream HLS playback object for the given + // stream URL pair. Defaults to newHLSPlayStream pre-wired with this server's + // load balancer and a fresh SPBHID; tests may override via a functional option. + newHLSStream func(streamURL, fullURL string) *hlsPlayStream + // newFlvTsConn constructs a per-request HTTP-FLV/TS connection bound to ctx. + // Defaults to newHTTPFlvTsConnection pre-wired with this server's load + // balancer; tests may override via a functional option. + newFlvTsConn func(ctx context.Context) *httpFlvTsConnection } -func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) HTTPStreamProxyServer { +func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration, opts ...func(*httpStreamProxyServer)) HTTPStreamProxyServer { v := &httpStreamProxyServer{ environment: environment, loadBalancer: loadBalancer, gracefulQuitTimeout: gracefulQuitTimeout, } + + // Default shutdown: delegate to the underlying http.Server. The closure + // captures v rather than v.server so the dereference happens at call time, + // after Run() has assigned v.server. + v.shutdown = func(ctx context.Context) error { + return v.server.Shutdown(ctx) + } + // Default newServer: a real http.Server and ServeMux pair. + v.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + return &http.Server{Addr: addr, Handler: mux}, mux + } + // Default newHLSStream: a real hlsPlayStream wired with the server's load + // balancer and a fresh SPBHID for this stream. + v.newHLSStream = func(streamURL, fullURL string) *hlsPlayStream { + return newHLSPlayStream(func(s *hlsPlayStream) { + s.loadBalancer = v.loadBalancer + s.SRSProxyBackendHLSID = logger.GenerateContextID() + s.StreamURL, s.FullURL = streamURL, fullURL + }) + } + // Default newFlvTsConn: a real httpFlvTsConnection wired with the server's + // load balancer and the given ctx. + v.newFlvTsConn = func(ctx context.Context) *httpFlvTsConnection { + return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = ctx + c.loadBalancer = v.loadBalancer + }) + } + + for _, opt := range opts { + opt(v) + } return v } func (v *httpStreamProxyServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) v.wg.Wait() return nil @@ -70,8 +133,8 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { } // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} + server, mux := v.newServer(addr) + v.server = server logger.Debug(ctx, "HTTP Stream server listen at %v", addr) // Shutdown the server gracefully when quiting. @@ -82,7 +145,7 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() - v.server.Shutdown(ctx) + v.shutdown(ctx) }() // The basic version handler, also can be used as health check API. @@ -131,11 +194,11 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { return } - stream, _ := v.loadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) { - s.loadBalancer = v.loadBalancer - s.SRSProxyBackendHLSID = logger.GenerateContextID() - s.StreamURL, s.FullURL = streamURL, fullURL - })) + stream, err := v.loadBalancer.LoadOrStoreHLS(ctx, streamURL, v.newHLSStream(streamURL, fullURL)) + if err != nil { + http.Error(w, fmt.Sprintf("load or store hls %v", streamURL), http.StatusBadRequest) + return + } stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r) return @@ -155,10 +218,7 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { } // Use HTTP pseudo streaming to proxy the request. - newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { - c.ctx = ctx - c.loadBalancer = v.loadBalancer - }).ServeHTTP(w, r) + v.newFlvTsConn(ctx).ServeHTTP(w, r) return } @@ -203,10 +263,15 @@ type httpFlvTsConnection struct { ctx context.Context // The load balancer for origin servers. loadBalancer lb.OriginLoadBalancer + // buildBackendURL composes the backend HTTP URL for a request path. Defaults + // to buildBackendHTTPURL; tests may override via a functional option. + buildBackendURL func(ip string, port int, path string) string } func newHTTPFlvTsConnection(opts ...func(*httpFlvTsConnection)) *httpFlvTsConnection { - v := &httpFlvTsConnection{} + v := &httpFlvTsConnection{ + buildBackendURL: buildBackendHTTPURL, + } for _, opt := range opts { opt(v) } @@ -266,7 +331,7 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons } // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + backendURL := v.buildBackendURL(backend.IP, httpPort, r.URL.Path) req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) if err != nil { return errors.Wrapf(err, "create request to %v", backendURL) @@ -319,10 +384,15 @@ type hlsPlayStream struct { StreamURL string `json:"stream_url"` // The full request URL for HLS streaming FullURL string `json:"full_url"` + // buildBackendURL composes the backend HTTP URL for a request path. Defaults + // to buildBackendHTTPURL; tests may override via a functional option. + buildBackendURL func(ip string, port int, path string) string `json:"-"` } func newHLSPlayStream(opts ...func(*hlsPlayStream)) *hlsPlayStream { - v := &hlsPlayStream{} + v := &hlsPlayStream{ + buildBackendURL: buildBackendHTTPURL, + } for _, opt := range opts { opt(v) } @@ -375,7 +445,7 @@ func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *lb.OriginServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { - return errors.Errorf("no rtmp server") + return errors.Errorf("no http server") } var httpPort int @@ -386,7 +456,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite } // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + backendURL := v.buildBackendURL(backend.IP, httpPort, r.URL.Path) if r.URL.RawQuery != "" { backendURL += "?" + r.URL.RawQuery } @@ -425,7 +495,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { return errors.Wrapf(err, "read stream from %v", backendURL) } diff --git a/internal/proxy/http_test.go b/internal/proxy/http_test.go new file mode 100644 index 000000000..fa64225c4 --- /dev/null +++ b/internal/proxy/http_test.go @@ -0,0 +1,1289 @@ +// Copyright (c) 2026 Winlin +// +// SPDX-License-Identifier: MIT +package proxy + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + stdSync "sync" + "sync/atomic" + "testing" + "time" + + "srsx/internal/env/envfakes" + "srsx/internal/lb" + "srsx/internal/lb/lbfakes" +) + +// httptestHostPort splits an httptest.Server URL into host and port strings. +func httptestHostPort(t *testing.T, ts *httptest.Server) (string, string) { + t.Helper() + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("parse httptest URL %q: %v", ts.URL, err) + } + return u.Hostname(), u.Port() +} + +// reservedClosedPort binds and immediately closes a TCP port, returning an +// address that is reliably refused for the lifetime of the test. +func reservedClosedPort(t *testing.T) (string, string) { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("reserve port: %v", err) + } + addr := l.Addr().(*net.TCPAddr) + if err := l.Close(); err != nil { + t.Fatalf("close listener: %v", err) + } + return addr.IP.String(), strconv.Itoa(addr.Port) +} + +// ============================================================================= +// newHLSPlayStream +// ============================================================================= + +func TestHLSPlayStream_New_DefaultsBuildBackendURL(t *testing.T) { + v := newHLSPlayStream() + if v.buildBackendURL == nil { + t.Fatal("buildBackendURL should default to non-nil") + } + if got := v.buildBackendURL("1.2.3.4", 8080, "/live.ts"); got != "http://1.2.3.4:8080/live.ts" { + t.Fatalf("default buildBackendURL produced %q", got) + } +} + +func TestHLSPlayStream_New_AppliesOpts(t *testing.T) { + ctx := context.Background() + lbStub := &lbfakes.FakeOriginLoadBalancer{} + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = ctx + s.loadBalancer = lbStub + s.SRSProxyBackendHLSID = "spb-id" + s.StreamURL = "vhost/app/stream" + s.FullURL = "http://example.com/live.m3u8" + }) + if v.ctx != ctx { + t.Error("ctx not applied") + } + if v.loadBalancer != lbStub { + t.Error("loadBalancer not applied") + } + if v.SRSProxyBackendHLSID != "spb-id" { + t.Errorf("SRSProxyBackendHLSID = %q", v.SRSProxyBackendHLSID) + } + if v.StreamURL != "vhost/app/stream" { + t.Errorf("StreamURL = %q", v.StreamURL) + } + if v.FullURL != "http://example.com/live.m3u8" { + t.Errorf("FullURL = %q", v.FullURL) + } +} + +func TestHLSPlayStream_New_OptCanOverrideBuildBackendURL(t *testing.T) { + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.buildBackendURL = func(string, int, string) string { return "custom" } + }) + if got := v.buildBackendURL("", 0, ""); got != "custom" { + t.Fatalf("override not applied: got %q", got) + } +} + +// ============================================================================= +// Initialize +// ============================================================================= + +func TestHLSPlayStream_Initialize_SetsCtxWhenNil(t *testing.T) { + v := newHLSPlayStream() + ret := v.Initialize(context.Background()) + if v.ctx == nil { + t.Fatal("Initialize should set v.ctx when nil") + } + if ret != lb.HLSPlayStream(v) { + t.Fatal("Initialize should return v") + } +} + +func TestHLSPlayStream_Initialize_PreservesExistingCtx(t *testing.T) { + type ctxKey struct{} + existing := context.WithValue(context.Background(), ctxKey{}, "sentinel") + v := newHLSPlayStream(func(s *hlsPlayStream) { s.ctx = existing }) + v.Initialize(context.Background()) + if got, _ := v.ctx.Value(ctxKey{}).(string); got != "sentinel" { + t.Fatalf("Initialize should not replace existing ctx, value=%q", got) + } +} + +// ============================================================================= +// GetSPBHID +// ============================================================================= + +func TestHLSPlayStream_GetSPBHID(t *testing.T) { + v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "spb-xyz" }) + if v.GetSPBHID() != "spb-xyz" { + t.Fatalf("GetSPBHID = %q", v.GetSPBHID()) + } +} + +// ============================================================================= +// ServeHTTP / serve / CORS +// ============================================================================= + +func TestHLSPlayStream_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + if lbFake.PickCallCount() != 0 { + t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount()) + } +} + +func TestHLSPlayStream_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + // ApiError writes a JSON error response. Verify the body is non-empty + // and the status is not the default 200 (or that some response was made). + if rec.Body.Len() == 0 { + t.Fatal("ServeHTTP error branch should produce a response body") + } +} + +func TestHLSPlayStream_Serve_PickError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serve(v.ctx, rec, req) + if err == nil || !strings.Contains(err.Error(), "pick backend for vhost/app/stream") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_Serve_WrapsServeByBackendError(t *testing.T) { + // Backend with empty HTTP slice triggers serveByBackend's "no http server" + // error, which serve() then wraps with "serve %v with %v by backend %+v". + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil) + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + s.FullURL = "http://example.com/live.m3u8" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serve(v.ctx, rec, req) + if err == nil || !strings.Contains(err.Error(), "serve http://example.com/live.m3u8 with vhost/app/stream by backend") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_Serve_HappyPathRewritesM3U8(t *testing.T) { + m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\n" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, m3u8) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil) + + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.ctx = context.Background() + s.loadBalancer = lbFake + s.StreamURL = "vhost/app/stream" + s.FullURL = "http://example.com/live.m3u8" + s.SRSProxyBackendHLSID = "spb-1" + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serve(v.ctx, rec, req); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !strings.Contains(rec.Body.String(), "live-0.ts?spbhid=spb-1") { + t.Fatalf("body missing spbhid rewrite: %q", rec.Body.String()) + } +} + +// ============================================================================= +// serveByBackend — error paths (no HTTP round-trip needed) +// ============================================================================= + +func TestHLSPlayStream_ServeByBackend_NoHTTPEndpoint(t *testing.T) { + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"}) + if err == nil || !strings.Contains(err.Error(), "no http server") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_BadPort(t *testing.T) { + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}}) + if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_RequestBuildError(t *testing.T) { + v := newHLSPlayStream(func(s *hlsPlayStream) { + s.buildBackendURL = func(string, int, string) string { return "://invalid-url" } + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}}) + if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_DialError(t *testing.T) { + host, port := reservedClosedPort(t) + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.HasSuffix(err.Error(), "EOF") { + t.Fatalf("expected error suffixed with 'EOF', got: %v", err) + } +} + +// ============================================================================= +// serveByBackend — HTTP round-trip via httptest.Server +// ============================================================================= + +func TestHLSPlayStream_ServeByBackend_NonOKStatus(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.Contains(err.Error(), "status=404") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHLSPlayStream_ServeByBackend_TSPassthrough(t *testing.T) { + payload := []byte{0x47, 0x00, 0x01, 0x02, 0x03, 0x04} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp2t") + _, _ = w.Write(payload) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if got := rec.Body.Bytes(); !bytes.Equal(got, payload) { + t.Fatalf("body mismatch: got=%v want=%v", got, payload) + } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } +} + +func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithoutQuery(t *testing.T) { + m3u8 := "#EXTM3U\n#EXT-X-VERSION:3\nlive-0.ts\nlive-1.ts\n" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, m3u8) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + body := rec.Body.String() + for _, want := range []string{"live-0.ts?spbhid=ABC", "live-1.ts?spbhid=ABC"} { + if !strings.Contains(body, want) { + t.Fatalf("missing %q in body: %q", want, body) + } + } +} + +func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) { + m3u8 := "#EXTM3U\nlive-0.ts?token=foo\n" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, m3u8) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream(func(s *hlsPlayStream) { s.SRSProxyBackendHLSID = "ABC" }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if want := "live-0.ts?spbhid=ABC&&token=foo"; !strings.Contains(rec.Body.String(), want) { + t.Fatalf("missing %q in body: %q", want, rec.Body.String()) + } +} + +func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) { + var seenURL string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenURL = r.URL.String() + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts?token=foo", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !strings.Contains(seenURL, "token=foo") { + t.Fatalf("backend should see raw query, got %q", seenURL) + } +} + +// ============================================================================= +// httpFlvTsConnection +// ============================================================================= + +// ============================================================================= +// newHTTPFlvTsConnection +// ============================================================================= + +func TestHTTPFlvTsConn_New_DefaultsBuildBackendURL(t *testing.T) { + v := newHTTPFlvTsConnection() + if v.buildBackendURL == nil { + t.Fatal("buildBackendURL should default to non-nil") + } + if got := v.buildBackendURL("1.2.3.4", 8080, "/live.flv"); got != "http://1.2.3.4:8080/live.flv" { + t.Fatalf("default buildBackendURL produced %q", got) + } +} + +func TestHTTPFlvTsConn_New_AppliesOpts(t *testing.T) { + ctx := context.Background() + lbStub := &lbfakes.FakeOriginLoadBalancer{} + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = ctx + c.loadBalancer = lbStub + }) + if v.ctx != ctx { + t.Error("ctx not applied") + } + if v.loadBalancer != lbStub { + t.Error("loadBalancer not applied") + } +} + +func TestHTTPFlvTsConn_New_OptCanOverrideBuildBackendURL(t *testing.T) { + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.buildBackendURL = func(string, int, string) string { return "custom" } + }) + if got := v.buildBackendURL("", 0, ""); got != "custom" { + t.Fatalf("override not applied: got %q", got) + } +} + +// ============================================================================= +// ServeHTTP / serve / CORS +// ============================================================================= + +func TestHTTPFlvTsConn_ServeHTTP_CORSPreflightShortCircuits(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + if lbFake.PickCallCount() != 0 { + t.Fatalf("Pick should not be called on CORS preflight, calls=%d", lbFake.PickCallCount()) + } +} + +func TestHTTPFlvTsConn_ServeHTTP_ErrorBranchInvokesApiError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + v.ServeHTTP(rec, req) + if rec.Body.Len() == 0 { + t.Fatal("ServeHTTP error branch should produce a response body") + } +} + +func TestHTTPFlvTsConn_Serve_PickError(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(nil, errors.New("pick-fail")) + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil) + rec := httptest.NewRecorder() + err := v.serve(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "pick backend for") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_Serve_WrapsServeByBackendError(t *testing.T) { + // Empty HTTP slice on backend triggers serveByBackend's "no http stream + // server" error, which serve() wraps with "serve with ". + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: "127.0.0.1"}, nil) + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil) + rec := httptest.NewRecorder() + err := v.serve(context.Background(), rec, req) + if err == nil || !strings.Contains(err.Error(), "serve ") || !strings.Contains(err.Error(), " by backend ") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_Serve_HappyPath(t *testing.T) { + payload := []byte{0x46, 0x4c, 0x56, 0x01} // "FLV\x01" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(payload) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.PickReturns(&lb.OriginServer{IP: host, HTTP: []string{port}}, nil) + + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = context.Background() + c.loadBalancer = lbFake + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live/stream.flv", nil) + rec := httptest.NewRecorder() + if err := v.serve(context.Background(), rec, req); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !bytes.Equal(rec.Body.Bytes(), payload) { + t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload) + } +} + +// ============================================================================= +// serveByBackend — error paths +// ============================================================================= + +func TestHTTPFlvTsConn_ServeByBackend_NoHTTPEndpoint(t *testing.T) { + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, &lb.OriginServer{IP: "127.0.0.1"}) + if err == nil || !strings.Contains(err.Error(), "no http stream server") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_BadPort(t *testing.T) { + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"not-a-port"}}) + if err == nil || !strings.Contains(err.Error(), "parse http port not-a-port") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_RequestBuildError(t *testing.T) { + v := newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.buildBackendURL = func(string, int, string) string { return "://invalid-url" } + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: "127.0.0.1", HTTP: []string{"8080"}}) + if err == nil || !strings.Contains(err.Error(), "create request to ://invalid-url") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_DialError(t *testing.T) { + host, port := reservedClosedPort(t) + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.Contains(err.Error(), "do request to") { + t.Fatalf("unexpected err: %v", err) + } +} + +// ============================================================================= +// serveByBackend — HTTP round-trip +// ============================================================================= + +func TestHTTPFlvTsConn_ServeByBackend_NonOKStatus(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}) + if err == nil || !strings.Contains(err.Error(), "status=404") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_BodyPassthrough(t *testing.T) { + payload := []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(payload) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !bytes.Equal(rec.Body.Bytes(), payload) { + t.Fatalf("body mismatch: got=%v want=%v", rec.Body.Bytes(), payload) + } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_DropsRawQuery(t *testing.T) { + // Unlike hlsPlayStream.serveByBackend, the FLV/TS path forwards only + // r.URL.Path — it does NOT append RawQuery to the backend request. + var seenRawQuery string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenRawQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv?token=foo", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if seenRawQuery != "" { + t.Fatalf("backend should NOT see raw query, got %q", seenRawQuery) + } +} + +func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) { + var seenMethod string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenMethod = r.Method + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodHead, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if seenMethod != http.MethodHead { + t.Fatalf("backend method = %q, want HEAD", seenMethod) + } +} + +// ============================================================================= +// httpStreamProxyServer +// ============================================================================= + +// fakeHTTPProxyServer is an httpServer that blocks in ListenAndServe until +// Shutdown is called. Used to drive Run()'s lifecycle without binding a port. +type fakeHTTPProxyServer struct { + listenCalls atomic.Int32 + shutdownCalls atomic.Int32 + listenReturn error + shutdownReturn error + block chan struct{} + once stdSync.Once +} + +func newFakeHTTPProxyServer() *fakeHTTPProxyServer { + return &fakeHTTPProxyServer{ + listenReturn: http.ErrServerClosed, + block: make(chan struct{}), + } +} + +func (f *fakeHTTPProxyServer) ListenAndServe() error { + f.listenCalls.Add(1) + <-f.block + return f.listenReturn +} + +func (f *fakeHTTPProxyServer) Shutdown(ctx context.Context) error { + f.shutdownCalls.Add(1) + f.once.Do(func() { close(f.block) }) + return f.shutdownReturn +} + +// captureMuxFromRun calls Run with a fake server that captures the registered +// mux. Returns the mux and the fake server for further assertions. Caller is +// responsible for cancelling ctx to trigger shutdown. +func captureMuxFromRun(t *testing.T, env *envfakes.FakeProxyEnvironment, + lbFake *lbfakes.FakeOriginLoadBalancer, ctx context.Context, + opts ...func(*httpStreamProxyServer)) (*http.ServeMux, *fakeHTTPProxyServer, *httpStreamProxyServer) { + t.Helper() + + fakeSrv := newFakeHTTPProxyServer() + var capturedMux *http.ServeMux + + baseOpts := []func(*httpStreamProxyServer){ + func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + mux := http.NewServeMux() + capturedMux = mux + return fakeSrv, mux + } + }, + } + srvIface := NewHTTPStreamProxyServer(env, lbFake, 50*time.Millisecond, append(baseOpts, opts...)...) + srv := srvIface.(*httpStreamProxyServer) + + if err := srv.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedMux == nil { + t.Fatal("newServer was not called by Run") + } + return capturedMux, fakeSrv, srv +} + +// ============================================================================= +// NewHTTPStreamProxyServer +// ============================================================================= + +func TestHTTPStreamProxyServer_New_StoresFieldsAndDefaultsSeams(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + lbFake := &lbfakes.FakeOriginLoadBalancer{} + timeout := 2 * time.Second + srv := NewHTTPStreamProxyServer(env, lbFake, timeout).(*httpStreamProxyServer) + + if srv.environment != env { + t.Error("environment not stored") + } + if srv.loadBalancer != lbFake { + t.Error("loadBalancer not stored") + } + if srv.gracefulQuitTimeout != timeout { + t.Errorf("gracefulQuitTimeout = %v, want %v", srv.gracefulQuitTimeout, timeout) + } + if srv.shutdown == nil { + t.Error("shutdown seam should default to non-nil") + } + if srv.newServer == nil { + t.Error("newServer seam should default to non-nil") + } + if srv.newHLSStream == nil { + t.Error("newHLSStream seam should default to non-nil") + } + if srv.newFlvTsConn == nil { + t.Error("newFlvTsConn seam should default to non-nil") + } +} + +func TestHTTPStreamProxyServer_New_AppliesOpts(t *testing.T) { + var optCalled bool + srv := NewHTTPStreamProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + time.Second, + func(s *httpStreamProxyServer) { optCalled = true }, + ).(*httpStreamProxyServer) + if !optCalled { + t.Fatal("opt was not invoked") + } + if srv.shutdown == nil { + t.Error("default seams should still be set when opts don't override them") + } +} + +func TestHTTPStreamProxyServer_New_OptCanOverrideAllSeams(t *testing.T) { + customShutdown := func(context.Context) error { return errors.New("custom") } + customNewServer := func(string) (httpServer, *http.ServeMux) { return nil, nil } + customNewHLS := func(string, string) *hlsPlayStream { return nil } + customNewFlv := func(context.Context) *httpFlvTsConnection { return nil } + + srv := NewHTTPStreamProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + time.Second, + func(s *httpStreamProxyServer) { + s.shutdown = customShutdown + s.newServer = customNewServer + s.newHLSStream = customNewHLS + s.newFlvTsConn = customNewFlv + }, + ).(*httpStreamProxyServer) + + if err := srv.shutdown(context.Background()); err == nil || err.Error() != "custom" { + t.Errorf("custom shutdown not applied: %v", err) + } + // Pointer comparison on func values isn't supported by ==; call them and + // check the override took effect via observable behavior. + if got, _ := srv.newServer(""); got != nil { + t.Error("custom newServer not applied") + } + if srv.newHLSStream("", "") != nil { + t.Error("custom newHLSStream not applied") + } + if srv.newFlvTsConn(context.Background()) != nil { + t.Error("custom newFlvTsConn not applied") + } +} + +// ============================================================================= +// Default factory behavior +// ============================================================================= + +func TestHTTPStreamProxyServer_DefaultNewServer_BuildsRealServerAndMux(t *testing.T) { + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer) + + got, mux := srv.newServer(":12345") + if mux == nil { + t.Fatal("mux is nil") + } + real, ok := got.(*http.Server) + if !ok { + t.Fatalf("expected *http.Server, got %T", got) + } + if real.Addr != ":12345" { + t.Errorf("Addr = %q, want :12345", real.Addr) + } + if real.Handler != mux { + t.Error("Handler should be the returned mux") + } +} + +func TestHTTPStreamProxyServer_DefaultNewHLSStream_WiresFields(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake, + time.Second).(*httpStreamProxyServer) + + got := srv.newHLSStream("vhost/app/stream", "http://example.com/live.m3u8") + if got.loadBalancer != lbFake { + t.Error("loadBalancer not wired") + } + if got.StreamURL != "vhost/app/stream" { + t.Errorf("StreamURL = %q", got.StreamURL) + } + if got.FullURL != "http://example.com/live.m3u8" { + t.Errorf("FullURL = %q", got.FullURL) + } + if got.SRSProxyBackendHLSID == "" { + t.Error("SRSProxyBackendHLSID should be auto-generated") + } + if got.buildBackendURL == nil { + t.Error("buildBackendURL default should be propagated") + } +} + +func TestHTTPStreamProxyServer_DefaultNewFlvTsConn_WiresFields(t *testing.T) { + lbFake := &lbfakes.FakeOriginLoadBalancer{} + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, lbFake, + time.Second).(*httpStreamProxyServer) + + type ctxKey struct{} + ctx := context.WithValue(context.Background(), ctxKey{}, "sentinel") + got := srv.newFlvTsConn(ctx) + if got.ctx != ctx { + t.Error("ctx not wired") + } + if got.loadBalancer != lbFake { + t.Error("loadBalancer not wired") + } +} + +func TestHTTPStreamProxyServer_DefaultShutdown_DelegatesToServer(t *testing.T) { + fakeSrv := newFakeHTTPProxyServer() + srv := NewHTTPStreamProxyServer(&envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, time.Second).(*httpStreamProxyServer) + srv.server = fakeSrv // simulate what Run() would assign + if err := srv.shutdown(context.Background()); err != nil { + t.Fatalf("shutdown: %v", err) + } + if fakeSrv.shutdownCalls.Load() != 1 { + t.Fatalf("shutdown was not delegated to server, calls=%d", fakeSrv.shutdownCalls.Load()) + } +} + +// ============================================================================= +// Close +// ============================================================================= + +func TestHTTPStreamProxyServer_Close_InvokesShutdownWithDeadline(t *testing.T) { + var gotCtx context.Context + var calls int + srv := NewHTTPStreamProxyServer( + &envfakes.FakeProxyEnvironment{}, + &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, + func(s *httpStreamProxyServer) { + s.shutdown = func(ctx context.Context) error { + gotCtx = ctx + calls++ + return nil + } + }, + ).(*httpStreamProxyServer) + + if err := srv.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if calls != 1 { + t.Fatalf("shutdown calls = %d, want 1", calls) + } + if _, ok := gotCtx.Deadline(); !ok { + t.Error("Close should pass a deadline-bearing ctx to shutdown") + } +} + +// ============================================================================= +// Run — lifecycle +// ============================================================================= + +func TestHTTPStreamProxyServer_Run_AddrWithoutColonPrependsIt(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns("8080") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srvIface.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srvIface.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != ":8080" { + t.Fatalf("newServer addr = %q, want :8080", capturedAddr) + } +} + +func TestHTTPStreamProxyServer_Run_AddrWithColonUnchanged(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns("127.0.0.1:9999") + + var capturedAddr string + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + capturedAddr = addr + return fakeSrv, http.NewServeMux() + } + }) + defer srvIface.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := srvIface.Run(ctx); err != nil { + t.Fatalf("Run: %v", err) + } + if capturedAddr != "127.0.0.1:9999" { + t.Fatalf("newServer addr = %q", capturedAddr) + } +} + +func TestHTTPStreamProxyServer_Run_StaticFilesInvalidPath(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + env.StaticFilesReturns("/no/such/path/exists/__srsbot_test__") + + fakeSrv := newFakeHTTPProxyServer() + srvIface := NewHTTPStreamProxyServer(env, &lbfakes.FakeOriginLoadBalancer{}, + 50*time.Millisecond, func(s *httpStreamProxyServer) { + s.newServer = func(addr string) (httpServer, *http.ServeMux) { + return fakeSrv, http.NewServeMux() + } + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := srvIface.Run(ctx) + if err == nil || !strings.Contains(err.Error(), "invalid static files") { + t.Fatalf("unexpected err: %v", err) + } +} + +func TestHTTPStreamProxyServer_Run_CtxCancelTriggersShutdown(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, fakeSrv, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + // Wait briefly for ListenAndServe goroutine to be running. + deadline := time.Now().Add(time.Second) + for fakeSrv.listenCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.listenCalls.Load() == 0 { + t.Fatal("ListenAndServe goroutine did not start") + } + + cancel() + + // Wait for Shutdown to be invoked by the watcher goroutine. + deadline = time.Now().Add(time.Second) + for fakeSrv.shutdownCalls.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + if fakeSrv.shutdownCalls.Load() == 0 { + t.Fatal("Shutdown was not invoked after ctx cancel") + } +} + +// ============================================================================= +// Run — handler dispatch +// ============================================================================= + +func TestHTTPStreamProxyServer_Run_HandlerVersionsReturnsJSON(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/versions", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var body struct { + Code int `json:"code"` + PID string `json:"pid"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("json: %v\nbody=%s", err, rec.Body.String()) + } + if body.Code != 0 { + t.Errorf("Code = %d, want 0", body.Code) + } + if body.PID == "" { + t.Error("PID should be populated") + } + if body.Data.Version == "" { + t.Error("Version should be populated") + } +} + +func TestHTTPStreamProxyServer_Run_HandlerM3U8InvokesNewHLSStream(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + // Make LoadOrStoreHLS return whatever was passed in (the stream from newHLSStream). + lbFake.LoadOrStoreHLSStub = func(_ context.Context, _ string, s lb.HLSPlayStream) (lb.HLSPlayStream, error) { + return s, nil + } + + var capturedStreamURL, capturedFullURL string + var newHLSCalls int + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) { + // Wrap default newHLSStream to capture args, but return a real + // hlsPlayStream so the .(*hlsPlayStream) cast inside Run's handler works. + // The returned stream has a fake loadBalancer; ServeHTTP will short-circuit + // on the OPTIONS preflight we send below. + s.newHLSStream = func(streamURL, fullURL string) *hlsPlayStream { + newHLSCalls++ + capturedStreamURL, capturedFullURL = streamURL, fullURL + return newHLSPlayStream(func(h *hlsPlayStream) { + h.loadBalancer = lbFake + h.StreamURL, h.FullURL = streamURL, fullURL + }) + } + }) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if newHLSCalls != 1 { + t.Fatalf("newHLSStream calls = %d, want 1", newHLSCalls) + } + if !strings.HasSuffix(capturedStreamURL, "/live") { + t.Errorf("captured streamURL %q should end with /live", capturedStreamURL) + } + if !strings.Contains(capturedFullURL, "live.m3u8") { + t.Errorf("captured fullURL %q should contain live.m3u8", capturedFullURL) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerM3U8LoadOrStoreErrorReturns400(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.LoadOrStoreHLSReturns(nil, errors.New("redis down")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "load or store hls") { + t.Errorf("body should mention 'load or store hls', got %q", rec.Body.String()) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerFlvInvokesNewFlvTsConn(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + var newFlvCalls int + var capturedCtx context.Context + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) { + s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection { + newFlvCalls++ + capturedCtx = reqCtx + return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = reqCtx + c.loadBalancer = lbFake + }) + } + }) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if newFlvCalls != 1 { + t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls) + } + if capturedCtx == nil { + t.Error("captured ctx should be non-nil") + } +} + +func TestHTTPStreamProxyServer_Run_HandlerTsInvokesNewFlvTsConn(t *testing.T) { + // Same dispatch as .flv but for .ts (without spbhid). + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + var newFlvCalls int + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx, func(s *httpStreamProxyServer) { + s.newFlvTsConn = func(reqCtx context.Context) *httpFlvTsConnection { + newFlvCalls++ + return newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { + c.ctx = reqCtx + c.loadBalancer = lbFake + }) + } + }) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live.ts", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if newFlvCalls != 1 { + t.Fatalf("newFlvTsConn calls = %d, want 1", newFlvCalls) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDLoadsByID(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + + stub := newHLSPlayStream(func(h *hlsPlayStream) { + h.loadBalancer = lbFake + h.SRSProxyBackendHLSID = "ABC" + }) + lbFake.LoadHLSBySPBHIDReturns(stub, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com/live-0.ts?spbhid=ABC", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if lbFake.LoadHLSBySPBHIDCallCount() != 1 { + t.Fatalf("LoadHLSBySPBHID calls = %d, want 1", lbFake.LoadHLSBySPBHIDCallCount()) + } + _, gotID := lbFake.LoadHLSBySPBHIDArgsForCall(0) + if gotID != "ABC" { + t.Errorf("LoadHLSBySPBHID id = %q, want ABC", gotID) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerTsWithSPBHIDErrorReturns400(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + lbFake := &lbfakes.FakeOriginLoadBalancer{} + lbFake.LoadHLSBySPBHIDReturns(nil, errors.New("not found")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, lbFake, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/live-0.ts?spbhid=missing", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerUnmatchedReturns404(t *testing.T) { + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + // StaticFiles unset, no .m3u8/.flv/.ts suffix → 404. + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/random/path", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", rec.Code) + } +} + +func TestHTTPStreamProxyServer_Run_HandlerServesStaticFiles(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hi"), 0644); err != nil { + t.Fatalf("write: %v", err) + } + + env := &envfakes.FakeProxyEnvironment{} + env.HttpServerReturns(":0") + env.StaticFilesReturns(dir) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mux, _, _ := captureMuxFromRun(t, env, &lbfakes.FakeOriginLoadBalancer{}, ctx) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/hello.txt", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if rec.Body.String() != "hi" { + t.Errorf("body = %q, want hi", rec.Body.String()) + } +}