This commit is contained in:
sevico 2026-05-30 08:23:53 -04:00 committed by GitHub
commit d5bff0a354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 6 deletions

View File

@ -347,13 +347,14 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
// Copy all headers from backend to client before WriteHeader,
// because headers set after WriteHeader are silently ignored.
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)
logger.Debug(ctx, "HTTP start streaming")
@ -476,13 +477,14 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
// Copy all headers from backend to client before WriteHeader,
// because headers set after WriteHeader are silently ignored.
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)
// For TS file, directly copy it.
if !strings.HasSuffix(r.URL.Path, ".m3u8") {
@ -502,7 +504,7 @@ func (v *hlsPlayStream) serveByBackend(ctx context.Context, w http.ResponseWrite
m3u8 := string(b)
if strings.Contains(m3u8, ".ts?") {
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID))
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&", v.SRSProxyBackendHLSID))
} else {
m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID))
}

View File

@ -370,7 +370,7 @@ func TestHLSPlayStream_ServeByBackend_M3U8RewritesTSWithQuery(t *testing.T) {
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if want := "live-0.ts?spbhid=ABC&&token=foo"; !strings.Contains(rec.Body.String(), want) {
if want := "live-0.ts?spbhid=ABC&token=foo"; !strings.Contains(rec.Body.String(), want) {
t.Fatalf("missing %q in body: %q", want, rec.Body.String())
}
}
@ -396,6 +396,61 @@ func TestHLSPlayStream_ServeByBackend_AppendsRawQueryOnTS(t *testing.T) {
}
}
func TestHLSPlayStream_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.apple.mpegurl")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Custom-Header", "custom-value")
_, _ = io.WriteString(w, "#EXTM3U\nlive-0.ts\n")
}))
defer ts.Close()
host, port := httptestHostPort(t, ts)
v := newHLSPlayStream()
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.m3u8", nil)
rec := httptest.NewRecorder()
if err := v.serveByBackend(context.Background(), rec, req,
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}
// Verify headers are properly copied (not lost due to WriteHeader order)
if got := rec.Header().Get("Content-Type"); got != "application/vnd.apple.mpegurl" {
t.Errorf("Content-Type = %q, want application/vnd.apple.mpegurl", got)
}
if got := rec.Header().Get("Cache-Control"); got != "no-cache" {
t.Errorf("Cache-Control = %q, want no-cache", got)
}
if got := rec.Header().Get("X-Custom-Header"); got != "custom-value" {
t.Errorf("X-Custom-Header = %q, want custom-value", got)
}
}
func TestHLSPlayStream_ServeByBackend_TSHeadersCopiedFromBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/mp2t")
w.Header().Set("Cache-Control", "max-age=3600")
_, _ = w.Write([]byte{0x47, 0x00, 0x01})
}))
defer ts.Close()
host, port := httptestHostPort(t, ts)
v := newHLSPlayStream()
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.ts", nil)
rec := httptest.NewRecorder()
if err := v.serveByBackend(context.Background(), rec, req,
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if got := rec.Header().Get("Content-Type"); got != "video/mp2t" {
t.Errorf("Content-Type = %q, want video/mp2t", got)
}
if got := rec.Header().Get("Cache-Control"); got != "max-age=3600" {
t.Errorf("Cache-Control = %q, want max-age=3600", got)
}
}
// =============================================================================
// httpFlvTsConnection
// =============================================================================
@ -666,6 +721,36 @@ func TestHTTPFlvTsConn_ServeByBackend_PreservesMethod(t *testing.T) {
}
}
func TestHTTPFlvTsConn_ServeByBackend_HeadersCopiedFromBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/x-flv")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("X-Custom-Header", "flv-value")
_, _ = w.Write([]byte("FLV\x01\x05\x00\x00\x00\x09"))
}))
defer ts.Close()
host, port := httptestHostPort(t, ts)
v := newHTTPFlvTsConnection()
req := httptest.NewRequest(http.MethodGet, "http://example.com/live.flv", nil)
rec := httptest.NewRecorder()
if err := v.serveByBackend(context.Background(), rec, req,
&lb.OriginServer{IP: host, HTTP: []string{port}}); err != nil {
t.Fatalf("unexpected err: %v", err)
}
// Verify headers are properly copied (not lost due to WriteHeader order)
if got := rec.Header().Get("Content-Type"); got != "video/x-flv" {
t.Errorf("Content-Type = %q, want video/x-flv", got)
}
if got := rec.Header().Get("Cache-Control"); got != "no-store" {
t.Errorf("Cache-Control = %q, want no-store", got)
}
if got := rec.Header().Get("X-Custom-Header"); got != "flv-value" {
t.Errorf("X-Custom-Header = %q, want flv-value", got)
}
}
// =============================================================================
// httpStreamProxyServer
// =============================================================================