diff --git a/.gitignore b/.gitignore index 33e3cd817..0a30c5b80 100644 --- a/.gitignore +++ b/.gitignore @@ -47,4 +47,9 @@ cmake-build-debug # For AI /*personal* +/support* +/*srs-consults* /*workspace* +/skills/llm-switcher +/skills/*workspace* +/memory/202*.md diff --git a/internal/bootstrap/proxy.go b/internal/bootstrap/proxy.go index 985ed60b0..29667aef6 100644 --- a/internal/bootstrap/proxy.go +++ b/internal/bootstrap/proxy.go @@ -65,8 +65,9 @@ func (b *proxyBootstrap) run(ctx context.Context) error { // Start the Go pprof if enabled. debug.HandleGoPprof(ctx, environment) - // Initialize the load balancer. - if err := b.initializeLoadBalancer(ctx, environment); err != nil { + // Create and initialize the load balancer. + loadBalancer, err := b.initializeLoadBalancer(ctx, environment) + if err != nil { return err } @@ -77,36 +78,37 @@ func (b *proxyBootstrap) run(ctx context.Context) error { } // Start all servers and block until context is cancelled. - return b.startServers(ctx, environment, gracefulQuitTimeout) + return b.startServers(ctx, environment, loadBalancer, gracefulQuitTimeout) } // initializeLoadBalancer sets up the load balancer based on configuration. -func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment env.ProxyEnvironment) error { +func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment env.ProxyEnvironment) (lb.OriginLoadBalancer, error) { + var loadBalancer lb.OriginLoadBalancer switch environment.LoadBalancerType() { case "redis": - lb.SrsLoadBalancer = lb.NewRedisLoadBalancer(environment) + loadBalancer = lb.NewRedisLoadBalancer(environment) default: - lb.SrsLoadBalancer = lb.NewMemoryLoadBalancer(environment) + loadBalancer = lb.NewMemoryLoadBalancer(environment) } - if err := lb.SrsLoadBalancer.Initialize(ctx); err != nil { - return errors.Wrapf(err, "initialize srs load balancer") + if err := loadBalancer.Initialize(ctx); err != nil { + return nil, errors.Wrapf(err, "initialize srs load balancer") } - return nil + return loadBalancer, nil } // startServers initializes and starts all protocol servers. -func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) error { +func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) error { // Start the RTMP server. - rtmpProxyServer := proxy.NewRTMPProxyServer(environment) + rtmpProxyServer := proxy.NewRTMPProxyServer(environment, loadBalancer) if err := rtmpProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtmp server") } defer rtmpProxyServer.Close() // Start the WebRTC server. - webRTCProxyServer := proxy.NewWebRTCProxyServer(environment) + webRTCProxyServer := proxy.NewWebRTCProxyServer(environment, loadBalancer) if err := webRTCProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtc server") } @@ -120,21 +122,21 @@ func (b *proxyBootstrap) startServers(ctx context.Context, environment env.Proxy defer httpAPIProxyServer.Close() // Start the SRT server. - srsSRTProxyServer := proxy.NewSRSSRTProxyServer(environment) + srsSRTProxyServer := proxy.NewSRSSRTProxyServer(environment, loadBalancer) if err := srsSRTProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "srt server") } defer srsSRTProxyServer.Close() // Start the System API server. - systemAPI := proxy.NewSystemAPI(environment, gracefulQuitTimeout) + systemAPI := proxy.NewSystemAPI(environment, loadBalancer, gracefulQuitTimeout) if err := systemAPI.Run(ctx); err != nil { return errors.Wrapf(err, "system api server") } defer systemAPI.Close() // Start the HTTP web server. - httpStreamProxyServer := proxy.NewHTTPStreamProxyServer(environment, gracefulQuitTimeout) + httpStreamProxyServer := proxy.NewHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout) if err := httpStreamProxyServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } diff --git a/internal/lb/debug.go b/internal/lb/debug.go index 50cf8fd80..24cd60dee 100644 --- a/internal/lb/debug.go +++ b/internal/lb/debug.go @@ -12,8 +12,8 @@ import ( "srsx/internal/logger" ) -// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. -func NewDefaultSRSForDebugging(environment env.ProxyEnvironment) (*OriginServer, error) { +// NewDefaultOriginServerForDebugging initializes the default origin server, for debugging only. +func NewDefaultOriginServerForDebugging(environment env.ProxyEnvironment) (*OriginServer, error) { if environment.DefaultBackendEnabled() != "on" { return nil, nil } diff --git a/internal/lb/lb.go b/internal/lb/lb.go index cb841fbc1..f15f552cf 100644 --- a/internal/lb/lb.go +++ b/internal/lb/lb.go @@ -126,6 +126,3 @@ type OriginLoadBalancer interface { // Load the WebRTC streaming by ufrag, the ICE username. LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) } - -// SrsLoadBalancer is the global SRS load balancer instance. -var SrsLoadBalancer OriginLoadBalancer diff --git a/internal/lb/mem.go b/internal/lb/mem.go index 0f1702a34..1d625bc21 100644 --- a/internal/lb/mem.go +++ b/internal/lb/mem.go @@ -15,8 +15,8 @@ import ( "srsx/internal/sync" ) -// MemoryLoadBalancer stores state in memory. -type MemoryLoadBalancer struct { +// memoryLoadBalancer stores state in memory. +type memoryLoadBalancer struct { // The environment interface. environment env.ProxyEnvironment // All available SRS servers, key is server ID. @@ -35,7 +35,7 @@ type MemoryLoadBalancer struct { // NewMemoryLoadBalancer creates a new memory-based load balancer. func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer { - return &MemoryLoadBalancer{ + return &memoryLoadBalancer{ environment: environment, servers: sync.NewMap[string, *OriginServer](), picked: sync.NewMap[string, *OriginServer](), @@ -46,8 +46,8 @@ func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer } } -func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error { - server, err := NewDefaultSRSForDebugging(v.environment) +func (v *memoryLoadBalancer) Initialize(ctx context.Context) error { + server, err := NewDefaultOriginServerForDebugging(v.environment) if err != nil { return errors.Wrapf(err, "initialize default SRS") } @@ -75,12 +75,12 @@ func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error { return nil } -func (v *MemoryLoadBalancer) Update(ctx context.Context, server *OriginServer) error { +func (v *memoryLoadBalancer) Update(ctx context.Context, server *OriginServer) error { v.servers.Store(server.ID(), server) return nil } -func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) { +func (v *memoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) { // Always proxy to the same server for the same stream URL. if server, ok := v.picked.Load(streamURL); ok { return server, nil @@ -115,7 +115,7 @@ func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*Origi return server, nil } -func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { +func (v *memoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { // Load the HLS streaming for the SPBHID, for TS files. if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) @@ -124,7 +124,7 @@ func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) } } -func (v *MemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { +func (v *memoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { // Update the HLS streaming for the stream URL, for M3u8. actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) if actual == nil { @@ -137,7 +137,7 @@ func (v *MemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL strin return actual, nil } -func (v *MemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { +func (v *memoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { // Update the WebRTC streaming for the stream URL. v.rtcStreamURL.Store(streamURL, value) @@ -146,7 +146,7 @@ func (v *MemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, return nil } -func (v *MemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { +func (v *memoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { if actual, ok := v.rtcUfrag.Load(ufrag); !ok { return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) } else { diff --git a/internal/lb/redis.go b/internal/lb/redis.go index a3be9b961..0418e9986 100644 --- a/internal/lb/redis.go +++ b/internal/lb/redis.go @@ -19,8 +19,8 @@ import ( "srsx/internal/logger" ) -// RedisLoadBalancer stores state in Redis. -type RedisLoadBalancer struct { +// redisLoadBalancer stores state in Redis. +type redisLoadBalancer struct { // The environment interface. environment env.ProxyEnvironment // The redis client sdk. @@ -29,12 +29,12 @@ type RedisLoadBalancer struct { // NewRedisLoadBalancer creates a new Redis-based load balancer. func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer { - return &RedisLoadBalancer{ + return &redisLoadBalancer{ environment: environment, } } -func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { +func (v *redisLoadBalancer) Initialize(ctx context.Context) error { redisDatabase, err := strconv.Atoi(v.environment.RedisDB()) if err != nil { return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB()) @@ -52,7 +52,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { } logger.Debug(ctx, "RedisLB: connected to redis %v ok", rdb.String()) - server, err := NewDefaultSRSForDebugging(v.environment) + server, err := NewDefaultOriginServerForDebugging(v.environment) if err != nil { return errors.Wrapf(err, "initialize default SRS") } @@ -80,7 +80,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error { return nil } -func (v *RedisLoadBalancer) Update(ctx context.Context, server *OriginServer) error { +func (v *redisLoadBalancer) Update(ctx context.Context, server *OriginServer) error { b, err := json.Marshal(server) if err != nil { return errors.Wrapf(err, "marshal server %+v", server) @@ -130,7 +130,7 @@ func (v *RedisLoadBalancer) Update(ctx context.Context, server *OriginServer) er return nil } -func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) { +func (v *redisLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) { key := fmt.Sprintf("srs-proxy-url:%v", streamURL) // Always proxy to the same server for the same stream URL. @@ -188,7 +188,7 @@ func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*Origin return &server, nil } -func (v *RedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { +func (v *redisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) { key := v.redisKeySPBHID(spbhid) b, err := v.rdb.Get(ctx, key).Bytes() @@ -208,7 +208,7 @@ func (v *RedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) return nil, errors.Errorf("Redis load balancer cannot deserialize interface types") } -func (v *RedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { +func (v *redisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) { b, err := json.Marshal(value) if err != nil { return nil, errors.Wrapf(err, "marshal HLS %v", value) @@ -229,7 +229,7 @@ func (v *RedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string return value, nil } -func (v *RedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { +func (v *redisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error { b, err := json.Marshal(value) if err != nil { return errors.Wrapf(err, "marshal WebRTC %v", value) @@ -249,7 +249,7 @@ func (v *RedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, v return nil } -func (v *RedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { +func (v *redisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) { key := v.redisKeyUfrag(ufrag) b, err := v.rdb.Get(ctx, key).Bytes() @@ -267,26 +267,26 @@ func (v *RedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) return nil, errors.Errorf("Redis load balancer cannot deserialize interface types") } -func (v *RedisLoadBalancer) redisKeyUfrag(ufrag string) string { +func (v *redisLoadBalancer) redisKeyUfrag(ufrag string) string { return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) } -func (v *RedisLoadBalancer) redisKeyRTC(streamURL string) string { +func (v *redisLoadBalancer) redisKeyRTC(streamURL string) string { return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) } -func (v *RedisLoadBalancer) redisKeySPBHID(spbhid string) string { +func (v *redisLoadBalancer) redisKeySPBHID(spbhid string) string { return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) } -func (v *RedisLoadBalancer) redisKeyHLS(streamURL string) string { +func (v *redisLoadBalancer) redisKeyHLS(streamURL string) string { return fmt.Sprintf("srs-proxy-hls:%v", streamURL) } -func (v *RedisLoadBalancer) redisKeyServer(serverID string) string { +func (v *redisLoadBalancer) redisKeyServer(serverID string) string { return fmt.Sprintf("srs-proxy-server:%v", serverID) } -func (v *RedisLoadBalancer) redisKeyServers() string { +func (v *redisLoadBalancer) redisKeyServers() string { return fmt.Sprintf("srs-proxy-all-servers") } diff --git a/internal/proxy/api.go b/internal/proxy/api.go index c3365eec7..30189d032 100644 --- a/internal/proxy/api.go +++ b/internal/proxy/api.go @@ -147,6 +147,8 @@ func (v *httpAPIProxyServer) Run(ctx context.Context) error { type systemAPI struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The underlayer HTTP server. server *http.Server // The gracefully quit timeout, wait server to quit. @@ -155,9 +157,10 @@ type systemAPI struct { wg sync.WaitGroup } -func NewSystemAPI(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) *systemAPI { +func NewSystemAPI(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) *systemAPI { v := &systemAPI{ environment: environment, + loadBalancer: loadBalancer, gracefulQuitTimeout: gracefulQuitTimeout, } return v @@ -262,7 +265,7 @@ func (v *systemAPI) Run(ctx context.Context) error { srs.SRT, srs.RTC = srt, rtc srs.UpdatedAt = time.Now() }) - if err := lb.SrsLoadBalancer.Update(ctx, server); err != nil { + if err := v.loadBalancer.Update(ctx, server); err != nil { return errors.Wrapf(err, "update SRS server %+v", server) } diff --git a/internal/proxy/http.go b/internal/proxy/http.go index 9ce915e1d..c37090eef 100644 --- a/internal/proxy/http.go +++ b/internal/proxy/http.go @@ -34,6 +34,8 @@ type HTTPStreamProxyServer interface { type httpStreamProxyServer struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The underlayer HTTP server. server *http.Server // The gracefully quit timeout, wait server to quit. @@ -42,9 +44,10 @@ type httpStreamProxyServer struct { wg stdSync.WaitGroup } -func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) HTTPStreamProxyServer { +func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) HTTPStreamProxyServer { v := &httpStreamProxyServer{ environment: environment, + loadBalancer: loadBalancer, gracefulQuitTimeout: gracefulQuitTimeout, } return v @@ -128,7 +131,8 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { return } - stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) { + stream, _ := v.loadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) { + s.loadBalancer = v.loadBalancer s.SRSProxyBackendHLSID = logger.GenerateContextID() s.StreamURL, s.FullURL = streamURL, fullURL })) @@ -142,7 +146,7 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { strings.HasSuffix(r.URL.Path, ".ts") { // If SPBHID is specified, it must be a HLS stream client. if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { - if stream, err := lb.SrsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { + if stream, err := v.loadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) } else { stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r) @@ -153,6 +157,7 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { // Use HTTP pseudo streaming to proxy the request. newHTTPFlvTsConnection(func(c *httpFlvTsConnection) { c.ctx = ctx + c.loadBalancer = v.loadBalancer }).ServeHTTP(w, r) return } @@ -196,6 +201,8 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error { type httpFlvTsConnection struct { // The context for HTTP streaming. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer } func newHTTPFlvTsConnection(opts ...func(*httpFlvTsConnection)) *httpFlvTsConnection { @@ -233,7 +240,7 @@ func (v *httpFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -303,6 +310,8 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons type hlsPlayStream struct { // The context for HLS streaming. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The spbhid, used to identify the backend server. SRSProxyBackendHLSID string `json:"spbhid"` @@ -351,7 +360,7 @@ func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } diff --git a/internal/proxy/rtc.go b/internal/proxy/rtc.go index 71628602e..120f4836f 100644 --- a/internal/proxy/rtc.go +++ b/internal/proxy/rtc.go @@ -36,6 +36,8 @@ type WebRTCProxyServer interface { type webRTCProxyServer struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The UDP listener for WebRTC server. listener *net.UDPConn @@ -51,11 +53,12 @@ type webRTCProxyServer struct { wg stdSync.WaitGroup } -func NewWebRTCProxyServer(environment env.ProxyEnvironment, opts ...func(*webRTCProxyServer)) WebRTCProxyServer { +func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*webRTCProxyServer)) WebRTCProxyServer { v := &webRTCProxyServer{ - environment: environment, - usernames: sync.NewMap[string, *rtcConnection](), - addresses: sync.NewMap[string, *rtcConnection](), + environment: environment, + loadBalancer: loadBalancer, + usernames: sync.NewMap[string, *rtcConnection](), + addresses: sync.NewMap[string, *rtcConnection](), } for _, opt := range opts { opt(v) @@ -97,7 +100,7 @@ func (v *webRTCProxyServer) HandleApiForWHIP(ctx context.Context, w http.Respons } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -134,7 +137,7 @@ func (v *webRTCProxyServer) HandleApiForWHEP(ctx context.Context, w http.Respons } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } @@ -226,7 +229,8 @@ func (v *webRTCProxyServer) proxyApiToBackend( RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, } - if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) { + if err := v.loadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) { + c.loadBalancer = v.loadBalancer c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() c.Initialize(ctx, v.listener) @@ -315,10 +319,11 @@ func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd } // Load connection by username. - if s, err := lb.SrsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + if s, err := v.loadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) } else { connection = s.(*rtcConnection).Initialize(ctx, v.listener) + connection.loadBalancer = v.loadBalancer logger.Debug(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) } @@ -366,6 +371,8 @@ func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd type rtcConnection struct { // The stream context for WebRTC streaming. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The stream URL in vhost/app/stream schema. StreamURL string `json:"stream_url"` @@ -450,7 +457,7 @@ func (v *rtcConnection) connectBackend(ctx context.Context) error { } // Pick a backend SRS server to proxy the RTC stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, v.StreamURL) + backend, err := v.loadBalancer.Pick(ctx, v.StreamURL) if err != nil { return errors.Wrapf(err, "pick backend") } diff --git a/internal/proxy/rtmp.go b/internal/proxy/rtmp.go index 23be82416..3dea59ee8 100644 --- a/internal/proxy/rtmp.go +++ b/internal/proxy/rtmp.go @@ -31,14 +31,16 @@ type RTMPProxyServer interface { type rtmpProxyServer struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The TCP listener for RTMP server. listener *net.TCPListener // The wait group for all goroutines. wg sync.WaitGroup } -func NewRTMPProxyServer(environment env.ProxyEnvironment, opts ...func(*rtmpProxyServer)) RTMPProxyServer { - v := &rtmpProxyServer{environment: environment} +func NewRTMPProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*rtmpProxyServer)) RTMPProxyServer { + v := &rtmpProxyServer{environment: environment, loadBalancer: loadBalancer} for _, opt := range opts { opt(v) } @@ -102,7 +104,9 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error { } } - rc := newRTMPConnection() + rc := newRTMPConnection(func(c *rtmpConnection) { + c.loadBalancer = v.loadBalancer + }) if err := rc.serve(ctx, conn); err != nil { handleErr(err) } else { @@ -122,6 +126,8 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error { // then proxy to the corresponding backend server. All state is in the RTMP request, so this // connection is stateless. type rtmpConnection struct { + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer } func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection { @@ -296,6 +302,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error { // Find a backend SRS server to proxy the RTMP stream. backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) { client.typ = clientType + client.loadBalancer = v.loadBalancer }) defer backend.Close() @@ -429,6 +436,8 @@ type rtmpClientToBackend struct { client rtmp.Protocol // The stream type. typ RTMPClientType + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer } func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend { @@ -454,7 +463,7 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str } // Pick a backend SRS server to proxy the RTMP stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } diff --git a/internal/proxy/srt.go b/internal/proxy/srt.go index 7f3ba1fee..5c11cdc11 100644 --- a/internal/proxy/srt.go +++ b/internal/proxy/srt.go @@ -27,6 +27,8 @@ import ( type srsSRTProxyServer struct { // The environment interface. environment env.ProxyEnvironment + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The UDP listener for SRT server. listener *net.UDPConn @@ -39,11 +41,12 @@ type srsSRTProxyServer struct { wg stdSync.WaitGroup } -func NewSRSSRTProxyServer(environment env.ProxyEnvironment, opts ...func(*srsSRTProxyServer)) *srsSRTProxyServer { +func NewSRSSRTProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*srsSRTProxyServer)) *srsSRTProxyServer { v := &srsSRTProxyServer{ - environment: environment, - start: time.Now(), - sockets: sync.NewMap[uint32, *SRTConnection](), + environment: environment, + loadBalancer: loadBalancer, + start: time.Now(), + sockets: sync.NewMap[uint32, *SRTConnection](), } for _, opt := range opts { @@ -127,6 +130,7 @@ func (v *srsSRTProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { c.ctx = logger.WithContext(ctx) c.listenerUDP, c.socketID = v.listener, socketID + c.loadBalancer = v.loadBalancer c.start = v.start })) @@ -158,6 +162,8 @@ func (v *srsSRTProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd type SRTConnection struct { // The stream context for SRT connection. ctx context.Context + // The load balancer for origin servers. + loadBalancer lb.OriginLoadBalancer // The current socket ID. socketID uint32 @@ -356,7 +362,7 @@ func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) err } // Pick a backend SRS server to proxy the SRT stream. - backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL) + backend, err := v.loadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) }