Codex: Remove global proxy load balancer.

Create the origin load balancer in proxy bootstrap and pass it explicitly into protocol servers and system API handlers. Also keep load balancer implementations package-private and rename the default debugging origin helper.

Co-authored-by: chatgpt-codex-connector[bot] <199175422+chatgpt-codex-connector[bot]@users.noreply.github.com>
This commit is contained in:
winlin 2026-05-10 21:00:17 -04:00
parent 3b93ddfddf
commit f45bf30b46
11 changed files with 111 additions and 73 deletions

5
.gitignore vendored
View File

@ -47,4 +47,9 @@ cmake-build-debug
# For AI
/*personal*
/support*
/*srs-consults*
/*workspace*
/skills/llm-switcher
/skills/*workspace*
/memory/202*.md

View File

@ -65,8 +65,9 @@ func (b *proxyBootstrap) run(ctx context.Context) error {
// Start the Go pprof if enabled.
debug.HandleGoPprof(ctx, environment)
// Initialize the load balancer.
if err := b.initializeLoadBalancer(ctx, environment); err != nil {
// Create and initialize the load balancer.
loadBalancer, err := b.initializeLoadBalancer(ctx, environment)
if err != nil {
return err
}
@ -77,36 +78,37 @@ func (b *proxyBootstrap) run(ctx context.Context) error {
}
// Start all servers and block until context is cancelled.
return b.startServers(ctx, environment, gracefulQuitTimeout)
return b.startServers(ctx, environment, loadBalancer, gracefulQuitTimeout)
}
// initializeLoadBalancer sets up the load balancer based on configuration.
func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment env.ProxyEnvironment) error {
func (b *proxyBootstrap) initializeLoadBalancer(ctx context.Context, environment env.ProxyEnvironment) (lb.OriginLoadBalancer, error) {
var loadBalancer lb.OriginLoadBalancer
switch environment.LoadBalancerType() {
case "redis":
lb.SrsLoadBalancer = lb.NewRedisLoadBalancer(environment)
loadBalancer = lb.NewRedisLoadBalancer(environment)
default:
lb.SrsLoadBalancer = lb.NewMemoryLoadBalancer(environment)
loadBalancer = lb.NewMemoryLoadBalancer(environment)
}
if err := lb.SrsLoadBalancer.Initialize(ctx); err != nil {
return errors.Wrapf(err, "initialize srs load balancer")
if err := loadBalancer.Initialize(ctx); err != nil {
return nil, errors.Wrapf(err, "initialize srs load balancer")
}
return nil
return loadBalancer, nil
}
// startServers initializes and starts all protocol servers.
func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) error {
func (b *proxyBootstrap) startServers(ctx context.Context, environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) error {
// Start the RTMP server.
rtmpProxyServer := proxy.NewRTMPProxyServer(environment)
rtmpProxyServer := proxy.NewRTMPProxyServer(environment, loadBalancer)
if err := rtmpProxyServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtmp server")
}
defer rtmpProxyServer.Close()
// Start the WebRTC server.
webRTCProxyServer := proxy.NewWebRTCProxyServer(environment)
webRTCProxyServer := proxy.NewWebRTCProxyServer(environment, loadBalancer)
if err := webRTCProxyServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtc server")
}
@ -120,21 +122,21 @@ func (b *proxyBootstrap) startServers(ctx context.Context, environment env.Proxy
defer httpAPIProxyServer.Close()
// Start the SRT server.
srsSRTProxyServer := proxy.NewSRSSRTProxyServer(environment)
srsSRTProxyServer := proxy.NewSRSSRTProxyServer(environment, loadBalancer)
if err := srsSRTProxyServer.Run(ctx); err != nil {
return errors.Wrapf(err, "srt server")
}
defer srsSRTProxyServer.Close()
// Start the System API server.
systemAPI := proxy.NewSystemAPI(environment, gracefulQuitTimeout)
systemAPI := proxy.NewSystemAPI(environment, loadBalancer, gracefulQuitTimeout)
if err := systemAPI.Run(ctx); err != nil {
return errors.Wrapf(err, "system api server")
}
defer systemAPI.Close()
// Start the HTTP web server.
httpStreamProxyServer := proxy.NewHTTPStreamProxyServer(environment, gracefulQuitTimeout)
httpStreamProxyServer := proxy.NewHTTPStreamProxyServer(environment, loadBalancer, gracefulQuitTimeout)
if err := httpStreamProxyServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http server")
}

View File

@ -12,8 +12,8 @@ import (
"srsx/internal/logger"
)
// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
func NewDefaultSRSForDebugging(environment env.ProxyEnvironment) (*OriginServer, error) {
// NewDefaultOriginServerForDebugging initializes the default origin server, for debugging only.
func NewDefaultOriginServerForDebugging(environment env.ProxyEnvironment) (*OriginServer, error) {
if environment.DefaultBackendEnabled() != "on" {
return nil, nil
}

View File

@ -126,6 +126,3 @@ type OriginLoadBalancer interface {
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error)
}
// SrsLoadBalancer is the global SRS load balancer instance.
var SrsLoadBalancer OriginLoadBalancer

View File

@ -15,8 +15,8 @@ import (
"srsx/internal/sync"
)
// MemoryLoadBalancer stores state in memory.
type MemoryLoadBalancer struct {
// memoryLoadBalancer stores state in memory.
type memoryLoadBalancer struct {
// The environment interface.
environment env.ProxyEnvironment
// All available SRS servers, key is server ID.
@ -35,7 +35,7 @@ type MemoryLoadBalancer struct {
// NewMemoryLoadBalancer creates a new memory-based load balancer.
func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
return &MemoryLoadBalancer{
return &memoryLoadBalancer{
environment: environment,
servers: sync.NewMap[string, *OriginServer](),
picked: sync.NewMap[string, *OriginServer](),
@ -46,8 +46,8 @@ func NewMemoryLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer
}
}
func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error {
server, err := NewDefaultSRSForDebugging(v.environment)
func (v *memoryLoadBalancer) Initialize(ctx context.Context) error {
server, err := NewDefaultOriginServerForDebugging(v.environment)
if err != nil {
return errors.Wrapf(err, "initialize default SRS")
}
@ -75,12 +75,12 @@ func (v *MemoryLoadBalancer) Initialize(ctx context.Context) error {
return nil
}
func (v *MemoryLoadBalancer) Update(ctx context.Context, server *OriginServer) error {
func (v *memoryLoadBalancer) Update(ctx context.Context, server *OriginServer) error {
v.servers.Store(server.ID(), server)
return nil
}
func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) {
func (v *memoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) {
// Always proxy to the same server for the same stream URL.
if server, ok := v.picked.Load(streamURL); ok {
return server, nil
@ -115,7 +115,7 @@ func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*Origi
return server, nil
}
func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) {
func (v *memoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) {
// Load the HLS streaming for the SPBHID, for TS files.
if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
@ -124,7 +124,7 @@ func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string)
}
}
func (v *MemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) {
func (v *memoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) {
// Update the HLS streaming for the stream URL, for M3u8.
actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
if actual == nil {
@ -137,7 +137,7 @@ func (v *MemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL strin
return actual, nil
}
func (v *MemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error {
func (v *memoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error {
// Update the WebRTC streaming for the stream URL.
v.rtcStreamURL.Store(streamURL, value)
@ -146,7 +146,7 @@ func (v *MemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string,
return nil
}
func (v *MemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) {
func (v *memoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) {
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
} else {

View File

@ -19,8 +19,8 @@ import (
"srsx/internal/logger"
)
// RedisLoadBalancer stores state in Redis.
type RedisLoadBalancer struct {
// redisLoadBalancer stores state in Redis.
type redisLoadBalancer struct {
// The environment interface.
environment env.ProxyEnvironment
// The redis client sdk.
@ -29,12 +29,12 @@ type RedisLoadBalancer struct {
// NewRedisLoadBalancer creates a new Redis-based load balancer.
func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
return &RedisLoadBalancer{
return &redisLoadBalancer{
environment: environment,
}
}
func (v *RedisLoadBalancer) Initialize(ctx context.Context) error {
func (v *redisLoadBalancer) Initialize(ctx context.Context) error {
redisDatabase, err := strconv.Atoi(v.environment.RedisDB())
if err != nil {
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB())
@ -52,7 +52,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error {
}
logger.Debug(ctx, "RedisLB: connected to redis %v ok", rdb.String())
server, err := NewDefaultSRSForDebugging(v.environment)
server, err := NewDefaultOriginServerForDebugging(v.environment)
if err != nil {
return errors.Wrapf(err, "initialize default SRS")
}
@ -80,7 +80,7 @@ func (v *RedisLoadBalancer) Initialize(ctx context.Context) error {
return nil
}
func (v *RedisLoadBalancer) Update(ctx context.Context, server *OriginServer) error {
func (v *redisLoadBalancer) Update(ctx context.Context, server *OriginServer) error {
b, err := json.Marshal(server)
if err != nil {
return errors.Wrapf(err, "marshal server %+v", server)
@ -130,7 +130,7 @@ func (v *RedisLoadBalancer) Update(ctx context.Context, server *OriginServer) er
return nil
}
func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) {
func (v *redisLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) {
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
// Always proxy to the same server for the same stream URL.
@ -188,7 +188,7 @@ func (v *RedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*Origin
return &server, nil
}
func (v *RedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) {
func (v *redisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) {
key := v.redisKeySPBHID(spbhid)
b, err := v.rdb.Get(ctx, key).Bytes()
@ -208,7 +208,7 @@ func (v *RedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string)
return nil, errors.Errorf("Redis load balancer cannot deserialize interface types")
}
func (v *RedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) {
func (v *redisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) {
b, err := json.Marshal(value)
if err != nil {
return nil, errors.Wrapf(err, "marshal HLS %v", value)
@ -229,7 +229,7 @@ func (v *RedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string
return value, nil
}
func (v *RedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error {
func (v *redisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error {
b, err := json.Marshal(value)
if err != nil {
return errors.Wrapf(err, "marshal WebRTC %v", value)
@ -249,7 +249,7 @@ func (v *RedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, v
return nil
}
func (v *RedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) {
func (v *redisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) {
key := v.redisKeyUfrag(ufrag)
b, err := v.rdb.Get(ctx, key).Bytes()
@ -267,26 +267,26 @@ func (v *RedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string)
return nil, errors.Errorf("Redis load balancer cannot deserialize interface types")
}
func (v *RedisLoadBalancer) redisKeyUfrag(ufrag string) string {
func (v *redisLoadBalancer) redisKeyUfrag(ufrag string) string {
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
}
func (v *RedisLoadBalancer) redisKeyRTC(streamURL string) string {
func (v *redisLoadBalancer) redisKeyRTC(streamURL string) string {
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
}
func (v *RedisLoadBalancer) redisKeySPBHID(spbhid string) string {
func (v *redisLoadBalancer) redisKeySPBHID(spbhid string) string {
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
}
func (v *RedisLoadBalancer) redisKeyHLS(streamURL string) string {
func (v *redisLoadBalancer) redisKeyHLS(streamURL string) string {
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
}
func (v *RedisLoadBalancer) redisKeyServer(serverID string) string {
func (v *redisLoadBalancer) redisKeyServer(serverID string) string {
return fmt.Sprintf("srs-proxy-server:%v", serverID)
}
func (v *RedisLoadBalancer) redisKeyServers() string {
func (v *redisLoadBalancer) redisKeyServers() string {
return fmt.Sprintf("srs-proxy-all-servers")
}

View File

@ -147,6 +147,8 @@ func (v *httpAPIProxyServer) Run(ctx context.Context) error {
type systemAPI struct {
// The environment interface.
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
@ -155,9 +157,10 @@ type systemAPI struct {
wg sync.WaitGroup
}
func NewSystemAPI(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) *systemAPI {
func NewSystemAPI(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) *systemAPI {
v := &systemAPI{
environment: environment,
loadBalancer: loadBalancer,
gracefulQuitTimeout: gracefulQuitTimeout,
}
return v
@ -262,7 +265,7 @@ func (v *systemAPI) Run(ctx context.Context) error {
srs.SRT, srs.RTC = srt, rtc
srs.UpdatedAt = time.Now()
})
if err := lb.SrsLoadBalancer.Update(ctx, server); err != nil {
if err := v.loadBalancer.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update SRS server %+v", server)
}

View File

@ -34,6 +34,8 @@ type HTTPStreamProxyServer interface {
type httpStreamProxyServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
@ -42,9 +44,10 @@ type httpStreamProxyServer struct {
wg stdSync.WaitGroup
}
func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, gracefulQuitTimeout time.Duration) HTTPStreamProxyServer {
func NewHTTPStreamProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, gracefulQuitTimeout time.Duration) HTTPStreamProxyServer {
v := &httpStreamProxyServer{
environment: environment,
loadBalancer: loadBalancer,
gracefulQuitTimeout: gracefulQuitTimeout,
}
return v
@ -128,7 +131,8 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error {
return
}
stream, _ := lb.SrsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) {
stream, _ := v.loadBalancer.LoadOrStoreHLS(ctx, streamURL, newHLSPlayStream(func(s *hlsPlayStream) {
s.loadBalancer = v.loadBalancer
s.SRSProxyBackendHLSID = logger.GenerateContextID()
s.StreamURL, s.FullURL = streamURL, fullURL
}))
@ -142,7 +146,7 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error {
strings.HasSuffix(r.URL.Path, ".ts") {
// If SPBHID is specified, it must be a HLS stream client.
if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" {
if stream, err := lb.SrsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
if stream, err := v.loadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest)
} else {
stream.Initialize(ctx).(*hlsPlayStream).ServeHTTP(w, r)
@ -153,6 +157,7 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error {
// Use HTTP pseudo streaming to proxy the request.
newHTTPFlvTsConnection(func(c *httpFlvTsConnection) {
c.ctx = ctx
c.loadBalancer = v.loadBalancer
}).ServeHTTP(w, r)
return
}
@ -196,6 +201,8 @@ func (v *httpStreamProxyServer) Run(ctx context.Context) error {
type httpFlvTsConnection struct {
// The context for HTTP streaming.
ctx context.Context
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
}
func newHTTPFlvTsConnection(opts ...func(*httpFlvTsConnection)) *httpFlvTsConnection {
@ -233,7 +240,7 @@ func (v *httpFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter,
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL)
backend, err := v.loadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
@ -303,6 +310,8 @@ func (v *httpFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons
type hlsPlayStream struct {
// The context for HLS streaming.
ctx context.Context
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The spbhid, used to identify the backend server.
SRSProxyBackendHLSID string `json:"spbhid"`
@ -351,7 +360,7 @@ func (v *hlsPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *htt
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL)
backend, err := v.loadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}

View File

@ -36,6 +36,8 @@ type WebRTCProxyServer interface {
type webRTCProxyServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The UDP listener for WebRTC server.
listener *net.UDPConn
@ -51,11 +53,12 @@ type webRTCProxyServer struct {
wg stdSync.WaitGroup
}
func NewWebRTCProxyServer(environment env.ProxyEnvironment, opts ...func(*webRTCProxyServer)) WebRTCProxyServer {
func NewWebRTCProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*webRTCProxyServer)) WebRTCProxyServer {
v := &webRTCProxyServer{
environment: environment,
usernames: sync.NewMap[string, *rtcConnection](),
addresses: sync.NewMap[string, *rtcConnection](),
environment: environment,
loadBalancer: loadBalancer,
usernames: sync.NewMap[string, *rtcConnection](),
addresses: sync.NewMap[string, *rtcConnection](),
}
for _, opt := range opts {
opt(v)
@ -97,7 +100,7 @@ func (v *webRTCProxyServer) HandleApiForWHIP(ctx context.Context, w http.Respons
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL)
backend, err := v.loadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
@ -134,7 +137,7 @@ func (v *webRTCProxyServer) HandleApiForWHEP(ctx context.Context, w http.Respons
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL)
backend, err := v.loadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
@ -226,7 +229,8 @@ func (v *webRTCProxyServer) proxyApiToBackend(
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if err := lb.SrsLoadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) {
if err := v.loadBalancer.StoreWebRTC(ctx, streamURL, newRTCConnection(func(c *rtcConnection) {
c.loadBalancer = v.loadBalancer
c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag()
c.Initialize(ctx, v.listener)
@ -315,10 +319,11 @@ func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd
}
// Load connection by username.
if s, err := lb.SrsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
if s, err := v.loadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
connection = s.(*rtcConnection).Initialize(ctx, v.listener)
connection.loadBalancer = v.loadBalancer
logger.Debug(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL)
}
@ -366,6 +371,8 @@ func (v *webRTCProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd
type rtcConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
@ -450,7 +457,7 @@ func (v *rtcConnection) connectBackend(ctx context.Context) error {
}
// Pick a backend SRS server to proxy the RTC stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, v.StreamURL)
backend, err := v.loadBalancer.Pick(ctx, v.StreamURL)
if err != nil {
return errors.Wrapf(err, "pick backend")
}

View File

@ -31,14 +31,16 @@ type RTMPProxyServer interface {
type rtmpProxyServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The TCP listener for RTMP server.
listener *net.TCPListener
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewRTMPProxyServer(environment env.ProxyEnvironment, opts ...func(*rtmpProxyServer)) RTMPProxyServer {
v := &rtmpProxyServer{environment: environment}
func NewRTMPProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*rtmpProxyServer)) RTMPProxyServer {
v := &rtmpProxyServer{environment: environment, loadBalancer: loadBalancer}
for _, opt := range opts {
opt(v)
}
@ -102,7 +104,9 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
}
}
rc := newRTMPConnection()
rc := newRTMPConnection(func(c *rtmpConnection) {
c.loadBalancer = v.loadBalancer
})
if err := rc.serve(ctx, conn); err != nil {
handleErr(err)
} else {
@ -122,6 +126,8 @@ func (v *rtmpProxyServer) Run(ctx context.Context) error {
// then proxy to the corresponding backend server. All state is in the RTMP request, so this
// connection is stateless.
type rtmpConnection struct {
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
}
func newRTMPConnection(opts ...func(*rtmpConnection)) *rtmpConnection {
@ -296,6 +302,7 @@ func (v *rtmpConnection) serve(ctx context.Context, conn *net.TCPConn) error {
// Find a backend SRS server to proxy the RTMP stream.
backend = newRTMPClientToBackend(func(client *rtmpClientToBackend) {
client.typ = clientType
client.loadBalancer = v.loadBalancer
})
defer backend.Close()
@ -429,6 +436,8 @@ type rtmpClientToBackend struct {
client rtmp.Protocol
// The stream type.
typ RTMPClientType
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
}
func newRTMPClientToBackend(opts ...func(*rtmpClientToBackend)) *rtmpClientToBackend {
@ -454,7 +463,7 @@ func (v *rtmpClientToBackend) Connect(ctx context.Context, tcUrl, streamName str
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL)
backend, err := v.loadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}

View File

@ -27,6 +27,8 @@ import (
type srsSRTProxyServer struct {
// The environment interface.
environment env.ProxyEnvironment
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The UDP listener for SRT server.
listener *net.UDPConn
@ -39,11 +41,12 @@ type srsSRTProxyServer struct {
wg stdSync.WaitGroup
}
func NewSRSSRTProxyServer(environment env.ProxyEnvironment, opts ...func(*srsSRTProxyServer)) *srsSRTProxyServer {
func NewSRSSRTProxyServer(environment env.ProxyEnvironment, loadBalancer lb.OriginLoadBalancer, opts ...func(*srsSRTProxyServer)) *srsSRTProxyServer {
v := &srsSRTProxyServer{
environment: environment,
start: time.Now(),
sockets: sync.NewMap[uint32, *SRTConnection](),
environment: environment,
loadBalancer: loadBalancer,
start: time.Now(),
sockets: sync.NewMap[uint32, *SRTConnection](),
}
for _, opt := range opts {
@ -127,6 +130,7 @@ func (v *srsSRTProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd
conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) {
c.ctx = logger.WithContext(ctx)
c.listenerUDP, c.socketID = v.listener, socketID
c.loadBalancer = v.loadBalancer
c.start = v.start
}))
@ -158,6 +162,8 @@ func (v *srsSRTProxyServer) handleClientUDP(ctx context.Context, addr *net.UDPAd
type SRTConnection struct {
// The stream context for SRT connection.
ctx context.Context
// The load balancer for origin servers.
loadBalancer lb.OriginLoadBalancer
// The current socket ID.
socketID uint32
@ -356,7 +362,7 @@ func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) err
}
// Pick a backend SRS server to proxy the SRT stream.
backend, err := lb.SrsLoadBalancer.Pick(ctx, streamURL)
backend, err := v.loadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}