diff --git a/internal/proxy/http.go b/internal/proxy/http.go index 2bf052460..54049f6cf 100644 --- a/internal/proxy/http.go +++ b/internal/proxy/http.go @@ -347,13 +347,14 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) } - // Copy all headers from backend to client. - w.WriteHeader(resp.StatusCode) + // Copy all headers from backend to client before WriteHeader, + // because headers set after WriteHeader are silently ignored. for k, v := range resp.Header { for _, vv := range v { w.Header().Add(k, vv) } } + w.WriteHeader(resp.StatusCode) logger.Debug(ctx, "HTTP start streaming") @@ -476,13 +477,14 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) } - // Copy all headers from backend to client. - w.WriteHeader(resp.StatusCode) + // Copy all headers from backend to client before WriteHeader, + // because headers set after WriteHeader are silently ignored. for k, v := range resp.Header { for _, vv := range v { w.Header().Add(k, vv) } } + w.WriteHeader(resp.StatusCode) // For TS file, directly copy it. if !strings.HasSuffix(r.URL.Path, ".m3u8") { @@ -502,7 +504,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite m3u8 := string(b) if strings.Contains(m3u8, ".ts?") { - m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&", v.SRSProxyBackendHLSID)) } else { m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) } diff --git a/internal/proxy/http_test.go b/internal/proxy/http_test.go index fa64225c4..0025daeb8 100644 --- a/internal/proxy/http_test.go +++ b/internal/proxy/http_test.go @@ -370,7 +370,7 @@ func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) { &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { t.Fatalf("unexpected err: %v", err) } - if want := "live-0.ts?spbhid=ABC&&token=foo"; !strings.Contains(rec.Body.String(), want) { + if want := "live-0.ts?spbhid=ABC&token=foo"; !strings.Contains(rec.Body.String(), want) { t.Fatalf("missing %q in body: %q", want, rec.Body.String()) } } @@ -396,6 +396,61 @@ func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) { } } +func TestHLSPlayStream_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/vnd.apple.mpegurl") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("X-Custom-Header", "custom-value") + _, _ = io.WriteString(w, "#EXTM3U\nlive-0.ts\n") + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + + // Verify headers are properly copied (not lost due to WriteHeader order) + if got := rec.Header().Get("Content-Type"); got != "application/vnd.apple.mpegurl" { + t.Errorf("Content-Type = %q, want application/vnd.apple.mpegurl", got) + } + if got := rec.Header().Get("Cache-Control"); got != "no-cache" { + t.Errorf("Cache-Control = %q, want no-cache", got) + } + if got := rec.Header().Get("X-Custom-Header"); got != "custom-value" { + t.Errorf("X-Custom-Header = %q, want custom-value", got) + } +} + +func TestHLSPlayStream_ServeByBackend_TSHeadersCopiedFromBackend(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp2t") + w.Header().Set("Cache-Control", "max-age=3600") + _, _ = w.Write([]byte{0x47, 0x00, 0x01}) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHLSPlayStream() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + + if got := rec.Header().Get("Content-Type"); got != "video/mp2t" { + t.Errorf("Content-Type = %q, want video/mp2t", got) + } + if got := rec.Header().Get("Cache-Control"); got != "max-age=3600" { + t.Errorf("Cache-Control = %q, want max-age=3600", got) + } +} + // ============================================================================= // httpFlvTsConnection // ============================================================================= @@ -666,6 +721,36 @@ func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) { } } +func TestHTTPFlvTsConn_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/x-flv") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("X-Custom-Header", "flv-value") + _, _ = w.Write([]byte("FLV\x01\x05\x00\x00\x00\x09")) + })) + defer ts.Close() + host, port := httptestHostPort(t, ts) + + v := newHTTPFlvTsConnection() + req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil) + rec := httptest.NewRecorder() + if err := v.serveByBackend(context.Background(), rec, req, + &lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil { + t.Fatalf("unexpected err: %v", err) + } + + // Verify headers are properly copied (not lost due to WriteHeader order) + if got := rec.Header().Get("Content-Type"); got != "video/x-flv" { + t.Errorf("Content-Type = %q, want video/x-flv", got) + } + if got := rec.Header().Get("Cache-Control"); got != "no-store" { + t.Errorf("Cache-Control = %q, want no-store", got) + } + if got := rec.Header().Get("X-Custom-Header"); got != "flv-value" { + t.Errorf("X-Custom-Header = %q, want flv-value", got) + } +} + // ============================================================================= // httpStreamProxyServer // =============================================================================