- Refactor the Go proxy for dependency injection: every proxy server, the bootstrap, the signal handler, the load balancers, and AMF0 now accept functional-option seams (factories/closures) so tests can inject fakes without binding real sockets, talking to real Redis, or racing on package globals. - Drop the package-global `lb.SrsLoadBalancer`. The bootstrap creates the LB locally and threads it through every proxy server constructor. Two old global indirections in `internal/signal` and `internal/rtmp/amf0` are likewise replaced by per-instance fields. - Rename `internal/server` → `internal/proxy` and rename the `lb` public surface for clarity: `SRSLoadBalancer` is split into `OriginService` / `HLSService` / `RTCService` and recomposed as `OriginLoadBalancer`; `SRSServer` → `OriginServer`; all proxy server types gain a `Proxy` qualifier (e.g. `RTMPServer` → `RTMPProxyServer`). - Extract the Redis client behind a new `internal/redisclient` package with a minimal `RedisClient` interface and a counterfeiter fake. - Add counterfeiter fakes (`proxyfakes`, `lbfakes`, `redisclientfakes`) and ~7.5k lines of unit tests covering bootstrap, memory + Redis LBs, all five proxy servers, the signal handler, and AMF0. - Add two new E2E flows — `proxy-e2e-srt-test.sh` (SRT publish through proxy, verify SRT/RTMP/HTTP-FLV/HLS playback) and `proxy-e2e-whip-test.sh` (WHIP publish, verify RTMP/HTTP-FLV/HLS via origin `rtc_to_rtmp`) — plus `setup-ffmpeg-with-whip.sh`, a macOS builder for an ffmpeg with openssl-DTLS WHIP and SRT support that the two scripts auto-invoke when needed. - Workspace reorg: move `memory/` and `skills/` to the repo root so all agent tools (Claude / Codex / Kiro / OpenClaw) share one source of truth via symlinks. Sync `docs/proxy/proxy-load-balancer.md` and `memory/srs-codebase-map.md` with the new names. No protocol, log, HTTP API, or wire-format changes. Refactor only — all externally observable proxy behavior is unchanged. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: chatgpt-codex-connector[bot] <199175422+chatgpt-codex-connector[bot]@users.noreply.github.com>
300 lines
9.2 KiB
Go
300 lines
9.2 KiB
Go
// Copyright (c) 2026 Winlin
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
package lb
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/rand"
|
|
"strconv"
|
|
"time"
|
|
|
|
"srsx/internal/env"
|
|
"srsx/internal/errors"
|
|
"srsx/internal/logger"
|
|
"srsx/internal/redisclient"
|
|
)
|
|
|
|
// redisLoadBalancer stores state in Redis.
|
|
type redisLoadBalancer struct {
|
|
// The environment interface.
|
|
environment env.ProxyEnvironment
|
|
// The redis client.
|
|
rdb redisclient.RedisClient
|
|
// newClient is the factory used by Initialize to build the Redis client.
|
|
// A struct field (rather than a package global) so concurrent tests can
|
|
// each supply their own without racing on shared state.
|
|
newClient func(addr, password string, db int) redisclient.RedisClient
|
|
// keepaliveInterval is the period at which the default-backend keep-alive
|
|
// goroutine re-Updates its registration. Struct field for test injection.
|
|
keepaliveInterval time.Duration
|
|
}
|
|
|
|
// NewRedisLoadBalancer creates a new Redis-based load balancer.
|
|
func NewRedisLoadBalancer(environment env.ProxyEnvironment) OriginLoadBalancer {
|
|
return &redisLoadBalancer{
|
|
environment: environment,
|
|
newClient: redisclient.New,
|
|
keepaliveInterval: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
func (v *redisLoadBalancer) Initialize(ctx context.Context) error {
|
|
redisDatabase, err := strconv.Atoi(v.environment.RedisDB())
|
|
if err != nil {
|
|
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", v.environment.RedisDB())
|
|
}
|
|
|
|
rdb := v.newClient(
|
|
fmt.Sprintf("%v:%v", v.environment.RedisHost(), v.environment.RedisPort()),
|
|
v.environment.RedisPassword(),
|
|
redisDatabase,
|
|
)
|
|
v.rdb = rdb
|
|
|
|
if err := rdb.Ping(ctx).Err(); err != nil {
|
|
return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
|
|
}
|
|
logger.Debug(ctx, "RedisLB: connected to redis %v ok", rdb.String())
|
|
|
|
server, err := NewDefaultOriginServerForDebugging(v.environment)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "initialize default SRS")
|
|
}
|
|
|
|
if server != nil {
|
|
if err := v.Update(ctx, server); err != nil {
|
|
return errors.Wrapf(err, "update default SRS %+v", server)
|
|
}
|
|
|
|
// Keep alive.
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(v.keepaliveInterval):
|
|
if err := v.Update(ctx, server); err != nil {
|
|
logger.Warn(ctx, "update default SRS %+v failed, %+v", server, err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
logger.Debug(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (v *redisLoadBalancer) Update(ctx context.Context, server *OriginServer) error {
|
|
b, err := json.Marshal(server)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "marshal server %+v", server)
|
|
}
|
|
|
|
key := v.redisKeyServer(server.ID())
|
|
if err = v.rdb.Set(ctx, key, b, ServerAliveDuration).Err(); err != nil {
|
|
return errors.Wrapf(err, "set key=%v server %+v", key, server)
|
|
}
|
|
|
|
// Query all servers from redis, in json string.
|
|
var serverKeys []string
|
|
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
|
|
if err := json.Unmarshal(b, &serverKeys); err != nil {
|
|
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
|
|
}
|
|
}
|
|
|
|
// Check each server expiration, if not exists in redis, remove from servers.
|
|
for i := len(serverKeys) - 1; i >= 0; i-- {
|
|
if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
|
|
serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
|
|
}
|
|
}
|
|
|
|
// Add server to servers if not exists.
|
|
var found bool
|
|
for _, serverKey := range serverKeys {
|
|
if serverKey == key {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
serverKeys = append(serverKeys, key)
|
|
}
|
|
|
|
// Update all servers to redis.
|
|
b, err = json.Marshal(serverKeys)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "marshal servers %+v", serverKeys)
|
|
}
|
|
if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
|
|
return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (v *redisLoadBalancer) Pick(ctx context.Context, streamURL string) (*OriginServer, error) {
|
|
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
|
|
|
|
// Always proxy to the same server for the same stream URL.
|
|
if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
|
|
// If server not exists, ignore and pick another server for the stream URL.
|
|
if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
|
|
var server OriginServer
|
|
if err := json.Unmarshal(b, &server); err != nil {
|
|
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
|
|
}
|
|
|
|
// TODO: If server fail, we should migrate the streams to another server.
|
|
return &server, nil
|
|
}
|
|
}
|
|
|
|
// Query all servers from redis, in json string.
|
|
var serverKeys []string
|
|
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
|
|
if err := json.Unmarshal(b, &serverKeys); err != nil {
|
|
return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
|
|
}
|
|
}
|
|
|
|
// No server found, failed.
|
|
if len(serverKeys) == 0 {
|
|
return nil, fmt.Errorf("no server available for %v", streamURL)
|
|
}
|
|
|
|
// All server should be alive, if not, should have been removed by redis. So we only
|
|
// random pick one that is always available. Use global rand which is thread-safe since Go 1.20.
|
|
var serverKey string
|
|
var server OriginServer
|
|
for i := 0; i < 3; i++ {
|
|
tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
|
|
b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
|
|
if err == nil && len(b) > 0 {
|
|
if err := json.Unmarshal(b, &server); err != nil {
|
|
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
|
|
}
|
|
|
|
serverKey = tryServerKey
|
|
break
|
|
}
|
|
}
|
|
if serverKey == "" {
|
|
return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
|
|
}
|
|
|
|
// Update the picked server for the stream URL.
|
|
if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
|
|
return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
|
|
}
|
|
|
|
return &server, nil
|
|
}
|
|
|
|
func (v *redisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) {
|
|
key := v.redisKeySPBHID(spbhid)
|
|
|
|
b, err := v.rdb.Get(ctx, key).Bytes()
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "get key=%v HLS", key)
|
|
}
|
|
|
|
// Store the raw JSON bytes that will be unmarshaled by the concrete type
|
|
// The caller will need to handle the deserialization
|
|
var actual map[string]interface{}
|
|
if err := json.Unmarshal(b, &actual); err != nil {
|
|
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
|
|
}
|
|
|
|
// Return nil for now - Redis LB needs the concrete type to properly deserialize
|
|
// This is a limitation of using Redis with interfaces
|
|
return nil, errors.Errorf("Redis load balancer cannot deserialize interface types")
|
|
}
|
|
|
|
func (v *redisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value HLSPlayStream) (HLSPlayStream, error) {
|
|
b, err := json.Marshal(value)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "marshal HLS %v", value)
|
|
}
|
|
|
|
key := v.redisKeyHLS(streamURL)
|
|
if err = v.rdb.Set(ctx, key, b, HLSAliveDuration).Err(); err != nil {
|
|
return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
|
|
}
|
|
|
|
// Get SPBHID from value
|
|
key2 := v.redisKeySPBHID(value.GetSPBHID())
|
|
if err := v.rdb.Set(ctx, key2, b, HLSAliveDuration).Err(); err != nil {
|
|
return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
|
|
}
|
|
|
|
// Return the same value since we just stored it
|
|
return value, nil
|
|
}
|
|
|
|
func (v *redisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value RTCConnection) error {
|
|
b, err := json.Marshal(value)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "marshal WebRTC %v", value)
|
|
}
|
|
|
|
key := v.redisKeyRTC(streamURL)
|
|
if err = v.rdb.Set(ctx, key, b, RTCAliveDuration).Err(); err != nil {
|
|
return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
|
|
}
|
|
|
|
// Get Ufrag from value
|
|
key2 := v.redisKeyUfrag(value.GetUfrag())
|
|
if err := v.rdb.Set(ctx, key2, b, RTCAliveDuration).Err(); err != nil {
|
|
return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (v *redisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (RTCConnection, error) {
|
|
key := v.redisKeyUfrag(ufrag)
|
|
|
|
b, err := v.rdb.Get(ctx, key).Bytes()
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
|
|
}
|
|
|
|
// Return nil for now - Redis LB needs the concrete type to properly deserialize
|
|
// This is a limitation of using Redis with interfaces
|
|
var actual map[string]interface{}
|
|
if err := json.Unmarshal(b, &actual); err != nil {
|
|
return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
|
|
}
|
|
|
|
return nil, errors.Errorf("Redis load balancer cannot deserialize interface types")
|
|
}
|
|
|
|
func (v *redisLoadBalancer) redisKeyUfrag(ufrag string) string {
|
|
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
|
|
}
|
|
|
|
func (v *redisLoadBalancer) redisKeyRTC(streamURL string) string {
|
|
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
|
|
}
|
|
|
|
func (v *redisLoadBalancer) redisKeySPBHID(spbhid string) string {
|
|
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
|
|
}
|
|
|
|
func (v *redisLoadBalancer) redisKeyHLS(streamURL string) string {
|
|
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
|
|
}
|
|
|
|
func (v *redisLoadBalancer) redisKeyServer(serverID string) string {
|
|
return fmt.Sprintf("srs-proxy-server:%v", serverID)
|
|
}
|
|
|
|
func (v *redisLoadBalancer) redisKeyServers() string {
|
|
return fmt.Sprintf("srs-proxy-all-servers")
|
|
}
|