Setup Frontend Livestream site and Backend HTTP Server by python

This commit is contained in:
Jason-JP-Yang 2025-06-22 12:33:50 +08:00
parent bf4ea37ea3
commit c467522328
16 changed files with 3862 additions and 456 deletions

View File

@ -1,37 +1,186 @@
# no-daemon and write log to console config for srs.
# @see full.conf for detail config.
# docker config for srs.
# @see full.conf for detail config explanation.
#############################################################################################
# Last modified: 2025-05-10 Jason Yang
# Contributor: SRS Team, Jason Yang, Jasper, HyperKNF, Hayden, NPL ITP Team
#############################################################################################
listen 1935;
max_connections 1000;
daemon off;
srs_log_tank all;
#############################################################################################
# Global sections
#############################################################################################
ff_log_dir ./objs;
ff_log_level info;
srs_log_tank all;
srs_log_file ./objs/srs.log;
# TRACE, DEBUG, INFO, WARN, ERROR
srs_log_level_v2 info;
max_connections 1000;
daemon off;
utc_time off;
#############################################################################################
# RTMP sections
#############################################################################################
listen 1935;
chunk_size 60000;
#############################################################################################
# HTTP sections
#############################################################################################
http_api {
enabled on;
listen 1985;
enabled on;
listen 1985;
crossdomain on;
auth {
enabled on;
username python_stats;
password wMePq3ahpoLRzgsVg7BY9eE82uuJHT0YukD2ZE1JfMY2RjP4e6QnUaKg3V9x5s9M;
}
https {
enabled off;
listen 1990;
key ./conf/server.key;
cert ./conf/server.crt;
}
}
http_server {
enabled on;
listen 8080;
enabled on;
listen 8080;
dir ./objs/nginx/html;
crossdomain on;
https {
enabled off;
listen 8088;
key ./conf/server.key;
cert ./conf/server.crt;
}
}
#############################################################################################
# SRT server section
#############################################################################################
srt_server {
# whether SRT server is enabled.
enabled off;
listen 10080;
maxbw 1000000000;
mss 1500;
connect_timeout 4000;
peer_idle_timeout 8000;
default_app live;
peerlatency 0;
recvlatency 0;
latency 0;
tsbpdmode off;
tlpktdrop off;
sendbuf 2000000;
recvbuf 2000000;
}
#############################################################################################
# WebRTC server section
#############################################################################################
rtc_server {
enabled on;
listen 8000; # UDP port
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate
candidate $CANDIDATE;
enabled on;
listen 8000;
protocol udp;
tcp {
enabled off;
listen 8000;
}
candidate $CANDIDATE;
ip_family ipv4;
api_as_candidates on;
resolve_api_domain on;
ecdsa on;
encrypt on;
}
#############################################################################################
# PYTHON_ADDONS sections
#############################################################################################
python_addons {
# Enable or disable Python addon management
enabled on;
# Python addon definitions
# Each addon block defines a Python script to run
addon {
script "./python/httpbackend_server.py";
}
}
#############################################################################################
# VHOST sections
#############################################################################################
vhost __defaultVhost__ {
enabled on;
hls {
enabled on;
enabled on;
}
srt {
# Whether enable SRT on this vhost.
enabled on;
srt_to_rtmp on;
}
http_remux {
enabled on;
mount [vhost]/[app]/[stream].flv;
enabled on;
mount [vhost]/[app]/[stream].flv;
}
rtc {
enabled on;
enabled on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc
rtmp_to_rtc on;
rtmp_to_rtc on;
keep_bframe on;
# opus_bitrate 48000;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp
rtc_to_rtmp on;
rtc_to_rtmp on;
pli_for_rtmp 6.0;
aac_bitrate 48000;
}
}
chunk_size 128;
tcp_nodelay on;
min_latency on;
play {
gop_cache off;
queue_length 10;
mw_latency 100;
mw_msgs 0;
}
publish {
mr off;
mr_latency 300;
firstpkt_timeout 20000;
normal_timeout 5000;
}
security {
enabled off;
allow play all;
allow publish all;
}
dvr {
enabled on;
dvr_path ./DVR_Record/[stream]/[2006].[01].[02].[15].[04].[05].mp4;
dvr_plan segment;
dvr_duration 30;
dvr_apply /^live\/.*/;
dvr_wait_keyframe on;
time_jitter full;
}
}

1
trunk/livestream_site Submodule

@ -0,0 +1 @@
Subproject commit 348b347bd7e1cf9d59acff0ba1bdf5390f0751d7

View File

