diff --git a/proxy/.gitignore b/proxy/.gitignore deleted file mode 100644 index c20f4b678..000000000 --- a/proxy/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -.idea -srs-proxy -.env -.go-formarted \ No newline at end of file diff --git a/proxy/Makefile b/proxy/Makefile deleted file mode 100644 index 29084d5b7..000000000 --- a/proxy/Makefile +++ /dev/null @@ -1,23 +0,0 @@ -.PHONY: all build test fmt clean run - -all: build - -build: fmt ./srs-proxy - -./srs-proxy: *.go - go build -o srs-proxy . - -test: - go test ./... - -fmt: ./.go-formarted - -./.go-formarted: *.go - touch .go-formarted - go fmt ./... - -clean: - rm -f srs-proxy .go-formarted - -run: fmt - go run . diff --git a/proxy/README.md b/proxy/README.md new file mode 100644 index 000000000..5b455e829 --- /dev/null +++ b/proxy/README.md @@ -0,0 +1,6 @@ +# Proxy + +Migrated to below repositoties: + +* [proxy-go](https://github.com/ossrs/proxy-go) An common proxy server for any media servers with RTMP/SRT/HLS/HTTP-FLV and WebRTC/WHIP/WHEP protocols support. + diff --git a/proxy/api.go b/proxy/api.go deleted file mode 100644 index 9b08c8fe2..000000000 --- a/proxy/api.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "fmt" - "net/http" - "os" - "strings" - "sync" - "time" - - "srs-proxy/errors" - "srs-proxy/logger" -) - -// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, -// to proxy other HTTP API of SRS like the streams and clients, etc. -type srsHTTPAPIServer struct { - // The underlayer HTTP server. - server *http.Server - // The WebRTC server. - rtc *srsWebRTCServer - // The gracefully quit timeout, wait server to quit. - gracefulQuitTimeout time.Duration - // The wait group for all goroutines. - wg sync.WaitGroup -} - -func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer { - v := &srsHTTPAPIServer{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *srsHTTPAPIServer) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) - defer cancel() - v.server.Shutdown(ctx) - - v.wg.Wait() - return nil -} - -func (v *srsHTTPAPIServer) Run(ctx context.Context) error { - // Parse address to listen. - addr := envHttpAPI() - if !strings.Contains(addr, ":") { - addr = ":" + addr - } - - // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} - logger.Df(ctx, "HTTP API server listen at %v", addr) - - // Shutdown the server gracefully when quiting. - go func() { - ctxParent := ctx - <-ctxParent.Done() - - ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) - defer cancel() - - v.server.Shutdown(ctx) - }() - - // The basic version handler, also can be used as health check API. - logger.Df(ctx, "Handle /api/v1/versions by %v", addr) - mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { - apiResponse(ctx, w, r, map[string]string{ - "signature": Signature(), - "version": Version(), - }) - }) - - // The WebRTC WHIP API handler. - logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) - mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { - if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { - apiError(ctx, w, r, err) - } - }) - - // The WebRTC WHEP API handler. - logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) - mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { - if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { - apiError(ctx, w, r, err) - } - }) - - // Run HTTP API server. - v.wg.Add(1) - go func() { - defer v.wg.Done() - - err := v.server.ListenAndServe() - if err != nil { - if ctx.Err() != context.Canceled { - // TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "HTTP API accept err %+v", err) - } else { - logger.Df(ctx, "HTTP API server done") - } - } - }() - - return nil -} - -// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service -// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter -// for Prometheus metrics. -type systemAPI struct { - // The underlayer HTTP server. - server *http.Server - // The gracefully quit timeout, wait server to quit. - gracefulQuitTimeout time.Duration - // The wait group for all goroutines. - wg sync.WaitGroup -} - -func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI { - v := &systemAPI{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *systemAPI) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) - defer cancel() - v.server.Shutdown(ctx) - - v.wg.Wait() - return nil -} - -func (v *systemAPI) Run(ctx context.Context) error { - // Parse address to listen. - addr := envSystemAPI() - if !strings.Contains(addr, ":") { - addr = ":" + addr - } - - // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} - logger.Df(ctx, "System API server listen at %v", addr) - - // Shutdown the server gracefully when quiting. - go func() { - ctxParent := ctx - <-ctxParent.Done() - - ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) - defer cancel() - - v.server.Shutdown(ctx) - }() - - // The basic version handler, also can be used as health check API. - logger.Df(ctx, "Handle /api/v1/versions by %v", addr) - mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { - apiResponse(ctx, w, r, map[string]string{ - "signature": Signature(), - "version": Version(), - }) - }) - - // The register service for SRS media servers. - logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr) - mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) { - if err := func() error { - var deviceID, ip, serverID, serviceID, pid string - var rtmp, stream, api, srt, rtc []string - if err := ParseBody(r.Body, &struct { - // The IP of SRS, mandatory. - IP *string `json:"ip"` - // The server id of SRS, store in file, may not change, mandatory. - ServerID *string `json:"server"` - // The service id of SRS, always change when restarted, mandatory. - ServiceID *string `json:"service"` - // The process id of SRS, always change when restarted, mandatory. - PID *string `json:"pid"` - // The RTMP listen endpoints, mandatory. - RTMP *[]string `json:"rtmp"` - // The HTTP Stream listen endpoints, optional. - HTTP *[]string `json:"http"` - // The API listen endpoints, optional. - API *[]string `json:"api"` - // The SRT listen endpoints, optional. - SRT *[]string `json:"srt"` - // The RTC listen endpoints, optional. - RTC *[]string `json:"rtc"` - // The device id of SRS, optional. - DeviceID *string `json:"device_id"` - }{ - IP: &ip, DeviceID: &deviceID, - ServerID: &serverID, ServiceID: &serviceID, PID: &pid, - RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc, - }); err != nil { - return errors.Wrapf(err, "parse body") - } - - if ip == "" { - return errors.Errorf("empty ip") - } - if serverID == "" { - return errors.Errorf("empty server") - } - if serviceID == "" { - return errors.Errorf("empty service") - } - if pid == "" { - return errors.Errorf("empty pid") - } - if len(rtmp) == 0 { - return errors.Errorf("empty rtmp") - } - - server := NewSRSServer(func(srs *SRSServer) { - srs.IP, srs.DeviceID = ip, deviceID - srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid - srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api - srs.SRT, srs.RTC = srt, rtc - srs.UpdatedAt = time.Now() - }) - if err := srsLoadBalancer.Update(ctx, server); err != nil { - return errors.Wrapf(err, "update SRS server %+v", server) - } - - logger.Df(ctx, "Register SRS media server, %+v", server) - return nil - }(); err != nil { - apiError(ctx, w, r, err) - } - - type Response struct { - Code int `json:"code"` - PID string `json:"pid"` - } - - apiResponse(ctx, w, r, &Response{ - Code: 0, PID: fmt.Sprintf("%v", os.Getpid()), - }) - }) - - // Run System API server. - v.wg.Add(1) - go func() { - defer v.wg.Done() - - err := v.server.ListenAndServe() - if err != nil { - if ctx.Err() != context.Canceled { - // TODO: If System API server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "System API accept err %+v", err) - } else { - logger.Df(ctx, "System API server done") - } - } - }() - - return nil -} diff --git a/proxy/debug.go b/proxy/debug.go deleted file mode 100644 index e9cc7f98a..000000000 --- a/proxy/debug.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "net/http" - - "srs-proxy/logger" -) - -func handleGoPprof(ctx context.Context) { - if addr := envGoPprof(); addr != "" { - go func() { - logger.Df(ctx, "Start Go pprof at %v", addr) - http.ListenAndServe(addr, nil) - }() - } -} diff --git a/proxy/env.go b/proxy/env.go deleted file mode 100644 index 3bb8b0b4d..000000000 --- a/proxy/env.go +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "io" - "os" - "path" - "strings" - - "srs-proxy/errors" - "srs-proxy/logger" -) - -// loadEnvFile loads the environment variables from file. Note that we only use .env file. -func loadEnvFile(ctx context.Context) error { - workDir, err := os.Getwd() - if err != nil { - return errors.Wrapf(err, "getpwd") - } - - envFile := path.Join(workDir, ".env") - if _, err := os.Stat(envFile); err != nil { - return nil - } - - file, err := os.Open(envFile) - if err != nil { - return errors.Wrapf(err, "open %v", envFile) - } - defer file.Close() - - b, err := io.ReadAll(file) - if err != nil { - return errors.Wrapf(err, "read %v", envFile) - } - - lines := strings.Split(strings.Replace(string(b), "\r\n", "\n", -1), "\n") - logger.Df(ctx, "load env file %v, lines=%v", envFile, len(lines)) - - for _, line := range lines { - if strings.HasPrefix(strings.TrimSpace(line), "#") { - continue - } - - if pos := strings.IndexByte(line, '='); pos > 0 { - key := strings.TrimSpace(line[:pos]) - value := strings.TrimSpace(line[pos+1:]) - if v := os.Getenv(key); v != "" { - continue - } - - os.Setenv(key, value) - } - } - - return nil -} - -// buildDefaultEnvironmentVariables setups the default environment variables. -func buildDefaultEnvironmentVariables(ctx context.Context) { - // Whether enable the Go pprof. - setEnvDefault("GO_PPROF", "") - // Force shutdown timeout. - setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") - // Graceful quit timeout. - setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") - - // The HTTP API server. - setEnvDefault("PROXY_HTTP_API", "11985") - // The HTTP web server. - setEnvDefault("PROXY_HTTP_SERVER", "18080") - // The RTMP media server. - setEnvDefault("PROXY_RTMP_SERVER", "11935") - // The WebRTC media server, via UDP protocol. - setEnvDefault("PROXY_WEBRTC_SERVER", "18000") - // The SRT media server, via UDP protocol. - setEnvDefault("PROXY_SRT_SERVER", "20080") - // The API server of proxy itself. - setEnvDefault("PROXY_SYSTEM_API", "12025") - // The static directory for web server. - setEnvDefault("PROXY_STATIC_FILES", "../trunk/research") - - // The load balancer, use redis or memory. - setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory") - // The redis server host. - setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1") - // The redis server port. - setEnvDefault("PROXY_REDIS_PORT", "6379") - // The redis server password. - setEnvDefault("PROXY_REDIS_PASSWORD", "") - // The redis server db. - setEnvDefault("PROXY_REDIS_DB", "0") - - // Whether enable the default backend server, for debugging. - setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off") - // Default backend server IP, for debugging. - setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") - // Default backend server port, for debugging. - setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935") - // Default backend api port, for debugging. - setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") - // Default backend udp rtc port, for debugging. - setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") - // Default backend udp srt port, for debugging. - setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080") - - logger.Df(ctx, "load .env as GO_PPROF=%v, "+ - "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ - "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ - "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ - "PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ - "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ - "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ - "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ - "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ - "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", - envGoPprof(), - envForceQuitTimeout(), envGraceQuitTimeout(), - envHttpAPI(), envHttpServer(), envRtmpServer(), - envWebRTCServer(), envSRTServer(), - envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(), - envDefaultBackendIP(), envDefaultBackendRTMP(), - envDefaultBackendHttp(), envDefaultBackendAPI(), - envDefaultBackendRTC(), envDefaultBackendSRT(), - envLoadBalancerType(), envRedisHost(), envRedisPort(), - envRedisPassword(), envRedisDB(), - ) -} - -func envStaticFiles() string { - return os.Getenv("PROXY_STATIC_FILES") -} - -func envDefaultBackendSRT() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") -} - -func envDefaultBackendRTC() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") -} - -func envDefaultBackendAPI() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_API") -} - -func envSRTServer() string { - return os.Getenv("PROXY_SRT_SERVER") -} - -func envWebRTCServer() string { - return os.Getenv("PROXY_WEBRTC_SERVER") -} - -func envDefaultBackendHttp() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP") -} - -func envRedisDB() string { - return os.Getenv("PROXY_REDIS_DB") -} - -func envRedisPassword() string { - return os.Getenv("PROXY_REDIS_PASSWORD") -} - -func envRedisPort() string { - return os.Getenv("PROXY_REDIS_PORT") -} - -func envRedisHost() string { - return os.Getenv("PROXY_REDIS_HOST") -} - -func envLoadBalancerType() string { - return os.Getenv("PROXY_LOAD_BALANCER_TYPE") -} - -func envDefaultBackendRTMP() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") -} - -func envDefaultBackendIP() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_IP") -} - -func envDefaultBackendEnabled() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") -} - -func envGraceQuitTimeout() string { - return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") -} - -func envForceQuitTimeout() string { - return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") -} - -func envGoPprof() string { - return os.Getenv("GO_PPROF") -} - -func envSystemAPI() string { - return os.Getenv("PROXY_SYSTEM_API") -} - -func envRtmpServer() string { - return os.Getenv("PROXY_RTMP_SERVER") -} - -func envHttpServer() string { - return os.Getenv("PROXY_HTTP_SERVER") -} - -func envHttpAPI() string { - return os.Getenv("PROXY_HTTP_API") -} - -// setEnvDefault set env key=value if not set. -func setEnvDefault(key, value string) { - if os.Getenv(key) == "" { - os.Setenv(key, value) - } -} diff --git a/proxy/errors/errors.go b/proxy/errors/errors.go deleted file mode 100644 index 257bc3ccd..000000000 --- a/proxy/errors/errors.go +++ /dev/null @@ -1,270 +0,0 @@ -// Package errors provides simple error handling primitives. -// -// The traditional error handling idiom in Go is roughly akin to -// -// if err != nil { -// return err -// } -// -// which applied recursively up the call stack results in error reports -// without context or debugging information. The errors package allows -// programmers to add context to the failure path in their code in a way -// that does not destroy the original value of the error. -// -// Adding context to an error -// -// The errors.Wrap function returns a new error that adds context to the -// original error by recording a stack trace at the point Wrap is called, -// and the supplied message. For example -// -// _, err := ioutil.ReadAll(r) -// if err != nil { -// return errors.Wrap(err, "read failed") -// } -// -// If additional control is required the errors.WithStack and errors.WithMessage -// functions destructure errors.Wrap into its component operations of annotating -// an error with a stack trace and an a message, respectively. -// -// Retrieving the cause of an error -// -// Using errors.Wrap constructs a stack of errors, adding context to the -// preceding error. Depending on the nature of the error it may be necessary -// to reverse the operation of errors.Wrap to retrieve the original error -// for inspection. Any error value which implements this interface -// -// type causer interface { -// Cause() error -// } -// -// can be inspected by errors.Cause. errors.Cause will recursively retrieve -// the topmost error which does not implement causer, which is assumed to be -// the original cause. For example: -// -// switch err := errors.Cause(err).(type) { -// case *MyError: -// // handle specifically -// default: -// // unknown error -// } -// -// causer interface is not exported by this package, but is considered a part -// of stable public API. -// -// Formatted printing of errors -// -// All error values returned from this package implement fmt.Formatter and can -// be formatted by the fmt package. The following verbs are supported -// -// %s print the error. If the error has a Cause it will be -// printed recursively -// %v see %s -// %+v extended format. Each Frame of the error's StackTrace will -// be printed in detail. -// -// Retrieving the stack trace of an error or wrapper -// -// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are -// invoked. This information can be retrieved with the following interface. -// -// type stackTracer interface { -// StackTrace() errors.StackTrace -// } -// -// Where errors.StackTrace is defined as -// -// type StackTrace []Frame -// -// The Frame type represents a call site in the stack trace. Frame supports -// the fmt.Formatter interface that can be used for printing information about -// the stack trace of this error. For example: -// -// if err, ok := err.(stackTracer); ok { -// for _, f := range err.StackTrace() { -// fmt.Printf("%+s:%d", f) -// } -// } -// -// stackTracer interface is not exported by this package, but is considered a part -// of stable public API. -// -// See the documentation for Frame.Format for more details. -// Fork from https://github.com/pkg/errors -package errors - -import ( - "fmt" - "io" -) - -// New returns an error with the supplied message. -// New also records the stack trace at the point it was called. -func New(message string) error { - return &fundamental{ - msg: message, - stack: callers(), - } -} - -// Errorf formats according to a format specifier and returns the string -// as a value that satisfies error. -// Errorf also records the stack trace at the point it was called. -func Errorf(format string, args ...interface{}) error { - return &fundamental{ - msg: fmt.Sprintf(format, args...), - stack: callers(), - } -} - -// fundamental is an error that has a message and a stack, but no caller. -type fundamental struct { - msg string - *stack -} - -func (f *fundamental) Error() string { return f.msg } - -func (f *fundamental) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - io.WriteString(s, f.msg) - f.stack.Format(s, verb) - return - } - fallthrough - case 's': - io.WriteString(s, f.msg) - case 'q': - fmt.Fprintf(s, "%q", f.msg) - } -} - -// WithStack annotates err with a stack trace at the point WithStack was called. -// If err is nil, WithStack returns nil. -func WithStack(err error) error { - if err == nil { - return nil - } - return &withStack{ - err, - callers(), - } -} - -type withStack struct { - error - *stack -} - -func (w *withStack) Cause() error { return w.error } - -func (w *withStack) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - fmt.Fprintf(s, "%+v", w.Cause()) - w.stack.Format(s, verb) - return - } - fallthrough - case 's': - io.WriteString(s, w.Error()) - case 'q': - fmt.Fprintf(s, "%q", w.Error()) - } -} - -// Wrap returns an error annotating err with a stack trace -// at the point Wrap is called, and the supplied message. -// If err is nil, Wrap returns nil. -func Wrap(err error, message string) error { - if err == nil { - return nil - } - err = &withMessage{ - cause: err, - msg: message, - } - return &withStack{ - err, - callers(), - } -} - -// Wrapf returns an error annotating err with a stack trace -// at the point Wrapf is call, and the format specifier. -// If err is nil, Wrapf returns nil. -func Wrapf(err error, format string, args ...interface{}) error { - if err == nil { - return nil - } - err = &withMessage{ - cause: err, - msg: fmt.Sprintf(format, args...), - } - return &withStack{ - err, - callers(), - } -} - -// WithMessage annotates err with a new message. -// If err is nil, WithMessage returns nil. -func WithMessage(err error, message string) error { - if err == nil { - return nil - } - return &withMessage{ - cause: err, - msg: message, - } -} - -type withMessage struct { - cause error - msg string -} - -func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -func (w *withMessage) Cause() error { return w.cause } - -func (w *withMessage) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - fmt.Fprintf(s, "%+v\n", w.Cause()) - io.WriteString(s, w.msg) - return - } - fallthrough - case 's', 'q': - io.WriteString(s, w.Error()) - } -} - -// Cause returns the underlying cause of the error, if possible. -// An error value has a cause if it implements the following -// interface: -// -// type causer interface { -// Cause() error -// } -// -// If the error does not implement Cause, the original error will -// be returned. If the error is nil, nil will be returned without further -// investigation. -func Cause(err error) error { - type causer interface { - Cause() error - } - - for err != nil { - cause, ok := err.(causer) - if !ok { - break - } - err = cause.Cause() - } - return err -} diff --git a/proxy/errors/stack.go b/proxy/errors/stack.go deleted file mode 100644 index 6c42db5a8..000000000 --- a/proxy/errors/stack.go +++ /dev/null @@ -1,187 +0,0 @@ -// Fork from https://github.com/pkg/errors -package errors - -import ( - "fmt" - "io" - "path" - "runtime" - "strings" -) - -// Frame represents a program counter inside a stack frame. -type Frame uintptr - -// pc returns the program counter for this frame; -// multiple frames may have the same PC value. -func (f Frame) pc() uintptr { return uintptr(f) - 1 } - -// file returns the full path to the file that contains the -// function for this Frame's pc. -func (f Frame) file() string { - fn := runtime.FuncForPC(f.pc()) - if fn == nil { - return "unknown" - } - file, _ := fn.FileLine(f.pc()) - return file -} - -// line returns the line number of source code of the -// function for this Frame's pc. -func (f Frame) line() int { - fn := runtime.FuncForPC(f.pc()) - if fn == nil { - return 0 - } - _, line := fn.FileLine(f.pc()) - return line -} - -// Format formats the frame according to the fmt.Formatter interface. -// -// %s source file -// %d source line -// %n function name -// %v equivalent to %s:%d -// -// Format accepts flags that alter the printing of some verbs, as follows: -// -// %+s path of source file relative to the compile time GOPATH -// %+v equivalent to %+s:%d -func (f Frame) Format(s fmt.State, verb rune) { - switch verb { - case 's': - switch { - case s.Flag('+'): - pc := f.pc() - fn := runtime.FuncForPC(pc) - if fn == nil { - io.WriteString(s, "unknown") - } else { - file, _ := fn.FileLine(pc) - fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) - } - default: - io.WriteString(s, path.Base(f.file())) - } - case 'd': - fmt.Fprintf(s, "%d", f.line()) - case 'n': - name := runtime.FuncForPC(f.pc()).Name() - io.WriteString(s, funcname(name)) - case 'v': - f.Format(s, 's') - io.WriteString(s, ":") - f.Format(s, 'd') - } -} - -// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). -type StackTrace []Frame - -// Format formats the stack of Frames according to the fmt.Formatter interface. -// -// %s lists source files for each Frame in the stack -// %v lists the source file and line number for each Frame in the stack -// -// Format accepts flags that alter the printing of some verbs, as follows: -// -// %+v Prints filename, function, and line number for each Frame in the stack. -func (st StackTrace) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case s.Flag('+'): - for _, f := range st { - fmt.Fprintf(s, "\n%+v", f) - } - case s.Flag('#'): - fmt.Fprintf(s, "%#v", []Frame(st)) - default: - fmt.Fprintf(s, "%v", []Frame(st)) - } - case 's': - fmt.Fprintf(s, "%s", []Frame(st)) - } -} - -// stack represents a stack of program counters. -type stack []uintptr - -func (s *stack) Format(st fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case st.Flag('+'): - for _, pc := range *s { - f := Frame(pc) - fmt.Fprintf(st, "\n%+v", f) - } - } - } -} - -func (s *stack) StackTrace() StackTrace { - f := make([]Frame, len(*s)) - for i := 0; i < len(f); i++ { - f[i] = Frame((*s)[i]) - } - return f -} - -func callers() *stack { - const depth = 32 - var pcs [depth]uintptr - n := runtime.Callers(3, pcs[:]) - var st stack = pcs[0:n] - return &st -} - -// funcname removes the path prefix component of a function's name reported by func.Name(). -func funcname(name string) string { - i := strings.LastIndex(name, "/") - name = name[i+1:] - i = strings.Index(name, ".") - return name[i+1:] -} - -func trimGOPATH(name, file string) string { - // Here we want to get the source file path relative to the compile time - // GOPATH. As of Go 1.6.x there is no direct way to know the compiled - // GOPATH at runtime, but we can infer the number of path segments in the - // GOPATH. We note that fn.Name() returns the function name qualified by - // the import path, which does not include the GOPATH. Thus we can trim - // segments from the beginning of the file path until the number of path - // separators remaining is one more than the number of path separators in - // the function name. For example, given: - // - // GOPATH /home/user - // file /home/user/src/pkg/sub/file.go - // fn.Name() pkg/sub.Type.Method - // - // We want to produce: - // - // pkg/sub/file.go - // - // From this we can easily see that fn.Name() has one less path separator - // than our desired output. We count separators from the end of the file - // path until it finds two more than in the function name and then move - // one character forward to preserve the initial path segment without a - // leading separator. - const sep = "/" - goal := strings.Count(name, sep) + 2 - i := len(file) - for n := 0; n < goal; n++ { - i = strings.LastIndex(file[:i], sep) - if i == -1 { - // not enough separators found, set i so that the slice expression - // below leaves file unmodified - i = -len(sep) - break - } - } - // get back to 0 or trim the leading separator - file = file[i+len(sep):] - return file -} diff --git a/proxy/go.mod b/proxy/go.mod deleted file mode 100644 index e9e196d2f..000000000 --- a/proxy/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module srs-proxy - -go 1.18 - -require github.com/go-redis/redis/v8 v8.11.5 - -require ( - github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect -) diff --git a/proxy/go.sum b/proxy/go.sum deleted file mode 100644 index 7342ff813..000000000 --- a/proxy/go.sum +++ /dev/null @@ -1,15 +0,0 @@ -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= -github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= -golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/proxy/http.go b/proxy/http.go deleted file mode 100644 index 0a4b0b75b..000000000 --- a/proxy/http.go +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "strconv" - "strings" - stdSync "sync" - "time" - - "srs-proxy/errors" - "srs-proxy/logger" -) - -// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, -// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy -// the request to the origin server. -type srsHTTPStreamServer struct { - // The underlayer HTTP server. - server *http.Server - // The gracefully quit timeout, wait server to quit. - gracefulQuitTimeout time.Duration - // The wait group for all goroutines. - wg stdSync.WaitGroup -} - -func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer { - v := &srsHTTPStreamServer{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *srsHTTPStreamServer) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) - defer cancel() - v.server.Shutdown(ctx) - - v.wg.Wait() - return nil -} - -func (v *srsHTTPStreamServer) Run(ctx context.Context) error { - // Parse address to listen. - addr := envHttpServer() - if !strings.Contains(addr, ":") { - addr = ":" + addr - } - - // Create server and handler. - mux := http.NewServeMux() - v.server = &http.Server{Addr: addr, Handler: mux} - logger.Df(ctx, "HTTP Stream server listen at %v", addr) - - // Shutdown the server gracefully when quiting. - go func() { - ctxParent := ctx - <-ctxParent.Done() - - ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) - defer cancel() - - v.server.Shutdown(ctx) - }() - - // The basic version handler, also can be used as health check API. - logger.Df(ctx, "Handle /api/v1/versions by %v", addr) - mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { - type Response struct { - Code int `json:"code"` - PID string `json:"pid"` - Data struct { - Major int `json:"major"` - Minor int `json:"minor"` - Revision int `json:"revision"` - Version string `json:"version"` - } `json:"data"` - } - - res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())} - res.Data.Major = VersionMajor() - res.Data.Minor = VersionMinor() - res.Data.Revision = VersionRevision() - res.Data.Version = Version() - - apiResponse(ctx, w, r, &res) - }) - - // The static web server, for the web pages. - var staticServer http.Handler - if staticFiles := envStaticFiles(); staticFiles != "" { - if _, err := os.Stat(staticFiles); err != nil { - return errors.Wrapf(err, "invalid static files %v", staticFiles) - } - - staticServer = http.FileServer(http.Dir(staticFiles)) - logger.Df(ctx, "Handle static files at %v", staticFiles) - } - - // The default handler, for both static web server and streaming server. - logger.Df(ctx, "Handle / by %v", addr) - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - // For HLS streaming, we will proxy the request to the streaming server. - if strings.HasSuffix(r.URL.Path, ".m3u8") { - unifiedURL, fullURL := convertURLToStreamURL(r) - streamURL, err := buildStreamURL(unifiedURL) - if err != nil { - http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest) - return - } - - stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { - s.SRSProxyBackendHLSID = logger.GenerateContextID() - s.StreamURL, s.FullURL = streamURL, fullURL - })) - - stream.Initialize(ctx).ServeHTTP(w, r) - return - } - - // For HTTP streaming, we will proxy the request to the streaming server. - if strings.HasSuffix(r.URL.Path, ".flv") || - strings.HasSuffix(r.URL.Path, ".ts") { - // If SPBHID is specified, it must be a HLS stream client. - if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { - if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { - http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) - } else { - stream.Initialize(ctx).ServeHTTP(w, r) - } - return - } - - // Use HTTP pseudo streaming to proxy the request. - NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) { - c.ctx = ctx - }).ServeHTTP(w, r) - return - } - - // Serve by static server. - if staticServer != nil { - staticServer.ServeHTTP(w, r) - return - } - - http.NotFound(w, r) - }) - - // Run HTTP server. - v.wg.Add(1) - go func() { - defer v.wg.Done() - - err := v.server.ListenAndServe() - if err != nil { - if ctx.Err() != context.Canceled { - // TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "HTTP Stream accept err %+v", err) - } else { - logger.Df(ctx, "HTTP Stream server done") - } - } - }() - - return nil -} - -// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS -// connection. There is no state need to be sync between proxy servers. -// -// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request, -// then proxy to the corresponding backend server. All state is in the HTTP request, so this -// connection is stateless. -type HTTPFlvTsConnection struct { - // The context for HTTP streaming. - ctx context.Context -} - -func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection { - v := &HTTPFlvTsConnection{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - ctx := logger.WithContext(v.ctx) - - if err := v.serve(ctx, w, r); err != nil { - apiError(ctx, w, r, err) - } else { - logger.Df(ctx, "HTTP client done") - } -} - -func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - // Always allow CORS for all requests. - if ok := apiCORS(ctx, w, r); ok { - return nil - } - - // Build the stream URL in vhost/app/stream schema. - unifiedURL, fullURL := convertURLToStreamURL(r) - logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL) - - streamURL, err := buildStreamURL(unifiedURL) - if err != nil { - return errors.Wrapf(err, "build stream url %v", unifiedURL) - } - - // Pick a backend SRS server to proxy the RTMP stream. - backend, err := srsLoadBalancer.Pick(ctx, streamURL) - if err != nil { - return errors.Wrapf(err, "pick backend for %v", streamURL) - } - - if err = v.serveByBackend(ctx, w, r, backend); err != nil { - return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) - } - - return nil -} - -func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { - // Parse HTTP port from backend. - if len(backend.HTTP) == 0 { - return errors.Errorf("no http stream server") - } - - var httpPort int - if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) - } else { - httpPort = int(iv) - } - - // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) - req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) - if err != nil { - return errors.Wrapf(err, "create request to %v", backendURL) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return errors.Wrapf(err, "do request to %v", backendURL) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) - } - - // Copy all headers from backend to client. - w.WriteHeader(resp.StatusCode) - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) - } - } - - logger.Df(ctx, "HTTP start streaming") - - // Proxy the stream from backend to client. - if _, err := io.Copy(w, resp.Body); err != nil { - return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) - } - - return nil -} - -// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS -// clients will share this object, and they do not use the same ctx among proxy servers. -// -// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. -// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create -// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert -// to the stream URL and then query the backend server to serve it. -type HLSPlayStream struct { - // The context for HLS streaming. - ctx context.Context - - // The spbhid, used to identify the backend server. - SRSProxyBackendHLSID string `json:"spbhid"` - // The stream URL in vhost/app/stream schema. - StreamURL string `json:"stream_url"` - // The full request URL for HLS streaming - FullURL string `json:"full_url"` -} - -func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { - v := &HLSPlayStream{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { - if v.ctx == nil { - v.ctx = logger.WithContext(ctx) - } - return v -} - -func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - - if err := v.serve(v.ctx, w, r); err != nil { - apiError(v.ctx, w, r, err) - } else { - logger.Df(v.ctx, "HLS client %v for %v with %v done", - v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) - } -} - -func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL - - // Always allow CORS for all requests. - if ok := apiCORS(ctx, w, r); ok { - return nil - } - - // Pick a backend SRS server to proxy the RTMP stream. - backend, err := srsLoadBalancer.Pick(ctx, streamURL) - if err != nil { - return errors.Wrapf(err, "pick backend for %v", streamURL) - } - - if err = v.serveByBackend(ctx, w, r, backend); err != nil { - return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) - } - - return nil -} - -func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { - // Parse HTTP port from backend. - if len(backend.HTTP) == 0 { - return errors.Errorf("no rtmp server") - } - - var httpPort int - if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) - } else { - httpPort = int(iv) - } - - // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) - if r.URL.RawQuery != "" { - backendURL += "?" + r.URL.RawQuery - } - - req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) - if err != nil { - return errors.Wrapf(err, "create request to %v", backendURL) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return errors.Errorf("do request to %v EOF", backendURL) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) - } - - // Copy all headers from backend to client. - w.WriteHeader(resp.StatusCode) - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) - } - } - - // For TS file, directly copy it. - if !strings.HasSuffix(r.URL.Path, ".m3u8") { - if _, err := io.Copy(w, resp.Body); err != nil { - return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) - } - - return nil - } - - // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts - // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return errors.Wrapf(err, "read stream from %v", backendURL) - } - - m3u8 := string(b) - if strings.Contains(m3u8, ".ts?") { - m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) - } else { - m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) - } - - if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { - return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL) - } - - return nil -} diff --git a/proxy/logger/context.go b/proxy/logger/context.go deleted file mode 100644 index fb15b767e..000000000 --- a/proxy/logger/context.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package logger - -import ( - "context" - "crypto/rand" - "crypto/sha256" - "encoding/hex" -) - -type key string - -var cidKey key = "cid.proxy.ossrs.org" - -// generateContextID generates a random context id in string. -func GenerateContextID() string { - randomBytes := make([]byte, 32) - _, _ = rand.Read(randomBytes) - hash := sha256.Sum256(randomBytes) - hashString := hex.EncodeToString(hash[:]) - cid := hashString[:7] - return cid -} - -// WithContext creates a new context with cid, which will be used for log. -func WithContext(ctx context.Context) context.Context { - return WithContextID(ctx, GenerateContextID()) -} - -// WithContextID creates a new context with cid, which will be used for log. -func WithContextID(ctx context.Context, cid string) context.Context { - return context.WithValue(ctx, cidKey, cid) -} - -// ContextID returns the cid in context, or empty string if not set. -func ContextID(ctx context.Context) string { - if cid, ok := ctx.Value(cidKey).(string); ok { - return cid - } - return "" -} diff --git a/proxy/logger/log.go b/proxy/logger/log.go deleted file mode 100644 index 22c4df81d..000000000 --- a/proxy/logger/log.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package logger - -import ( - "context" - "io/ioutil" - stdLog "log" - "os" -) - -type logger interface { - Printf(ctx context.Context, format string, v ...any) -} - -type loggerPlus struct { - logger *stdLog.Logger - level string -} - -func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { - v := &loggerPlus{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { - format, args := f, a - if cid := ContextID(ctx); cid != "" { - format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) - } - - v.logger.Printf(format, args...) -} - -var verboseLogger logger - -func Vf(ctx context.Context, format string, a ...interface{}) { - verboseLogger.Printf(ctx, format, a...) -} - -var debugLogger logger - -func Df(ctx context.Context, format string, a ...interface{}) { - debugLogger.Printf(ctx, format, a...) -} - -var warnLogger logger - -func Wf(ctx context.Context, format string, a ...interface{}) { - warnLogger.Printf(ctx, format, a...) -} - -var errorLogger logger - -func Ef(ctx context.Context, format string, a ...interface{}) { - errorLogger.Printf(ctx, format, a...) -} - -const ( - logVerboseLabel = "verb" - logDebugLabel = "debug" - logWarnLabel = "warn" - logErrorLabel = "error" -) - -func init() { - verboseLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logVerboseLabel - }) - debugLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logDebugLabel - }) - warnLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logWarnLabel - }) - errorLogger = newLoggerPlus(func(logger *loggerPlus) { - logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) - logger.level = logErrorLabel - }) -} diff --git a/proxy/main.go b/proxy/main.go deleted file mode 100644 index 430b2da60..000000000 --- a/proxy/main.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "os" - - "srs-proxy/errors" - "srs-proxy/logger" -) - -func main() { - ctx := logger.WithContext(context.Background()) - logger.Df(ctx, "%v/%v started", Signature(), Version()) - - // Install signals. - ctx, cancel := context.WithCancel(ctx) - installSignals(ctx, cancel) - - // Start the main loop, ignore the user cancel error. - err := doMain(ctx) - if err != nil && ctx.Err() != context.Canceled { - logger.Ef(ctx, "main: %+v", err) - os.Exit(-1) - } - - logger.Df(ctx, "%v done", Signature()) -} - -func doMain(ctx context.Context) error { - // Setup the environment variables. - if err := loadEnvFile(ctx); err != nil { - return errors.Wrapf(err, "load env") - } - - buildDefaultEnvironmentVariables(ctx) - - // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur - // because the main thread exits after the context is cancelled. However, sometimes the main thread - // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. - if err := installForceQuit(ctx); err != nil { - return errors.Wrapf(err, "install force quit") - } - - // Start the Go pprof if enabled. - handleGoPprof(ctx) - - // Initialize SRS load balancers. - switch lbType := envLoadBalancerType(); lbType { - case "memory": - srsLoadBalancer = NewMemoryLoadBalancer() - case "redis": - srsLoadBalancer = NewRedisLoadBalancer() - default: - return errors.Errorf("invalid load balancer %v", lbType) - } - - if err := srsLoadBalancer.Initialize(ctx); err != nil { - return errors.Wrapf(err, "initialize srs load balancer") - } - - // Parse the gracefully quit timeout. - gracefulQuitTimeout, err := parseGracefullyQuitTimeout() - if err != nil { - return errors.Wrapf(err, "parse gracefully quit timeout") - } - - // Start the RTMP server. - srsRTMPServer := NewSRSRTMPServer() - defer srsRTMPServer.Close() - if err := srsRTMPServer.Run(ctx); err != nil { - return errors.Wrapf(err, "rtmp server") - } - - // Start the WebRTC server. - srsWebRTCServer := NewSRSWebRTCServer() - defer srsWebRTCServer.Close() - if err := srsWebRTCServer.Run(ctx); err != nil { - return errors.Wrapf(err, "rtc server") - } - - // Start the HTTP API server. - srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) { - server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer - }) - defer srsHTTPAPIServer.Close() - if err := srsHTTPAPIServer.Run(ctx); err != nil { - return errors.Wrapf(err, "http api server") - } - - // Start the SRT server. - srsSRTServer := NewSRSSRTServer() - defer srsSRTServer.Close() - if err := srsSRTServer.Run(ctx); err != nil { - return errors.Wrapf(err, "srt server") - } - - // Start the System API server. - systemAPI := NewSystemAPI(func(server *systemAPI) { - server.gracefulQuitTimeout = gracefulQuitTimeout - }) - defer systemAPI.Close() - if err := systemAPI.Run(ctx); err != nil { - return errors.Wrapf(err, "system api server") - } - - // Start the HTTP web server. - srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) { - server.gracefulQuitTimeout = gracefulQuitTimeout - }) - defer srsHTTPStreamServer.Close() - if err := srsHTTPStreamServer.Run(ctx); err != nil { - return errors.Wrapf(err, "http server") - } - - // Wait for the main loop to quit. - <-ctx.Done() - return nil -} diff --git a/proxy/rtc.go b/proxy/rtc.go deleted file mode 100644 index 3516751ef..000000000 --- a/proxy/rtc.go +++ /dev/null @@ -1,515 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "encoding/binary" - "fmt" - "io/ioutil" - "net" - "net/http" - "strconv" - "strings" - stdSync "sync" - - "srs-proxy/errors" - "srs-proxy/logger" - "srs-proxy/sync" -) - -// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out -// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the -// SDP answer. -type srsWebRTCServer struct { - // The UDP listener for WebRTC server. - listener *net.UDPConn - - // Fast cache for the username to identify the connection. - // The key is username, the value is the UDP address. - usernames sync.Map[string, *RTCConnection] - // Fast cache for the udp address to identify the connection. - // The key is UDP address, the value is the username. - // TODO: Support fast earch by uint64 address. - addresses sync.Map[string, *RTCConnection] - - // The wait group for server. - wg stdSync.WaitGroup -} - -func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer { - v := &srsWebRTCServer{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *srsWebRTCServer) Close() error { - if v.listener != nil { - _ = v.listener.Close() - } - - v.wg.Wait() - return nil -} - -func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - defer r.Body.Close() - ctx = logger.WithContext(ctx) - - // Always allow CORS for all requests. - if ok := apiCORS(ctx, w, r); ok { - return nil - } - - // Read remote SDP offer from body. - remoteSDPOffer, err := ioutil.ReadAll(r.Body) - if err != nil { - return errors.Wrapf(err, "read remote sdp offer") - } - - // Build the stream URL in vhost/app/stream schema. - unifiedURL, fullURL := convertURLToStreamURL(r) - logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) - - streamURL, err := buildStreamURL(unifiedURL) - if err != nil { - return errors.Wrapf(err, "build stream url %v", unifiedURL) - } - - // Pick a backend SRS server to proxy the RTMP stream. - backend, err := srsLoadBalancer.Pick(ctx, streamURL) - if err != nil { - return errors.Wrapf(err, "pick backend for %v", streamURL) - } - - if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { - return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) - } - - return nil -} - -func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - defer r.Body.Close() - ctx = logger.WithContext(ctx) - - // Always allow CORS for all requests. - if ok := apiCORS(ctx, w, r); ok { - return nil - } - - // Read remote SDP offer from body. - remoteSDPOffer, err := ioutil.ReadAll(r.Body) - if err != nil { - return errors.Wrapf(err, "read remote sdp offer") - } - - // Build the stream URL in vhost/app/stream schema. - unifiedURL, fullURL := convertURLToStreamURL(r) - logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) - - streamURL, err := buildStreamURL(unifiedURL) - if err != nil { - return errors.Wrapf(err, "build stream url %v", unifiedURL) - } - - // Pick a backend SRS server to proxy the RTMP stream. - backend, err := srsLoadBalancer.Pick(ctx, streamURL) - if err != nil { - return errors.Wrapf(err, "pick backend for %v", streamURL) - } - - if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { - return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) - } - - return nil -} - -func (v *srsWebRTCServer) proxyApiToBackend( - ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, - remoteSDPOffer string, streamURL string, -) error { - // Parse HTTP port from backend. - if len(backend.API) == 0 { - return errors.Errorf("no http api server") - } - - var apiPort int - if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse http port %v", backend.API[0]) - } else { - apiPort = int(iv) - } - - // Connect to backend SRS server via HTTP client. - backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) - if r.URL.RawQuery != "" { - backendURL += "?" + r.URL.RawQuery - } - - req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer)) - if err != nil { - return errors.Wrapf(err, "create request to %v", backendURL) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return errors.Errorf("do request to %v EOF", backendURL) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status) - } - - // Copy all headers from backend to client. - w.WriteHeader(resp.StatusCode) - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) - } - } - - // Parse the local SDP answer from backend. - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return errors.Wrapf(err, "read stream from %v", backendURL) - } - - // Replace the WebRTC UDP port in answer. - localSDPAnswer := string(b) - for _, endpoint := range backend.RTC { - _, _, port, err := parseListenEndpoint(endpoint) - if err != nil { - return errors.Wrapf(err, "parse endpoint %v", endpoint) - } - - from := fmt.Sprintf(" %v typ host", port) - to := fmt.Sprintf(" %v typ host", envWebRTCServer()) - localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) - } - - // Fetch the ice-ufrag and ice-pwd from local SDP answer. - remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer) - if err != nil { - return errors.Wrapf(err, "parse remote sdp offer") - } - - localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer) - if err != nil { - return errors.Wrapf(err, "parse local sdp answer") - } - - // Save the new WebRTC connection to LB. - icePair := &RTCICEPair{ - RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, - LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, - } - if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) { - c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() - c.Initialize(ctx, v.listener) - - // Cache the connection for fast search by username. - v.usernames.Store(c.Ufrag, c) - })); err != nil { - return errors.Wrapf(err, "load or store webrtc %v", streamURL) - } - - // Response client with local answer. - if _, err = w.Write([]byte(localSDPAnswer)); err != nil { - return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer) - } - - logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB", - len(localSDPAnswer), localICEUfrag, len(localICEPwd)) - return nil -} - -func (v *srsWebRTCServer) Run(ctx context.Context) error { - // Parse address to listen. - endpoint := envWebRTCServer() - if !strings.Contains(endpoint, ":") { - endpoint = fmt.Sprintf(":%v", endpoint) - } - - saddr, err := net.ResolveUDPAddr("udp", endpoint) - if err != nil { - return errors.Wrapf(err, "resolve udp addr %v", endpoint) - } - - listener, err := net.ListenUDP("udp", saddr) - if err != nil { - return errors.Wrapf(err, "listen udp %v", saddr) - } - v.listener = listener - logger.Df(ctx, "WebRTC server listen at %v", saddr) - - // Consume all messages from UDP media transport. - v.wg.Add(1) - go func() { - defer v.wg.Done() - - for ctx.Err() == nil { - buf := make([]byte, 4096) - n, caddr, err := listener.ReadFromUDP(buf) - if err != nil { - // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "read from udp failed, err=%+v", err) - continue - } - - if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { - logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) - } - } - }() - - return nil -} - -func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { - var connection *RTCConnection - - // If STUN binding request, parse the ufrag and identify the connection. - if err := func() error { - if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) { - return nil - } - - var pkt RTCStunPacket - if err := pkt.UnmarshalBinary(data); err != nil { - return errors.Wrapf(err, "unmarshal stun packet") - } - - // Search the connection in fast cache. - if s, ok := v.usernames.Load(pkt.Username); ok { - connection = s - return nil - } - - // Load connection by username. - if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { - return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) - } else { - connection = s.Initialize(ctx, v.listener) - logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) - } - - // Cache connection for fast search. - if connection != nil { - v.usernames.Store(pkt.Username, connection) - } - return nil - }(); err != nil { - return err - } - - // Search the connection by addr. - if s, ok := v.addresses.Load(addr.String()); ok { - connection = s - } else if connection != nil { - // Cache the address for fast search. - v.addresses.Store(addr.String(), connection) - } - - // If connection is not found, ignore the packet. - if connection == nil { - // TODO: Should logging the dropped packet, only logging the first one for each address. - return nil - } - - // Proxy the packet to backend. - if err := connection.HandlePacket(addr, data); err != nil { - return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL) - } - - return nil -} - -// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC -// connection, identify by the ufrag in sdp offer/answer and ICE binding request. -// -// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is -// in the client request. The RTCConnection is stateful, and need to sync the ufrag between -// proxy servers. -// -// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch -// to another UDP address, it may connect to another WebRTC proxy, then we should discover the -// RTCConnection by the ufrag from the ICE binding request. -type RTCConnection struct { - // The stream context for WebRTC streaming. - ctx context.Context - - // The stream URL in vhost/app/stream schema. - StreamURL string `json:"stream_url"` - // The ufrag for this WebRTC connection. - Ufrag string `json:"ufrag"` - - // The UDP connection proxy to backend. - backendUDP *net.UDPConn - // The client UDP address. Note that it may change. - clientUDP *net.UDPAddr - // The listener UDP connection, used to send messages to client. - listenerUDP *net.UDPConn -} - -func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { - v := &RTCConnection{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { - if v.ctx == nil { - v.ctx = logger.WithContext(ctx) - } - if listener != nil { - v.listenerUDP = listener - } - return v -} - -func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { - ctx := v.ctx - - // Update the current UDP address. - v.clientUDP = addr - - // Start the UDP proxy to backend. - if err := v.connectBackend(ctx); err != nil { - return errors.Wrapf(err, "connect backend for %v", v.StreamURL) - } - - // Proxy client message to backend. - if v.backendUDP == nil { - return nil - } - - // Proxy all messages from backend to client. - go func() { - for ctx.Err() == nil { - buf := make([]byte, 4096) - n, _, err := v.backendUDP.ReadFromUDP(buf) - if err != nil { - // TODO: If backend server closed unexpectedly, we should notice the stream to quit. - logger.Wf(ctx, "read from backend failed, err=%v", err) - break - } - - if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { - // TODO: If backend server closed unexpectedly, we should notice the stream to quit. - logger.Wf(ctx, "write to client failed, err=%v", err) - break - } - } - }() - - if _, err := v.backendUDP.Write(data); err != nil { - return errors.Wrapf(err, "write to backend %v", v.StreamURL) - } - - return nil -} - -func (v *RTCConnection) connectBackend(ctx context.Context) error { - if v.backendUDP != nil { - return nil - } - - // Pick a backend SRS server to proxy the RTC stream. - backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL) - if err != nil { - return errors.Wrapf(err, "pick backend") - } - - // Parse UDP port from backend. - if len(backend.RTC) == 0 { - return errors.Errorf("no udp server") - } - - _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) - if err != nil { - return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL) - } - - // Connect to backend SRS server via UDP client. - // TODO: FIXME: Support close the connection when timeout or DTLS alert. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} - if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { - return errors.Wrapf(err, "dial udp to %v", backendAddr) - } else { - v.backendUDP = backendUDP - } - - return nil -} - -type RTCICEPair struct { - // The remote ufrag, used for ICE username and session id. - RemoteICEUfrag string `json:"remote_ufrag"` - // The remote pwd, used for ICE password. - RemoteICEPwd string `json:"remote_pwd"` - // The local ufrag, used for ICE username and session id. - LocalICEUfrag string `json:"local_ufrag"` - // The local pwd, used for ICE password. - LocalICEPwd string `json:"local_pwd"` -} - -// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. -func (v *RTCICEPair) Ufrag() string { - return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) -} - -type RTCStunPacket struct { - // The stun message type. - MessageType uint16 - // The stun username, or ufrag. - Username string -} - -func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { - if len(data) < 20 { - return errors.Errorf("stun packet too short %v", len(data)) - } - - p := data - v.MessageType = binary.BigEndian.Uint16(p) - messageLen := binary.BigEndian.Uint16(p[2:]) - //magicCookie := p[:8] - //transactionID := p[:20] - p = p[20:] - - if len(p) != int(messageLen) { - return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen) - } - - for len(p) > 0 { - typ := binary.BigEndian.Uint16(p) - length := binary.BigEndian.Uint16(p[2:]) - p = p[4:] - - if len(p) < int(length) { - return errors.Errorf("stun attribute length invalid %v < %v", len(p), length) - } - - value := p[:length] - p = p[length:] - - if length%4 != 0 { - p = p[4-length%4:] - } - - switch typ { - case 0x0006: - v.Username = string(value) - } - } - - return nil -} diff --git a/proxy/rtmp.go b/proxy/rtmp.go deleted file mode 100644 index 97413a199..000000000 --- a/proxy/rtmp.go +++ /dev/null @@ -1,655 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "fmt" - "math/rand" - "net" - "strconv" - "strings" - "sync" - "time" - - "srs-proxy/errors" - "srs-proxy/logger" - "srs-proxy/rtmp" -) - -// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS -// server. It will figure out the backend server to proxy to. Unlike the edge server, it will -// not cache the stream, but just proxy the stream to backend. -type srsRTMPServer struct { - // The TCP listener for RTMP server. - listener *net.TCPListener - // The random number generator. - rd *rand.Rand - // The wait group for all goroutines. - wg sync.WaitGroup -} - -func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer { - v := &srsRTMPServer{ - rd: rand.New(rand.NewSource(time.Now().UnixNano())), - } - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *srsRTMPServer) Close() error { - if v.listener != nil { - v.listener.Close() - } - - v.wg.Wait() - return nil -} - -func (v *srsRTMPServer) Run(ctx context.Context) error { - endpoint := envRtmpServer() - if !strings.Contains(endpoint, ":") { - endpoint = ":" + endpoint - } - - addr, err := net.ResolveTCPAddr("tcp", endpoint) - if err != nil { - return errors.Wrapf(err, "resolve rtmp addr %v", endpoint) - } - - listener, err := net.ListenTCP("tcp", addr) - if err != nil { - return errors.Wrapf(err, "listen rtmp addr %v", addr) - } - v.listener = listener - logger.Df(ctx, "RTMP server listen at %v", addr) - - v.wg.Add(1) - go func() { - defer v.wg.Done() - - for { - conn, err := v.listener.AcceptTCP() - if err != nil { - if ctx.Err() != context.Canceled { - // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "RTMP server accept err %+v", err) - } else { - logger.Df(ctx, "RTMP server done") - } - return - } - - v.wg.Add(1) - go func(ctx context.Context, conn *net.TCPConn) { - defer v.wg.Done() - defer conn.Close() - - handleErr := func(err error) { - if isPeerClosedError(err) { - logger.Df(ctx, "RTMP peer is closed") - } else { - logger.Wf(ctx, "RTMP serve err %+v", err) - } - } - - rc := NewRTMPConnection(func(client *RTMPConnection) { - client.rd = v.rd - }) - if err := rc.serve(ctx, conn); err != nil { - handleErr(err) - } else { - logger.Df(ctx, "RTMP client done") - } - }(logger.WithContext(ctx), conn) - } - }() - - return nil -} - -// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between -// proxy servers. -// -// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request, -// then proxy to the corresponding backend server. All state is in the RTMP request, so this -// connection is stateless. -type RTMPConnection struct { - // The random number generator. - rd *rand.Rand -} - -func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { - v := &RTMPConnection{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { - logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) - - // If any goroutine quit, cancel another one. - parentCtx := ctx - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - var backend *RTMPClientToBackend - if true { - go func() { - <-ctx.Done() - conn.Close() - if backend != nil { - backend.Close() - } - }() - } - - // Simple handshake with client. - hs := rtmp.NewHandshake(v.rd) - if _, err := hs.ReadC0S0(conn); err != nil { - return errors.Wrapf(err, "read c0") - } - if _, err := hs.ReadC1S1(conn); err != nil { - return errors.Wrapf(err, "read c1") - } - if err := hs.WriteC0S0(conn); err != nil { - return errors.Wrapf(err, "write s1") - } - if err := hs.WriteC1S1(conn); err != nil { - return errors.Wrapf(err, "write s1") - } - if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil { - return errors.Wrapf(err, "write s2") - } - if _, err := hs.ReadC2S2(conn); err != nil { - return errors.Wrapf(err, "read c2") - } - - client := rtmp.NewProtocol(conn) - logger.Df(ctx, "RTMP simple handshake done") - - // Expect RTMP connect command with tcUrl. - var connectReq *rtmp.ConnectAppPacket - if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil { - return errors.Wrapf(err, "expect connect req") - } - - if true { - ack := rtmp.NewWindowAcknowledgementSize() - ack.AckSize = 2500000 - if err := client.WritePacket(ctx, ack, 0); err != nil { - return errors.Wrapf(err, "write set ack size") - } - } - if true { - chunk := rtmp.NewSetChunkSize() - chunk.ChunkSize = 128 - if err := client.WritePacket(ctx, chunk, 0); err != nil { - return errors.Wrapf(err, "write set chunk size") - } - } - - connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) - connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) - connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127)) - connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1)) - connectRes.Args.Set("level", rtmp.NewAmf0String("status")) - connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success")) - connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded")) - connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0)) - connectResData := rtmp.NewAmf0EcmaArray() - connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888")) - connectResData.Set("srs_version", rtmp.NewAmf0String(Version())) - connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx))) - connectRes.Args.Set("data", connectResData) - if err := client.WritePacket(ctx, connectRes, 0); err != nil { - return errors.Wrapf(err, "write connect res") - } - - tcUrl := connectReq.TcUrl() - logger.Df(ctx, "RTMP connect app %v", tcUrl) - - // Expect RTMP command to identify the client, a publisher or viewer. - var currentStreamID, nextStreamID int - var streamName string - var clientType RTMPClientType - for clientType == "" { - var identifyReq rtmp.Packet - if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil { - return errors.Wrapf(err, "expect identify req") - } - - var response rtmp.Packet - switch pkt := identifyReq.(type) { - case *rtmp.CallPacket: - if pkt.CommandName == "createStream" { - identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) - response = identifyRes - - nextStreamID = 1 - identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) - } else if pkt.CommandName == "getStreamLength" { - // Ignore and do not reply these packets. - } else { - // For releaseStream, FCPublish, etc. - identifyRes := rtmp.NewCallPacket() - response = identifyRes - - identifyRes.TransactionID = pkt.TransactionID - identifyRes.CommandName = "_result" - identifyRes.CommandObject = rtmp.NewAmf0Null() - identifyRes.Args = rtmp.NewAmf0Undefined() - } - case *rtmp.PublishPacket: - streamName = string(pkt.StreamName) - clientType = RTMPClientTypePublisher - - identifyRes := rtmp.NewCallPacket() - response = identifyRes - - identifyRes.CommandName = "onFCPublish" - identifyRes.CommandObject = rtmp.NewAmf0Null() - - data := rtmp.NewAmf0Object() - data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) - data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) - identifyRes.Args = data - case *rtmp.PlayPacket: - streamName = string(pkt.StreamName) - clientType = RTMPClientTypeViewer - - identifyRes := rtmp.NewCallPacket() - response = identifyRes - - identifyRes.CommandName = "onStatus" - identifyRes.CommandObject = rtmp.NewAmf0Null() - - data := rtmp.NewAmf0Object() - data.Set("level", rtmp.NewAmf0String("status")) - data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset")) - data.Set("description", rtmp.NewAmf0String("Playing and resetting stream.")) - data.Set("details", rtmp.NewAmf0String("stream")) - data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) - identifyRes.Args = data - } - - if response != nil { - if err := client.WritePacket(ctx, response, currentStreamID); err != nil { - return errors.Wrapf(err, "write identify res for req=%v, stream=%v", - identifyReq, currentStreamID) - } - } - - // Update the stream ID for next request. - currentStreamID = nextStreamID - } - logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", - tcUrl, streamName, currentStreamID, clientType) - - // Find a backend SRS server to proxy the RTMP stream. - backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { - client.rd, client.typ = v.rd, clientType - }) - defer backend.Close() - - if err := backend.Connect(ctx, tcUrl, streamName); err != nil { - return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) - } - - // Start the streaming. - if clientType == RTMPClientTypePublisher { - identifyRes := rtmp.NewCallPacket() - - identifyRes.CommandName = "onStatus" - identifyRes.CommandObject = rtmp.NewAmf0Null() - - data := rtmp.NewAmf0Object() - data.Set("level", rtmp.NewAmf0String("status")) - data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) - data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) - data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) - identifyRes.Args = data - - if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { - return errors.Wrapf(err, "start publish") - } - } else if clientType == RTMPClientTypeViewer { - identifyRes := rtmp.NewCallPacket() - - identifyRes.CommandName = "onStatus" - identifyRes.CommandObject = rtmp.NewAmf0Null() - - data := rtmp.NewAmf0Object() - data.Set("level", rtmp.NewAmf0String("status")) - data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start")) - data.Set("description", rtmp.NewAmf0String("Started playing stream.")) - data.Set("details", rtmp.NewAmf0String("stream")) - data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) - identifyRes.Args = data - - if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { - return errors.Wrapf(err, "start play") - } - } - logger.Df(ctx, "RTMP start streaming") - - // For all proxy goroutines. - var wg sync.WaitGroup - defer wg.Wait() - - // Proxy all message from backend to client. - wg.Add(1) - var r0 error - go func() { - defer wg.Done() - defer cancel() - - r0 = func() error { - for { - m, err := backend.client.ReadMessage(ctx) - if err != nil { - return errors.Wrapf(err, "read message") - } - //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) - - // TODO: Update the stream ID if not the same. - if err := client.WriteMessage(ctx, m); err != nil { - return errors.Wrapf(err, "write message") - } - } - }() - }() - - // Proxy all messages from client to backend. - wg.Add(1) - var r1 error - go func() { - defer wg.Done() - defer cancel() - - r1 = func() error { - for { - m, err := client.ReadMessage(ctx) - if err != nil { - return errors.Wrapf(err, "read message") - } - //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) - - // TODO: Update the stream ID if not the same. - if err := backend.client.WriteMessage(ctx, m); err != nil { - return errors.Wrapf(err, "write message") - } - } - }() - }() - - // Wait until all goroutine quit. - wg.Wait() - - // Reset the error if caused by another goroutine. - if r0 != nil { - return errors.Wrapf(r0, "proxy backend->client") - } - if r1 != nil { - return errors.Wrapf(r1, "proxy client->backend") - } - - return parentCtx.Err() -} - -type RTMPClientType string - -const ( - RTMPClientTypePublisher RTMPClientType = "publisher" - RTMPClientTypeViewer RTMPClientType = "viewer" -) - -// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend. -type RTMPClientToBackend struct { - // The random number generator. - rd *rand.Rand - // The underlayer tcp client. - tcpConn *net.TCPConn - // The RTMP protocol client. - client *rtmp.Protocol - // The stream type. - typ RTMPClientType -} - -func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend { - v := &RTMPClientToBackend{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *RTMPClientToBackend) Close() error { - if v.tcpConn != nil { - v.tcpConn.Close() - } - return nil -} - -func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { - // Build the stream URL in vhost/app/stream schema. - streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) - if err != nil { - return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName) - } - - // Pick a backend SRS server to proxy the RTMP stream. - backend, err := srsLoadBalancer.Pick(ctx, streamURL) - if err != nil { - return errors.Wrapf(err, "pick backend for %v", streamURL) - } - - // Parse RTMP port from backend. - if len(backend.RTMP) == 0 { - return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) - } - - var rtmpPort int - if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0]) - } else { - rtmpPort = int(iv) - } - - // Connect to backend SRS server via TCP client. - addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort} - c, err := net.DialTCP("tcp", nil, addr) - if err != nil { - return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend) - } - v.tcpConn = c - - hs := rtmp.NewHandshake(v.rd) - client := rtmp.NewProtocol(c) - v.client = client - - // Simple RTMP handshake with server. - if err := hs.WriteC0S0(c); err != nil { - return errors.Wrapf(err, "write c0") - } - if err := hs.WriteC1S1(c); err != nil { - return errors.Wrapf(err, "write c1") - } - - if _, err = hs.ReadC0S0(c); err != nil { - return errors.Wrapf(err, "read s0") - } - if _, err := hs.ReadC1S1(c); err != nil { - return errors.Wrapf(err, "read s1") - } - if _, err = hs.ReadC2S2(c); err != nil { - return errors.Wrapf(err, "read c2") - } - logger.Df(ctx, "backend simple handshake done, server=%v", addr) - - if err := hs.WriteC2S2(c, hs.C1S1()); err != nil { - return errors.Wrapf(err, "write c2") - } - - // Connect RTMP app on tcUrl with server. - if true { - connectApp := rtmp.NewConnectAppPacket() - connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl)) - if err := client.WritePacket(ctx, connectApp, 1); err != nil { - return errors.Wrapf(err, "write connect app") - } - } - - if true { - var connectAppRes *rtmp.ConnectAppResPacket - if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil { - return errors.Wrapf(err, "expect connect app res") - } - logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID()) - } - - // Play or view RTMP stream with server. - if v.typ == RTMPClientTypeViewer { - return v.play(ctx, client, streamName) - } - - // Publish RTMP stream with server. - return v.publish(ctx, client, streamName) -} - -func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { - if true { - identifyReq := rtmp.NewCallPacket() - identifyReq.CommandName = "releaseStream" - identifyReq.TransactionID = 2 - identifyReq.CommandObject = rtmp.NewAmf0Null() - identifyReq.Args = rtmp.NewAmf0String(streamName) - if err := client.WritePacket(ctx, identifyReq, 0); err != nil { - return errors.Wrapf(err, "releaseStream") - } - } - for { - var identifyRes *rtmp.CallPacket - if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { - return errors.Wrapf(err, "expect releaseStream res") - } - if identifyRes.CommandName == "_result" { - break - } - } - - if true { - identifyReq := rtmp.NewCallPacket() - identifyReq.CommandName = "FCPublish" - identifyReq.TransactionID = 3 - identifyReq.CommandObject = rtmp.NewAmf0Null() - identifyReq.Args = rtmp.NewAmf0String(streamName) - if err := client.WritePacket(ctx, identifyReq, 0); err != nil { - return errors.Wrapf(err, "FCPublish") - } - } - for { - var identifyRes *rtmp.CallPacket - if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { - return errors.Wrapf(err, "expect FCPublish res") - } - if identifyRes.CommandName == "_result" { - break - } - } - - var currentStreamID int - if true { - createStream := rtmp.NewCreateStreamPacket() - createStream.TransactionID = 4 - createStream.CommandObject = rtmp.NewAmf0Null() - if err := client.WritePacket(ctx, createStream, 0); err != nil { - return errors.Wrapf(err, "createStream") - } - } - for { - var identifyRes *rtmp.CreateStreamResPacket - if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { - return errors.Wrapf(err, "expect createStream res") - } - if sid := identifyRes.StreamID; sid != 0 { - currentStreamID = int(sid) - break - } - } - - if true { - publishStream := rtmp.NewPublishPacket() - publishStream.TransactionID = 5 - publishStream.CommandObject = rtmp.NewAmf0Null() - publishStream.StreamName = *rtmp.NewAmf0String(streamName) - publishStream.StreamType = *rtmp.NewAmf0String("live") - if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { - return errors.Wrapf(err, "publish") - } - } - for { - var identifyRes *rtmp.CallPacket - if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { - return errors.Wrapf(err, "expect publish res") - } - // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start). - if identifyRes.CommandName == "onStatus" { - if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil { - return errors.Errorf("onStatus args not object") - } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { - return errors.Errorf("onStatus code not string") - } else if *code != "NetStream.Publish.Start" { - return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) - } - break - } - } - logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID) - - return nil -} - -func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { - var currentStreamID int - if true { - createStream := rtmp.NewCreateStreamPacket() - createStream.TransactionID = 4 - createStream.CommandObject = rtmp.NewAmf0Null() - if err := client.WritePacket(ctx, createStream, 0); err != nil { - return errors.Wrapf(err, "createStream") - } - } - for { - var identifyRes *rtmp.CreateStreamResPacket - if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { - return errors.Wrapf(err, "expect createStream res") - } - if sid := identifyRes.StreamID; sid != 0 { - currentStreamID = int(sid) - break - } - } - - playStream := rtmp.NewPlayPacket() - playStream.StreamName = *rtmp.NewAmf0String(streamName) - if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { - return errors.Wrapf(err, "play") - } - - for { - var identifyRes *rtmp.CallPacket - if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { - return errors.Wrapf(err, "expect releaseStream res") - } - if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" { - break - } - } - return nil -} diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go deleted file mode 100644 index 129b6617f..000000000 --- a/proxy/rtmp/amf0.go +++ /dev/null @@ -1,771 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package rtmp - -import ( - "bytes" - "encoding" - "encoding/binary" - "fmt" - "math" - "sync" - - "srs-proxy/errors" -) - -// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview -type amf0Marker uint8 - -const ( - amf0MarkerNumber amf0Marker = iota // 0 - amf0MarkerBoolean // 1 - amf0MarkerString // 2 - amf0MarkerObject // 3 - amf0MarkerMovieClip // 4 - amf0MarkerNull // 5 - amf0MarkerUndefined // 6 - amf0MarkerReference // 7 - amf0MarkerEcmaArray // 8 - amf0MarkerObjectEnd // 9 - amf0MarkerStrictArray // 10 - amf0MarkerDate // 11 - amf0MarkerLongString // 12 - amf0MarkerUnsupported // 13 - amf0MarkerRecordSet // 14 - amf0MarkerXmlDocument // 15 - amf0MarkerTypedObject // 16 - amf0MarkerAvmPlusObject // 17 - - amf0MarkerForbidden amf0Marker = 0xff -) - -func (v amf0Marker) String() string { - switch v { - case amf0MarkerNumber: - return "Amf0Number" - case amf0MarkerBoolean: - return "amf0Boolean" - case amf0MarkerString: - return "Amf0String" - case amf0MarkerObject: - return "Amf0Object" - case amf0MarkerNull: - return "Null" - case amf0MarkerUndefined: - return "Undefined" - case amf0MarkerReference: - return "Reference" - case amf0MarkerEcmaArray: - return "EcmaArray" - case amf0MarkerObjectEnd: - return "ObjectEnd" - case amf0MarkerStrictArray: - return "StrictArray" - case amf0MarkerDate: - return "Date" - case amf0MarkerLongString: - return "LongString" - case amf0MarkerUnsupported: - return "Unsupported" - case amf0MarkerXmlDocument: - return "XmlDocument" - case amf0MarkerTypedObject: - return "TypedObject" - case amf0MarkerAvmPlusObject: - return "AvmPlusObject" - case amf0MarkerMovieClip: - return "MovieClip" - case amf0MarkerRecordSet: - return "RecordSet" - default: - return "Forbidden" - } -} - -// For utest to mock it. -type amf0Buffer interface { - Bytes() []byte - WriteByte(c byte) error - Write(p []byte) (n int, err error) -} - -var createBuffer = func() amf0Buffer { - return &bytes.Buffer{} -} - -// All AMF0 things. -type amf0Any interface { - // Binary marshaler and unmarshaler. - encoding.BinaryUnmarshaler - encoding.BinaryMarshaler - // Get the size of bytes to marshal this object. - Size() int - - // Get the Marker of any AMF0 stuff. - amf0Marker() amf0Marker -} - -type amf0Converter struct { - from amf0Any -} - -func NewAmf0Converter(from amf0Any) *amf0Converter { - return &amf0Converter{from: from} -} - -func (v *amf0Converter) ToNumber() *amf0Number { - return amf0AnyTo[*amf0Number](v.from) -} - -func (v *amf0Converter) ToBoolean() *amf0Boolean { - return amf0AnyTo[*amf0Boolean](v.from) -} - -func (v *amf0Converter) ToString() *amf0String { - return amf0AnyTo[*amf0String](v.from) -} - -func (v *amf0Converter) ToObject() *amf0Object { - return amf0AnyTo[*amf0Object](v.from) -} - -func (v *amf0Converter) ToNull() *amf0Null { - return amf0AnyTo[*amf0Null](v.from) -} - -func (v *amf0Converter) ToUndefined() *amf0Undefined { - return amf0AnyTo[*amf0Undefined](v.from) -} - -func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { - return amf0AnyTo[*amf0EcmaArray](v.from) -} - -func (v *amf0Converter) ToStrictArray() *amf0StrictArray { - return amf0AnyTo[*amf0StrictArray](v.from) -} - -// Convert any to specified object. -func amf0AnyTo[T amf0Any](a amf0Any) T { - var to T - if a != nil { - if v, ok := a.(T); ok { - return v - } - } - return to -} - -// Discovery the amf0 object from the bytes b. -func Amf0Discovery(p []byte) (a amf0Any, err error) { - if len(p) < 1 { - return nil, errors.Errorf("require 1 bytes only %v", len(p)) - } - m := amf0Marker(p[0]) - - switch m { - case amf0MarkerNumber: - return NewAmf0Number(0), nil - case amf0MarkerBoolean: - return NewAmf0Boolean(false), nil - case amf0MarkerString: - return NewAmf0String(""), nil - case amf0MarkerObject: - return NewAmf0Object(), nil - case amf0MarkerNull: - return NewAmf0Null(), nil - case amf0MarkerUndefined: - return NewAmf0Undefined(), nil - case amf0MarkerReference: - case amf0MarkerEcmaArray: - return NewAmf0EcmaArray(), nil - case amf0MarkerObjectEnd: - return &amf0ObjectEOF{}, nil - case amf0MarkerStrictArray: - return NewAmf0StrictArray(), nil - case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, - amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, - amf0MarkerRecordSet: - return nil, errors.Errorf("Marker %v is not supported", m) - } - return nil, errors.Errorf("Marker %v is invalid", m) -} - -// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 -type amf0UTF8 string - -func (v *amf0UTF8) Size() int { - return 2 + len(string(*v)) -} - -func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 2 { - return errors.Errorf("require 2 bytes only %v", len(p)) - } - size := uint16(p[0])<<8 | uint16(p[1]) - - if p = data[2:]; len(p) < int(size) { - return errors.Errorf("require %v bytes only %v", int(size), len(p)) - } - *v = amf0UTF8(string(p[:size])) - - return -} - -func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { - data = make([]byte, v.Size()) - - size := uint16(len(string(*v))) - data[0] = byte(size >> 8) - data[1] = byte(size) - - if size > 0 { - copy(data[2:], []byte(*v)) - } - - return -} - -// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type -type amf0Number float64 - -func NewAmf0Number(f float64) *amf0Number { - v := amf0Number(f) - return &v -} - -func (v *amf0Number) amf0Marker() amf0Marker { - return amf0MarkerNumber -} - -func (v *amf0Number) Size() int { - return 1 + 8 -} - -func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 9 { - return errors.Errorf("require 9 bytes only %v", len(p)) - } - if m := amf0Marker(p[0]); m != amf0MarkerNumber { - return errors.Errorf("Amf0Number amf0Marker %v is illegal", m) - } - - f := binary.BigEndian.Uint64(p[1:]) - *v = amf0Number(math.Float64frombits(f)) - return -} - -func (v *amf0Number) MarshalBinary() (data []byte, err error) { - data = make([]byte, 9) - data[0] = byte(amf0MarkerNumber) - f := math.Float64bits(float64(*v)) - binary.BigEndian.PutUint64(data[1:], f) - return -} - -// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type -type amf0String string - -func NewAmf0String(s string) *amf0String { - v := amf0String(s) - return &v -} - -func (v *amf0String) amf0Marker() amf0Marker { - return amf0MarkerString -} - -func (v *amf0String) Size() int { - u := amf0UTF8(*v) - return 1 + u.Size() -} - -func (v *amf0String) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 1 { - return errors.Errorf("require 1 bytes only %v", len(p)) - } - if m := amf0Marker(p[0]); m != amf0MarkerString { - return errors.Errorf("Amf0String amf0Marker %v is illegal", m) - } - - var sv amf0UTF8 - if err = sv.UnmarshalBinary(p[1:]); err != nil { - return errors.WithMessage(err, "utf8") - } - *v = amf0String(string(sv)) - return -} - -func (v *amf0String) MarshalBinary() (data []byte, err error) { - u := amf0UTF8(*v) - - var pb []byte - if pb, err = u.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "utf8") - } - - data = append([]byte{byte(amf0MarkerString)}, pb...) - return -} - -// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type -type amf0ObjectEOF struct { -} - -func (v *amf0ObjectEOF) amf0Marker() amf0Marker { - return amf0MarkerObjectEnd -} - -func (v *amf0ObjectEOF) Size() int { - return 3 -} - -func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { - p := data - - if len(p) < 3 { - return errors.Errorf("require 3 bytes only %v", len(p)) - } - - if p[0] != 0 || p[1] != 0 || p[2] != 9 { - return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3]) - } - return -} - -func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { - return []byte{0, 0, 9}, nil -} - -// Use array for object and ecma array, to keep the original order. -type amf0Property struct { - key amf0UTF8 - value amf0Any -} - -// The object-like AMF0 structure, like object and ecma array and strict array. -type amf0ObjectBase struct { - properties []*amf0Property - lock sync.Mutex -} - -func (v *amf0ObjectBase) Size() int { - v.lock.Lock() - defer v.lock.Unlock() - - var size int - - for _, p := range v.properties { - key, value := p.key, p.value - size += key.Size() + value.Size() - } - - return size -} - -func (v *amf0ObjectBase) Get(key string) amf0Any { - v.lock.Lock() - defer v.lock.Unlock() - - for _, p := range v.properties { - if string(p.key) == key { - return p.value - } - } - - return nil -} - -func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { - v.lock.Lock() - defer v.lock.Unlock() - - prop := &amf0Property{key: amf0UTF8(key), value: value} - - var ok bool - for i, p := range v.properties { - if string(p.key) == key { - v.properties[i] = prop - ok = true - } - } - - if !ok { - v.properties = append(v.properties, prop) - } - - return v -} - -func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { - // if no eof, elems specified by maxElems. - if !eof && maxElems < 0 { - return errors.Errorf("maxElems=%v without eof", maxElems) - } - // if eof, maxElems must be -1. - if eof && maxElems != -1 { - return errors.Errorf("maxElems=%v with eof", maxElems) - } - - readOne := func() (amf0UTF8, amf0Any, error) { - var u amf0UTF8 - if err = u.UnmarshalBinary(p); err != nil { - return "", nil, errors.WithMessage(err, "prop name") - } - - p = p[u.Size():] - var a amf0Any - if a, err = Amf0Discovery(p); err != nil { - return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) - } - return u, a, nil - } - - pushOne := func(u amf0UTF8, a amf0Any) error { - // For object property, consume the whole bytes. - if err = a.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) - } - - v.Set(string(u), a) - p = p[a.Size():] - return nil - } - - for eof { - u, a, err := readOne() - if err != nil { - return errors.WithMessage(err, "read") - } - - // For object EOF, we should only consume total 3bytes. - if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd { - // 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte. - p = p[1:] - return nil - } - - if err := pushOne(u, a); err != nil { - return errors.WithMessage(err, "push") - } - } - - for len(v.properties) < maxElems { - u, a, err := readOne() - if err != nil { - return errors.WithMessage(err, "read") - } - - if err := pushOne(u, a); err != nil { - return errors.WithMessage(err, "push") - } - } - - return -} - -func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { - v.lock.Lock() - defer v.lock.Unlock() - - var pb []byte - for _, p := range v.properties { - key, value := p.key, p.value - - if pb, err = key.MarshalBinary(); err != nil { - return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) - } - if _, err = b.Write(pb); err != nil { - return errors.Wrapf(err, "write %v", string(key)) - } - - if pb, err = value.MarshalBinary(); err != nil { - return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) - } - if _, err = b.Write(pb); err != nil { - return errors.Wrapf(err, "marshal value for %v", string(key)) - } - } - - return -} - -// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type -type amf0Object struct { - amf0ObjectBase - eof amf0ObjectEOF -} - -func NewAmf0Object() *amf0Object { - v := &amf0Object{} - v.properties = []*amf0Property{} - return v -} - -func (v *amf0Object) amf0Marker() amf0Marker { - return amf0MarkerObject -} - -func (v *amf0Object) Size() int { - return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() -} - -func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 1 { - return errors.Errorf("require 1 byte only %v", len(p)) - } - if m := amf0Marker(p[0]); m != amf0MarkerObject { - return errors.Errorf("Amf0Object amf0Marker %v is illegal", m) - } - p = p[1:] - - if err = v.unmarshal(p, true, -1); err != nil { - return errors.WithMessage(err, "unmarshal") - } - - return -} - -func (v *amf0Object) MarshalBinary() (data []byte, err error) { - b := createBuffer() - - if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - if err = v.marshal(b); err != nil { - return nil, errors.WithMessage(err, "marshal") - } - - var pb []byte - if pb, err = v.eof.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal") - } - if _, err = b.Write(pb); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - return b.Bytes(), nil -} - -// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type -type amf0EcmaArray struct { - amf0ObjectBase - count uint32 - eof amf0ObjectEOF -} - -func NewAmf0EcmaArray() *amf0EcmaArray { - v := &amf0EcmaArray{} - v.properties = []*amf0Property{} - return v -} - -func (v *amf0EcmaArray) amf0Marker() amf0Marker { - return amf0MarkerEcmaArray -} - -func (v *amf0EcmaArray) Size() int { - return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() -} - -func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 5 { - return errors.Errorf("require 5 bytes only %v", len(p)) - } - if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { - return errors.Errorf("EcmaArray amf0Marker %v is illegal", m) - } - v.count = binary.BigEndian.Uint32(p[1:]) - p = p[5:] - - if err = v.unmarshal(p, true, -1); err != nil { - return errors.WithMessage(err, "unmarshal") - } - return -} - -func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { - b := createBuffer() - - if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - if err = binary.Write(b, binary.BigEndian, v.count); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - if err = v.marshal(b); err != nil { - return nil, errors.WithMessage(err, "marshal") - } - - var pb []byte - if pb, err = v.eof.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal") - } - if _, err = b.Write(pb); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - return b.Bytes(), nil -} - -// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type -type amf0StrictArray struct { - amf0ObjectBase - count uint32 -} - -func NewAmf0StrictArray() *amf0StrictArray { - v := &amf0StrictArray{} - v.properties = []*amf0Property{} - return v -} - -func (v *amf0StrictArray) amf0Marker() amf0Marker { - return amf0MarkerStrictArray -} - -func (v *amf0StrictArray) Size() int { - return int(1) + 4 + v.amf0ObjectBase.Size() -} - -func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 5 { - return errors.Errorf("require 5 bytes only %v", len(p)) - } - if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { - return errors.Errorf("StrictArray amf0Marker %v is illegal", m) - } - v.count = binary.BigEndian.Uint32(p[1:]) - p = p[5:] - - if int(v.count) <= 0 { - return - } - - if err = v.unmarshal(p, false, int(v.count)); err != nil { - return errors.WithMessage(err, "unmarshal") - } - return -} - -func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { - b := createBuffer() - - if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - if err = binary.Write(b, binary.BigEndian, v.count); err != nil { - return nil, errors.Wrap(err, "marshal") - } - - if err = v.marshal(b); err != nil { - return nil, errors.WithMessage(err, "marshal") - } - - return b.Bytes(), nil -} - -// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined. -type amf0SingleMarkerObject struct { - target amf0Marker -} - -func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject { - return amf0SingleMarkerObject{target: m} -} - -func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker { - return v.target -} - -func (v *amf0SingleMarkerObject) Size() int { - return int(1) -} - -func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 1 { - return errors.Errorf("require 1 byte only %v", len(p)) - } - if m := amf0Marker(p[0]); m != v.target { - return errors.Errorf("%v amf0Marker %v is illegal", v.target, m) - } - return -} - -func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { - return []byte{byte(v.target)}, nil -} - -// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type -type amf0Null struct { - amf0SingleMarkerObject -} - -func NewAmf0Null() *amf0Null { - v := amf0Null{} - v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) - return &v -} - -// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type -type amf0Undefined struct { - amf0SingleMarkerObject -} - -func NewAmf0Undefined() amf0Any { - v := amf0Undefined{} - v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) - return &v -} - -// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type -type amf0Boolean bool - -func NewAmf0Boolean(b bool) amf0Any { - v := amf0Boolean(b) - return &v -} - -func (v *amf0Boolean) amf0Marker() amf0Marker { - return amf0MarkerBoolean -} - -func (v *amf0Boolean) Size() int { - return int(2) -} - -func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { - var p []byte - if p = data; len(p) < 2 { - return errors.Errorf("require 2 bytes only %v", len(p)) - } - if m := amf0Marker(p[0]); m != amf0MarkerBoolean { - return errors.Errorf("BOOL amf0Marker %v is illegal", m) - } - if p[1] == 0 { - *v = false - } else { - *v = true - } - return -} - -func (v *amf0Boolean) MarshalBinary() (data []byte, err error) { - var b byte - if *v { - b = 1 - } - return []byte{byte(amf0MarkerBoolean), b}, nil -} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go deleted file mode 100644 index ff01172ac..000000000 --- a/proxy/rtmp/rtmp.go +++ /dev/null @@ -1,1792 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package rtmp - -import ( - "bufio" - "bytes" - "context" - "encoding" - "encoding/binary" - "fmt" - "io" - "math/rand" - "sync" - - "srs-proxy/errors" -) - -// The handshake implements the RTMP handshake protocol. -type Handshake struct { - // The random number generator. - r *rand.Rand - // The c1s1 cache. - c1s1 []byte -} - -func NewHandshake(r *rand.Rand) *Handshake { - return &Handshake{r: r} -} - -func (v *Handshake) C1S1() []byte { - return v.c1s1 -} - -func (v *Handshake) WriteC0S0(w io.Writer) (err error) { - r := bytes.NewReader([]byte{0x03}) - if _, err = io.Copy(w, r); err != nil { - return errors.Wrap(err, "write c0s0") - } - - return -} - -func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { - b := &bytes.Buffer{} - if _, err = io.CopyN(b, r, 1); err != nil { - return nil, errors.Wrap(err, "read c0s0") - } - - c0 = b.Bytes() - - return -} - -func (v *Handshake) WriteC1S1(w io.Writer) (err error) { - p := make([]byte, 1536) - - for i := 8; i < len(p); i++ { - p[i] = byte(v.r.Int()) - } - - r := bytes.NewReader(p) - if _, err = io.Copy(w, r); err != nil { - return errors.Wrap(err, "write c0s1") - } - - return -} - -func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { - b := &bytes.Buffer{} - if _, err = io.CopyN(b, r, 1536); err != nil { - return nil, errors.Wrap(err, "read c1s1") - } - - c1s1 = b.Bytes() - v.c1s1 = c1s1 - - return -} - -func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { - r := bytes.NewReader(s1c1[:]) - if _, err = io.Copy(w, r); err != nil { - return errors.Wrap(err, "write c2s2") - } - - return -} - -func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { - b := &bytes.Buffer{} - if _, err = io.CopyN(b, r, 1536); err != nil { - return nil, errors.Wrap(err, "read c2s2") - } - - c2 = b.Bytes() - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 16, @section 6.1. Chunk Format -// Extended timestamp: 0 or 4 bytes -// This field MUST be sent when the normal timsestamp is set to -// 0xffffff, it MUST NOT be sent if the normal timestamp is set to -// anything else. So for values less than 0xffffff the normal -// timestamp field SHOULD be used in which case the extended timestamp -// MUST NOT be present. For values greater than or equal to 0xffffff -// the normal timestamp field MUST NOT be used and MUST be set to -// 0xffffff and the extended timestamp MUST be sent. -const extendedTimestamp = uint64(0xffffff) - -// The default chunk size of RTMP is 128 bytes. -const defaultChunkSize = 128 - -// The intput or output settings for RTMP protocol. -type settings struct { - chunkSize uint32 -} - -func newSettings() *settings { - return &settings{ - chunkSize: defaultChunkSize, - } -} - -// The chunk stream which transport a message once. -type chunkStream struct { - format formatType - cid chunkID - header messageHeader - message *Message - count uint64 - extendedTimestamp bool -} - -func newChunkStream() *chunkStream { - return &chunkStream{} -} - -// The protocol implements the RTMP command and chunk stack. -type Protocol struct { - r *bufio.Reader - w *bufio.Writer - input struct { - opt *settings - chunks map[chunkID]*chunkStream - - transactions map[amf0Number]amf0String - ltransactions sync.Mutex - } - output struct { - opt *settings - } -} - -func NewProtocol(rw io.ReadWriter) *Protocol { - v := &Protocol{ - r: bufio.NewReader(rw), - w: bufio.NewWriter(rw), - } - - v.input.opt = newSettings() - v.input.chunks = map[chunkID]*chunkStream{} - v.input.transactions = map[amf0Number]amf0String{} - - v.output.opt = newSettings() - - return v -} - -func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { - for { - if m, err = v.ReadMessage(ctx); err != nil { - return nil, errors.WithMessage(err, "read message") - } - - var pkt Packet - if pkt, err = v.DecodeMessage(m); err != nil { - return nil, errors.WithMessage(err, "decode message") - } - - if p, ok := pkt.(T); ok { - *ppkt = p - break - } - } - - return -} - -// Deprecated: Please use rtmp.ExpectPacket instead. -func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { - panic("Please use rtmp.ExpectPacket instead") -} - -func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { - for { - if m, err = v.ReadMessage(ctx); err != nil { - return nil, errors.WithMessage(err, "read message") - } - - if len(types) == 0 { - return - } - - for _, t := range types { - if m.MessageType == t { - return - } - } - } - - return -} - -func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { - var commandName amf0String - if err = commandName.UnmarshalBinary(p); err != nil { - return nil, errors.WithMessage(err, "unmarshal command name") - } - - switch commandName { - case commandResult, commandError: - var transactionID amf0Number - if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { - return nil, errors.WithMessage(err, "unmarshal tid") - } - - var requestName amf0String - if err = func() error { - v.input.ltransactions.Lock() - defer v.input.ltransactions.Unlock() - - var ok bool - if requestName, ok = v.input.transactions[transactionID]; !ok { - return errors.Errorf("No matched request for tid=%v", transactionID) - } - delete(v.input.transactions, transactionID) - - return nil - }(); err != nil { - return nil, errors.WithMessage(err, "discovery request name") - } - - switch requestName { - case commandConnect: - return NewConnectAppResPacket(transactionID), nil - case commandCreateStream: - return NewCreateStreamResPacket(transactionID), nil - case commandReleaseStream, commandFCPublish, commandFCUnpublish: - call := NewCallPacket() - call.TransactionID = transactionID - return call, nil - default: - return nil, errors.Errorf("No request for %v", string(requestName)) - } - case commandConnect: - return NewConnectAppPacket(), nil - case commandPublish: - return NewPublishPacket(), nil - case commandPlay: - return NewPlayPacket(), nil - default: - return NewCallPacket(), nil - } -} - -func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { - p := m.Payload[:] - if len(p) == 0 { - return nil, errors.New("Empty packet") - } - - switch m.MessageType { - case MessageTypeAMF3Command, MessageTypeAMF3Data: - p = p[1:] - } - - switch m.MessageType { - case MessageTypeSetChunkSize: - pkt = NewSetChunkSize() - case MessageTypeWindowAcknowledgementSize: - pkt = NewWindowAcknowledgementSize() - case MessageTypeSetPeerBandwidth: - pkt = NewSetPeerBandwidth() - case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: - if pkt, err = v.parseAMFObject(p); err != nil { - return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) - } - case MessageTypeUserControl: - pkt = NewUserControl() - default: - return nil, errors.Errorf("Unknown message %v", m.MessageType) - } - - if err = pkt.UnmarshalBinary(p); err != nil { - return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) - } - - return -} - -func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { - for m == nil { - // TODO: We should convert buffered io to async io, because we will be stuck in block io here, - // TODO: but the risk is acceptable because we literally will set the underlay io timeout. - if ctx.Err() != nil { - return nil, ctx.Err() - } - - var cid chunkID - var format formatType - if format, cid, err = v.readBasicHeader(ctx); err != nil { - return nil, errors.WithMessage(err, "read basic header") - } - - var ok bool - var chunk *chunkStream - if chunk, ok = v.input.chunks[cid]; !ok { - chunk = newChunkStream() - v.input.chunks[cid] = chunk - chunk.header.betterCid = cid - } - - if err = v.readMessageHeader(ctx, chunk, format); err != nil { - return nil, errors.WithMessage(err, "read message header") - } - - if m, err = v.readMessagePayload(ctx, chunk); err != nil { - return nil, errors.WithMessage(err, "read message payload") - } - - if err = v.onMessageArrivated(m); err != nil { - return nil, errors.WithMessage(err, "on message") - } - } - - return -} - -func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { - // Empty payload message. - if chunk.message.payloadLength == 0 { - m = chunk.message - chunk.message = nil - return - } - - // Calculate the chunk payload size. - chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) - if chunkedPayloadSize > int(v.input.opt.chunkSize) { - chunkedPayloadSize = int(v.input.opt.chunkSize) - } - - b := make([]byte, chunkedPayloadSize) - if _, err = io.ReadFull(v.r, b); err != nil { - return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) - } - chunk.message.Payload = append(chunk.message.Payload, b...) - - // Got entire RTMP message? - if int(chunk.message.payloadLength) == len(chunk.message.Payload) { - m = chunk.message - chunk.message = nil - } - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 18, @section 6.1.2. Chunk Message Header -// There are four different formats for the chunk message header, -// selected by the "fmt" field in the chunk basic header. -type formatType uint8 - -const ( - // 6.1.2.1. Type 0 - // Chunks of Type 0 are 11 bytes long. This type MUST be used at the - // start of a chunk stream, and whenever the stream timestamp goes - // backward (e.g., because of a backward seek). - formatType0 formatType = iota - // 6.1.2.2. Type 1 - // Chunks of Type 1 are 7 bytes long. The message stream ID is not - // included; this chunk takes the same stream ID as the preceding chunk. - // Streams with variable-sized messages (for example, many video - // formats) SHOULD use this format for the first chunk of each new - // message after the first. - formatType1 - // 6.1.2.3. Type 2 - // Chunks of Type 2 are 3 bytes long. Neither the stream ID nor the - // message length is included; this chunk has the same stream ID and - // message length as the preceding chunk. Streams with constant-sized - // messages (for example, some audio and data formats) SHOULD use this - // format for the first chunk of each message after the first. - formatType2 - // 6.1.2.4. Type 3 - // Chunks of Type 3 have no header. Stream ID, message length and - // timestamp delta are not present; chunks of this type take values from - // the preceding chunk. When a single message is split into chunks, all - // chunks of a message except the first one, SHOULD use this type. Refer - // to example 2 in section 6.2.2. Stream consisting of messages of - // exactly the same size, stream ID and spacing in time SHOULD use this - // type for all chunks after chunk of Type 2. Refer to example 1 in - // section 6.2.1. If the delta between the first message and the second - // message is same as the time stamp of first message, then chunk of - // type 3 would immediately follow the chunk of type 0 as there is no - // need for a chunk of type 2 to register the delta. If Type 3 chunk - // follows a Type 0 chunk, then timestamp delta for this Type 3 chunk is - // the same as the timestamp of Type 0 chunk. - formatType3 -) - -// The message header size, index is format. -var messageHeaderSizes = []int{11, 7, 3, 0} - -// Parse the chunk message header. -// 3bytes: timestamp delta, fmt=0,1,2 -// 3bytes: payload length, fmt=0,1 -// 1bytes: message type, fmt=0,1 -// 4bytes: stream id, fmt=0 -// where: -// fmt=0, 0x0X -// fmt=1, 0x4X -// fmt=2, 0x8X -// fmt=3, 0xCX -func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { - // We should not assert anything about fmt, for the first packet. - // (when first packet, the chunk.message is nil). - // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. - // the previous packet is: - // 04 // fmt=0, cid=4 - // 00 00 1a // timestamp=26 - // 00 00 9d // payload_length=157 - // 08 // message_type=8(audio) - // 01 00 00 00 // stream_id=1 - // the current packet maybe: - // c4 // fmt=3, cid=4 - // it's ok, for the packet is audio, and timestamp delta is 26. - // the current packet must be parsed as: - // fmt=0, cid=4 - // timestamp=26+26=52 - // payload_length=157 - // message_type=8(audio) - // stream_id=1 - // so we must update the timestamp even fmt=3 for first packet. - // - // The fresh packet used to update the timestamp even fmt=3 for first packet. - // fresh packet always means the chunk is the first one of message. - var isFirstChunkOfMsg bool - if chunk.message == nil { - isFirstChunkOfMsg = true - } - - // But, we can ensure that when a chunk stream is fresh, - // the fmt must be 0, a new stream. - if chunk.count == 0 && format != formatType0 { - // For librtmp, if ping, it will send a fresh stream with fmt=1, - // 0x42 where: fmt=1, cid=2, protocol contorl user-control message - // 0x00 0x00 0x00 where: timestamp=0 - // 0x00 0x00 0x06 where: payload_length=6 - // 0x04 where: message_type=4(protocol control user-control message) - // 0x00 0x06 where: event Ping(0x06) - // 0x00 0x00 0x0d 0x0f where: event data 4bytes ping timestamp. - // @see: https://github.com/ossrs/srs/issues/98 - if chunk.cid == chunkIDProtocolControl && format == formatType1 { - // We accept cid=2, fmt=1 to make librtmp happy. - } else { - return errors.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) - } - } - - // When exists cache msg, means got an partial message, - // the fmt must not be type0 which means new message. - if chunk.message != nil && format == formatType0 { - return errors.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) - } - - // Create msg when new chunk stream start - if chunk.message == nil { - chunk.message = NewMessage() - } - - // Read the message header. - p := make([]byte, messageHeaderSizes[format]) - if _, err = io.ReadFull(v.r, p); err != nil { - return errors.Wrapf(err, "read %vB message header", len(p)) - } - - // Prse the message header. - // 3bytes: timestamp delta, fmt=0,1,2 - // 3bytes: payload length, fmt=0,1 - // 1bytes: message type, fmt=0,1 - // 4bytes: stream id, fmt=0 - // where: - // fmt=0, 0x0X - // fmt=1, 0x4X - // fmt=2, 0x8X - // fmt=3, 0xCX - if format <= formatType2 { - chunk.header.timestampDelta = uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) - p = p[3:] - - // fmt: 0 - // timestamp: 3 bytes - // If the timestamp is greater than or equal to 16777215 - // (hexadecimal 0x00ffffff), this value MUST be 16777215, and the - // 'extended timestamp header' MUST be present. Otherwise, this value - // SHOULD be the entire timestamp. - // - // fmt: 1 or 2 - // timestamp delta: 3 bytes - // If the delta is greater than or equal to 16777215 (hexadecimal - // 0x00ffffff), this value MUST be 16777215, and the 'extended - // timestamp header' MUST be present. Otherwise, this value SHOULD be - // the entire delta. - chunk.extendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp - if !chunk.extendedTimestamp { - // Extended timestamp: 0 or 4 bytes - // This field MUST be sent when the normal timsestamp is set to - // 0xffffff, it MUST NOT be sent if the normal timestamp is set to - // anything else. So for values less than 0xffffff the normal - // timestamp field SHOULD be used in which case the extended timestamp - // MUST NOT be present. For values greater than or equal to 0xffffff - // the normal timestamp field MUST NOT be used and MUST be set to - // 0xffffff and the extended timestamp MUST be sent. - if format == formatType0 { - // 6.1.2.1. Type 0 - // For a type-0 chunk, the absolute timestamp of the message is sent - // here. - chunk.header.Timestamp = uint64(chunk.header.timestampDelta) - } else { - // 6.1.2.2. Type 1 - // 6.1.2.3. Type 2 - // For a type-1 or type-2 chunk, the difference between the previous - // chunk's timestamp and the current chunk's timestamp is sent here. - chunk.header.Timestamp += uint64(chunk.header.timestampDelta) - } - } - - if format <= formatType1 { - payloadLength := uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) - p = p[3:] - - // For a message, if msg exists in cache, the size must not changed. - // always use the actual msg size to compare, for the cache payload length can changed, - // for the fmt type1(stream_id not changed), user can change the payload - // length(it's not allowed in the continue chunks). - if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { - return errors.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) - } - chunk.header.payloadLength = payloadLength - - chunk.header.MessageType = MessageType(p[0]) - p = p[1:] - - if format == formatType0 { - chunk.header.streamID = uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 - p = p[4:] - } - } - } else { - // Update the timestamp even fmt=3 for first chunk packet - if isFirstChunkOfMsg && !chunk.extendedTimestamp { - chunk.header.Timestamp += uint64(chunk.header.timestampDelta) - } - } - - // Read extended-timestamp - if chunk.extendedTimestamp { - var timestamp uint32 - if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { - return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) - } - - // We always use 31bits timestamp, for some server may use 32bits extended timestamp. - // @see https://github.com/ossrs/srs/issues/111 - timestamp &= 0x7fffffff - - // TODO: FIXME: Support detect the extended timestamp. - // @see http://blog.csdn.net/win_lin/article/details/13363699 - chunk.header.Timestamp = uint64(timestamp) - } - - // The extended-timestamp must be unsigned-int, - // 24bits timestamp: 0xffffff = 16777215ms = 16777.215s = 4.66h - // 32bits timestamp: 0xffffffff = 4294967295ms = 4294967.295s = 1193.046h = 49.71d - // because the rtmp protocol says the 32bits timestamp is about "50 days": - // 3. Byte Order, Alignment, and Time Format - // Because timestamps are generally only 32 bits long, they will roll - // over after fewer than 50 days. - // - // but, its sample says the timestamp is 31bits: - // An application could assume, for example, that all - // adjacent timestamps are within 2^31 milliseconds of each other, so - // 10000 comes after 4000000000, while 3000000000 comes before - // 4000000000. - // and flv specification says timestamp is 31bits: - // Extension of the Timestamp field to form a SI32 value. This - // field represents the upper 8 bits, while the previous - // Timestamp field represents the lower 24 bits of the time in - // milliseconds. - // in a word, 31bits timestamp is ok. - // convert extended timestamp to 31bits. - chunk.header.Timestamp &= 0x7fffffff - - // Copy header to msg - chunk.message.messageHeader = chunk.header - - // Increase the msg count, the chunk stream can accept fmt=1/2/3 message now. - chunk.count++ - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header -// The Chunk Basic Header encodes the chunk stream ID and the chunk -// type(represented by fmt field in the figure below). Chunk type -// determines the format of the encoded message header. Chunk Basic -// Header field may be 1, 2, or 3 bytes, depending on the chunk stream -// ID. -// -// The bits 0-5 (least significant) in the chunk basic header represent -// the chunk stream ID. -// -// Chunk stream IDs 2-63 can be encoded in the 1-byte version of this -// field. -// 0 1 2 3 4 5 6 7 -// +-+-+-+-+-+-+-+-+ -// |fmt| cs id | -// +-+-+-+-+-+-+-+-+ -// Figure 6 Chunk basic header 1 -// -// Chunk stream IDs 64-319 can be encoded in the 2-byte version of this -// field. ID is computed as (the second byte + 64). -// 0 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |fmt| 0 | cs id - 64 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// Figure 7 Chunk basic header 2 -// -// Chunk stream IDs 64-65599 can be encoded in the 3-byte version of -// this field. ID is computed as ((the third byte)*256 + the second byte -// + 64). -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |fmt| 1 | cs id - 64 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// Figure 8 Chunk basic header 3 -// -// cs id: 6 bits -// fmt: 2 bits -// cs id - 64: 8 or 16 bits -// -// Chunk stream IDs with values 64-319 could be represented by both 2- -// byte version and 3-byte version of this field. -func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { - // 2-63, 1B chunk header - var t uint8 - if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, errors.Wrap(err, "read basic header") - } - cid = chunkID(t & 0x3f) - format = formatType((t >> 6) & 0x03) - - if cid > 1 { - return - } - - // 64-319, 2B chunk header - if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) - } - cid = chunkID(64 + uint32(t)) - - // 64-65599, 3B chunk header - if cid == 1 { - if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) - } - cid += chunkID(uint32(t) * 256) - } - - return -} - -func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { - m := NewMessage() - - if m.Payload, err = pkt.MarshalBinary(); err != nil { - return errors.WithMessage(err, "marshal payload") - } - - m.MessageType = pkt.Type() - m.streamID = uint32(streamID) - m.betterCid = pkt.BetterCid() - - if err = v.WriteMessage(ctx, m); err != nil { - return errors.WithMessage(err, "write message") - } - - if err = v.onPacketWriten(m, pkt); err != nil { - return errors.WithMessage(err, "on write packet") - } - - return -} - -func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { - var tid amf0Number - var name amf0String - - switch pkt := pkt.(type) { - case *ConnectAppPacket: - tid, name = pkt.TransactionID, pkt.CommandName - case *CreateStreamPacket: - tid, name = pkt.TransactionID, pkt.CommandName - case *CallPacket: - tid, name = pkt.TransactionID, pkt.CommandName - } - - if tid > 0 && len(name) > 0 { - v.input.ltransactions.Lock() - defer v.input.ltransactions.Unlock() - - v.input.transactions[tid] = name - } - - return -} - -func (v *Protocol) onMessageArrivated(m *Message) (err error) { - if m == nil { - return - } - - var pkt Packet - switch m.MessageType { - case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: - if pkt, err = v.DecodeMessage(m); err != nil { - return errors.Errorf("decode message %v", m.MessageType) - } - } - - switch pkt := pkt.(type) { - case *SetChunkSize: - v.input.opt.chunkSize = pkt.ChunkSize - } - - return -} - -func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { - m.payloadLength = uint32(len(m.Payload)) - - var c0h, c3h []byte - if c0h, err = m.generateC0Header(); err != nil { - return errors.WithMessage(err, "generate c0 header") - } - if c3h, err = m.generateC3Header(); err != nil { - return errors.WithMessage(err, "generate c3 header") - } - - var h []byte - p := m.Payload - for len(p) > 0 { - // TODO: We should convert buffered io to async io, because we will be stuck in block io here, - // TODO: but the risk is acceptable because we literally will set the underlay io timeout. - if ctx.Err() != nil { - return ctx.Err() - } - - if h == nil { - h = c0h - } else { - h = c3h - } - - if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { - return errors.Wrapf(err, "write c0c3 header %x", h) - } - - size := len(p) - if size > int(v.output.opt.chunkSize) { - size = int(v.output.opt.chunkSize) - } - - if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { - return errors.Wrapf(err, "write chunk payload %vB", size) - } - p = p[size:] - } - - // TODO: We should convert buffered io to async io, because we will be stuck in block io here, - // TODO: but the risk is acceptable because we literally will set the underlay io timeout. - if ctx.Err() != nil { - return ctx.Err() - } - - // TODO: FIXME: Use writev to write for high performance. - if err = v.w.Flush(); err != nil { - return errors.Wrapf(err, "flush writer") - } - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header -// 1byte. One byte field to represent the message type. A range of type IDs -// (1-7) are reserved for protocol control messages. -type MessageType uint8 - -const ( - // Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 5. Protocol Control Messages - // RTMP reserves message type IDs 1-7 for protocol control messages. - // These messages contain information needed by the RTM Chunk Stream - // protocol or RTMP itself. Protocol messages with IDs 1 & 2 are - // reserved for usage with RTM Chunk Stream protocol. Protocol messages - // with IDs 3-6 are reserved for usage of RTMP. Protocol message with ID - // 7 is used between edge server and origin server. - MessageTypeSetChunkSize MessageType = 0x01 - MessageTypeAbort MessageType = 0x02 // 0x02 - MessageTypeAcknowledgement MessageType = 0x03 // 0x03 - MessageTypeUserControl MessageType = 0x04 // 0x04 - MessageTypeWindowAcknowledgementSize MessageType = 0x05 // 0x05 - MessageTypeSetPeerBandwidth MessageType = 0x06 // 0x06 - MessageTypeEdgeAndOriginServerCommand MessageType = 0x07 // 0x07 - // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3. Types of messages - // The server and the client send messages over the network to - // communicate with each other. The messages can be of any type which - // includes audio messages, video messages, command messages, shared - // object messages, data messages, and user control messages. - // - // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.4. Audio message - // The client or the server sends this message to send audio data to the - // peer. The message type value of 8 is reserved for audio messages. - MessageTypeAudio MessageType = 0x08 - // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.5. Video message - // The client or the server sends this message to send video data to the - // peer. The message type value of 9 is reserved for video messages. - // These messages are large and can delay the sending of other type of - // messages. To avoid such a situation, the video message is assigned - // the lowest priority. - MessageTypeVideo MessageType = 0x09 // 0x09 - // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.1. Command message - // Command messages carry the AMF-encoded commands between the client - // and the server. These messages have been assigned message type value - // of 20 for AMF0 encoding and message type value of 17 for AMF3 - // encoding. These messages are sent to perform some operations like - // connect, createStream, publish, play, pause on the peer. Command - // messages like onstatus, result etc. are used to inform the sender - // about the status of the requested commands. A command message - // consists of command name, transaction ID, and command object that - // contains related parameters. A client or a server can request Remote - // Procedure Calls (RPC) over streams that are communicated using the - // command messages to the peer. - MessageTypeAMF3Command MessageType = 17 // 0x11 - MessageTypeAMF0Command MessageType = 20 // 0x14 - // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.2. Data message - // The client or the server sends this message to send Metadata or any - // user data to the peer. Metadata includes details about the - // data(audio, video etc.) like creation time, duration, theme and so - // on. These messages have been assigned message type value of 18 for - // AMF0 and message type value of 15 for AMF3. - MessageTypeAMF0Data MessageType = 18 // 0x12 - MessageTypeAMF3Data MessageType = 15 // 0x0f -) - -// The header of message. -type messageHeader struct { - // 3bytes. - // Three-byte field that contains a timestamp delta of the message. - // @remark, only used for decoding message from chunk stream. - timestampDelta uint32 - // 3bytes. - // Three-byte field that represents the size of the payload in bytes. - // It is set in big-endian format. - payloadLength uint32 - // 1byte. - // One byte field to represent the message type. A range of type IDs - // (1-7) are reserved for protocol control messages. - MessageType MessageType - // 4bytes. - // Four-byte field that identifies the stream of the message. These - // bytes are set in little-endian format. - streamID uint32 - - // The chunk stream id over which transport. - betterCid chunkID - - // Four-byte field that contains a timestamp of the message. - // The 4 bytes are packed in the big-endian order. - // @remark, we use 64bits for large time for jitter detect and for large tbn like HLS. - Timestamp uint64 -} - -// The RTMP message, transport over chunk stream in RTMP. -// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header -type Message struct { - messageHeader - - // The payload which carries the RTMP packet. - Payload []byte -} - -func NewMessage() *Message { - return &Message{} -} - -func NewStreamMessage(streamID int) *Message { - v := NewMessage() - v.streamID = uint32(streamID) - v.betterCid = chunkIDOverStream - return v -} - -func (v *Message) generateC3Header() ([]byte, error) { - var c3h []byte - if v.Timestamp < extendedTimestamp { - c3h = make([]byte, 1) - } else { - c3h = make([]byte, 1+4) - } - - p := c3h - p[0] = 0xc0 | byte(v.betterCid&0x3f) - p = p[1:] - - // In RTMP protocol, there must not any timestamp in C3 header, - // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, - // always carry a extended timestamp in C3 header. - // @see: http://blog.csdn.net/win_lin/article/details/13363699 - if v.Timestamp >= extendedTimestamp { - p[0] = byte(v.Timestamp >> 24) - p[1] = byte(v.Timestamp >> 16) - p[2] = byte(v.Timestamp >> 8) - p[3] = byte(v.Timestamp) - } - - return c3h, nil -} - -func (v *Message) generateC0Header() ([]byte, error) { - var c0h []byte - if v.Timestamp < extendedTimestamp { - c0h = make([]byte, 1+3+3+1+4) - } else { - c0h = make([]byte, 1+3+3+1+4+4) - } - - p := c0h - p[0] = byte(v.betterCid) & 0x3f - p = p[1:] - - if v.Timestamp < extendedTimestamp { - p[0] = byte(v.Timestamp >> 16) - p[1] = byte(v.Timestamp >> 8) - p[2] = byte(v.Timestamp) - } else { - p[0] = 0xff - p[1] = 0xff - p[2] = 0xff - } - p = p[3:] - - p[0] = byte(v.payloadLength >> 16) - p[1] = byte(v.payloadLength >> 8) - p[2] = byte(v.payloadLength) - p = p[3:] - - p[0] = byte(v.MessageType) - p = p[1:] - - p[0] = byte(v.streamID) - p[1] = byte(v.streamID >> 8) - p[2] = byte(v.streamID >> 16) - p[3] = byte(v.streamID >> 24) - p = p[4:] - - if v.Timestamp >= extendedTimestamp { - p[0] = byte(v.Timestamp >> 24) - p[1] = byte(v.Timestamp >> 16) - p[2] = byte(v.Timestamp >> 8) - p[3] = byte(v.Timestamp) - } - - return c0h, nil -} - -// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header -type chunkID uint32 - -const ( - chunkIDProtocolControl chunkID = 0x02 - chunkIDOverConnection chunkID = 0x03 - chunkIDOverConnection2 chunkID = 0x04 - chunkIDOverStream chunkID = 0x05 - chunkIDOverStream2 chunkID = 0x06 - chunkIDVideo chunkID = 0x07 - chunkIDAudio chunkID = 0x08 -) - -// The Command Name of message. -const ( - commandConnect amf0String = amf0String("connect") - commandCreateStream amf0String = amf0String("createStream") - commandCloseStream amf0String = amf0String("closeStream") - commandPlay amf0String = amf0String("play") - commandPause amf0String = amf0String("pause") - commandOnBWDone amf0String = amf0String("onBWDone") - commandOnStatus amf0String = amf0String("onStatus") - commandResult amf0String = amf0String("_result") - commandError amf0String = amf0String("_error") - commandReleaseStream amf0String = amf0String("releaseStream") - commandFCPublish amf0String = amf0String("FCPublish") - commandFCUnpublish amf0String = amf0String("FCUnpublish") - commandPublish amf0String = amf0String("publish") - commandRtmpSampleAccess amf0String = amf0String("|RtmpSampleAccess") -) - -// The RTMP packet, transport as payload of RTMP message. -type Packet interface { - // Marshaler and unmarshaler - Size() int - encoding.BinaryUnmarshaler - encoding.BinaryMarshaler - - // RTMP protocol fields for each packet. - BetterCid() chunkID - Type() MessageType -} - -// A Call packet, both object and args are AMF0 objects. -type objectCallPacket struct { - CommandName amf0String - TransactionID amf0Number - CommandObject *amf0Object - Args *amf0Object -} - -func (v *objectCallPacket) BetterCid() chunkID { - return chunkIDOverConnection -} - -func (v *objectCallPacket) Type() MessageType { - return MessageTypeAMF0Command -} - -func (v *objectCallPacket) Size() int { - size := v.CommandName.Size() + v.TransactionID.Size() + v.CommandObject.Size() - if v.Args != nil { - size += v.Args.Size() - } - return size -} - -func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { - p := data - - if err = v.CommandName.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal command name") - } - p = p[v.CommandName.Size():] - - if err = v.TransactionID.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal tid") - } - p = p[v.TransactionID.Size():] - - if err = v.CommandObject.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal command") - } - p = p[v.CommandObject.Size():] - - if len(p) == 0 { - return - } - - v.Args = NewAmf0Object() - if err = v.Args.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal args") - } - - return -} - -func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { - var pb []byte - if pb, err = v.CommandName.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal command name") - } - data = append(data, pb...) - - if pb, err = v.TransactionID.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal tid") - } - data = append(data, pb...) - - if pb, err = v.CommandObject.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal command object") - } - data = append(data, pb...) - - if v.Args != nil { - if pb, err = v.Args.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal args") - } - data = append(data, pb...) - } - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 45, @section 4.1.1. connect -// The client sends the connect command to the server to request -// connection to a server application instance. -type ConnectAppPacket struct { - objectCallPacket -} - -func NewConnectAppPacket() *ConnectAppPacket { - v := &ConnectAppPacket{} - v.CommandName = commandConnect - v.CommandObject = NewAmf0Object() - v.TransactionID = amf0Number(1.0) - return v -} - -func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { - if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { - return errors.WithMessage(err, "unmarshal call") - } - - if v.CommandName != commandConnect { - return errors.Errorf("Invalid command name %v", string(v.CommandName)) - } - - if v.TransactionID != 1.0 { - return errors.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) - } - - return -} - -func (v *ConnectAppPacket) TcUrl() string { - if v.CommandObject != nil { - if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { - return string(*v) - } - } - return "" -} - -// The response for ConnectAppPacket. -type ConnectAppResPacket struct { - objectCallPacket -} - -func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { - v := &ConnectAppResPacket{} - v.CommandName = commandResult - v.CommandObject = NewAmf0Object() - v.Args = NewAmf0Object() - v.TransactionID = tid - return v -} - -func (v *ConnectAppResPacket) SrsID() string { - if v.Args != nil { - if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { - if v, ok := v.Get("srs_id").(*amf0String); ok { - return string(*v) - } - } - } - return "" -} - -func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { - if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { - return errors.WithMessage(err, "unmarshal call") - } - - if v.CommandName != commandResult { - return errors.Errorf("Invalid command name %v", string(v.CommandName)) - } - - return -} - -// A Call object, command object is variant. -type variantCallPacket struct { - CommandName amf0String - TransactionID amf0Number - CommandObject amf0Any // object or null -} - -func (v *variantCallPacket) BetterCid() chunkID { - return chunkIDOverConnection -} - -func (v *variantCallPacket) Type() MessageType { - return MessageTypeAMF0Command -} - -func (v *variantCallPacket) Size() int { - size := v.CommandName.Size() + v.TransactionID.Size() - - if v.CommandObject != nil { - size += v.CommandObject.Size() - } - - return size -} - -func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { - p := data - - if err = v.CommandName.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal command name") - } - p = p[v.CommandName.Size():] - - if err = v.TransactionID.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal tid") - } - p = p[v.TransactionID.Size():] - - if len(p) > 0 { - if v.CommandObject, err = Amf0Discovery(p); err != nil { - return errors.WithMessage(err, "discovery command object") - } - if err = v.CommandObject.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal command object") - } - p = p[v.CommandObject.Size():] - } - - return -} - -func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { - var pb []byte - if pb, err = v.CommandName.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal command name") - } - data = append(data, pb...) - - if pb, err = v.TransactionID.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal tid") - } - data = append(data, pb...) - - if v.CommandObject != nil { - if pb, err = v.CommandObject.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal command object") - } - data = append(data, pb...) - } - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 51, @section 4.1.2. Call -// The call method of the NetConnection object runs remote procedure -// calls (RPC) at the receiving end. The called RPC name is passed as a -// parameter to the call command. -// @remark onStatus packet is a call packet. -type CallPacket struct { - variantCallPacket - Args amf0Any // optional or object or null -} - -func NewCallPacket() *CallPacket { - return &CallPacket{} -} - -func (v *CallPacket) ArgsCode() string { - if v.Args != nil { - if v, ok := v.Args.(*amf0Object); ok { - if code, ok := v.Get("code").(*amf0String); ok { - return string(*code) - } - } - } - return "" -} - -func (v *CallPacket) Size() int { - size := v.variantCallPacket.Size() - - if v.Args != nil { - size += v.Args.Size() - } - - return size -} - -func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { - p := data - - if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal call") - } - p = p[v.variantCallPacket.Size():] - - if len(p) > 0 { - if v.Args, err = Amf0Discovery(p); err != nil { - return errors.WithMessage(err, "discovery args") - } - if err = v.Args.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal args") - } - } - - return -} - -func (v *CallPacket) MarshalBinary() (data []byte, err error) { - var pb []byte - if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal call") - } - data = append(data, pb...) - - if v.Args != nil { - if pb, err = v.Args.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal args") - } - data = append(data, pb...) - } - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 52, @section 4.1.3. createStream -// The client sends this command to the server to create a logical -// channel for message communication The publishing of audio, video, and -// metadata is carried out over stream channel created using the -// createStream command. -type CreateStreamPacket struct { - variantCallPacket -} - -func NewCreateStreamPacket() *CreateStreamPacket { - v := &CreateStreamPacket{} - v.CommandName = commandCreateStream - v.TransactionID = amf0Number(2) - v.CommandObject = NewAmf0Null() - return v -} - -// The response for create stream -type CreateStreamResPacket struct { - variantCallPacket - StreamID amf0Number -} - -func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { - v := &CreateStreamResPacket{} - v.CommandName = commandResult - v.TransactionID = tid - v.CommandObject = NewAmf0Null() - v.StreamID = 0 - return v -} - -func (v *CreateStreamResPacket) Size() int { - return v.variantCallPacket.Size() + v.StreamID.Size() -} - -func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { - p := data - - if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal call") - } - p = p[v.variantCallPacket.Size():] - - if err = v.StreamID.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal sid") - } - - return -} - -func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { - var pb []byte - if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal call") - } - data = append(data, pb...) - - if pb, err = v.StreamID.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal sid") - } - data = append(data, pb...) - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish -type PublishPacket struct { - variantCallPacket - StreamName amf0String - StreamType amf0String -} - -func NewPublishPacket() *PublishPacket { - v := &PublishPacket{} - v.CommandName = commandPublish - v.CommandObject = NewAmf0Null() - v.StreamType = "live" - return v -} - -func (v *PublishPacket) Size() int { - return v.variantCallPacket.Size() + v.StreamName.Size() + v.StreamType.Size() -} - -func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { - p := data - - if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal call") - } - p = p[v.variantCallPacket.Size():] - - if err = v.StreamName.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal stream name") - } - p = p[v.StreamName.Size():] - - if err = v.StreamType.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal stream type") - } - - return -} - -func (v *PublishPacket) MarshalBinary() (data []byte, err error) { - var pb []byte - if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal call") - } - data = append(data, pb...) - - if pb, err = v.StreamName.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal stream name") - } - data = append(data, pb...) - - if pb, err = v.StreamType.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal stream type") - } - data = append(data, pb...) - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play -type PlayPacket struct { - variantCallPacket - StreamName amf0String -} - -func NewPlayPacket() *PlayPacket { - v := &PlayPacket{} - v.CommandName = commandPlay - v.CommandObject = NewAmf0Null() - return v -} - -func (v *PlayPacket) Size() int { - return v.variantCallPacket.Size() + v.StreamName.Size() -} - -func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { - p := data - - if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal call") - } - p = p[v.variantCallPacket.Size():] - - if err = v.StreamName.UnmarshalBinary(p); err != nil { - return errors.WithMessage(err, "unmarshal stream name") - } - p = p[v.StreamName.Size():] - - return -} - -func (v *PlayPacket) MarshalBinary() (data []byte, err error) { - var pb []byte - if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal call") - } - data = append(data, pb...) - - if pb, err = v.StreamName.MarshalBinary(); err != nil { - return nil, errors.WithMessage(err, "marshal stream name") - } - data = append(data, pb...) - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 31, @section 5.1. Set Chunk Size -// Protocol control message 1, Set Chunk Size, is used to notify the -// peer about the new maximum chunk size. -type SetChunkSize struct { - ChunkSize uint32 -} - -func NewSetChunkSize() *SetChunkSize { - return &SetChunkSize{ - ChunkSize: defaultChunkSize, - } -} - -func (v *SetChunkSize) BetterCid() chunkID { - return chunkIDProtocolControl -} - -func (v *SetChunkSize) Type() MessageType { - return MessageTypeSetChunkSize -} - -func (v *SetChunkSize) Size() int { - return 4 -} - -func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { - if len(data) < 4 { - return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) - } - v.ChunkSize = binary.BigEndian.Uint32(data) - - return -} - -func (v *SetChunkSize) MarshalBinary() (data []byte, err error) { - data = make([]byte, 4) - binary.BigEndian.PutUint32(data, v.ChunkSize) - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.5. Window Acknowledgement Size (5) -// The client or the server sends this message to inform the peer which -// window size to use when sending acknowledgment. -type WindowAcknowledgementSize struct { - AckSize uint32 -} - -func NewWindowAcknowledgementSize() *WindowAcknowledgementSize { - return &WindowAcknowledgementSize{} -} - -func (v *WindowAcknowledgementSize) BetterCid() chunkID { - return chunkIDProtocolControl -} - -func (v *WindowAcknowledgementSize) Type() MessageType { - return MessageTypeWindowAcknowledgementSize -} - -func (v *WindowAcknowledgementSize) Size() int { - return 4 -} - -func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { - if len(data) < 4 { - return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) - } - v.AckSize = binary.BigEndian.Uint32(data) - - return -} - -func (v *WindowAcknowledgementSize) MarshalBinary() (data []byte, err error) { - data = make([]byte, 4) - binary.BigEndian.PutUint32(data, v.AckSize) - - return -} - -// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) -// The sender can mark this message hard (0), soft (1), or dynamic (2) -// using the Limit type field. -type LimitType uint8 - -const ( - LimitTypeHard LimitType = iota - LimitTypeSoft - LimitTypeDynamic -) - -// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) -// The client or the server sends this message to update the output -// bandwidth of the peer. -type SetPeerBandwidth struct { - Bandwidth uint32 - LimitType LimitType -} - -func NewSetPeerBandwidth() *SetPeerBandwidth { - return &SetPeerBandwidth{} -} - -func (v *SetPeerBandwidth) BetterCid() chunkID { - return chunkIDProtocolControl -} - -func (v *SetPeerBandwidth) Type() MessageType { - return MessageTypeSetPeerBandwidth -} - -func (v *SetPeerBandwidth) Size() int { - return 4 + 1 -} - -func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { - if len(data) < 5 { - return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) - } - v.Bandwidth = binary.BigEndian.Uint32(data) - v.LimitType = LimitType(data[4]) - - return -} - -func (v *SetPeerBandwidth) MarshalBinary() (data []byte, err error) { - data = make([]byte, 5) - binary.BigEndian.PutUint32(data, v.Bandwidth) - data[4] = byte(v.LimitType) - - return -} - -type EventType uint16 - -const ( - // Generally, 4bytes event-data - - // The server sends this event to notify the client - // that a stream has become functional and can be - // used for communication. By default, this event - // is sent on ID 0 after the application connect - // command is successfully received from the - // client. The event data is 4-byte and represents - // The stream ID of the stream that became - // Functional. - EventTypeStreamBegin = 0x00 - - // The server sends this event to notify the client - // that the playback of data is over as requested - // on this stream. No more data is sent without - // issuing additional commands. The client discards - // The messages received for the stream. The - // 4 bytes of event data represent the ID of the - // stream on which playback has ended. - EventTypeStreamEOF = 0x01 - - // The server sends this event to notify the client - // that there is no more data on the stream. If the - // server does not detect any message for a time - // period, it can notify the subscribed clients - // that the stream is dry. The 4 bytes of event - // data represent the stream ID of the dry stream. - EventTypeStreamDry = 0x02 - - // The client sends this event to inform the server - // of the buffer size (in milliseconds) that is - // used to buffer any data coming over a stream. - // This event is sent before the server starts - // processing the stream. The first 4 bytes of the - // event data represent the stream ID and the next - // 4 bytes represent the buffer length, in - // milliseconds. - EventTypeSetBufferLength = 0x03 // 8bytes event-data - - // The server sends this event to notify the client - // that the stream is a recorded stream. The - // 4 bytes event data represent the stream ID of - // The recorded stream. - EventTypeStreamIsRecorded = 0x04 - - // The server sends this event to test whether the - // client is reachable. Event data is a 4-byte - // timestamp, representing the local server time - // When the server dispatched the command. The - // client responds with kMsgPingResponse on - // receiving kMsgPingRequest. - EventTypePingRequest = 0x06 - - // The client sends this event to the server in - // Response to the ping request. The event data is - // a 4-byte timestamp, which was received with the - // kMsgPingRequest request. - EventTypePingResponse = 0x07 - - // For PCUC size=3, for example the payload is "00 1A 01", - // it's a FMS control event, where the event type is 0x001a and event data is 0x01, - // please notice that the event data is only 1 byte for this event. - EventTypeFmsEvent0 = 0x1a -) - -// Please read @doc rtmp_specification_1.0.pdf, @page 32, @5.4. User Control Message (4) -// The client or the server sends this message to notify the peer about the user control events. -// This message carries Event type and Event data. -type UserControl struct { - // Event type is followed by Event data. - // @see: SrcPCUCEventType - EventType EventType - // The event data generally in 4bytes. - // @remark for event type is 0x001a, only 1bytes. - // @see SrsPCUCFmsEvent0 - EventData int32 - // 4bytes if event_type is SetBufferLength; otherwise 0. - ExtraData int32 -} - -func NewUserControl() *UserControl { - return &UserControl{} -} - -func (v *UserControl) BetterCid() chunkID { - return chunkIDProtocolControl -} - -func (v *UserControl) Type() MessageType { - return MessageTypeUserControl -} - -func (v *UserControl) Size() int { - size := 2 - - if v.EventType == EventTypeFmsEvent0 { - size += 1 - } else { - size += 4 - } - - if v.EventType == EventTypeSetBufferLength { - size += 4 - } - - return size -} - -func (v *UserControl) UnmarshalBinary(data []byte) (err error) { - if len(data) < 3 { - return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) - } - - v.EventType = EventType(binary.BigEndian.Uint16(data)) - if len(data) < v.Size() { - return errors.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) - } - - if v.EventType == EventTypeFmsEvent0 { - v.EventData = int32(uint8(data[2])) - } else { - v.EventData = int32(binary.BigEndian.Uint32(data[2:])) - } - - if v.EventType == EventTypeSetBufferLength { - v.ExtraData = int32(binary.BigEndian.Uint32(data[6:])) - } - - return -} - -func (v *UserControl) MarshalBinary() (data []byte, err error) { - data = make([]byte, v.Size()) - binary.BigEndian.PutUint16(data, uint16(v.EventType)) - - if v.EventType == EventTypeFmsEvent0 { - data[2] = uint8(v.EventData) - } else { - binary.BigEndian.PutUint32(data[2:], uint32(v.EventData)) - } - - if v.EventType == EventTypeSetBufferLength { - binary.BigEndian.PutUint32(data[6:], uint32(v.ExtraData)) - } - - return -} diff --git a/proxy/signal.go b/proxy/signal.go deleted file mode 100644 index c77108990..000000000 --- a/proxy/signal.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "os" - "os/signal" - "syscall" - "time" - - "srs-proxy/errors" - "srs-proxy/logger" -) - -func installSignals(ctx context.Context, cancel context.CancelFunc) { - sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) - - go func() { - for s := range sc { - logger.Df(ctx, "Got signal %v", s) - cancel() - } - }() -} - -func installForceQuit(ctx context.Context) error { - var forceTimeout time.Duration - if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { - return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) - } else { - forceTimeout = t - } - - go func() { - <-ctx.Done() - time.Sleep(forceTimeout) - logger.Wf(ctx, "Force to exit by timeout") - os.Exit(1) - }() - return nil -} diff --git a/proxy/srs.go b/proxy/srs.go deleted file mode 100644 index 9fc66de07..000000000 --- a/proxy/srs.go +++ /dev/null @@ -1,553 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "encoding/json" - "fmt" - "math/rand" - "os" - "strconv" - "strings" - "time" - - // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ - "github.com/go-redis/redis/v8" - - "srs-proxy/errors" - "srs-proxy/logger" - "srs-proxy/sync" -) - -// If server heartbeat in this duration, it's alive. -const srsServerAliveDuration = 300 * time.Second - -// If HLS streaming update in this duration, it's alive. -const srsHLSAliveDuration = 120 * time.Second - -// If WebRTC streaming update in this duration, it's alive. -const srsRTCAliveDuration = 120 * time.Second - -type SRSServer struct { - // The server IP. - IP string `json:"ip,omitempty"` - // The server device ID, configured by user. - DeviceID string `json:"device_id,omitempty"` - // The server id of SRS, store in file, may not change, mandatory. - ServerID string `json:"server_id,omitempty"` - // The service id of SRS, always change when restarted, mandatory. - ServiceID string `json:"service_id,omitempty"` - // The process id of SRS, always change when restarted, mandatory. - PID string `json:"pid,omitempty"` - // The RTMP listen endpoints. - RTMP []string `json:"rtmp,omitempty"` - // The HTTP Stream listen endpoints. - HTTP []string `json:"http,omitempty"` - // The HTTP API listen endpoints. - API []string `json:"api,omitempty"` - // The SRT server listen endpoints. - SRT []string `json:"srt,omitempty"` - // The RTC server listen endpoints. - RTC []string `json:"rtc,omitempty"` - // Last update time. - UpdatedAt time.Time `json:"update_at,omitempty"` -} - -func (v *SRSServer) ID() string { - return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID) -} - -func (v *SRSServer) String() string { - return fmt.Sprintf("%v", v) -} - -func (v *SRSServer) Format(f fmt.State, c rune) { - switch c { - case 'v', 's': - if f.Flag('+') { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID)) - if v.DeviceID != "" { - sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID)) - } - if len(v.RTMP) > 0 { - sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ","))) - } - if len(v.HTTP) > 0 { - sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ","))) - } - if len(v.API) > 0 { - sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ","))) - } - if len(v.SRT) > 0 { - sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ","))) - } - if len(v.RTC) > 0 { - sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ","))) - } - sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999"))) - fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String()) - } else { - fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID()) - } - default: - fmt.Fprintf(f, "%v, fmt=%%%c", v, c) - } -} - -func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { - v := &SRSServer{} - for _, opt := range opts { - opt(v) - } - return v -} - -// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. -func NewDefaultSRSForDebugging() (*SRSServer, error) { - if envDefaultBackendEnabled() != "on" { - return nil, nil - } - - if envDefaultBackendIP() == "" { - return nil, fmt.Errorf("empty default backend ip") - } - if envDefaultBackendRTMP() == "" { - return nil, fmt.Errorf("empty default backend rtmp") - } - - server := NewSRSServer(func(srs *SRSServer) { - srs.IP = envDefaultBackendIP() - srs.RTMP = []string{envDefaultBackendRTMP()} - srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) - srs.ServiceID = logger.GenerateContextID() - srs.PID = fmt.Sprintf("%v", os.Getpid()) - srs.UpdatedAt = time.Now() - }) - - if envDefaultBackendHttp() != "" { - server.HTTP = []string{envDefaultBackendHttp()} - } - if envDefaultBackendAPI() != "" { - server.API = []string{envDefaultBackendAPI()} - } - if envDefaultBackendRTC() != "" { - server.RTC = []string{envDefaultBackendRTC()} - } - if envDefaultBackendSRT() != "" { - server.SRT = []string{envDefaultBackendSRT()} - } - return server, nil -} - -// SRSLoadBalancer is the interface to load balance the SRS servers. -type SRSLoadBalancer interface { - // Initialize the load balancer. - Initialize(ctx context.Context) error - // Update the backer server. - Update(ctx context.Context, server *SRSServer) error - // Pick a backend server for the specified stream URL. - Pick(ctx context.Context, streamURL string) (*SRSServer, error) - // Load or store the HLS streaming for the specified stream URL. - LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) - // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. - LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) - // Store the WebRTC streaming for the specified stream URL. - StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error - // Load the WebRTC streaming by ufrag, the ICE username. - LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) -} - -// srsLoadBalancer is the global SRS load balancer. -var srsLoadBalancer SRSLoadBalancer - -// srsMemoryLoadBalancer stores state in memory. -type srsMemoryLoadBalancer struct { - // All available SRS servers, key is server ID. - servers sync.Map[string, *SRSServer] - // The picked server to servce client by specified stream URL, key is stream url. - picked sync.Map[string, *SRSServer] - // The HLS streaming, key is stream URL. - hlsStreamURL sync.Map[string, *HLSPlayStream] - // The HLS streaming, key is SPBHID. - hlsSPBHID sync.Map[string, *HLSPlayStream] - // The WebRTC streaming, key is stream URL. - rtcStreamURL sync.Map[string, *RTCConnection] - // The WebRTC streaming, key is ufrag. - rtcUfrag sync.Map[string, *RTCConnection] -} - -func NewMemoryLoadBalancer() SRSLoadBalancer { - return &srsMemoryLoadBalancer{} -} - -func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { - if server, err := NewDefaultSRSForDebugging(); err != nil { - return errors.Wrapf(err, "initialize default SRS") - } else if server != nil { - if err := v.Update(ctx, server); err != nil { - return errors.Wrapf(err, "update default SRS %+v", server) - } - - // Keep alive. - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(30 * time.Second): - if err := v.Update(ctx, server); err != nil { - logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) - } - } - } - }() - logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server) - } - return nil -} - -func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error { - v.servers.Store(server.ID(), server) - return nil -} - -func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { - // Always proxy to the same server for the same stream URL. - if server, ok := v.picked.Load(streamURL); ok { - return server, nil - } - - // Gather all servers that were alive within the last few seconds. - var servers []*SRSServer - v.servers.Range(func(key string, server *SRSServer) bool { - if time.Since(server.UpdatedAt) < srsServerAliveDuration { - servers = append(servers, server) - } - return true - }) - - // If no servers available, use all possible servers. - if len(servers) == 0 { - v.servers.Range(func(key string, server *SRSServer) bool { - servers = append(servers, server) - return true - }) - } - - // No server found, failed. - if len(servers) == 0 { - return nil, fmt.Errorf("no server available for %v", streamURL) - } - - // Pick a server randomly from servers. - server := servers[rand.Intn(len(servers))] - v.picked.Store(streamURL, server) - return server, nil -} - -func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { - // Load the HLS streaming for the SPBHID, for TS files. - if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { - return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) - } else { - return actual, nil - } -} - -func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { - // Update the HLS streaming for the stream URL, for M3u8. - actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) - if actual == nil { - return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL) - } - - // Update the HLS streaming for the SPBHID, for TS files. - v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual) - - return actual, nil -} - -func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { - // Update the WebRTC streaming for the stream URL. - v.rtcStreamURL.Store(streamURL, value) - - // Update the WebRTC streaming for the ufrag. - v.rtcUfrag.Store(value.Ufrag, value) - return nil -} - -func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { - if actual, ok := v.rtcUfrag.Load(ufrag); !ok { - return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) - } else { - return actual, nil - } -} - -type srsRedisLoadBalancer struct { - // The redis client sdk. - rdb *redis.Client -} - -func NewRedisLoadBalancer() SRSLoadBalancer { - return &srsRedisLoadBalancer{} -} - -func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { - redisDatabase, err := strconv.Atoi(envRedisDB()) - if err != nil { - return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB()) - } - - rdb := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()), - Password: envRedisPassword(), - DB: redisDatabase, - }) - v.rdb = rdb - - if err := rdb.Ping(ctx).Err(); err != nil { - return errors.Wrapf(err, "unable to connect to redis %v", rdb.String()) - } - logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String()) - - if server, err := NewDefaultSRSForDebugging(); err != nil { - return errors.Wrapf(err, "initialize default SRS") - } else if server != nil { - if err := v.Update(ctx, server); err != nil { - return errors.Wrapf(err, "update default SRS %+v", server) - } - - // Keep alive. - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(30 * time.Second): - if err := v.Update(ctx, server); err != nil { - logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) - } - } - } - }() - logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server) - } - return nil -} - -func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error { - b, err := json.Marshal(server) - if err != nil { - return errors.Wrapf(err, "marshal server %+v", server) - } - - key := v.redisKeyServer(server.ID()) - if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil { - return errors.Wrapf(err, "set key=%v server %+v", key, server) - } - - // Query all servers from redis, in json string. - var serverKeys []string - if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { - if err := json.Unmarshal(b, &serverKeys); err != nil { - return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) - } - } - - // Check each server expiration, if not exists in redis, remove from servers. - for i := len(serverKeys) - 1; i >= 0; i-- { - if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil { - serverKeys = append(serverKeys[:i], serverKeys[i+1:]...) - } - } - - // Add server to servers if not exists. - var found bool - for _, serverKey := range serverKeys { - if serverKey == key { - found = true - break - } - } - if !found { - serverKeys = append(serverKeys, key) - } - - // Update all servers to redis. - b, err = json.Marshal(serverKeys) - if err != nil { - return errors.Wrapf(err, "marshal servers %+v", serverKeys) - } - if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil { - return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys) - } - - return nil -} - -func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { - key := fmt.Sprintf("srs-proxy-url:%v", streamURL) - - // Always proxy to the same server for the same stream URL. - if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil { - // If server not exists, ignore and pick another server for the stream URL. - if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 { - var server SRSServer - if err := json.Unmarshal(b, &server); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b)) - } - - // TODO: If server fail, we should migrate the streams to another server. - return &server, nil - } - } - - // Query all servers from redis, in json string. - var serverKeys []string - if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { - if err := json.Unmarshal(b, &serverKeys); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) - } - } - - // No server found, failed. - if len(serverKeys) == 0 { - return nil, fmt.Errorf("no server available for %v", streamURL) - } - - // All server should be alive, if not, should have been removed by redis. So we only - // random pick one that is always available. - var serverKey string - var server SRSServer - for i := 0; i < 3; i++ { - tryServerKey := serverKeys[rand.Intn(len(serverKeys))] - b, err := v.rdb.Get(ctx, tryServerKey).Bytes() - if err == nil && len(b) > 0 { - if err := json.Unmarshal(b, &server); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b)) - } - - serverKey = tryServerKey - break - } - } - if serverKey == "" { - return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL) - } - - // Update the picked server for the stream URL. - if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil { - return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey) - } - - return &server, nil -} - -func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { - key := v.redisKeySPBHID(spbhid) - - b, err := v.rdb.Get(ctx, key).Bytes() - if err != nil { - return nil, errors.Wrapf(err, "get key=%v HLS", key) - } - - var actual HLSPlayStream - if err := json.Unmarshal(b, &actual); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b)) - } - - return &actual, nil -} - -func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { - b, err := json.Marshal(value) - if err != nil { - return nil, errors.Wrapf(err, "marshal HLS %v", value) - } - - key := v.redisKeyHLS(streamURL) - if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil { - return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value) - } - - key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID) - if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil { - return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value) - } - - // Query the HLS streaming from redis. - b2, err := v.rdb.Get(ctx, key).Bytes() - if err != nil { - return nil, errors.Wrapf(err, "get key=%v HLS", key) - } - - var actual HLSPlayStream - if err := json.Unmarshal(b2, &actual); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2)) - } - - return &actual, nil -} - -func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { - b, err := json.Marshal(value) - if err != nil { - return errors.Wrapf(err, "marshal WebRTC %v", value) - } - - key := v.redisKeyRTC(streamURL) - if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil { - return errors.Wrapf(err, "set key=%v WebRTC %v", key, value) - } - - key2 := v.redisKeyUfrag(value.Ufrag) - if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil { - return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value) - } - - return nil -} - -func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { - key := v.redisKeyUfrag(ufrag) - - b, err := v.rdb.Get(ctx, key).Bytes() - if err != nil { - return nil, errors.Wrapf(err, "get key=%v WebRTC", key) - } - - var actual RTCConnection - if err := json.Unmarshal(b, &actual); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b)) - } - - return &actual, nil -} - -func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { - return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) -} - -func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string { - return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) -} - -func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { - return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) -} - -func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string { - return fmt.Sprintf("srs-proxy-hls:%v", streamURL) -} - -func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { - return fmt.Sprintf("srs-proxy-server:%v", serverID) -} - -func (v *srsRedisLoadBalancer) redisKeyServers() string { - return fmt.Sprintf("srs-proxy-all-servers") -} diff --git a/proxy/srt.go b/proxy/srt.go deleted file mode 100644 index fd4c2eac5..000000000 --- a/proxy/srt.go +++ /dev/null @@ -1,574 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "net" - "strings" - stdSync "sync" - "time" - - "srs-proxy/errors" - "srs-proxy/logger" - "srs-proxy/sync" -) - -// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to -// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the -// backend server. -type srsSRTServer struct { - // The UDP listener for SRT server. - listener *net.UDPConn - - // The SRT connections, identify by the socket ID. - sockets sync.Map[uint32, *SRTConnection] - // The system start time. - start time.Time - - // The wait group for server. - wg stdSync.WaitGroup -} - -func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer { - v := &srsSRTServer{ - start: time.Now(), - } - - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *srsSRTServer) Close() error { - if v.listener != nil { - v.listener.Close() - } - - v.wg.Wait() - return nil -} - -func (v *srsSRTServer) Run(ctx context.Context) error { - // Parse address to listen. - endpoint := envSRTServer() - if !strings.Contains(endpoint, ":") { - endpoint = ":" + endpoint - } - - saddr, err := net.ResolveUDPAddr("udp", endpoint) - if err != nil { - return errors.Wrapf(err, "resolve udp addr %v", endpoint) - } - - listener, err := net.ListenUDP("udp", saddr) - if err != nil { - return errors.Wrapf(err, "listen udp %v", saddr) - } - v.listener = listener - logger.Df(ctx, "SRT server listen at %v", saddr) - - // Consume all messages from UDP media transport. - v.wg.Add(1) - go func() { - defer v.wg.Done() - - for ctx.Err() == nil { - buf := make([]byte, 4096) - n, caddr, err := v.listener.ReadFromUDP(buf) - if err != nil { - // TODO: If SRT server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "read from udp failed, err=%+v", err) - continue - } - - if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { - logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) - } - } - }() - - return nil -} - -func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { - socketID := srtParseSocketID(data) - - var pkt *SRTHandshakePacket - if srtIsHandshake(data) { - pkt = &SRTHandshakePacket{} - if err := pkt.UnmarshalBinary(data); err != nil { - return err - } - - if socketID == 0 { - socketID = pkt.SRTSocketID - } - } - - conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { - c.ctx = logger.WithContext(ctx) - c.listenerUDP, c.socketID = v.listener, socketID - c.start = v.start - })) - - ctx = conn.ctx - if !ok { - logger.Df(ctx, "Create new SRT connection skt=%v", socketID) - } - - if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil { - return errors.Wrapf(err, "handle packet") - } else if newSocketID != 0 && newSocketID != socketID { - // The connection may use a new socket ID. - // TODO: FIXME: Should cleanup the dead SRT connection. - v.sockets.Store(newSocketID, conn) - } - - return nil -} - -// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT -// connection, identify by the socket ID. -// -// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in -// the client request. The SRTConnection is stateless, and no need to sync between proxy servers. -// -// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the -// client should never switch to another network or port. If this occurs, the client may be served -// by a different proxy server and fail because the other proxy server cannot identify the client. -type SRTConnection struct { - // The stream context for SRT connection. - ctx context.Context - - // The current socket ID. - socketID uint32 - - // The UDP connection proxy to backend. - backendUDP *net.UDPConn - // The listener UDP connection, used to send messages to client. - listenerUDP *net.UDPConn - - // Listener start time. - start time.Time - - // Handshake packets with client. - handshake0 *SRTHandshakePacket - handshake1 *SRTHandshakePacket - handshake2 *SRTHandshakePacket - handshake3 *SRTHandshakePacket -} - -func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { - v := &SRTConnection{} - for _, opt := range opts { - opt(v) - } - return v -} - -func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { - ctx := v.ctx - - // If not handshake, try to proxy to backend directly. - if pkt == nil { - // Proxy client message to backend. - if v.backendUDP != nil { - if _, err := v.backendUDP.Write(data); err != nil { - return v.socketID, errors.Wrapf(err, "write to backend") - } - } - - return v.socketID, nil - } - - // Handle handshake messages. - if err := v.handleHandshake(ctx, pkt, addr, data); err != nil { - return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt) - } - - return v.socketID, nil -} - -func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { - // Handle handshake 0 and 1 messages. - if pkt.SynCookie == 0 { - // Save handshake 0 packet. - v.handshake0 = pkt - logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0) - - // Response handshake 1. - v.handshake1 = &SRTHandshakePacket{ - ControlFlag: pkt.ControlFlag, - ControlType: 0, - SubType: 0, - AdditionalInfo: 0, - Timestamp: uint32(time.Since(v.start).Microseconds()), - SocketID: pkt.SRTSocketID, - Version: 5, - EncryptionField: 0, - ExtensionField: 0x4A17, - InitSequence: pkt.InitSequence, - MTU: pkt.MTU, - FlowWindow: pkt.FlowWindow, - HandshakeType: 1, - SRTSocketID: pkt.SRTSocketID, - SynCookie: 0x418d5e4e, - PeerIP: net.ParseIP("127.0.0.1"), - } - logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1) - - if b, err := v.handshake1.MarshalBinary(); err != nil { - return errors.Wrapf(err, "marshal handshake 1") - } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { - return errors.Wrapf(err, "write handshake 1") - } - - return nil - } - - // Handle handshake 2 and 3 messages. - // Parse stream id from packet. - streamID, err := pkt.StreamID() - if err != nil { - return errors.Wrapf(err, "parse stream id") - } - - // Save handshake packet. - v.handshake2 = pkt - logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID) - - // Start the UDP proxy to backend. - if err := v.connectBackend(ctx, streamID); err != nil { - return errors.Wrapf(err, "connect backend for %v", streamID) - } - - // Proxy client message to backend. - if v.backendUDP == nil { - return errors.Errorf("no backend for %v", streamID) - } - - // Proxy handshake 0 to backend server. - if b, err := v.handshake0.MarshalBinary(); err != nil { - return errors.Wrapf(err, "marshal handshake 0") - } else if _, err = v.backendUDP.Write(b); err != nil { - return errors.Wrapf(err, "write handshake 0") - } - logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0) - - // Read handshake 1 from backend server. - b := make([]byte, 4096) - handshake1p := &SRTHandshakePacket{} - if nn, err := v.backendUDP.Read(b); err != nil { - return errors.Wrapf(err, "read handshake 1") - } else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil { - return errors.Wrapf(err, "unmarshal handshake 1") - } - logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p) - - // Proxy handshake 2 to backend server. - handshake2p := *v.handshake2 - handshake2p.SynCookie = handshake1p.SynCookie - if b, err := handshake2p.MarshalBinary(); err != nil { - return errors.Wrapf(err, "marshal handshake 2") - } else if _, err = v.backendUDP.Write(b); err != nil { - return errors.Wrapf(err, "write handshake 2") - } - logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p) - - // Read handshake 3 from backend server. - handshake3p := &SRTHandshakePacket{} - if nn, err := v.backendUDP.Read(b); err != nil { - return errors.Wrapf(err, "read handshake 3") - } else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil { - return errors.Wrapf(err, "unmarshal handshake 3") - } - logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p) - - // Response handshake 3 to client. - v.handshake3 = &*handshake3p - v.handshake3.SynCookie = v.handshake1.SynCookie - v.socketID = handshake3p.SRTSocketID - logger.Df(ctx, "Handshake 3: %v", v.handshake3) - - if b, err := v.handshake3.MarshalBinary(); err != nil { - return errors.Wrapf(err, "marshal handshake 3") - } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { - return errors.Wrapf(err, "write handshake 3") - } - - // Start a goroutine to proxy message from backend to client. - // TODO: FIXME: Support close the connection when timeout or client disconnected. - go func() { - for ctx.Err() == nil { - nn, err := v.backendUDP.Read(b) - if err != nil { - // TODO: If backend server closed unexpectedly, we should notice the stream to quit. - logger.Wf(ctx, "read from backend failed, err=%v", err) - return - } - if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { - // TODO: If backend server closed unexpectedly, we should notice the stream to quit. - logger.Wf(ctx, "write to client failed, err=%v", err) - return - } - } - }() - return nil -} - -func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error { - if v.backendUDP != nil { - return nil - } - - // Parse stream id to host and resource. - host, resource, err := parseSRTStreamID(streamID) - if err != nil { - return errors.Wrapf(err, "parse stream id %v", streamID) - } - - if host == "" { - host = "localhost" - } - - streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource)) - if err != nil { - return errors.Wrapf(err, "build stream url %v", streamID) - } - - // Pick a backend SRS server to proxy the SRT stream. - backend, err := srsLoadBalancer.Pick(ctx, streamURL) - if err != nil { - return errors.Wrapf(err, "pick backend for %v", streamURL) - } - - // Parse UDP port from backend. - if len(backend.SRT) == 0 { - return errors.Errorf("no udp server %v for %v", backend, streamURL) - } - - _, _, udpPort, err := parseListenEndpoint(backend.SRT[0]) - if err != nil { - return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) - } - - // Connect to backend SRS server via UDP client. - // TODO: FIXME: Support close the connection when timeout or client disconnected. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} - if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { - return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) - } else { - v.backendUDP = backendUDP - } - - return nil -} - -// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2 -// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1 -type SRTHandshakePacket struct { - // F: 1 bit. Packet Type Flag. The control packet has this flag set to - // "1". The data packet has this flag set to "0". - ControlFlag uint8 - // Control Type: 15 bits. Control Packet Type. The use of these bits - // is determined by the control packet type definition. - // Handshake control packets (Control Type = 0x0000) are used to - // exchange peer configurations, to agree on connection parameters, and - // to establish a connection. - ControlType uint16 - // Subtype: 16 bits. This field specifies an additional subtype for - // specific packets. - SubType uint16 - // Type-specific Information: 32 bits. The use of this field depends on - // the particular control packet type. Handshake packets do not use - // this field. - AdditionalInfo uint32 - // Timestamp: 32 bits. - Timestamp uint32 - // Destination Socket ID: 32 bits. - SocketID uint32 - - // Version: 32 bits. A base protocol version number. Currently used - // values are 4 and 5. Values greater than 5 are reserved for future - // use. - Version uint32 - // Encryption Field: 16 bits. Block cipher family and key size. The - // values of this field are described in Table 2. The default value - // is AES-128. - // 0 | No Encryption Advertised - // 2 | AES-128 - // 3 | AES-192 - // 4 | AES-256 - EncryptionField uint16 - // Extension Field: 16 bits. This field is message specific extension - // related to Handshake Type field. The value MUST be set to 0 - // except for the following cases. (1) If the handshake control - // packet is the INDUCTION message, this field is sent back by the - // Listener. (2) In the case of a CONCLUSION message, this field - // value should contain a combination of Extension Type values. - // 0x00000001 | HSREQ - // 0x00000002 | KMREQ - // 0x00000004 | CONFIG - // 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1 - ExtensionField uint16 - // Initial Packet Sequence Number: 32 bits. The sequence number of the - // very first data packet to be sent. - InitSequence uint32 - // Maximum Transmission Unit Size: 32 bits. This value is typically set - // to 1500, which is the default Maximum Transmission Unit (MTU) size - // for Ethernet, but can be less. - MTU uint32 - // Maximum Flow Window Size: 32 bits. The value of this field is the - // maximum number of data packets allowed to be "in flight" (i.e. the - // number of sent packets for which an ACK control packet has not yet - // been received). - FlowWindow uint32 - // Handshake Type: 32 bits. This field indicates the handshake packet - // type. - // 0xFFFFFFFD | DONE - // 0xFFFFFFFE | AGREEMENT - // 0xFFFFFFFF | CONCLUSION - // 0x00000000 | WAVEHAND - // 0x00000001 | INDUCTION - HandshakeType uint32 - // SRT Socket ID: 32 bits. This field holds the ID of the source SRT - // socket from which a handshake packet is issued. - SRTSocketID uint32 - // SYN Cookie: 32 bits. Randomized value for processing a handshake. - // The value of this field is specified by the handshake message - // type. - SynCookie uint32 - // Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's - // sender. The value consists of four 32-bit fields. - PeerIP net.IP - // Extensions. - // Extension Type: 16 bits. The value of this field is used to process - // an integrated handshake. Each extension can have a pair of - // request and response types. - // Extension Length: 16 bits. The length of the Extension Contents - // field in four-byte blocks. - // Extension Contents: variable length. The payload of the extension. - ExtraData []byte -} - -func (v *SRTHandshakePacket) IsData() bool { - return v.ControlFlag == 0x00 -} - -func (v *SRTHandshakePacket) IsControl() bool { - return v.ControlFlag == 0x80 -} - -func (v *SRTHandshakePacket) IsHandshake() bool { - return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00 -} - -func (v *SRTHandshakePacket) StreamID() (string, error) { - p := v.ExtraData - for { - if len(p) < 2 { - return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData)) - } - - extType := binary.BigEndian.Uint16(p) - extSize := binary.BigEndian.Uint16(p[2:]) - p = p[4:] - - if len(p) < int(extSize*4) { - return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData)) - } - - // Ignore other packets except stream id. - if extType != 0x05 { - p = p[extSize*4:] - continue - } - - // We must copy it, because we will decode the stream id. - data := append([]byte{}, p[:extSize*4]...) - - // Reverse the stream id encoded in little-endian to big-endian. - for i := 0; i < len(data); i += 4 { - value := binary.LittleEndian.Uint32(data[i:]) - binary.BigEndian.PutUint32(data[i:], value) - } - - // Trim the trailing zero bytes. - data = bytes.TrimRight(data, "\x00") - return string(data), nil - } -} - -func (v *SRTHandshakePacket) String() string { - return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB", - v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData)) -} - -func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error { - if len(b) < 4 { - return errors.Errorf("Invalid packet length %v", len(b)) - } - v.ControlFlag = b[0] & 0x80 - v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff - v.SubType = binary.BigEndian.Uint16(b[2:4]) - - if len(b) < 64 { - return errors.Errorf("Invalid packet length %v", len(b)) - } - v.AdditionalInfo = binary.BigEndian.Uint32(b[4:]) - v.Timestamp = binary.BigEndian.Uint32(b[8:]) - v.SocketID = binary.BigEndian.Uint32(b[12:]) - v.Version = binary.BigEndian.Uint32(b[16:]) - v.EncryptionField = binary.BigEndian.Uint16(b[20:]) - v.ExtensionField = binary.BigEndian.Uint16(b[22:]) - v.InitSequence = binary.BigEndian.Uint32(b[24:]) - v.MTU = binary.BigEndian.Uint32(b[28:]) - v.FlowWindow = binary.BigEndian.Uint32(b[32:]) - v.HandshakeType = binary.BigEndian.Uint32(b[36:]) - v.SRTSocketID = binary.BigEndian.Uint32(b[40:]) - v.SynCookie = binary.BigEndian.Uint32(b[44:]) - - // Only support IPv4. - v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48]) - - v.ExtraData = b[64:] - - return nil -} - -func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) { - b := make([]byte, 64+len(v.ExtraData)) - binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType) - binary.BigEndian.PutUint16(b[2:], v.SubType) - binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo) - binary.BigEndian.PutUint32(b[8:], v.Timestamp) - binary.BigEndian.PutUint32(b[12:], v.SocketID) - binary.BigEndian.PutUint32(b[16:], v.Version) - binary.BigEndian.PutUint16(b[20:], v.EncryptionField) - binary.BigEndian.PutUint16(b[22:], v.ExtensionField) - binary.BigEndian.PutUint32(b[24:], v.InitSequence) - binary.BigEndian.PutUint32(b[28:], v.MTU) - binary.BigEndian.PutUint32(b[32:], v.FlowWindow) - binary.BigEndian.PutUint32(b[36:], v.HandshakeType) - binary.BigEndian.PutUint32(b[40:], v.SRTSocketID) - binary.BigEndian.PutUint32(b[44:], v.SynCookie) - - // Only support IPv4. - ip := v.PeerIP.To4() - b[48] = ip[3] - b[49] = ip[2] - b[50] = ip[1] - b[51] = ip[0] - - if len(v.ExtraData) > 0 { - copy(b[64:], v.ExtraData) - } - - return b, nil -} diff --git a/proxy/sync/map.go b/proxy/sync/map.go deleted file mode 100644 index f3f0d61ab..000000000 --- a/proxy/sync/map.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package sync - -import "sync" - -type Map[K comparable, V any] struct { - m sync.Map -} - -func (m *Map[K, V]) Delete(key K) { - m.m.Delete(key) -} - -func (m *Map[K, V]) Load(key K) (value V, ok bool) { - v, ok := m.m.Load(key) - if !ok { - return value, ok - } - return v.(V), ok -} - -func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { - v, loaded := m.m.LoadAndDelete(key) - if !loaded { - return value, loaded - } - return v.(V), loaded -} - -func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { - a, loaded := m.m.LoadOrStore(key, value) - return a.(V), loaded -} - -func (m *Map[K, V]) Range(f func(key K, value V) bool) { - m.m.Range(func(key, value any) bool { - return f(key.(K), value.(V)) - }) -} - -func (m *Map[K, V]) Store(key K, value V) { - m.m.Store(key, value) -} diff --git a/proxy/utils.go b/proxy/utils.go deleted file mode 100644 index 275a8e9da..000000000 --- a/proxy/utils.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import ( - "context" - "encoding/binary" - "encoding/json" - stdErr "errors" - "fmt" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "os" - "path" - "reflect" - "regexp" - "strconv" - "strings" - "syscall" - "time" - - "srs-proxy/errors" - "srs-proxy/logger" -) - -func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { - w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) - - b, err := json.Marshal(data) - if err != nil { - apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write(b) -} - -func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { - logger.Wf(ctx, "HTTP API error %+v", err) - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintln(w, fmt.Sprintf("%v", err)) -} - -func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { - // Always support CORS. Note that browser may send origin header for m3u8, but no origin header - // for ts. So we always response CORS header. - if true { - // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, - // headers, expose headers and methods. - w.Header().Set("Access-Control-Allow-Origin", "*") - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - w.Header().Set("Access-Control-Allow-Headers", "*") - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - w.Header().Set("Access-Control-Allow-Methods", "*") - } - - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return true - } - - return false -} - -func parseGracefullyQuitTimeout() (time.Duration, error) { - if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { - return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) - } else { - return t, nil - } -} - -// ParseBody read the body from r, and unmarshal JSON to v. -func ParseBody(r io.ReadCloser, v interface{}) error { - b, err := ioutil.ReadAll(r) - if err != nil { - return errors.Wrapf(err, "read body") - } - defer r.Close() - - if len(b) == 0 { - return nil - } - - if err := json.Unmarshal(b, v); err != nil { - return errors.Wrapf(err, "json unmarshal %v", string(b)) - } - - return nil -} - -// buildStreamURL build as vhost/app/stream for stream URL r. -func buildStreamURL(r string) (string, error) { - u, err := url.Parse(r) - if err != nil { - return "", errors.Wrapf(err, "parse url %v", r) - } - - // If not domain or ip in hostname, it's __defaultVhost__. - defaultVhost := !strings.Contains(u.Hostname(), ".") - - // If hostname is actually an IP address, it's __defaultVhost__. - if ip := net.ParseIP(u.Hostname()); ip.To4() != nil { - defaultVhost = true - } - - if defaultVhost { - return fmt.Sprintf("__defaultVhost__%v", u.Path), nil - } - - // Ignore port, only use hostname as vhost. - return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil -} - -// isPeerClosedError indicates whether peer object closed the connection. -func isPeerClosedError(err error) bool { - causeErr := errors.Cause(err) - - if stdErr.Is(causeErr, io.EOF) { - return true - } - - if stdErr.Is(causeErr, syscall.EPIPE) { - return true - } - - if netErr, ok := causeErr.(*net.OpError); ok { - if sysErr, ok := netErr.Err.(*os.SyscallError); ok { - if stdErr.Is(sysErr.Err, syscall.ECONNRESET) { - return true - } - } - } - - return false -} - -// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL -// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL -// with extension. -func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - hostname := "__defaultVhost__" - if strings.Contains(r.Host, ":") { - if v, _, err := net.SplitHostPort(r.Host); err == nil { - hostname = v - } - } - - var appStream, streamExt string - - // Parse app/stream from query string. - q := r.URL.Query() - if app := q.Get("app"); app != "" { - appStream = "/" + app - } - if stream := q.Get("stream"); stream != "" { - appStream = fmt.Sprintf("%v/%v", appStream, stream) - } - - // Parse app/stream from path. - if appStream == "" { - streamExt = path.Ext(r.URL.Path) - appStream = strings.TrimSuffix(r.URL.Path, streamExt) - } - - unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream) - fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) - return -} - -// rtcIsSTUN returns true if data of UDP payload is a STUN packet. -func rtcIsSTUN(data []byte) bool { - return len(data) > 0 && (data[0] == 0 || data[0] == 1) -} - -// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet. -func rtcIsRTPOrRTCP(data []byte) bool { - return len(data) >= 12 && (data[0]&0xC0) == 0x80 -} - -// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet. -func srtIsHandshake(data []byte) bool { - return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000 -} - -// srtParseSocketID parse the socket id from the SRT packet. -func srtParseSocketID(data []byte) uint32 { - if len(data) >= 16 { - return binary.BigEndian.Uint32(data[12:]) - } - return 0 -} - -// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. -func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { - if true { - ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) - ufragMatch := ufragRe.FindStringSubmatch(sdp) - if len(ufragMatch) <= 1 { - return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) - } - ufrag = ufragMatch[1] - } - - if true { - pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) - pwdMatch := pwdRe.FindStringSubmatch(sdp) - if len(pwdMatch) <= 1 { - return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) - } - pwd = pwdMatch[1] - } - - return ufrag, pwd, nil -} - -// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required). -// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url -func parseSRTStreamID(sid string) (host, resource string, err error) { - if true { - hostRe := regexp.MustCompile(`h=([^,]+)`) - hostMatch := hostRe.FindStringSubmatch(sid) - if len(hostMatch) > 1 { - host = hostMatch[1] - } - } - - if true { - resourceRe := regexp.MustCompile(`r=([^,]+)`) - resourceMatch := resourceRe.FindStringSubmatch(sid) - if len(resourceMatch) <= 1 { - return "", "", errors.Errorf("no resource in sid %v", sid) - } - resource = resourceMatch[1] - } - - return host, resource, nil -} - -// parseListenEndpoint parse the listen endpoint as: -// port The tcp listen port, like 1935. -// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 -func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { - // If no colon in ep, it's port in string. - if !strings.Contains(ep, ":") { - if p, err := strconv.Atoi(ep); err != nil { - return "", nil, 0, errors.Wrapf(err, "parse port %v", ep) - } else { - return "tcp", nil, uint16(p), nil - } - } - - // Must be protocol://ip:port schema. - parts := strings.Split(ep, ":") - if len(parts) != 3 { - return "", nil, 0, errors.Errorf("invalid endpoint %v", ep) - } - - if p, err := strconv.Atoi(parts[2]); err != nil { - return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2]) - } else { - return parts[0], net.ParseIP(parts[1]), uint16(p), nil - } -} diff --git a/proxy/version.go b/proxy/version.go deleted file mode 100644 index 0cb7efe31..000000000 --- a/proxy/version.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2025 Winlin -// -// SPDX-License-Identifier: MIT -package main - -import "fmt" - -func VersionMajor() int { - return 1 -} - -// VersionMinor specifies the typical version of SRS we adapt to. -func VersionMinor() int { - return 5 -} - -func VersionRevision() int { - return 0 -} - -func Version() string { - return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision()) -} - -func Signature() string { - return "SRSProxy" -}