@ -1,122 +0,0 @@
#!/usr/bin/env python3
"""
SRS Analytics Script
This demonstrates another Python process that can run alongside SRS.
"""
import time
import signal
import sys
import argparse
import logging
import json
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Thread
from datetime import datetime
class AnalyticsHandler(BaseHTTPRequestHandler):
def do_GET(self):
"""Handle GET requests for analytics data"""
if self.path == '/stats':
# Return sample analytics data
stats = {
'timestamp': datetime.now().isoformat(),
'connections': 42,
'streams': 5,
'bandwidth': '1.2 Mbps',
'uptime': '2h 30m'
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(stats, indent=2).encode())
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
"""Override to use our logger"""
pass
class SRSAnalytics:
def __init__(self, port=8888):
self.port = port
self.running = True
self.server = None
self.server_thread = None
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] [Python Analytics] %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
self.logger = logging.getLogger(__name__)
# Set up signal handlers
signal.signal(signal.SIGTERM, self.signal_handler)
signal.signal(signal.SIGINT, self.signal_handler)
def signal_handler(self, signum, frame):
"""Handle shutdown signals from SRS"""
self.logger.info(f"Received signal {signum}, shutting down gracefully...")
self.running = False
if self.server:
self.server.shutdown()
def start_http_server(self):
"""Start the HTTP analytics server"""
try:
self.server = HTTPServer(('localhost', self.port), AnalyticsHandler)
self.logger.info(f"Analytics server started on http://localhost:{self.port}")
self.server.serve_forever()
except Exception as e:
self.logger.error(f"Error starting HTTP server: {e}")
def run(self):
"""Main analytics loop"""
self.logger.info(f"SRS Analytics started on port {self.port}")
try:
# Start HTTP server in a separate thread
self.server_thread = Thread(target=self.start_http_server)
self.server_thread.daemon = True
self.server_thread.start()
# Main analytics loop
while self.running:
# Simulate analytics work
self.logger.debug("Processing analytics data...")
# You can add your analytics logic here:
# - Collect stream metrics
# - Process viewer statistics
# - Generate reports
# - Store data to database
time.sleep(10) # Process every 10 seconds
except Exception as e:
self.logger.error(f"Error in analytics loop: {e}")
finally:
self.cleanup()
def cleanup(self):
"""Cleanup before shutdown"""
self.logger.info("Cleaning up Analytics server...")
if self.server:
self.server.shutdown()
self.logger.info("Analytics server stopped")
def main():
parser = argparse.ArgumentParser(description='SRS Analytics Server')
parser.add_argument('--port', type=int, default=8888, help='HTTP server port')
args = parser.parse_args()
analytics = SRSAnalytics(args.port)
analytics.run()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,616 @@
import math
from typing import Optional, List
class AvatarGenerator:
DEFAULT_COLORS = ["#000000", "#8f1414", "#e50e0e", "#f3450f", "#fcac03"]
DEFAULT_SIZE = 80
DEFAULT_VARIANT = 'marble'
VALID_VARIANTS = {'marble', 'beam', 'pixel', 'sunset', 'ring', 'bauhaus'}
def __init__(self, default_colors: Optional[List[str]] = None, default_size: Optional[int] = None, default_variant: Optional[str] = None):
self.colors = default_colors if default_colors else self.DEFAULT_COLORS
self.size = default_size if default_size else self.DEFAULT_SIZE
self.variant = default_variant if default_variant else self.DEFAULT_VARIANT
self.AVATAR_GENERATORS = {
'marble': self._generate_marble_avatar,
'beam': self._generate_beam_avatar,
'pixel': self._generate_pixel_avatar,
'sunset': self._generate_sunset_avatar,
'ring': self._generate_ring_avatar,
'bauhaus': self._generate_bauhaus_avatar,
}
@staticmethod
def _hash_code(name: str) -> int:
"""Generate hash from name string - exact port of hashCode function"""
try:
if not name or not isinstance(name, str):
# Consider raising a TypeError or returning a default hash for invalid input
return 12345 # Fallback for invalid input
hash_val = 0
for char in name:
hash_val = (hash_val << 5) - hash_val + ord(char)
hash_val &= hash_val # Convert to 32bit integer
return abs(hash_val)
except Exception:
return 12345 # fallback hash value
@staticmethod
def _get_modulus(num: int, max_val: int) -> int:
"""Get modulus"""
return num % max_val
@staticmethod
def _get_digit(number: int, ntn: int) -> int:
"""Get digit at position"""
try:
if not isinstance(number, int) or not isinstance(ntn, int):
return 0 # Fallback for invalid input type
return int((number / (10 ** ntn)) % 10)
except (ZeroDivisionError, ValueError, OverflowError):
return 0
@staticmethod
def _get_boolean(number: int, ntn: int) -> bool:
"""Get boolean from digit"""
return not ((AvatarGenerator._get_digit(number, ntn)) % 2)
@staticmethod
def _get_angle(x: float, y: float) -> float:
"""Get angle from coordinates"""
return math.atan2(y, x) * 180 / math.pi
@staticmethod
def _get_unit(number: int, range_val: int, index: int = None) -> int:
"""Get unit value with optional negative"""
try:
if not isinstance(number, int) or not isinstance(range_val, int) or range_val == 0:
return 0 # Fallback for invalid input
value = number % range_val
if index and ((AvatarGenerator._get_digit(number, index) % 2) == 0):
return -value
return value
except (ZeroDivisionError, ValueError, TypeError):
return 0
@staticmethod
def _get_random_color(number: int, colors: List[str], range_val: int) -> str:
"""Get random color from palette"""
try:
if not colors or range_val == 0:
return AvatarGenerator.DEFAULT_COLORS[0] # Fallback
return colors[number % range_val]
except (IndexError, TypeError, ZeroDivisionError):
return AvatarGenerator.DEFAULT_COLORS[0]
@staticmethod
def _get_contrast(hexcolor: str) -> str:
"""Get contrasting color (black or white)"""
try:
if not hexcolor or not isinstance(hexcolor, str):
return '#000000' # Fallback
# Remove leading # if present
if hexcolor.startswith('#'):
hexcolor = hexcolor[1:]
# Clean up any remaining quotes or whitespace
hexcolor = hexcolor.strip().replace('"', '').replace("'", "")
# Ensure we have exactly 6 hex characters
if len(hexcolor) != 6:
return '#000000' # Fallback
# Validate hex characters
try:
int(hexcolor, 16)
except ValueError:
return '#000000' # Fallback
# Convert to RGB
r = int(hexcolor[0:2], 16)
g = int(hexcolor[2:4], 16)
b = int(hexcolor[4:6], 16)
# Get YIQ ratio
yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
# Check contrast
return '#000000' if yiq >= 128 else '#FFFFFF'
except Exception:
return '#000000'
@staticmethod
def normalize_colors(colors_param: Optional[str]) -> List[str]:
"""Normalize color palette from query parameter"""
try:
if not colors_param:
return AvatarGenerator.DEFAULT_COLORS
# Clean up URL encoding
cleaned = colors_param.replace('%22', '"').replace('%20', ' ').replace('%5B', '[').replace('%5D', ']').strip()
# Extract colors from different formats
color_list = AvatarGenerator._extract_color_strings(cleaned)
# Validate and normalize colors
return AvatarGenerator._validate_color_list(color_list)
except Exception:
return AvatarGenerator.DEFAULT_COLORS
@staticmethod
def _extract_color_strings(cleaned_colors: str) -> List[str]:
"""Extract color strings from cleaned input"""
# Handle array format: ["#xxxxxx", "#xxxxxx", ...]
if cleaned_colors.startswith('[') and cleaned_colors.endswith(']'):
inner_content = cleaned_colors[1:-1].strip()
if not inner_content:
return []
return [part.strip().replace('"', '').replace("'", '') for part in inner_content.split(',')]
# Handle comma-separated format: #xxxxxx, #xxxxxx, ...
return [part.strip().replace('"', '').replace("'", '') for part in cleaned_colors.split(',')]
@staticmethod
def _validate_color_list(color_list: List[str]) -> List[str]:
"""Validate and normalize hex colors"""
normalized_colors = []
for color in color_list:
color = color.strip()
if not color:
continue
if not color.startswith('#'):
color = '#' + color
if AvatarGenerator._is_valid_hex_color(color):
normalized_colors.append(color)
return normalized_colors if normalized_colors else AvatarGenerator.DEFAULT_COLORS
@staticmethod
def _is_valid_hex_color(color: str) -> bool:
"""Check if string is valid hex color"""
return len(color) == 7 and all(c in '0123456789ABCDEFabcdef' for c in color[1:])
def _generate_marble_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
"""Generate marble variant avatar"""
ELEMENTS = 3
SIZE = 80 # Intrinsic size of the SVG design
def generate_colors_marble(name: str, current_colors: List[str]):
num_from_name = self._hash_code(name)
range_val = len(current_colors) if current_colors else len(self.DEFAULT_COLORS)
elements_properties = []
for i in range(ELEMENTS):
elements_properties.append({
'color': self._get_random_color(num_from_name + i, current_colors, range_val),
'translateX': self._get_unit(num_from_name * (i + 1), SIZE // 10, 1),
'translateY': self._get_unit(num_from_name * (i + 1), SIZE // 10, 2),
'scale': 1.2 + self._get_unit(num_from_name * (i + 1), SIZE // 20) / 10,
'rotate': self._get_unit(num_from_name * (i + 1), 360, 1)
})
return elements_properties
properties = generate_colors_marble(name, colors)
mask_id = f"mask_{self._hash_code(name)}"
filter_id = f"filter_{self._hash_code(name)}"
# Use the passed 'size' for width/height, internal 'SIZE' for viewBox
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
if title:
svg += f'<title>{name}</title>'
svg += f'''
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
</mask>
<g mask="url(#{mask_id})">
<rect width="{SIZE}" height="{SIZE}" fill="{properties[0]['color']}" />
<path
filter="url(#{filter_id})"
d="M32.414 59.35L50.376 70.5H72.5v-71H33.728L26.5 13.381l19.057 27.08L32.414 59.35z"
fill="{properties[1]['color']}"
transform="translate({properties[1]['translateX']} {properties[1]['translateY']}) rotate({properties[1]['rotate']} {SIZE // 2} {SIZE // 2}) scale({properties[1]['scale']})"
/>
<path
filter="url(#{filter_id})"
style="mix-blend-mode: overlay;"
d="M22.216 24L0 46.75l14.108 38.129L78 86l-3.081-59.276-22.378 4.005 12.972 20.186-23.35 27.395L22.215 24z"
fill="{properties[2]['color']}"
transform="translate({properties[2]['translateX']} {properties[2]['translateY']}) rotate({properties[2]['rotate']} {SIZE // 2} {SIZE // 2}) scale({properties[2]['scale']})"
/>
</g>
<defs>
<filter
id="{filter_id}"
filterUnits="userSpaceOnUse"
colorInterpolationFilters="sRGB"
>
<feFlood floodOpacity="0" result="BackgroundImageFix" />
<feBlend in="SourceGraphic" in2="BackgroundImageFix" result="shape" />
<feGaussianBlur stdDeviation="7" result="effect1_foregroundBlur" />
</filter>
</defs>
</svg>'''
return svg
def _generate_beam_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
"""Generate beam variant avatar"""
SIZE = 36 # Intrinsic size of the SVG design
def generate_data_beam(name: str, current_colors: List[str]):
num_from_name = self._hash_code(name)
range_val = len(current_colors) if current_colors else len(self.DEFAULT_COLORS)
wrapper_color = self._get_random_color(num_from_name, current_colors, range_val)
pre_translate_x = self._get_unit(num_from_name, 10, 1)
wrapper_translate_x = pre_translate_x + SIZE // 9 if pre_translate_x < 5 else pre_translate_x
pre_translate_y = self._get_unit(num_from_name, 10, 2)
wrapper_translate_y = pre_translate_y + SIZE // 9 if pre_translate_y < 5 else pre_translate_y
data = {
'wrapperColor': wrapper_color,
'faceColor': self._get_contrast(wrapper_color),
'backgroundColor': self._get_random_color(num_from_name + 13, current_colors, range_val),
'wrapperTranslateX': wrapper_translate_x,
'wrapperTranslateY': wrapper_translate_y,
'wrapperRotate': self._get_unit(num_from_name, 360),
'wrapperScale': 1 + self._get_unit(num_from_name, SIZE // 12) / 10,
'isMouthOpen': self._get_boolean(num_from_name, 2),
'isCircle': self._get_boolean(num_from_name, 1),
'eyeSpread': self._get_unit(num_from_name, 5),
'mouthSpread': self._get_unit(num_from_name, 3),
'faceRotate': self._get_unit(num_from_name, 10, 3),
'faceTranslateX': wrapper_translate_x // 2 if wrapper_translate_x > SIZE // 6 else self._get_unit(num_from_name, 8, 1),
'faceTranslateY': wrapper_translate_y // 2 if wrapper_translate_y > SIZE // 6 else self._get_unit(num_from_name, 7, 2),
}
return data
data = generate_data_beam(name, colors)
mask_id = f"mask_{self._hash_code(name)}"
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
if title:
svg += f'<title>{name}</title>'
rx_value = SIZE if data['isCircle'] else SIZE // 6
svg += f'''
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
</mask>
<g mask="url(#{mask_id})">
<rect width="{SIZE}" height="{SIZE}" fill="{data['backgroundColor']}" />
<rect
x="0"
y="0"
width="{SIZE}"
height="{SIZE}"
transform="translate({data['wrapperTranslateX']} {data['wrapperTranslateY']}) rotate({data['wrapperRotate']} {SIZE // 2} {SIZE // 2}) scale({data['wrapperScale']})"
fill="{data['wrapperColor']}"
rx="{rx_value}"
/>
<g transform="translate({data['faceTranslateX']} {data['faceTranslateY']}) rotate({data['faceRotate']} {SIZE // 2} {SIZE // 2})">'''
if data['isMouthOpen']:
svg += f'''
<path
d="M15 {19 + data['mouthSpread']}c2 1 4 1 6 0"
stroke="{data['faceColor']}"
fill="none"
strokeLinecap="round"
/>'''
else:
svg += f'''
<path
d="M13,{19 + data['mouthSpread']} a1,0.75 0 0,0 10,0"
fill="{data['faceColor']}"
/>'''
svg += f'''
<rect
x="{14 - data['eyeSpread']}"
y="14"
width="1.5"
height="2"
rx="1"
stroke="none"
fill="{data['faceColor']}"
/>
<rect
x="{20 + data['eyeSpread']}"
y="14"
width="1.5"
height="2"
rx="1"
stroke="none"
fill="{data['faceColor']}"
/>
</g>
</g>
</svg>'''
return svg
def _generate_pixel_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
"""Generate pixel variant avatar"""
ELEMENTS = 64 # 8x8 grid
SIZE = 80 # Intrinsic SVG size
def generate_colors_pixel(name: str, current_colors: List[str]):
num_from_name = self._hash_code(name)
# Use current_colors if provided, otherwise fallback to instance default, then class default
final_colors = current_colors if current_colors else self.colors
range_val = len(final_colors)
color_list = []
for i in range(ELEMENTS):
color_list.append(self._get_random_color(num_from_name + i, final_colors, range_val))
return color_list
pixel_colors = generate_colors_pixel(name, colors)
mask_id = f"mask_{self._hash_code(name)}"
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
if title:
svg += f'<title>{name}</title>'
svg += f'''
<mask id="{mask_id}" mask-type="alpha" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
</mask>
<g mask="url(#{mask_id})">'''
# Generate 8x8 grid of pixels
idx = 0
pixel_size = SIZE // 8 # Each pixel is 10x10 if SIZE is 80
for row in range(8):
for col in range(8):
svg += f'''
<rect x="{col * pixel_size}" y="{row * pixel_size}" width="{pixel_size}" height="{pixel_size}" fill="{pixel_colors[idx]}" />'''
idx +=1
svg += '''
</g>
</svg>'''
return svg
def _generate_sunset_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
"""Generate sunset variant avatar"""
ELEMENTS = 4 # For 4 gradient stops
SIZE = 80 # Intrinsic SVG size
def generate_colors_sunset(name: str, current_colors: List[str]):
num_from_name = self._hash_code(name)
# Use current_colors if provided, otherwise fallback to instance default, then class default
final_colors = current_colors if current_colors else self.colors
# Ensure we have at least 4 colors for sunset, repeat if necessary
if len(final_colors) < ELEMENTS:
final_colors = (final_colors * (ELEMENTS // len(final_colors) + 1))[:ELEMENTS]
range_val = len(final_colors)
color_list = []
for i in range(ELEMENTS):
color_list.append(self._get_random_color(num_from_name + i, final_colors, range_val))
return color_list
sunset_colors = generate_colors_sunset(name, colors)
# name_without_space = name.replace(' ', '').replace('-', '').replace('_', '') # Not used
hash_val = abs(self._hash_code(name))
mask_id = f"mask_{hash_val}"
gradient_id_1 = f"gradient_paint0_linear_{hash_val}"
gradient_id_2 = f"gradient_paint1_linear_{hash_val}"
rx_value = '' if square else f'rx="{SIZE * 2}"'
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
if title:
svg += f'<title>{name}</title>'
svg += f'''
<defs>
<linearGradient
id="{gradient_id_1}"
x1="{SIZE // 2}"
y1="0"
x2="{SIZE // 2}"
y2="{SIZE // 2}"
gradientUnits="userSpaceOnUse"
>
<stop stop-color="{sunset_colors[0]}" />
<stop offset="1" stop-color="{sunset_colors[1]}" />
</linearGradient>
<linearGradient
id="{gradient_id_2}"
x1="{SIZE // 2}"
y1="{SIZE // 2}"
x2="{SIZE // 2}"
y2="{SIZE}"
gradientUnits="userSpaceOnUse"
>
<stop stop-color="{sunset_colors[2]}" />
<stop offset="1" stop-color="{sunset_colors[3]}" />
</linearGradient>
</defs>
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
<rect width="{SIZE}" height="{SIZE}" {rx_value} fill="#FFFFFF" />
</mask>
<g mask="url(#{mask_id})">
<path fill="url(#{gradient_id_1})" d="M0 0h{SIZE}v{SIZE//2}H0z" />
<path fill="url(#{gradient_id_2})" d="M0 {SIZE//2}h{SIZE}v{SIZE//2}H0z" />
</g>
</svg>'''
return svg
def _generate_ring_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
"""Generate ring variant avatar"""
SIZE = 90 # Intrinsic SVG size
COLORS_NEEDED = 9 # Number of colors used in the SVG paths
def generate_colors_ring(name: str, current_colors: List[str]):
num_from_name = self._hash_code(name)
# Use current_colors if provided, otherwise fallback to instance default, then class default
final_colors = current_colors if current_colors else self.colors
# Ensure we have enough colors, repeat if necessary
if len(final_colors) < COLORS_NEEDED:
final_colors = (final_colors * (COLORS_NEEDED // len(final_colors) + 1))[:COLORS_NEEDED]
range_val = len(final_colors)
ring_palette = []
for i in range(COLORS_NEEDED):
ring_palette.append(self._get_random_color(num_from_name + i, final_colors, range_val))
return ring_palette
ring_colors = generate_colors_ring(name, colors)
mask_id = f"mask_{self._hash_code(name)}"
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
if title:
svg += f'<title>{name}</title>'
svg += f'''
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
</mask>
<g mask="url(#{mask_id})">
<path d="M0 0h{SIZE}v{SIZE//2}H0z" fill="{ring_colors[0]}" />
<path d="M0 {SIZE//2}h{SIZE}v{SIZE//2}H0z" fill="{ring_colors[1]}" />
<path d="M{SIZE-7} {SIZE//2}a{SIZE//2-2} {SIZE//2-2} 0 00-{SIZE-14} 0h{SIZE-14}z" fill="{ring_colors[2]}" />
<path d="M{SIZE-7} {SIZE//2}a{SIZE//2-2} {SIZE//2-2} 0 01-{SIZE-14} 0h{SIZE-14}z" fill="{ring_colors[3]}" />
<path d="M{SIZE-13} {SIZE//2}a{SIZE//2-8} {SIZE//2-8} 0 10-{SIZE-26} 0h{SIZE-26}z" fill="{ring_colors[4]}" />
<path d="M{SIZE-13} {SIZE//2}a{SIZE//2-8} {SIZE//2-8} 0 11-{SIZE-26} 0h{SIZE-26}z" fill="{ring_colors[5]}" />
<path d="M{SIZE-19} {SIZE//2}a{SIZE//2-14} {SIZE//2-14} 0 00-{SIZE-38} 0h{SIZE-38}z" fill="{ring_colors[6]}" />
<path d="M{SIZE-19} {SIZE//2}a{SIZE//2-14} {SIZE//2-14} 0 01-{SIZE-38} 0h{SIZE-38}z" fill="{ring_colors[7]}" />
<circle cx="{SIZE//2}" cy="{SIZE//2}" r="{SIZE//2-22}" fill="{ring_colors[8]}" />
</g>
</svg>'''
return svg
def _generate_bauhaus_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
"""Generate bauhaus variant avatar"""
ELEMENTS = 4 # Number of geometric elements
SIZE = 80 # Intrinsic SVG size
def generate_colors_bauhaus(name: str, current_colors: List[str]):
num_from_name = self._hash_code(name)
# Use current_colors if provided, otherwise fallback to instance default, then class default
final_colors = current_colors if current_colors else self.colors
range_val = len(final_colors)
properties = []
for i in range(ELEMENTS):
properties.append({
'color': self._get_random_color(num_from_name + i, final_colors, range_val),
'translateX': self._get_unit(num_from_name * (i + 1), SIZE // 2 - (i + 17), 1),
'translateY': self._get_unit(num_from_name * (i + 1), SIZE // 2 - (i + 17), 2),
'rotate': self._get_unit(num_from_name * (i + 1), 360),
'isSquare': self._get_boolean(num_from_name, 2)
})
return properties
properties = generate_colors_bauhaus(name, colors)
mask_id = f"mask_{self._hash_code(name)}"
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
if title:
svg += f'<title>{name}</title>'
svg += f'''
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
</mask>
<g mask="url(#{mask_id})">
<rect width="{SIZE}" height="{SIZE}" fill="{properties[0]['color']}" />
<rect
x="{(SIZE - 60) // 2}"
y="{(SIZE - 20) // 2}"
width="{SIZE}"
height="{SIZE if properties[1]['isSquare'] else SIZE // 8}"
fill="{properties[1]['color']}"
transform="translate({properties[1]['translateX']} {properties[1]['translateY']}) rotate({properties[1]['rotate']} {SIZE // 2} {SIZE // 2})"
/>
<circle
cx="{SIZE // 2}"
cy="{SIZE // 2}"
fill="{properties[2]['color']}"
r="{SIZE // 5}"
transform="translate({properties[2]['translateX']} {properties[2]['translateY']})"
/>
<line
x1="0"
y1="{SIZE // 2}"
x2="{SIZE}"
y2="{SIZE // 2}"
strokeWidth="2"
stroke="{properties[3]['color']}"
transform="translate({properties[3]['translateX']} {properties[3]['translateY']}) rotate({properties[3]['rotate']} {SIZE // 2} {SIZE // 2})"
/>
</g>
</svg>'''
return svg
def generate_avatar(self,
name: str,
variant: Optional[str] = None,
colors: Optional[List[str]] = None, # Allow passing colors as list of strings
colors_param: Optional[str] = None, # Allow passing colors as a string parameter
size: Optional[int] = None,
square: bool = False,
title: bool = True) -> str:
"""
Generate an SVG avatar.
Args:
name: The name to generate the avatar for.
variant: The avatar variant (e.g., 'marble', 'beam'). Uses instance default if None.
colors: A list of hex color strings. Overrides colors_param if provided.
colors_param: A string representation of colors (e.g., "['#FF0000', '#00FF00']" or "#FF0000,#00FF00").
Used if 'colors' list is not provided.
size: The desired size of the avatar in pixels. Uses instance default if None.
square: If True, generates a square avatar. Otherwise, a circular one.
title: If True, includes a <title> tag with the name in the SVG.
Returns:
An SVG string representing the avatar.
"""
current_variant = variant if variant and variant in self.VALID_VARIANTS else self.variant
current_size = size if size else self.size
# Determine colors: prioritize direct list, then string param, then instance default
if colors:
# Validate and normalize if a list is directly passed
final_colors = self._validate_color_list(colors)
elif colors_param:
final_colors = self.normalize_colors(colors_param)
else:
final_colors = self.colors # Use instance default colors
if not final_colors: # Ensure there's always a fallback
final_colors = self.DEFAULT_COLORS
generator_func = self.AVATAR_GENERATORS.get(current_variant)
if generator_func:
return generator_func(name, final_colors, current_size, square, title)
else:
# Fallback to default variant if the chosen one is somehow invalid after checks
default_gen_func = self.AVATAR_GENERATORS.get(self.DEFAULT_VARIANT)
return default_gen_func(name, final_colors, current_size, square, title)

890
trunk/python/database.py Normal file
View File

@ -0,0 +1,890 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Database module for JWT authentication system
Provides SQLite database initialization and management for user authentication
"""
import sqlite3
import os
import json
import re
import threading
import uuid
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Union
from contextlib import contextmanager
import random
import string
# Import SRS logger
from srs_logger import get_logger
class Database:
_instance = None
_lock = threading.Lock()
def __new__(cls, db_path: str = "./objs/srs_database.db"):
"""Singleton pattern to ensure only one database instance"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, db_path: str = "./objs/srs_database.db"):
if self._initialized:
return
self._initialized = True
self.db_path = db_path
self.logger = get_logger()
self._connection_lock = threading.Lock()
# Initialize database
self._init_database()
def _init_database(self):
"""Initialize database and create tables if they don't exist"""
try:
# Create objs directory if it doesn't exist
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
self.logger.info(f"Created database directory: {db_dir}")
# Check if database exists
db_exists = os.path.exists(self.db_path)
if not db_exists:
self.logger.warn(f"Database file not found, creating new database: {self.db_path}")
else:
self.logger.info(f"Using existing database: {self.db_path}")
# Create tables
self._create_tables()
if not db_exists:
self.logger.info("Database initialized successfully with all tables created")
else:
self.logger.info("Database connection established and tables verified")
except Exception as e:
self.logger.exception(f"Failed to initialize database: {e}")
raise
def _create_tables(self):
"""Create all required tables"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
# Create users table
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
user_id TEXT PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
email TEXT UNIQUE,
hashed_password TEXT NOT NULL,
salt TEXT NOT NULL,
user_group TEXT NOT NULL DEFAULT '["user"]',
last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_activated BOOLEAN NOT NULL DEFAULT 1
)
''')
# Create auth_session table
cursor.execute('''
CREATE TABLE IF NOT EXISTS auth_session (
session_id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
hashed_authkey TEXT NOT NULL,
salt TEXT NOT NULL,
expire_time TIMESTAMP NOT NULL,
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
)
''')
# Create refresh_session table
cursor.execute('''
CREATE TABLE IF NOT EXISTS refresh_session (
session_id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
hashed_refreshkey TEXT NOT NULL,
salt TEXT NOT NULL,
expire_time TIMESTAMP NOT NULL,
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
)
''')
# Create streams table
cursor.execute('''
CREATE TABLE IF NOT EXISTS streams (
stream_id TEXT PRIMARY KEY,
stream_code TEXT UNIQUE NOT NULL,
streamer_id TEXT NOT NULL,
stream_title TEXT NOT NULL,
stream_description TEXT,
stream_tags TEXT DEFAULT '[]',
stream_visibility TEXT NOT NULL DEFAULT 'public',
quality_info TEXT,
active_time TIMESTAMP,
stream_status TEXT NOT NULL DEFAULT 'planned',
FOREIGN KEY (streamer_id) REFERENCES users (user_id) ON DELETE CASCADE
)
''')
# Create system_stats table
cursor.execute('''
CREATE TABLE IF NOT EXISTS system_stats (
timestamp TIMESTAMP PRIMARY KEY,
srs_uptime REAL NOT NULL,
srs_cpu_percent REAL NOT NULL,
srs_memory_percent REAL NOT NULL,
srs_recv_KBps REAL NOT NULL,
srs_send_KBps REAL NOT NULL,
disk_read_KBps REAL NOT NULL,
disk_write_KBps REAL NOT NULL,
os_uptime REAL NOT NULL,
os_cpu_percent REAL NOT NULL,
os_memory_percent REAL NOT NULL
)
''')
# Create indexes for better performance
cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_user_id ON auth_session(user_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_expire ON auth_session(expire_time)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_user_id ON refresh_session(user_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_expire ON refresh_session(expire_time)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_system_stats_timestamp ON system_stats(timestamp)')
conn.commit()
self.logger.debug("All database tables and indexes created/verified successfully")
except Exception as e:
self.logger.exception(f"Failed to create database tables: {e}")
raise
@contextmanager
def get_connection(self):
"""Get a database connection with proper error handling and cleanup"""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=30.0)
conn.row_factory = sqlite3.Row # Enable column access by name
conn.execute('PRAGMA foreign_keys = ON') # Enable foreign key constraints
# Configure datetime adapter for Python 3.12+ compatibility
sqlite3.register_adapter(datetime, lambda dt: dt.isoformat())
sqlite3.register_converter("TIMESTAMP", lambda b: datetime.fromisoformat(b.decode()))
yield conn
except Exception as e:
if conn:
conn.rollback()
self.logger.error(f"Database connection error: {e}")
raise
finally:
if conn:
conn.close()
# ============================================================================
# Users management functions
# ============================================================================
def cleanup_expired_sessions(self):
"""Remove expired auth and refresh sessions, and sessions for deactivated users"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
current_time = datetime.now()
# Clean up expired auth sessions
cursor.execute('''
DELETE FROM auth_session
WHERE expire_time < ?
''', (current_time,))
auth_expired_deleted = cursor.rowcount
# Clean up expired refresh sessions
cursor.execute('''
DELETE FROM refresh_session
WHERE expire_time < ?
''', (current_time,))
refresh_expired_deleted = cursor.rowcount
# Clean up sessions for deactivated users
cursor.execute('''
DELETE FROM auth_session
WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0)
''')
auth_deactivated_deleted = cursor.rowcount
cursor.execute('''
DELETE FROM refresh_session
WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0)
''')
refresh_deactivated_deleted = cursor.rowcount
conn.commit()
total_auth_deleted = auth_expired_deleted + auth_deactivated_deleted
total_refresh_deleted = refresh_expired_deleted + refresh_deactivated_deleted
if total_auth_deleted > 0 or total_refresh_deleted > 0:
self.logger.info(f"Cleaned up sessions: {total_auth_deleted} auth ({auth_expired_deleted} expired, {auth_deactivated_deleted} deactivated), {total_refresh_deleted} refresh ({refresh_expired_deleted} expired, {refresh_deactivated_deleted} deactivated)")
except Exception as e:
self.logger.exception(f"Failed to cleanup expired sessions: {e}")
def get_user(self, username_or_email: str) -> Optional[Dict[str, Any]]:
"""Get user by username or email"""
if not username_or_email:
self.logger.warn("get_user called with None or empty value")
return None
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT user_id, username, name, email, hashed_password, salt,
user_group, last_active, is_activated
FROM users WHERE username = ? OR email = ?
''', (username_or_email, username_or_email))
row = cursor.fetchone()
if row:
user_data = dict(row)
user_data['user_group'] = json.loads(user_data['user_group'])
return user_data
else:
return None
except Exception as e:
self.logger.exception(f"Failed to get user by username/email '{username_or_email}': {e}")
return None
def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user by user_id"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT user_id, username, name, email, hashed_password, salt,
user_group, last_active, is_activated
FROM users WHERE user_id = ?
''', (user_id,))
row = cursor.fetchone()
if row:
user_data = dict(row)
user_data['user_group'] = json.loads(user_data['user_group'])
return user_data
else:
return None
except Exception as e:
self.logger.exception(f"Failed to get user by user_id {user_id}: {e}")
return None
def create_user(self, username: str, name: str, email: Optional[str],
hashed_password: str, salt: str, user_group: List[str] = None,
is_activated: bool = True) -> Optional[Dict[str, Any]]:
"""Create a new user and return user data"""
try:
# Validate username format
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
self.logger.warn(f"Invalid username format: '{username}'. Only a-z, A-Z, 0-9, '_', '-' are allowed")
return None
# Validate email format if provided
if email and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email):
self.logger.warn(f"Invalid email format: '{email}'")
return None
# Ensure user_group is not empty
if user_group is None or len(user_group) == 0:
user_group = ["user"]
# Generate unique user_id (UUID collisions are extremely rare, but we'll still check)
user_id = str(uuid.uuid4())
user_group_json = json.dumps(user_group)
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO users (user_id, username, name, email, hashed_password, salt, user_group, is_activated)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (user_id, username, name, email, hashed_password, salt, user_group_json, is_activated))
conn.commit()
self.logger.info(f"Created new user: {username} (ID: {user_id}, activated: {is_activated})")
# Return the created user data directly to avoid deadlock
return {
'user_id': user_id,
'username': username,
'name': name,
'email': email,
'hashed_password': hashed_password,
'salt': salt,
'user_group': user_group, # Already parsed list
'last_active': datetime.now().isoformat(),
'is_activated': is_activated
}
except sqlite3.IntegrityError as e:
if 'username' in str(e):
self.logger.warn(f"Username '{username}' already exists")
elif 'email' in str(e):
self.logger.warn(f"Email '{email}' already exists")
else:
self.logger.warn(f"User creation failed due to constraint: {e}")
return None
except Exception as e:
self.logger.exception(f"Failed to create user '{username}': {e}")
return None
def update_user_last_active(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Update user's last active timestamp and return updated user data"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
current_time = datetime.now()
cursor.execute('''
UPDATE users
SET last_active = ?
WHERE user_id = ?
''', (current_time, user_id))
conn.commit()
self.logger.debug(f"Updated last_active for user_id: {user_id}")
# Get the updated user data in the same connection to avoid deadlock
cursor.execute('''
SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated
FROM users
WHERE user_id = ?
''', (user_id,))
row = cursor.fetchone()
if row:
user_data = dict(row)
user_data['user_group'] = json.loads(user_data['user_group'])
return user_data
else:
self.logger.warn(f"No user found with user_id: {user_id}")
return None
except Exception as e:
self.logger.exception(f"Failed to update last_active for user_id {user_id}: {e}")
return None
def update_user_activation(self, user_id: str, activation: bool) -> Optional[Dict[str, Any]]:
"""Update user's activation status and cleanup sessions if deactivated, return updated user data"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
UPDATE users
SET is_activated = ?
WHERE user_id = ?
''', (activation, user_id))
conn.commit()
self.logger.info(f"Updated activation status for user_id {user_id}: {activation}")
# Get the updated user data in the same connection to avoid deadlock
cursor.execute('''
SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated
FROM users
WHERE user_id = ?
''', (user_id,))
row = cursor.fetchone()
if row:
# Clean up sessions after updating activation status (outside the transaction)
if not activation: self.cleanup_expired_sessions()
user_data = dict(row)
user_data['user_group'] = json.loads(user_data['user_group'])
return user_data
else:
self.logger.warn(f"No user found with user_id: {user_id}")
return None
except Exception as e:
self.logger.exception(f"Failed to update activation for user_id {user_id}: {e}")
return None
def update_user_password(self, user_id: str, hashed_password: str, salt: str) -> Optional[Dict[str, Any]]:
"""Update user's password and salt, return updated user data"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
UPDATE users
SET hashed_password = ?, salt = ?
WHERE user_id = ?
''', (hashed_password, salt, user_id))
conn.commit()
if cursor.rowcount == 0:
self.logger.warn(f"No user found with user_id: {user_id}")
return None
self.logger.info(f"Updated password for user_id: {user_id}")
# Get the updated user data in the same connection to avoid deadlock
cursor.execute('''
SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated
FROM users
WHERE user_id = ?
''', (user_id,))
row = cursor.fetchone()
if row:
user_data = dict(row)
user_data['user_group'] = json.loads(user_data['user_group'])
return user_data
else:
self.logger.warn(f"No user found with user_id: {user_id}")
return None
except Exception as e:
self.logger.exception(f"Failed to update password for user_id {user_id}: {e}")
return None
def delete_user(self, user_id: str) -> bool:
"""Delete user and all associated data from all tables"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
# Check if user exists first
cursor.execute('SELECT username FROM users WHERE user_id = ?', (user_id,))
user_row = cursor.fetchone()
if not user_row:
self.logger.warn(f"No user found with user_id: {user_id}")
return False
username = user_row[0]
# Delete auth sessions (foreign key constraints will handle this automatically, but we'll be explicit)
cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,))
auth_deleted = cursor.rowcount
# Delete refresh sessions
cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,))
refresh_deleted = cursor.rowcount
# Delete user
cursor.execute('DELETE FROM users WHERE user_id = ?', (user_id,))
user_deleted = cursor.rowcount
conn.commit()
if user_deleted > 0:
self.logger.info(f"Deleted user '{username}' (ID: {user_id}) and all associated data: {auth_deleted} auth sessions, {refresh_deleted} refresh sessions")
return True
else:
self.logger.warn(f"Failed to delete user with user_id: {user_id}")
return False
except Exception as e:
self.logger.exception(f"Failed to delete user with user_id {user_id}: {e}")
return False
# ============================================================================
# Session management functions
# ============================================================================
def create_auth_session(self, user_id: str, hashed_authkey: str, salt: str,
expire_time: datetime) -> Optional[Dict[str, Any]]:
"""Create an auth session and return session data"""
try:
# Generate unique session_id (UUID collisions are extremely rare)
session_id = str(uuid.uuid4())
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO auth_session (session_id, user_id, hashed_authkey, salt, expire_time)
VALUES (?, ?, ?, ?, ?)
''', (session_id, user_id, hashed_authkey, salt, expire_time))
conn.commit()
self.logger.debug(f"Created auth session for user_id {user_id}, session_id: {session_id}")
# Return the created session data
return {
'session_id': session_id,
'user_id': user_id,
'hashed_authkey': hashed_authkey,
'salt': salt,
'expire_time': expire_time.isoformat()
}
except Exception as e:
self.logger.exception(f"Failed to create auth session for user_id {user_id}: {e}")
return None
def create_refresh_session(self, user_id: str, hashed_refreshkey: str, salt: str,
expire_time: datetime) -> Optional[Dict[str, Any]]:
"""Create a refresh session and return session data"""
try:
# Generate unique session_id (UUID collisions are extremely rare)
session_id = str(uuid.uuid4())
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO refresh_session (session_id, user_id, hashed_refreshkey, salt, expire_time)
VALUES (?, ?, ?, ?, ?)
''', (session_id, user_id, hashed_refreshkey, salt, expire_time))
conn.commit()
self.logger.debug(f"Created refresh session for user_id {user_id}, session_id: {session_id}")
# Return the created session data
return {
'session_id': session_id,
'user_id': user_id,
'hashed_refreshkey': hashed_refreshkey,
'salt': salt,
'expire_time': expire_time.isoformat()
}
except Exception as e:
self.logger.exception(f"Failed to create refresh session for user_id {user_id}: {e}")
return None
def get_auth_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get valid auth session by session_id"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT session_id, user_id, hashed_authkey, salt, expire_time
FROM auth_session
WHERE session_id = ? AND expire_time > ?
''', (session_id, datetime.now()))
row = cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
self.logger.exception(f"Failed to get auth session for session_id {session_id}: {e}")
return None
def get_refresh_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get valid refresh session by session_id"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT session_id, user_id, hashed_refreshkey, salt, expire_time
FROM refresh_session
WHERE session_id = ? AND expire_time > ?
''', (session_id, datetime.now()))
row = cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
self.logger.exception(f"Failed to get refresh session for session_id {session_id}: {e}")
return None
def delete_auth_session(self, session_id: str) -> bool:
"""Delete an auth session"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM auth_session WHERE session_id = ?', (session_id,))
conn.commit()
if cursor.rowcount > 0:
self.logger.debug(f"Deleted auth session_id: {session_id}")
return True
return False
except Exception as e:
self.logger.exception(f"Failed to delete auth session {session_id}: {e}")
return False
def delete_refresh_session(self, session_id: str) -> bool:
"""Delete a refresh session"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM refresh_session WHERE session_id = ?', (session_id,))
conn.commit()
if cursor.rowcount > 0:
self.logger.debug(f"Deleted refresh session_id: {session_id}")
return True
return False
except Exception as e:
self.logger.exception(f"Failed to delete refresh session {session_id}: {e}")
return False
def delete_user_sessions(self, user_id: str) -> bool:
"""Delete all sessions for a user (logout)"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
# Delete auth sessions
cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,))
auth_deleted = cursor.rowcount
# Delete refresh sessions
cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,))
refresh_deleted = cursor.rowcount
conn.commit()
self.logger.info(f"Deleted all sessions for user_id {user_id}: {auth_deleted} auth, {refresh_deleted} refresh")
return True
except Exception as e:
self.logger.exception(f"Failed to delete sessions for user_id {user_id}: {e}")
return False
# ============================================================================
# System statistics functions
# ============================================================================
def insert_system_stats(self,
srs_uptime: float, srs_cpu_percent: float,
srs_memory_percent: float, srs_recv_KBps: float,
srs_send_KBps: float, disk_read_KBps: float,
disk_write_KBps: float, os_uptime: float,
os_cpu_percent: float, os_memory_percent: float) -> bool:
"""Insert system statistics into the database"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
timestamp = datetime.now()
cursor.execute('''
INSERT INTO system_stats (timestamp,
srs_uptime, srs_cpu_percent, srs_memory_percent,
srs_recv_KBps, srs_send_KBps,
disk_read_KBps, disk_write_KBps,
os_uptime, os_cpu_percent, os_memory_percent)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (timestamp,
srs_uptime, srs_cpu_percent, srs_memory_percent,
srs_recv_KBps, srs_send_KBps,
disk_read_KBps, disk_write_KBps,
os_uptime, os_cpu_percent, os_memory_percent))
conn.commit()
self.logger.debug(f"Inserted system stats at {timestamp.isoformat()}")
return True
except Exception as e:
self.logger.exception(f"Failed to insert system stats: {e}")
return False
def get_system_stats(self, time_delta: int) -> List[Dict[str, Any]]:
"""Get system statistics from the database"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM system_stats WHERE timestamp >= ?
''', (datetime.now() - timedelta(seconds=time_delta),))
rows = cursor.fetchall()
columns = [column[0] for column in cursor.description]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
self.logger.exception(f"Failed to get system stats: {e}")
return []
def delete_expired_system_stats(self) -> bool:
"""Delete expired system statistics older than 20 minutes"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
DELETE FROM system_stats WHERE timestamp < ?
''', (datetime.now() - timedelta(minutes=20),))
conn.commit()
if cursor.rowcount > 0:
self.logger.info(f"Deleted {cursor.rowcount} expired system stats")
return True
else:
self.logger.debug("No expired system stats to delete")
return False
except Exception as e:
self.logger.exception(f"Failed to delete expired system stats: {e}")
return False
# ============================================================================
# Streams management functions
# ============================================================================
def create_stream(self, streamer_id: str, stream_title: str, stream_description: Optional[str] = None, stream_tags: Optional[List[str]] = [], stream_visibility: str = 'public') -> Dict[str, Any]:
"""Create a new stream and return stream data"""
try:
# Validate streamer_id exists
if not self.get_user_by_id(streamer_id):
self.logger.warn(f"Streamer with user_id {streamer_id} does not exist")
return None
# Generate unique stream_id (UUID collisions are extremely rare)
stream_id = str(uuid.uuid4())
# Generate stream_code in form of xxx-xxxx (only include letters and numbers)
def random_code():
part1 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=3))
part2 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
return f"{part1}-{part2}"
stream_code = random_code()
# Convert tags to JSON string
stream_tags_json = json.dumps(stream_tags)
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO streams (stream_id, stream_code, streamer_id, stream_title,
stream_description, stream_tags,
stream_visibility)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (stream_id, stream_code, streamer_id, stream_title,
stream_description, stream_tags_json,
stream_visibility))
conn.commit()
self.logger.info(f"Created new stream: {stream_title} (ID: {stream_id}, Streamer ID: {streamer_id})")
return {
'stream_id': stream_id,
'stream_code': stream_code,
'streamer_id': streamer_id,
'stream_title': stream_title,
'stream_description': stream_description,
'stream_tags': stream_tags,
'stream_visibility': stream_visibility,
'stream_status': 'planned', # Default status
}
except sqlite3.IntegrityError as e:
self.logger.warn(f"Stream creation failed due to constraint: {e}")
return None
except Exception as e:
self.logger.exception(f"Failed to create stream: {e}")
return None
def get_stream_by_vis(self, stream_visibility: str = 'public') -> Optional[List[Dict[str, Any]]]:
"""Get streams by visibility"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM streams
WHERE stream_visibility = ? AND stream_status = 'streaming'
''', (stream_visibility,))
rows = cursor.fetchall()
columns = [column[0] for column in cursor.description]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
self.logger.exception(f"Failed to get streams by visibility '{stream_visibility}': {e}")
return None
# ============================================================================
# Default database statistics functions
# ============================================================================
def get_database_stats(self) -> Dict[str, int]:
"""Get database statistics"""
try:
with self.get_connection() as conn:
cursor = conn.cursor()
stats = {}
# Count users
cursor.execute('SELECT COUNT(*) FROM users')
stats['users'] = cursor.fetchone()[0]
# Count active auth sessions
cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time > ?', (datetime.now(),))
stats['active_auth_sessions'] = cursor.fetchone()[0]
# Count active refresh sessions
cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time > ?', (datetime.now(),))
stats['active_refresh_sessions'] = cursor.fetchone()[0]
# Count expired auth sessions
cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time <= ?', (datetime.now(),))
stats['expired_auth_sessions'] = cursor.fetchone()[0]
# Count expired refresh sessions
cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time <= ?', (datetime.now(),))
stats['expired_refresh_sessions'] = cursor.fetchone()[0]
return stats
except Exception as e:
self.logger.exception(f"Failed to get database statistics: {e}")
return {}
# Global database instance
_global_db: Optional[Database] = None
def get_database(db_path: str = "./objs/srs_database.db") -> Database:
"""
Get the global database instance
Args:
db_path: Path to SQLite database file
Returns:
Database instance
"""
global _global_db
if _global_db is None:
# run_comprehensive_tests()
_global_db = Database(db_path)
return _global_db
if __name__ == "__main__":
# Comprehensive test suite for database functionality
import argparse
import uuid
from datetime import datetime, timedelta
from srs_logger import init_logger
logger = init_logger()
# Main execution
parser = argparse.ArgumentParser(description="Database Module Test")
parser.add_argument("--db-path", default="./objs/srs_database.db", help="Database file path")
args = parser.parse_args()
db = get_database(args.db_path)
# Show database statistics
stats = db.get_database_stats()
logger.info(f"Database statistics: {stats}")
# Clean up expired sessions
db.cleanup_expired_sessions()

View File

@ -1,142 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SRS Python Addon - Simple HTTP Server
This addon demonstrates how to create a simple HTTP server as an SRS addon.
"""
import sys
import signal
import time
import logging
import argparse
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('srs_http_addon')
class SRSAddonHandler(BaseHTTPRequestHandler):
"""Simple HTTP request handler for SRS addon."""
def do_GET(self):
"""Handle GET requests."""
if self.path == '/':
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
response = '''
<html>
<body>
<h1>SRS Python Addon - HTTP Server</h1>
<p>This is a simple HTTP server running as an SRS addon.</p>
<p>Status: Running</p>
<p>Time: {}</p>
</body>
</html>
'''.format(time.strftime('%Y-%m-%d %H:%M:%S'))
self.wfile.write(response.encode())
elif self.path == '/status':
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
response = '{"status": "running", "time": "%s"}' % time.strftime('%Y-%m-%d %H:%M:%S')
self.wfile.write(response.encode())
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
"""Override to use our logger."""
logger.info("%s - %s" % (self.address_string(), format % args))
class SRSHTTPAddon:
"""SRS HTTP Addon main class."""
def __init__(self, port=8888):
self.port = port
self.server = None
self.running = False
self.server_thread = None
def start(self):
"""Start the HTTP server."""
try:
self.server = HTTPServer(('', self.port), SRSAddonHandler)
self.running = True
# Start server in a separate thread
self.server_thread = threading.Thread(target=self._run_server)
self.server_thread.daemon = True
self.server_thread.start()
logger.info(f"SRS HTTP addon started on port {self.port}")
return True
except Exception as e:
logger.error(f"Failed to start HTTP addon: {e}")
return False
def stop(self):
"""Stop the HTTP server."""
if self.server and self.running:
self.running = False
self.server.shutdown()
self.server.server_close()
if self.server_thread:
self.server_thread.join(timeout=5)
logger.info("SRS HTTP addon stopped")
def _run_server(self):
"""Run the HTTP server."""
try:
self.server.serve_forever()
except Exception as e:
if self.running:
logger.error(f"HTTP server error: {e}")
def signal_handler(signum, frame):
"""Handle termination signals."""
logger.info(f"Received signal {signum}, shutting down...")
if 'addon' in globals():
addon.stop()
sys.exit(0)
def main():
"""Main function."""
parser = argparse.ArgumentParser(description='SRS HTTP Addon')
parser.add_argument('--port', type=int, default=8888, help='HTTP server port')
parser.add_argument('--verbose', action='store_true', help='Enable verbose logging')
args = parser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
# Set up signal handlers
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Create and start the addon
global addon
addon = SRSHTTPAddon(args.port)
if addon.start():
logger.info("SRS HTTP addon is running, press Ctrl+C to stop")
try:
while addon.running:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Interrupted by user")
finally:
addon.stop()
else:
logger.error("Failed to start SRS HTTP addon")
sys.exit(1)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,845 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FastAPI Authentication API Server
Provides secure authentication endpoints with token-based authentication
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Depends, Cookie, Response, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, EmailStr, Field, field_validator
import uvicorn
# Import our database and logger
from database import get_database, Database
from srs_logger import get_logger
# Import Function Class
from security_module import AuthManager
from system_stats_module import SystemStatsManager
from avatar_module import AvatarGenerator
# ============================================================================
# Pydantic Models for Request/Response
# ============================================================================
class RegisterRequest(BaseModel):
"""用户注册请求模型"""
username: str = Field(..., min_length=1, max_length=50,
description="用户名,只能包含字母、数字、下划线和连字符")
name: str = Field(..., min_length=1, max_length=100, description="用户真实姓名")
email: Optional[EmailStr] = Field(None, description="用户邮箱(可选)")
password: str = Field(..., min_length=6, max_length=128, description="用户密码")
@field_validator('username')
def validate_username(cls, v):
if not v.replace('_', '').replace('-', '').isalnum():
raise ValueError('用户名只能包含字母、数字、下划线和连字符')
return v
class LoginRequest(BaseModel):
"""用户登录请求模型"""
username_or_email: str = Field(..., min_length=1, description="用户名或邮箱")
password: str = Field(..., min_length=1, description="用户密码")
remember_me: Optional[bool] = Field(True, description="是否记住登录状态默认为False")
class RefreshRequest(BaseModel):
"""令牌刷新请求模型"""
auth_key_session_id: Optional[str] = Field(None, description="认证令牌会话ID可选用于删除旧的认证会话")
class UpdatePasswordRequest(BaseModel):
"""更新密码请求模型"""
original_password: str = Field(..., min_length=1, description="原密码")
password: str = Field(..., min_length=6, max_length=128, description="新密码")
class UserIDRequest(BaseModel):
"""用户ID请求模型"""
user_id: str = Field(..., min_length=1, description="传入的用户ID")
class AuthResponse(BaseModel):
"""认证响应模型"""
success: bool = Field(..., description="操作是否成功")
message: str = Field(..., description="响应消息")
auth_key: Optional[str] = Field(None, description="认证令牌")
auth_key_session_id: Optional[str] = Field(None, description="认证令牌会话ID")
refresh_key: Optional[str] = Field(None, description="刷新令牌")
refresh_key_session_id: Optional[str] = Field(None, description="刷新令牌会话ID")
class SimpleResponse(BaseModel):
"""简单响应模型"""
success: bool = Field(..., description="操作是否成功")
message: str = Field(..., description="响应消息")
class SystemStatsResponse(BaseModel):
"""系统统计响应模型"""
system_stats: List[Dict[str, Any]] = Field(..., description="系统统计信息列表,每个元素包含时间戳和统计数据")
# ============================================================================
# FastAPI Application Setup
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用程序生命周期管理"""
logger = get_logger()
logger.info("Starting FastAPI Authentication Server...")
# 初始化数据库
db = get_database()
logger.info("Database initialized")
# 创建认证管理器
auth_manager = AuthManager(db)
system_stats_manager = SystemStatsManager(db)
avatar_generator = AvatarGenerator()
app.state.auth_manager = auth_manager
logger.info("Auth manager initialized")
system_stats_manager.start_polling() # 启动系统统计轮询
app.state.system_stats_manager = system_stats_manager
logger.info("System stats manager initialized")
app.state.avatar_generator = avatar_generator
logger.info("Avatar generator initialized")
app.state.db = db
logger.info("Database connection established")
yield
logger.info("Shutting down FastAPI Authentication Server...")
# 创建FastAPI应用
app = FastAPI(
title="认证API服务器",
description="基于FastAPI的安全认证系统提供用户注册、登录、令牌刷新和登出功能",
version="1.0.0",
docs_url="/api/docs",
redoc_url="/api/redoc",
lifespan=lifespan
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该限制允许的源
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 获取认证管理器
def get_auth_manager() -> AuthManager:
return app.state.auth_manager
def get_system_stats_manager() -> SystemStatsManager:
return app.state.system_stats_manager
def get_avatar_generator() -> AvatarGenerator:
return app.state.avatar_generator
def get_db() -> Database:
return app.state.db
# HTTP Bearer认证方案
security = HTTPBearer()
# 认证依赖项
async def verify_auth_token(
credentials: HTTPAuthorizationCredentials = Depends(security),
auth_manager: AuthManager = Depends(get_auth_manager)
) -> Dict[str, Any]:
"""验证Bearer令牌并返回认证会话信息"""
try:
# Bearer token格式: "auth_key:auth_key_session_id"
token_parts = credentials.credentials.split(':')
if len(token_parts) != 2:
raise HTTPException(
status_code=401,
detail="无效的Bearer令牌格式应为 'auth_key:auth_key_session_id'"
)
auth_key, auth_key_session_id = token_parts
# 获取认证会话
auth_manager.db.cleanup_expired_sessions() # 清理过期会话
auth_session = auth_manager.db.get_auth_session(auth_key_session_id)
if not auth_session:
raise HTTPException(
status_code=401,
detail="认证会话无效或已过期"
)
# 验证认证令牌
if not auth_manager.verify_token(auth_key, auth_session['salt'], auth_session['hashed_authkey']):
raise HTTPException(
status_code=401,
detail="认证令牌无效"
)
# 获取用户信息以获取用户组
user_data = auth_manager.db.get_user_by_id(auth_session['user_id'])
auth_session['user_group'] = user_data.get('user_group', ['user'])
# 更新用户最后活跃时间
auth_manager.db.update_user_last_active(auth_session['user_id'])
return auth_session
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Token verification error: {e}")
raise HTTPException(
status_code=401,
detail="令牌验证失败"
)
# ============================================================================
# Cookie Configuration
# ============================================================================
def set_auth_cookies(response: Response, tokens: Dict[str, str]) -> None:
"""设置安全的认证cookie"""
response.set_cookie(
key="refresh_key",
value=tokens['refresh_key'],
httponly=True,
secure=False,
samesite="strict",
max_age=604800 # 7天
)
response.set_cookie(
key="refresh_key_session_id",
value=tokens['refresh_key_session_id'],
httponly=True,
secure=False,
samesite="strict",
max_age=604800 # 7天
)
def clear_auth_cookies(response: Response) -> None:
"""清除认证cookie"""
response.delete_cookie(key="refresh_key", httponly=True, secure=True, samesite="strict")
response.delete_cookie(key="refresh_key_session_id", httponly=True, secure=True, samesite="strict")
# ============================================================================
# Authentication API Endpoints
# ============================================================================
@app.post("/api/register", response_model=SimpleResponse,
tags=["Authentication API Endpoints"], summary="用户注册", description="注册新用户账户")
async def register(request: RegisterRequest, auth_manager: AuthManager = Depends(get_auth_manager)):
"""
用户注册端点
- **username**: 用户名必需只能包含字母数字下划线和连字符
- **name**: 用户真实姓名必需
- **email**: 用户邮箱可选
- **password**: 用户密码必需最少6个字符
返回注册是否成功的布尔值
"""
try:
success = auth_manager.register_user(
username=request.username,
name=request.name,
email=request.email,
password=request.password
)
if success:
return SimpleResponse(success=True, message="用户注册成功")
else:
raise HTTPException(
status_code=400,
detail="用户注册失败,用户名或邮箱可能已存在"
)
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Register endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.post("/api/login", response_model=AuthResponse,
tags=["Authentication API Endpoints"], summary="用户登录", description="用户登录并获取认证令牌")
async def login(request: LoginRequest, response: Response,
auth_manager: AuthManager = Depends(get_auth_manager)):
"""
用户登录端点
- **username_or_email**: 用户名或邮箱必需
- **password**: 用户密码必需
成功登录后返回认证令牌和刷新令牌并设置安全cookie
"""
try:
# 认证用户
user_data = auth_manager.authenticate_user(
username_or_email=request.username_or_email,
password=request.password
)
if not user_data:
raise HTTPException(
status_code=401,
detail="用户名/邮箱或密码错误"
)
# 创建认证令牌
tokens = auth_manager.create_auth_tokens(user_data['user_id'])
if not tokens:
raise HTTPException(
status_code=500,
detail="创建认证令牌失败"
)
# 设置安全cookie
if request.remember_me:
set_auth_cookies(response, tokens)
return AuthResponse(
success=True,
message="登录成功",
auth_key=tokens['auth_key'],
auth_key_session_id=tokens['auth_key_session_id'],
refresh_key=tokens['refresh_key'],
refresh_key_session_id=tokens['refresh_key_session_id']
)
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Login endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.post("/api/refresh", response_model=AuthResponse,
tags=["Authentication API Endpoints"], summary="刷新令牌", description="使用刷新令牌获取新的认证令牌")
async def refresh_tokens(request: RefreshRequest, response: Response,
refresh_key: Optional[str] = Cookie(None),
refresh_key_session_id: Optional[str] = Cookie(None),
auth_manager: AuthManager = Depends(get_auth_manager)):
"""
令牌刷新端点
- **auth_key_session_id**: 认证令牌会话ID可选用于删除旧的认证会话
refresh_key和refresh_key_session_id从httponly Cookie中自动获取
验证刷新令牌后返回新的认证令牌和刷新令牌并更新cookie
"""
try:
# 检查Cookie中的刷新令牌信息
if not refresh_key or not refresh_key_session_id:
raise HTTPException(
status_code=401,
detail="未找到刷新令牌,请重新登录"
)
# 刷新令牌
new_tokens = auth_manager.refresh_tokens(
refresh_key_session_id=refresh_key_session_id,
refresh_key=refresh_key,
auth_key_session_id=request.auth_key_session_id
)
if not new_tokens:
# 清除cookie
clear_auth_cookies(response)
raise HTTPException(
status_code=401,
detail="刷新令牌无效或已过期"
)
# 设置新的安全cookie
set_auth_cookies(response, new_tokens)
return AuthResponse(
success=True,
message="令牌刷新成功",
auth_key=new_tokens['auth_key'],
auth_key_session_id=new_tokens['auth_key_session_id'],
refresh_key=new_tokens['refresh_key'],
refresh_key_session_id=new_tokens['refresh_key_session_id']
)
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Refresh endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.post("/api/logout", response_model=SimpleResponse,
tags=["Authentication API Endpoints"], summary="登出", description="登出当前会话")
async def logout(response: Response,
refresh_key_session_id: Optional[str] = Cookie(None),
auth_session: Dict[str, Any] = Depends(verify_auth_token),
auth_manager: AuthManager = Depends(get_auth_manager)):
"""
登出端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
refresh_key_session_id从httponly Cookie中自动获取
登出指定的会话并清除相关cookie
"""
try:
# 检查Cookie中的刷新令牌会话ID
if not refresh_key_session_id:
# 清除cookie如果有的话
clear_auth_cookies(response)
raise HTTPException(
status_code=401,
detail="未找到刷新令牌会话ID"
)
success = auth_manager.logout_session(
refresh_key_session_id=refresh_key_session_id,
auth_key_session_id=auth_session['session_id']
)
# 清除cookie
clear_auth_cookies(response)
if success:
return SimpleResponse(success=True, message="登出成功")
else:
return SimpleResponse(success=False, message="会话不存在或已过期")
except Exception as e:
auth_manager.logger.exception(f"Logout endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.post("/api/logout-all", response_model=SimpleResponse,
tags=["Authentication API Endpoints"], summary="登出所有会话", description="登出用户的所有会话")
async def logout_all(response: Response,
auth_session: Dict[str, Any] = Depends(verify_auth_token),
auth_manager: AuthManager = Depends(get_auth_manager)):
"""
登出所有会话端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
登出用户的所有会话并清除相关cookie
"""
try:
success = auth_manager.logout_all_sessions(auth_session['user_id'])
# 清除cookie
clear_auth_cookies(response)
if success:
return SimpleResponse(success=True, message="所有会话已登出")
else:
return SimpleResponse(success=False, message="会话不存在或已过期")
except Exception as e:
auth_manager.logger.exception(f"Logout all endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.get("/api/profile", response_model=Dict[str, Any],
tags=["Authentication API Endpoints"], summary="获取用户信息", description="获取当前认证用户的基本信息")
async def get_profile(
auth_session: Dict[str, Any] = Depends(verify_auth_token),
auth_manager: AuthManager = Depends(get_auth_manager)
):
"""
获取用户信息端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
返回当前认证用户的基本信息
"""
try:
user_data = auth_manager.db.get_user_by_id(auth_session['user_id'])
if not user_data:
raise HTTPException(status_code=404, detail="用户未找到")
return {
"user_id": user_data['user_id'],
"username": user_data['username'],
"name": user_data['name'],
"email": user_data.get('email', None),
'user_group': user_data.get('user_group', ['user']),
"is_activated": user_data.get('is_activated', False),
"last_active": user_data.get('last_active', None)
}
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Get profile endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.put("/api/update-password", response_model=SimpleResponse,
tags=["Authentication API Endpoints"], summary="更新密码", description="更新当前用户的密码")
async def update_password(
request: UpdatePasswordRequest,
response: Response,
auth_session: Dict[str, Any] = Depends(verify_auth_token),
auth_manager: AuthManager = Depends(get_auth_manager)
):
"""
更新密码端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
- **original_password**: 原密码必需
- **password**: 新密码必需最少6个字符
验证原密码后更新为新密码并自动登出所有会话
"""
try:
user_id = auth_session['user_id']
# 调用AuthManager处理密码更新逻辑
result = auth_manager.update_user_password_with_verification(
user_id=user_id,
original_password=request.original_password,
new_password=request.password
)
if result["success"]:
# 如果密码更新成功且包含logout_all标志清除cookie
if result.get("logout_all", False):
clear_auth_cookies(response)
return SimpleResponse(success=True, message=result["message"])
else:
# 根据错误类型返回相应的HTTP状态码
if "用户未找到" in result["message"]:
raise HTTPException(status_code=404, detail=result["message"])
elif "原密码错误" in result["message"]:
raise HTTPException(status_code=401, detail=result["message"])
else:
raise HTTPException(status_code=500, detail=result["message"])
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Update password endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.delete("/api/delete-user", response_model=SimpleResponse,
tags=["Authentication API Endpoints"], summary="删除自己的账户", description="删除当前登录用户的账户")
async def delete_own_account(
response: Response,
auth_session: Dict[str, Any] = Depends(verify_auth_token),
auth_manager: AuthManager = Depends(get_auth_manager)
):
"""
删除自己账户端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
只能删除当前登录用户本人账户删除成功后将清除所有认证cookie
"""
try:
user_id = auth_session['user_id']
# 调用AuthManager处理自己账户删除逻辑
result = auth_manager.delete_own_account(user_id)
if result["success"]:
# 删除成功后清除认证cookie
clear_auth_cookies(response)
return SimpleResponse(success=True, message=result["message"])
else:
# 根据错误类型返回相应的HTTP状态码
if "未找到" in result["message"]:
raise HTTPException(status_code=404, detail=result["message"])
else:
raise HTTPException(status_code=500, detail=result["message"])
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Delete own account endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@app.delete("/api/admin/delete-user", response_model=SimpleResponse,
tags=["Admin API Endpoints"], summary="管理员删除用户", description="管理员删除指定用户账户")
async def admin_delete_user(
request: UserIDRequest,
auth_session: Dict[str, Any] = Depends(verify_auth_token),
auth_manager: AuthManager = Depends(get_auth_manager)
):
"""
管理员删除用户端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
需要管理员权限manager或admin用户组
- **user_id**: 要删除的用户ID必需
"""
try:
admin_user_id = auth_session['user_id']
target_user_id = request.user_id
# 调用AuthManager处理管理员删除用户逻辑
result = auth_manager.admin_delete_user(
admin_user_id=admin_user_id,
admin_user_groups=auth_session.get('user_group', ["user"]),
target_user_id=target_user_id
)
if result["success"]:
return SimpleResponse(success=True, message=result["message"])
else:
# 根据错误类型返回相应的HTTP状态码
if "未找到" in result["message"]:
raise HTTPException(status_code=404, detail=result["message"])
elif "权限不足" in result["message"]:
raise HTTPException(status_code=403, detail=result["message"])
else:
raise HTTPException(status_code=400, detail=result["message"])
except HTTPException:
raise
except Exception as e:
auth_manager.logger.exception(f"Admin delete user endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
# ============================================================================
# Avatar Generation Endpoint
# ============================================================================
@app.get("/avatar/{variant}",
tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像")
async def generate_avatar_variant_only(
variant: str,
auth_session: Dict[str, Any] = Depends(verify_auth_token),
name: Optional[str] = Query(default="John Doe", description="Name to generate avatar for"),
colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"),
size: int = Query(default=80, description="Avatar size"),
square: bool = Query(default=False, description="Square avatar"),
title: bool = Query(default=False, description="Include title element"),
avatar_generator: AvatarGenerator = Depends(get_avatar_generator),
):
"""Generate avatar with variant only"""
try:
if variant not in avatar_generator.VALID_VARIANTS:
raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}")
# Validate and sanitize inputs
if not name or not isinstance(name, str):
name = auth_session['user_id']
if not isinstance(size, int) or size < 10 or size > 1000:
size = avatar_generator.DEFAULT_SIZE
color_palette = avatar_generator.normalize_colors(colors)
generator = avatar_generator.AVATAR_GENERATORS[variant]
svg_content = generator(name, color_palette, size, square, title)
return Response(
content=svg_content,
media_type="image/svg+xml",
headers={
"Cache-Control": "s-max-age=1, stale-while-revalidate",
"Content-Type": "image/svg+xml"
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}")
@app.get("/avatar/{variant}/{size}",
tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像")
async def generate_avatar_variant_size(
variant: str,
size: int,
auth_session: Dict[str, Any] = Depends(verify_auth_token),
name: Optional[str] = Query(default="Clara Barton", description="Name to generate avatar for"),
colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"),
square: bool = Query(default=False, description="Square avatar"),
title: bool = Query(default=False, description="Include title element"),
avatar_generator: AvatarGenerator = Depends(get_avatar_generator)
):
"""Generate avatar with variant and size"""
try:
if variant not in avatar_generator.VALID_VARIANTS:
raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}")
# Validate and sanitize inputs
if not name or not isinstance(name, str):
name = auth_session['user_id']
if not isinstance(size, int) or size < 10 or size > 1000:
size = avatar_generator.DEFAULT_SIZE
color_palette = avatar_generator.normalize_colors(colors)
generator = avatar_generator.AVATAR_GENERATORS[variant]
svg_content = generator(name, color_palette, size, square, title)
return Response(
content=svg_content,
media_type="image/svg+xml",
headers={
"Cache-Control": "s-max-age=1, stale-while-revalidate",
"Content-Type": "image/svg+xml"
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}")
@app.get("/avatar/{variant}/{size}/{name}",
tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像")
async def generate_avatar_full_path(
variant: str,
size: int,
name: str,
colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"),
square: bool = Query(default=False, description="Square avatar"),
title: bool = Query(default=False, description="Include title element"),
avatar_generator: AvatarGenerator = Depends(get_avatar_generator),
auth_session: Dict[str, Any] = Depends(verify_auth_token),
):
"""Generate avatar with variant, size, and name in path"""
try:
if variant not in avatar_generator.VALID_VARIANTS:
raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}")
# Validate and sanitize inputs
if not name or not isinstance(name, str):
name = auth_session['user_id']
if not isinstance(size, int) or size < 10 or size > 1000:
size = avatar_generator.DEFAULT_SIZE
color_palette = avatar_generator.normalize_colors(colors)
generator = avatar_generator.AVATAR_GENERATORS[variant]
svg_content = generator(name, color_palette, size, square, title)
return Response(
content=svg_content,
media_type="image/svg+xml",
headers={
"Cache-Control": "s-max-age=1, stale-while-revalidate",
"Content-Type": "image/svg+xml"
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}")
# ============================================================================
# System Statistics API Endpoints
# ============================================================================
@app.get("/api/system-stats", response_model=SystemStatsResponse,
tags=["System Statistics API Endpoints"], summary="获取系统统计信息", description="获取服务器的系统统计信息")
async def get_system_stats(
auth_session: Dict[str, Any] = Depends(verify_auth_token),
time_delta: Optional[int] = Query(default=5, description="时间范围默认为5秒"),
time_interval: Optional[int] = Query(default=1, description="时间间隔默认为1秒"),
system_stats_manager: SystemStatsManager = Depends(get_system_stats_manager)
):
"""
获取系统统计信息端点
需要在Authorization头中提供Bearer令牌格式为: Bearer auth_key:auth_key_session_id
- **time_delta**: 时间范围默认为5秒
返回服务器的系统统计信息
"""
try:
stats = system_stats_manager.get_recent_stats(
auth_session.get('user_group', ['user']),
time_delta=time_delta,
time_interval=time_interval
)
if stats["success"]:
return stats
else:
if "未找到" in stats["message"]:
raise HTTPException(status_code=404, detail=stats["message"])
elif "权限不足" in stats["message"]:
raise HTTPException(status_code=403, detail=stats["message"])
else:
raise HTTPException(status_code=500, detail=stats["message"])
except HTTPException:
raise
except Exception as e:
system_stats_manager.logger.exception(f"Get system stats endpoint error: {e}")
raise HTTPException(status_code=500, detail="服务器内部错误")
# ============================================================================
# Additional Utility API Endpoints
# ============================================================================
@app.get("/api/health", response_model=Dict[str, Any],
tags=["Additional Utility API Endpoints"], summary="健康检查", description="检查API服务器状态")
async def health_check(db: Database = Depends(get_db)):
"""健康检查端点"""
try:
stats = db.get_database_stats()
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"database_stats": stats
}
except Exception as e:
raise HTTPException(status_code=503, detail=f"服务不可用: {str(e)}")
@app.post("/api/clean-cookies", response_model=SimpleResponse,
tags=["Additional Utility API Endpoints"], summary="清除认证Cookie", description="清除认证相关的Cookie")
async def clean_cookies(response: Response):
"""清除认证Cookie端点"""
try:
clear_auth_cookies(response)
return SimpleResponse(success=True, message="认证Cookie已清除")
except Exception as e:
raise HTTPException(status_code=500, detail=f"清除Cookie失败: {str(e)}")
# ============================================================================
# Static Files Configuration
# ============================================================================
# 在所有API路由定义之后挂载静态文件确保API路由有更高优先级
# Frontend路径配置
frontend_path = r"./livestream_site" # 前端静态文件目录
app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
logger = get_logger()
logger.info("Starting FastAPI Authentication Server on port 8000...")
# 检查是否在PyInstaller打包环境中运行
import sys
if getattr(sys, 'frozen', False):
# 在打包环境中直接传递app对象
uvicorn.run(
app,
host="127.0.0.1",
port=8000,
reload=False,
log_level="info"
)
else:
# 在开发环境中,使用字符串导入(支持热重载)
uvicorn.run(
"httpbackend_server:app",
host="127.0.0.1",
port=8000,
reload=True,
reload_delay=5.0,
log_level="info"
)

View File

@ -1,80 +0,0 @@
#!/usr/bin/env python3
"""
SRS Python Monitor Script
This is an example Python script that runs alongside SRS server.
It demonstrates how Python processes can be managed by SRS.
"""
import time
import signal
import sys
import argparse
import logging
from datetime import datetime
class SRSMonitor:
def __init__(self, config_file=None, verbose=False):
self.running = True
self.config_file = config_file
self.verbose = verbose
# Set up logging
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format='[%(asctime)s] [Python Monitor] %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
self.logger = logging.getLogger(__name__)
# Set up signal handlers
signal.signal(signal.SIGTERM, self.signal_handler)
signal.signal(signal.SIGINT, self.signal_handler)
def signal_handler(self, signum, frame):
"""Handle shutdown signals from SRS"""
self.logger.info(f"Received signal {signum}, shutting down gracefully...")
self.running = False
def run(self):
"""Main monitoring loop"""
self.logger.info("SRS Python Monitor started")
if self.config_file:
self.logger.info(f"Using config file: {self.config_file}")
try:
while self.running:
# Simulate monitoring work
self.logger.debug(f"Monitor heartbeat at {datetime.now()}")
# You can add your monitoring logic here:
# - Check stream status
# - Monitor server health
# - Send alerts
# - Log analytics
time.sleep(5) # Check every 5 seconds
except Exception as e:
self.logger.error(f"Error in monitor loop: {e}")
finally:
self.cleanup()
def cleanup(self):
"""Cleanup before shutdown"""
self.logger.info("Cleaning up Python Monitor...")
# Add any cleanup logic here
self.logger.info("Python Monitor stopped")
def main():
parser = argparse.ArgumentParser(description='SRS Python Monitor')
parser.add_argument('--config', help='Configuration file path')
parser.add_argument('--verbose', action='store_true', help='Enable verbose logging')
args = parser.parse_args()
monitor = SRSMonitor(args.config, args.verbose)
monitor.run()
if __name__ == '__main__':
main()

View File

@ -1,7 +1,70 @@
# Python packages required for SRS Python addons
requests>=2.25.0
psutil>=5.8.0
pyyaml>=5.4.0
websockets>=10.0
aiohttp>=3.8.0
numpy>=1.20.0
aiofiles==23.2.0
aiohappyeyeballs==2.6.1
aiohttp==3.12.13
aiosignal==1.3.2
altgraph==0.17.4
annotated-types==0.7.0
anyio==3.7.1
attrs==25.3.0
bcrypt==4.3.0
blinker==1.9.0
certifi==2025.4.26
cffi==1.17.1
charset-normalizer==3.4.2
click==8.2.1
colorama==0.4.6
cryptography==45.0.3
dnspython==2.7.0
ecdsa==0.19.1
email-validator==2.1.0
fastapi==0.104.1
Flask==3.1.1
Flask-SQLAlchemy==3.1.1
frozenlist==1.7.0
greenlet==3.2.3
h11==0.16.0
httpcore==1.0.9
httptools==0.6.4
httpx==0.25.2
idna==3.10
iniconfig==2.1.0
itsdangerous==2.2.0
Jinja2==3.1.2
jwt==1.3.1
MarkupSafe==3.0.2
multidict==6.4.4
packaging==25.0
passlib==1.7.4
pefile==2023.2.7
pluggy==1.6.0
propcache==0.3.2
pyasn1==0.6.1
pycparser==2.22
pydantic==2.5.0
pydantic_core==2.14.1
pyinstaller==6.14.1
pyinstaller-hooks-contrib==2025.5
PyJWT==2.8.0
pytest==7.4.3
pytest-asyncio==0.21.1
python-dotenv==1.0.0
python-jose==3.3.0
python-multipart==0.0.6
pywin32-ctypes==0.2.3
PyYAML==6.0.2
requests==2.32.4
rsa==4.9.1
setuptools==78.1.1
six==1.17.0
sniffio==1.3.1
SQLAlchemy==2.0.41
starlette==0.27.0
typing-inspection==0.4.1
typing_extensions==4.14.0
urllib3==2.4.0
uvicorn==0.24.0
watchfiles==1.0.5
websockets==15.0.1
Werkzeug==3.1.3
wheel==0.45.1
yarl==1.20.1

View File

@ -0,0 +1,317 @@
import hashlib
import secrets
import hmac
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from database import Database
from srs_logger import get_logger
# ============================================================================
# Security and Utility Functions
# ============================================================================
class AuthManager:
"""认证管理器"""
def __init__(self, db: Database):
self.db = db
self.logger = get_logger()
# Token 配置
self.AUTH_TOKEN_EXPIRE_HOURS = 1 # 认证令牌1小时过期
self.REFRESH_TOKEN_EXPIRE_DAYS = 7 # 刷新令牌7天过期
self.TOKEN_LENGTH = 128 # 令牌长度
def generate_secure_token(self) -> str:
"""生成安全的随机令牌"""
return secrets.token_urlsafe(self.TOKEN_LENGTH)
def generate_salt(self) -> str:
"""生成盐值"""
return secrets.token_hex(16)
def hash_password(self, password: str, salt: str) -> str:
"""对密码进行加盐哈希"""
return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex()
def verify_password(self, password: str, salt: str, hashed_password: str) -> bool:
"""验证密码"""
return hmac.compare_digest(
self.hash_password(password, salt),
hashed_password
)
def hash_token(self, token: str, salt: str) -> str:
"""对令牌进行加盐哈希"""
return hashlib.pbkdf2_hmac('sha256', token.encode(), salt.encode(), 100000).hex()
def verify_token(self, token: str, salt: str, hashed_token: str) -> bool:
"""验证令牌"""
return hmac.compare_digest(
self.hash_token(token, salt),
hashed_token
)
def register_user(self, username: str, name: str, email: Optional[str], password: str) -> bool:
"""注册新用户"""
try:
# 生成密码盐值和哈希
salt = self.generate_salt()
hashed_password = self.hash_password(password, salt)
# 创建用户
user_data = self.db.create_user(
username=username,
name=name,
email=email,
hashed_password=hashed_password,
salt=salt,
user_group=["user"],
is_activated=True
)
if user_data:
self.logger.info(f"User registered successfully: {username}")
return True
else:
self.logger.warn(f"Failed to register user: {username}")
return False
except Exception as e:
self.logger.exception(f"Error during user registration: {e}")
return False
def authenticate_user(self, username_or_email: str, password: str) -> Optional[Dict[str, Any]]:
"""认证用户并返回用户数据"""
try:
# 获取用户信息
user_data = self.db.get_user(username_or_email)
if not user_data:
self.logger.warn(f"User not found: {username_or_email}")
return None
# 检查用户是否激活
if not user_data.get('is_activated', False):
self.logger.warn(f"User account deactivated: {username_or_email}")
return None
# 验证密码
if self.verify_password(password, user_data['salt'], user_data['hashed_password']):
self.logger.info(f"User authenticated successfully: {username_or_email}")
return user_data
else:
self.logger.warn(f"Invalid password for user: {username_or_email}")
return None
except Exception as e:
self.logger.exception(f"Error during user authentication: {e}")
return None
def create_auth_tokens(self, user_id: str) -> Optional[Dict[str, Any]]:
"""为用户创建认证和刷新令牌"""
try:
# 清理过期会话
self.db.cleanup_expired_sessions()
# 生成认证令牌
auth_key = self.generate_secure_token()
auth_salt = self.generate_salt()
hashed_auth_key = self.hash_token(auth_key, auth_salt)
auth_expire_time = datetime.now() + timedelta(hours=self.AUTH_TOKEN_EXPIRE_HOURS)
# 生成刷新令牌
refresh_key = self.generate_secure_token()
refresh_salt = self.generate_salt()
hashed_refresh_key = self.hash_token(refresh_key, refresh_salt)
refresh_expire_time = datetime.now() + timedelta(days=self.REFRESH_TOKEN_EXPIRE_DAYS)
# 创建认证会话
auth_session = self.db.create_auth_session(
user_id=user_id,
hashed_authkey=hashed_auth_key,
salt=auth_salt,
expire_time=auth_expire_time
)
if not auth_session:
self.logger.error(f"Failed to create auth session for user: {user_id}")
return None
# 创建刷新会话
refresh_session = self.db.create_refresh_session(
user_id=user_id,
hashed_refreshkey=hashed_refresh_key,
salt=refresh_salt,
expire_time=refresh_expire_time
)
if not refresh_session:
self.logger.error(f"Failed to create refresh session for user: {user_id}")
# 清理已创建的认证会话
self.db.delete_auth_session(auth_session['session_id'])
return None
# 更新用户最后活跃时间
self.db.update_user_last_active(user_id)
return {
'auth_key': auth_key,
'auth_key_session_id': auth_session['session_id'],
'refresh_key': refresh_key,
'refresh_key_session_id': refresh_session['session_id']
}
except Exception as e:
self.logger.exception(f"Error creating auth tokens for user: {user_id}")
return None
def refresh_tokens(self, refresh_key_session_id: str, refresh_key: str,
auth_key_session_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""刷新认证令牌"""
try:
# 清理过期会话
self.db.cleanup_expired_sessions()
# 验证刷新会话
refresh_session = self.db.get_refresh_session(refresh_key_session_id)
if not refresh_session:
self.logger.warn(f"Invalid or expired refresh session: {refresh_key_session_id}")
return None
# 验证刷新令牌
if not self.verify_token(refresh_key, refresh_session['salt'], refresh_session['hashed_refreshkey']):
self.logger.warn(f"Invalid refresh token for session: {refresh_key_session_id}")
return None
user_id = refresh_session['user_id']
# 删除旧的认证会话如果提供了session_id
if auth_key_session_id:
self.db.delete_auth_session(auth_key_session_id)
# 删除旧的刷新会话
self.db.delete_refresh_session(refresh_key_session_id)
# 创建新的令牌
new_tokens = self.create_auth_tokens(user_id)
if new_tokens:
self.logger.info(f"Tokens refreshed successfully for user: {user_id}")
return new_tokens
else:
self.logger.error(f"Failed to create new tokens during refresh for user: {user_id}")
return None
except Exception as e:
self.logger.exception(f"Error during token refresh: {e}")
return None
def logout_session(self, refresh_key_session_id: str, auth_key_session_id: str = None) -> bool:
"""登出单个会话通过删除指定的refresh session和对应的auth session"""
try:
# 删除指定的刷新会话
refresh_success = self.db.delete_refresh_session(refresh_key_session_id)
if refresh_success:
self.logger.info(f"Refresh session logged out successfully: {refresh_key_session_id}")
auth_success = self.db.delete_auth_session(auth_key_session_id)
if auth_success:
self.logger.info(f"Auth session logged out successfully: {auth_key_session_id}")
else:
self.logger.warn(f"Failed to delete auth session: {auth_key_session_id}")
return refresh_success and auth_success
except Exception as e:
self.logger.exception(f"Error during session logout: {e}")
return False
def logout_all_sessions(self, user_id: str) -> bool:
"""登出用户的所有会话通过refresh session确定用户"""
try:
# 删除用户的所有会话
success = self.db.delete_user_sessions(user_id)
if success:
self.logger.info(f"All sessions logged out successfully for user: {user_id}")
return success
except Exception as e:
self.logger.exception(f"Error during logout all sessions: {e}")
return False
def update_user_password_with_verification(self, user_id: str, original_password: str, new_password: str) -> Dict[str, Any]:
"""验证原密码后更新用户密码,并登出所有会话"""
try:
# 获取用户数据
user_data = self.db.get_user_by_id(user_id)
if not user_data:
return {"success": False, "message": "用户未找到"}
# 验证原密码
if not self.verify_password(original_password, user_data['salt'], user_data['hashed_password']):
return {"success": False, "message": "原密码错误"}
# 生成新的盐值和密码哈希
new_salt = self.generate_salt()
new_hashed_password = self.hash_password(new_password, new_salt)
# 更新密码
updated_user = self.db.update_user_password(user_id, new_hashed_password, new_salt)
if not updated_user:
return {"success": False, "message": "密码更新失败"}
# 密码更新成功后,登出用户的所有会话(强制重新登录)
logout_success = self.db.delete_user_sessions(user_id)
if logout_success:
self.logger.info(f"All sessions logged out after password update for user: {user_data['username']}")
else:
self.logger.warn(f"Failed to logout all sessions after password update for user: {user_data['username']}")
self.logger.info(f"Password updated successfully for user: {user_data['username']}")
return {"success": True, "message": "密码更新成功,请重新登录", "logout_all": True}
except Exception as e:
self.logger.exception(f"Error updating user password: {e}")
return {"success": False, "message": "服务器内部错误"}
def delete_own_account(self, user_id: str) -> Dict[str, Any]:
"""删除用户自己的账户"""
try:
success = self.db.delete_user(user_id)
if not success:
return {"success": False, "message": "账户删除失败"}
self.logger.info(f"User-ID: '{user_id}' deleted his/her own account")
return {"success": True, "message": "您的账户已成功删除"}
except Exception as e:
self.logger.exception(f"Error deleting own account: {e}")
return {"success": False, "message": "服务器内部错误"}
def admin_delete_user(self, admin_user_id: str, admin_user_groups: list, target_user_id: str, ) -> Dict[str, Any]:
"""管理员删除指定用户账户"""
try:
# 检查管理员权限
if not any(group in admin_user_groups for group in ['manager', 'admin']):
return {
"success": False,
"message": "权限不足,只有管理员可以删除其他用户账户"
}
# 获取目标用户数据
target_user = self.db.get_user_by_id(target_user_id)
if not target_user:
return {"success": False, "message": "目标删除的用户未找到"}
# 执行删除操作
success = self.db.delete_user(target_user_id)
if not success:
return {"success": False, "message": "用户删除失败"}
self.logger.info(f"User-ID '{target_user_id}' deleted by admin-ID: '{admin_user_id}'")
return {"success": True, "message": f"用户 '{target_user_id}' 已成功删除"}
except Exception as e:
self.logger.exception(f"Error in admin delete user: {e}")
return {"success": False, "message": "服务器内部错误"}

405
trunk/python/srs_logger.py Normal file
View File

@ -0,0 +1,405 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SRS Python Logger Module
Provides synchronized logging with SRS main process.
"""
import logging
import os
import sys
import threading
import time
import re
from datetime import datetime
from typing import Optional, Dict, Any
import argparse
class SRSLogFormatter(logging.Formatter):
"""
Custom formatter that matches SRS log format with color support:
[2025-05-30 21:54:18.835][WARN][3448][python_analytics] message
"""
# ANSI color codes
COLORS = {
'RESET': '\033[0m',
'RED': '\033[31m',
'YELLOW': '\033[33m',
'WHITE': '\033[37m',
'CYAN': '\033[36m',
'GREEN': '\033[32m'
}
def __init__(self, use_colors=True):
super().__init__()
self.pid = os.getpid()
self.session_id = self._generate_session_id()
self.use_colors = use_colors and self._supports_color()
def _supports_color(self) -> bool:
"""Check if the terminal supports color output"""
import sys
# Check if we're in a terminal
if not hasattr(sys.stdout, 'isatty') or not sys.stdout.isatty():
return False
# Check for Windows terminal support
if os.name == 'nt':
try:
import colorama
colorama.init()
return True
except ImportError:
# Try to enable ANSI escape sequence support on Windows
try:
import ctypes
kernel32 = ctypes.windll.kernel32
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
return True
except:
return False
# Unix-like systems usually support colors
return True
def _generate_session_id(self) -> str:
"""Generate a meaningful session ID based on the calling script"""
import inspect
# Try to get the main script name
main_module = sys.modules.get('__main__')
if main_module and hasattr(main_module, '__file__'):
script_path = main_module.__file__
if script_path:
script_name = os.path.splitext(os.path.basename(script_path))[0]
return f"python_{script_name}"
# Fallback to inspect the call stack to find the calling script
try:
for frame_info in inspect.stack():
filename = frame_info.filename
if filename and not filename.endswith('srs_logger.py') and not filename.endswith('logging/__init__.py'):
script_name = os.path.splitext(os.path.basename(filename))[0]
if script_name != '<string>' and script_name != '<stdin>':
return f"python_{script_name}"
except:
pass
# Final fallback
return "python_unknown"
def format(self, record: logging.LogRecord) -> str:
"""Format log record to match SRS format with color support"""
# Get current timestamp with milliseconds
now = datetime.now()
timestamp = now.strftime("%Y-%m-%d %H:%M:%S.") + f"{now.microsecond // 1000:03d}"
# Map Python log levels to SRS log levels
level_mapping = {
'DEBUG': 'TRACE',
'INFO': 'INFO',
'WARNING': 'WARN',
'ERROR': 'ERROR',
'CRITICAL': 'ERROR'
}
srs_level = level_mapping.get(record.levelname, 'INFO')
# Format: [timestamp][level][pid][session] message
formatted_msg = f"[{timestamp}][{srs_level}][{self.pid}][{self.session_id}] {record.getMessage()}"
if record.exc_info:
formatted_msg += f"\n{self.formatException(record.exc_info)}"
# Apply colors if supported and enabled
if self.use_colors:
if srs_level == 'ERROR':
formatted_msg = f"{self.COLORS['RED']}{formatted_msg}{self.COLORS['RESET']}"
elif srs_level == 'WARN':
formatted_msg = f"{self.COLORS['YELLOW']}{formatted_msg}{self.COLORS['RESET']}"
# INFO, TRACE levels remain uncolored (default terminal color)
return formatted_msg
class SRSConfigParser:
"""Parser for SRS configuration files"""
@staticmethod
def parse_config_file(config_path: str) -> Dict[str, Any]:
"""Parse SRS configuration file and extract logging settings"""
config = {
'log_tank': 'all',
'log_file': './objs/srs.log',
'log_level': 'info'
}
if not os.path.exists(config_path):
return config
try:
with open(config_path, 'r', encoding='utf-8') as f:
content = f.read()
# Parse srs_log_tank
match = re.search(r'srs_log_tank\s+([^;]+);', content)
if match:
config['log_tank'] = match.group(1).strip()
# Parse srs_log_file
match = re.search(r'srs_log_file\s+([^;]+);', content)
if match:
config['log_file'] = match.group(1).strip()
# Parse srs_log_level_v2
match = re.search(r'srs_log_level_v2\s+([^;]+);', content)
if match:
config['log_level'] = match.group(1).strip().lower()
except Exception as e:
print(f"Warning: Failed to parse config file {config_path}: {e}", file=sys.stderr)
return config
class SRSLogger:
"""
SRS-compatible Python logger that synchronizes with SRS main process logging
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, config_path: Optional[str] = None):
"""Singleton pattern to ensure only one logger instance"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_path: Optional[str] = None):
if self._initialized:
return
self._initialized = True
self.config_path = config_path
self.logger = None
self.config = {}
self._setup_logger()
def _setup_logger(self):
"""Setup the logger based on SRS configuration"""
try:
# Parse SRS config if provided
if self.config_path and os.path.exists(self.config_path):
self.config = SRSConfigParser.parse_config_file(self.config_path)
else:
# Default configuration
self.config = {
'log_tank': 'all',
'log_file': './objs/srs.log',
'log_level': 'info'
}
# Create logger
self.logger = logging.getLogger('srs_python')
self.logger.setLevel(self._get_python_log_level(self.config['log_level']))
# Clear existing handlers
self.logger.handlers.clear()
# Setup handlers based on log_tank setting
log_tank = self.config['log_tank'].lower()
if log_tank in ['all', 'console']:
# Console handler with colors
console_handler = logging.StreamHandler(sys.stdout)
console_formatter = SRSLogFormatter(use_colors=True)
console_handler.setFormatter(console_formatter)
self.logger.addHandler(console_handler)
if log_tank in ['all', 'file']:
# File handler without colors
log_file = self.config['log_file']
# Create directory if it doesn't exist
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_formatter = SRSLogFormatter(use_colors=False)
file_handler.setFormatter(file_formatter)
self.logger.addHandler(file_handler)
# Prevent propagation to root logger
self.logger.propagate = False
self.info(f"SRS Python logger initialized with config: {self.config}")
except Exception as e:
print(f"ERROR: Failed to setup SRS logger: {e}", file=sys.stderr)
# Setup a basic console logger as fallback
self._setup_fallback_logger()
def _setup_fallback_logger(self):
"""Setup a basic fallback logger if main setup fails"""
self.logger = logging.getLogger('srs_python_fallback')
self.logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter('%(asctime)s [ERROR] [Python] %(message)s')
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.logger.propagate = False
def _get_python_log_level(self, srs_level: str) -> int:
"""Convert SRS log level to Python log level"""
level_mapping = {
'trace': logging.DEBUG,
'debug': logging.DEBUG,
'info': logging.INFO,
'warn': logging.WARNING,
'warning': logging.WARNING,
'error': logging.ERROR
}
return level_mapping.get(srs_level.lower(), logging.INFO)
def trace(self, message: str, *args, **kwargs):
"""Log trace message (mapped to DEBUG in Python)"""
if self.logger:
self.logger.debug(message, *args, **kwargs)
def debug(self, message: str, *args, **kwargs):
"""Log debug message"""
if self.logger:
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
"""Log info message"""
if self.logger:
self.logger.info(message, *args, **kwargs)
def warn(self, message: str, *args, **kwargs):
"""Log warning message"""
if self.logger:
self.logger.warning(message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
"""Log warning message"""
if self.logger:
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
"""Log error message"""
if self.logger:
self.logger.error(message, *args, **kwargs)
def critical(self, message: str, *args, **kwargs):
"""Log critical message (mapped to ERROR in SRS)"""
if self.logger:
self.logger.critical(message, *args, **kwargs)
def exception(self, message: str, *args, **kwargs):
"""Log exception with traceback"""
if self.logger:
self.logger.exception(message, *args, **kwargs)
# Global logger instance
_global_logger: Optional[SRSLogger] = None
def get_logger(config_path: Optional[str] = None) -> SRSLogger:
"""
Get the global SRS logger instance
Args:
config_path: Path to SRS configuration file
Returns:
SRSLogger instance
"""
global _global_logger
if _global_logger is None:
_global_logger = SRSLogger(config_path)
return _global_logger
def init_logger(config_path: Optional[str] = None) -> SRSLogger:
"""
Initialize the SRS logger with configuration
Args:
config_path: Path to SRS configuration file
Returns:
SRSLogger instance
"""
return get_logger(config_path)
# Convenience functions for direct logging
def trace(message: str, *args, **kwargs):
"""Log trace message"""
get_logger().trace(message, *args, **kwargs)
def debug(message: str, *args, **kwargs):
"""Log debug message"""
get_logger().debug(message, *args, **kwargs)
def info(message: str, *args, **kwargs):
"""Log info message"""
get_logger().info(message, *args, **kwargs)
def warn(message: str, *args, **kwargs):
"""Log warning message"""
get_logger().warn(message, *args, **kwargs)
def warning(message: str, *args, **kwargs):
"""Log warning message"""
get_logger().warning(message, *args, **kwargs)
def error(message: str, *args, **kwargs):
"""Log error message"""
get_logger().error(message, *args, **kwargs)
def critical(message: str, *args, **kwargs):
"""Log critical message"""
get_logger().critical(message, *args, **kwargs)
def exception(message: str, *args, **kwargs):
"""Log exception with traceback"""
get_logger().exception(message, *args, **kwargs)
def log_error_and_exit(message: str, exit_code: int = 1):
"""
Log an error message and exit the process
Used when Python process encounters fatal errors
"""
error(f"FATAL: {message}")
sys.exit(exit_code)
if __name__ == "__main__":
# Test the logger
parser = argparse.ArgumentParser(description="SRS Python Logger Test")
parser.add_argument("--config", help="Path to SRS config file")
args = parser.parse_args()
# Initialize logger
logger = init_logger(args.config)
# Test different log levels
logger.trace("This is a trace message")
logger.debug("This is a debug message")
logger.info("SRS Python logger test started")
logger.warn("This is a warning message")
logger.error("This is an error message")
# Test exception logging
try:
raise ValueError("Test exception")
except Exception:
logger.exception("Caught test exception")
logger.info("SRS Python logger test completed")

View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Example usage of SRS Logger module
This file demonstrates how to import and use the SRS logger in other Python files.
"""
import sys
import argparse
import time
import traceback
# Import the SRS logger module
from srs_logger import get_logger, init_logger, log_error_and_exit
import srs_logger
def example_basic_usage():
"""Basic usage example"""
logger = get_logger()
logger.info("Starting basic usage example")
logger.debug("This is a debug message")
logger.warn("This is a warning message")
logger.info("Basic usage example completed")
def example_with_config():
"""Example with config file parsing"""
# This would typically be passed via --config argument
config_path = "./conf/console.conf"
logger = init_logger(config_path)
logger.info(f"Logger initialized with config file: {config_path}")
# Demonstrate various log levels
logger.trace("Trace level message")
logger.debug("Debug level message")
logger.info("Info level message")
logger.warning("Warning level message")
logger.error("Error level message")
def example_error_handling():
"""Example of error handling and logging"""
logger = get_logger()
try:
# Simulate some work that might fail
logger.info("Starting risky operation")
# Simulate an error condition
if True: # This would be your actual error condition
raise RuntimeError("Simulated error in Python process")
logger.info("Risky operation completed successfully")
except Exception as e:
# Log the error with full traceback
logger.exception(f"Error in risky operation: {e}")
# For fatal errors, you can use log_error_and_exit
# log_error_and_exit(f"Fatal error occurred: {e}")
# Or just log and continue
logger.error(f"Continuing after error: {e}")
def example_convenience_functions():
"""Example using convenience functions"""
# These functions use the global logger instance
srs_logger.info("Using convenience function for info")
srs_logger.warn("Using convenience function for warning")
srs_logger.error("Using convenience function for error")
def analytics_simulation():
"""Simulate analytics processing with logging"""
logger = get_logger()
logger.info("Analytics process started")
try:
# Simulate analytics work
for i in range(5):
logger.debug(f"Processing analytics batch {i+1}/5")
time.sleep(0.5) # Simulate processing time
if i == 3: # Simulate a warning condition
logger.warn(f"High CPU usage detected during batch {i+1}")
logger.info("Analytics processing completed successfully")
except KeyboardInterrupt:
logger.warn("Analytics process interrupted by user")
return False
except Exception as e:
logger.exception(f"Analytics process failed: {e}")
return False
return True
def main():
"""Main function demonstrating different usage patterns"""
parser = argparse.ArgumentParser(description="SRS Logger Usage Examples")
parser.add_argument("--config", help="Path to SRS config file")
parser.add_argument("--port", type=int, default=8888, help="Port number (example parameter)")
parser.add_argument("--example", choices=['basic', 'config', 'error', 'convenience', 'analytics'],
default='analytics', help="Example to run")
args = parser.parse_args()
try:
# Initialize logger with config if provided
if args.config:
logger = init_logger(args.config)
logger.info(f"SRS Logger example started with config: {args.config}")
else:
logger = get_logger()
logger.info("SRS Logger example started without config file")
logger.info(f"Example parameters: port={args.port}, example={args.example}")
# Run the selected example
if args.example == 'basic':
example_basic_usage()
elif args.example == 'config':
example_with_config()
elif args.example == 'error':
example_error_handling()
elif args.example == 'convenience':
example_convenience_functions()
elif args.example == 'analytics':
success = analytics_simulation()
if not success:
log_error_and_exit("Analytics simulation failed")
logger.info("SRS Logger example completed successfully")
except Exception as e:
# Use the logger for any unhandled exceptions
if 'logger' in locals():
logger.exception(f"Unhandled exception in main: {e}")
else:
print(f"FATAL ERROR: {e}", file=sys.stderr)
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,402 @@
from database import Database
from srs_logger import get_logger
import requests
import time
import threading
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
# ============================================================================
# System Statistics Management Module
# ============================================================================
class SystemStatsManager:
"""系统统计管理器"""
def __init__(self, db: Database):
self.db = db
self.logger = get_logger()
self.api_url = "http://python_stats:wMePq3ahpoLRzgsVg7BY9eE82uuJHT0YukD2ZE1JfMY2RjP4e6QnUaKg3V9x5s9M@localhost:1985/api/v1/summaries"
self.previous_data: Optional[Dict[str, Any]] = None
self.previous_timestamp: Optional[float] = None
self.is_running = False
self.poll_thread: Optional[threading.Thread] = None
self.cleanup_thread: Optional[threading.Thread] = None
# 启动定期清理过期数据的线程
self._start_cleanup_thread()
def _fetch_system_stats(self):
"""获取系统统计数据并插入数据库"""
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
# 发送HTTP请求获取SRS统计数据
response = requests.get(self.api_url, timeout=10)
response.raise_for_status()
data = response.json()
# 检查返回状态
if data.get('code') != 0:
self.logger.error(f"SRS API returned error code: {data.get('code')}")
retry_count += 1
time.sleep(2 ** retry_count) # 指数退避
continue
# 提取有用的信息
stats_data = self._extract_stats_data(data)
if stats_data:
# 插入数据库
success = self.db.insert_system_stats(**stats_data)
if success:
self.logger.debug("Successfully inserted system stats")
return True
else:
self.logger.error("Failed to insert system stats to database")
retry_count += 1
else:
self.logger.error("Failed to extract stats data")
retry_count += 1
except requests.exceptions.RequestException as e:
self.logger.error(f"HTTP request failed: {e}")
retry_count += 1
except Exception as e:
self.logger.exception(f"Unexpected error in fetch_system_stats: {e}")
retry_count += 1
if retry_count < max_retries:
time.sleep(2 ** retry_count) # 指数退避重试
self.logger.error(f"Failed to fetch system stats after {max_retries} retries")
return False
def _extract_stats_data(self, api_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""从API返回数据中提取有用的统计信息"""
try:
data_section = api_data.get('data', {})
self_section = data_section.get('self', {})
system_section = data_section.get('system', {})
current_timestamp = data_section.get('now_ms', 0) / 1000.0 # 转换为秒
# SRS相关数据
srs_uptime = self_section.get('srs_uptime', 0.0)
srs_cpu_percent = self_section.get('cpu_percent', 0.0)
srs_memory_percent = self_section.get('mem_percent', 0.0)
# 网络流量数据需要计算KBps
srs_recv_bytes = system_section.get('srs_recv_bytes', 0)
srs_send_bytes = system_section.get('srs_send_bytes', 0)
srs_sample_time = system_section.get('srs_sample_time', 0) / 1000.0 # 转换为秒
# 计算KBps需要与上一次的数据比较
srs_recv_KBps = 0.0
srs_send_KBps = 0.0
if (self.previous_data and
self.previous_timestamp and
current_timestamp > self.previous_timestamp):
time_diff = current_timestamp - self.previous_timestamp
prev_recv = self.previous_data.get('srs_recv_bytes', 0)
prev_send = self.previous_data.get('srs_send_bytes', 0)
if time_diff > 0:
# 计算字节差值并转换为KBps
srs_recv_KBps = max(0, (srs_recv_bytes - prev_recv) / time_diff / 1024)
srs_send_KBps = max(0, (srs_send_bytes - prev_send) / time_diff / 1024)
# print(srs_recv_KBps, srs_send_KBps) # DEBUG
# 磁盘IO数据
disk_read_KBps = system_section.get('disk_read_KBps', 0.0)
disk_write_KBps = system_section.get('disk_write_KBps', 0.0)
# 操作系统相关数据
os_uptime = system_section.get('uptime', 0.0)
os_cpu_percent = system_section.get('cpu_percent', 0.0)
os_memory_percent = system_section.get('mem_ram_percent', 0.0)
# 更新上一次的数据
self.previous_data = {
'srs_recv_bytes': srs_recv_bytes,
'srs_send_bytes': srs_send_bytes,
'srs_sample_time': srs_sample_time
}
self.previous_timestamp = current_timestamp
return {
'srs_uptime': srs_uptime,
'srs_cpu_percent': srs_cpu_percent,
'srs_memory_percent': srs_memory_percent,
'srs_recv_KBps': srs_recv_KBps,
'srs_send_KBps': srs_send_KBps,
'disk_read_KBps': disk_read_KBps,
'disk_write_KBps': disk_write_KBps,
'os_uptime': os_uptime,
'os_cpu_percent': os_cpu_percent,
'os_memory_percent': os_memory_percent
}
except Exception as e:
self.logger.exception(f"Failed to extract stats data: {e}")
return None
def start_polling(self, interval: int = 3):
"""启动长轮询统计数据收集"""
if self.is_running:
self.logger.warn("Polling is already running")
return
self.is_running = True
self.poll_thread = threading.Thread(target=self._polling_loop, args=(interval,))
self.poll_thread.daemon = True
self.poll_thread.start()
self.logger.info(f"Started system stats polling with {interval}s interval")
def stop_polling(self):
"""停止长轮询"""
if not self.is_running:
self.logger.warn("Polling is not running")
return
self.is_running = False
if self.poll_thread:
self.poll_thread.join(timeout=5)
self.logger.info("Stopped system stats polling")
def _polling_loop(self, interval: int):
"""轮询循环"""
while self.is_running:
try:
self._fetch_system_stats()
except Exception as e:
self.logger.exception(f"Error in polling loop: {e}")
# 等待指定间隔,但允许提前停止
for _ in range(interval):
if not self.is_running:
break
time.sleep(1)
def _start_cleanup_thread(self):
"""启动清理过期数据的线程"""
self.cleanup_thread = threading.Thread(target=self._cleanup_loop)
self.cleanup_thread.daemon = True
self.cleanup_thread.start()
self.logger.info("Started system stats cleanup thread (5 minute interval)")
def _cleanup_loop(self):
"""定期清理过期数据的循环"""
while True:
try:
# 执行清理
success = self.db.delete_expired_system_stats()
if success:
self.logger.debug("Successfully cleaned up expired system stats")
else:
self.logger.warn("Failed to clean up expired system stats")
time.sleep(300)
except Exception as e:
self.logger.exception(f"Error in cleanup loop: {e}")
# 发生错误后等待一分钟再继续
time.sleep(60)
def _process_stats_data(self, raw_stats: list, time_delta: int, time_interval: int) -> list:
"""
处理原始统计数据按指定时间间隔进行插值和过滤
Args:
raw_stats: 原始统计数据列表
time_delta: 时间范围
time_interval: 时间间隔
Returns:
处理后的统计数据列表
"""
try:
if not raw_stats:
return []
# 参数验证
if time_interval <= 0:
self.logger.warn("Invalid time_interval, using 1 second")
time_interval = 1
if time_delta <= 0:
self.logger.warn("Invalid time_delta, using 60 seconds")
time_delta = 60
# 转换时间戳并排序
for stat in raw_stats:
if isinstance(stat['timestamp'], str):
stat['timestamp'] = datetime.fromisoformat(stat['timestamp'])
elif not isinstance(stat['timestamp'], datetime):
# 如果是其他格式,尝试解析
stat['timestamp'] = datetime.fromisoformat(str(stat['timestamp']))
# 按时间戳排序
raw_stats.sort(key=lambda x: x['timestamp'])
# 确定时间范围
now = datetime.now().replace(microsecond=0) # 取整到秒
start_time = now - timedelta(seconds=time_delta)
# 过滤掉超出时间范围的数据
filtered_stats = [stat for stat in raw_stats if stat['timestamp'] >= start_time]
if not filtered_stats:
self.logger.warn("No data points in specified time range")
return []
# 生成目标时间点列表(从最近一次记录向前)
target_times = []
current_time = now
while current_time >= start_time:
target_times.append(current_time)
current_time -= timedelta(seconds=time_interval)
# 反转列表,使其按时间顺序排列
target_times.reverse()
# 对每个目标时间点进行插值
processed_stats = []
numeric_fields = [
'srs_uptime', 'srs_cpu_percent', 'srs_memory_percent',
'srs_recv_KBps', 'srs_send_KBps', 'disk_read_KBps',
'disk_write_KBps', 'os_uptime', 'os_cpu_percent', 'os_memory_percent'
]
for target_time in target_times:
interpolated_data = self._interpolate_data_point(filtered_stats, target_time, numeric_fields)
if interpolated_data:
processed_stats.append(interpolated_data)
self.logger.debug(f"Processed {len(raw_stats)} raw points into {len(processed_stats)} interpolated points")
return processed_stats
except Exception as e:
self.logger.exception(f"Error processing stats data: {e}")
return raw_stats # 如果处理失败,返回原始数据
def _interpolate_data_point(self, raw_stats: list, target_time: datetime, numeric_fields: list) -> Optional[Dict[str, Any]]:
"""
为指定时间点插值数据
Args:
raw_stats: 原始统计数据列表已排序
target_time: 目标时间点
numeric_fields: 需要插值的数值字段列表
Returns:
插值后的数据点
"""
try:
# 查找最接近的前后两个数据点
before_point = None
after_point = None
for i, stat in enumerate(raw_stats):
stat_time = stat['timestamp']
if stat_time <= target_time:
before_point = stat
elif stat_time > target_time:
after_point = stat
break
# 如果没有找到合适的数据点,使用最近的点
if before_point is None and after_point is None:
return None
elif before_point is None:
# 只有后面的点,直接使用
result = after_point.copy()
result['timestamp'] = target_time.replace(microsecond=0)
return result
elif after_point is None:
# 只有前面的点,直接使用
result = before_point.copy()
result['timestamp'] = target_time.replace(microsecond=0)
return result
# 计算时间差和权重
before_time = before_point['timestamp']
after_time = after_point['timestamp']
# 如果时间点完全匹配,直接返回
if before_time == target_time:
return before_point.copy()
if after_time == target_time:
return after_point.copy()
# 计算插值权重
total_diff = (after_time - before_time).total_seconds()
if total_diff == 0:
# 两个点时间相同,使用前一个点
result = before_point.copy()
result['timestamp'] = target_time.replace(microsecond=0)
return result
target_diff = (target_time - before_time).total_seconds()
weight = target_diff / total_diff
# 执行线性插值
result = {'timestamp': target_time.replace(microsecond=0)} # 确保时间戳取整到秒
for field in numeric_fields:
if field in before_point and field in after_point:
before_val = float(before_point[field]) if before_point[field] is not None else 0.0
after_val = float(after_point[field]) if after_point[field] is not None else 0.0
# 线性插值
interpolated_val = before_val + (after_val - before_val) * weight
result[field] = round(interpolated_val, 4) # 保留4位小数
else:
# 如果字段不存在,使用前一个点的值
result[field] = before_point.get(field, 0.0)
return result
except Exception as e:
self.logger.exception(f"Error interpolating data point: {e}")
return None
def get_recent_stats(self, user_group: list, time_delta: int = 5, time_interval: int = 1):
"""获取最近的统计数据"""
try:
if not any(group in user_group for group in ['streamer', 'manager', 'admin']):
return {"success": False, "message": "权限不足,无法获取统计数据"}
# 获取原始数据,多获取一些以便插值
raw_stats = self.db.get_system_stats(time_delta + 5) # 多获取60秒数据用于插值
if not raw_stats:
return {"success": False, "message": "没有找到相关的统计数据"}
# 处理数据,按指定间隔进行插值和过滤
processed_stats = self._process_stats_data(raw_stats, time_delta, time_interval)
# 转换时间戳为字符串格式便于JSON序列化
for stat in processed_stats:
if isinstance(stat['timestamp'], datetime):
stat['timestamp'] = stat['timestamp'].isoformat()
return {
"success": True,
"system_stats": processed_stats,
"metadata": {
"time_delta": time_delta,
"time_interval": time_interval,
"data_points": len(processed_stats),
"original_points": len(raw_stats)
}
}
except Exception as e:
self.logger.exception(f"Error in get_recent_stats: {e}")
return {"success": False, "message": "服务器内部错误"}

View File

View File

@ -1,52 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import sys, shlex, time, subprocess
cmd = "./python.subprocess 0 80000"
args = shlex.split(str(cmd))
print "cmd: %s, args: %s"%(cmd, args)
# the communicate will read all data and wait for sub process to quit.
def use_communicate(args, fout, ferr):
process = subprocess.Popen(args, stdout=fout, stderr=ferr)
(stdout_str, stderr_str) = process.communicate()
return (stdout_str, stderr_str)
# if use subprocess.PIPE, the pipe will full about 50KB data,
# and sub process will blocked, then timeout will kill it.
def use_poll(args, fout, ferr, timeout):
(stdout_str, stderr_str) = (None, None)
process = subprocess.Popen(args, stdout=fout, stderr=ferr)
starttime = time.time()
while True:
process.poll()
if process.returncode is not None:
(stdout_str, stderr_str) = process.communicate()
break
if timeout > 0 and time.time() - starttime >= timeout:
print "timeout, kill process. timeout=%s"%(timeout)
process.kill()
break
time.sleep(1)
process.wait()
return (stdout_str, stderr_str)
# stdout/stderr can be fd, fileobject, subprocess.PIPE, None
fnull = open("/dev/null", "rw")
fout = fnull.fileno()#subprocess.PIPE#fnull#fnull.fileno()
ferr = fnull.fileno()#subprocess.PIPE#fnull#fnull.fileno()
print "fout=%s, ferr=%s"%(fout, ferr)
#(stdout_str, stderr_str) = use_communicate(args, fout, ferr)
(stdout_str, stderr_str) = use_poll(args, fout, ferr, 10)
def print_result(stdout_str, stderr_str):
if stdout_str is None:
stdout_str = ""
if stderr_str is None:
stderr_str = ""
print "terminated, size of stdout=%s, stderr=%s"%(len(stdout_str), len(stderr_str))
while True:
time.sleep(1)
print_result(stdout_str, stderr_str)

View File

@ -1,32 +0,0 @@
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
/**
# always to print to stdout and stderr.
g++ python.subprocess.cpp -o python.subprocess
*/
int main(int argc, char** argv) {
if (argc <= 2) {
printf("Usage: <%s> <interval_ms> <max_loop>\n"
" %s 50 100000\n", argv[0], argv[0]);
exit(-1);
return -1;
}
int interval_ms = ::atoi(argv[1]);
int max_loop = ::atoi(argv[2]);
printf("always to print to stdout and stderr.\n");
printf("interval: %d ms\n", interval_ms);
printf("max_loop: %d\n", max_loop);
for (int i = 0; i < max_loop; i++) {
fprintf(stdout, "always to print to stdout and stderr. interval=%dms, max=%d, current=%d\n", interval_ms, max_loop, i);
fprintf(stderr, "always to print to stdout and stderr. interval=%dms, max=%d, current=%d\n", interval_ms, max_loop, i);
if (interval_ms > 0) {
usleep(interval_ms * 1000);
}
}
return 0;
}