Setup Frontend Livestream site and Backend HTTP Server by python
This commit is contained in:
parent
bf4ea37ea3
commit
c467522328
|
|
@ -1,37 +1,186 @@
|
|||
# no-daemon and write log to console config for srs.
|
||||
# @see full.conf for detail config.
|
||||
# docker config for srs.
|
||||
# @see full.conf for detail config explanation.
|
||||
#############################################################################################
|
||||
# Last modified: 2025-05-10 Jason Yang
|
||||
# Contributor: SRS Team, Jason Yang, Jasper, HyperKNF, Hayden, NPL ITP Team
|
||||
#############################################################################################
|
||||
|
||||
listen 1935;
|
||||
max_connections 1000;
|
||||
daemon off;
|
||||
srs_log_tank all;
|
||||
#############################################################################################
|
||||
# Global sections
|
||||
#############################################################################################
|
||||
ff_log_dir ./objs;
|
||||
ff_log_level info;
|
||||
|
||||
srs_log_tank all;
|
||||
srs_log_file ./objs/srs.log;
|
||||
# TRACE, DEBUG, INFO, WARN, ERROR
|
||||
srs_log_level_v2 info;
|
||||
|
||||
max_connections 1000;
|
||||
daemon off;
|
||||
utc_time off;
|
||||
|
||||
#############################################################################################
|
||||
# RTMP sections
|
||||
#############################################################################################
|
||||
listen 1935;
|
||||
chunk_size 60000;
|
||||
|
||||
#############################################################################################
|
||||
# HTTP sections
|
||||
#############################################################################################
|
||||
http_api {
|
||||
enabled on;
|
||||
listen 1985;
|
||||
enabled on;
|
||||
listen 1985;
|
||||
crossdomain on;
|
||||
auth {
|
||||
enabled on;
|
||||
username python_stats;
|
||||
password wMePq3ahpoLRzgsVg7BY9eE82uuJHT0YukD2ZE1JfMY2RjP4e6QnUaKg3V9x5s9M;
|
||||
}
|
||||
https {
|
||||
enabled off;
|
||||
listen 1990;
|
||||
key ./conf/server.key;
|
||||
cert ./conf/server.crt;
|
||||
}
|
||||
}
|
||||
http_server {
|
||||
enabled on;
|
||||
listen 8080;
|
||||
enabled on;
|
||||
listen 8080;
|
||||
dir ./objs/nginx/html;
|
||||
crossdomain on;
|
||||
https {
|
||||
enabled off;
|
||||
listen 8088;
|
||||
key ./conf/server.key;
|
||||
cert ./conf/server.crt;
|
||||
}
|
||||
}
|
||||
|
||||
#############################################################################################
|
||||
# SRT server section
|
||||
#############################################################################################
|
||||
srt_server {
|
||||
# whether SRT server is enabled.
|
||||
enabled off;
|
||||
listen 10080;
|
||||
|
||||
maxbw 1000000000;
|
||||
mss 1500;
|
||||
|
||||
connect_timeout 4000;
|
||||
peer_idle_timeout 8000;
|
||||
|
||||
default_app live;
|
||||
peerlatency 0;
|
||||
recvlatency 0;
|
||||
latency 0;
|
||||
|
||||
tsbpdmode off;
|
||||
tlpktdrop off;
|
||||
sendbuf 2000000;
|
||||
recvbuf 2000000;
|
||||
}
|
||||
|
||||
#############################################################################################
|
||||
# WebRTC server section
|
||||
#############################################################################################
|
||||
rtc_server {
|
||||
enabled on;
|
||||
listen 8000; # UDP port
|
||||
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate
|
||||
candidate $CANDIDATE;
|
||||
enabled on;
|
||||
listen 8000;
|
||||
|
||||
protocol udp;
|
||||
tcp {
|
||||
enabled off;
|
||||
listen 8000;
|
||||
}
|
||||
candidate $CANDIDATE;
|
||||
ip_family ipv4;
|
||||
api_as_candidates on;
|
||||
resolve_api_domain on;
|
||||
|
||||
ecdsa on;
|
||||
encrypt on;
|
||||
}
|
||||
|
||||
#############################################################################################
|
||||
# PYTHON_ADDONS sections
|
||||
#############################################################################################
|
||||
python_addons {
|
||||
# Enable or disable Python addon management
|
||||
enabled on;
|
||||
|
||||
# Python addon definitions
|
||||
# Each addon block defines a Python script to run
|
||||
addon {
|
||||
script "./python/httpbackend_server.py";
|
||||
}
|
||||
}
|
||||
|
||||
#############################################################################################
|
||||
# VHOST sections
|
||||
#############################################################################################
|
||||
vhost __defaultVhost__ {
|
||||
enabled on;
|
||||
|
||||
hls {
|
||||
enabled on;
|
||||
enabled on;
|
||||
}
|
||||
srt {
|
||||
# Whether enable SRT on this vhost.
|
||||
enabled on;
|
||||
srt_to_rtmp on;
|
||||
}
|
||||
|
||||
http_remux {
|
||||
enabled on;
|
||||
mount [vhost]/[app]/[stream].flv;
|
||||
enabled on;
|
||||
mount [vhost]/[app]/[stream].flv;
|
||||
}
|
||||
|
||||
rtc {
|
||||
enabled on;
|
||||
enabled on;
|
||||
|
||||
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc
|
||||
rtmp_to_rtc on;
|
||||
rtmp_to_rtc on;
|
||||
keep_bframe on;
|
||||
# opus_bitrate 48000;
|
||||
|
||||
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp
|
||||
rtc_to_rtmp on;
|
||||
rtc_to_rtmp on;
|
||||
pli_for_rtmp 6.0;
|
||||
aac_bitrate 48000;
|
||||
}
|
||||
}
|
||||
|
||||
chunk_size 128;
|
||||
tcp_nodelay on;
|
||||
min_latency on;
|
||||
play {
|
||||
gop_cache off;
|
||||
queue_length 10;
|
||||
mw_latency 100;
|
||||
mw_msgs 0;
|
||||
}
|
||||
publish {
|
||||
mr off;
|
||||
mr_latency 300;
|
||||
firstpkt_timeout 20000;
|
||||
normal_timeout 5000;
|
||||
}
|
||||
|
||||
security {
|
||||
enabled off;
|
||||
allow play all;
|
||||
allow publish all;
|
||||
}
|
||||
|
||||
dvr {
|
||||
enabled on;
|
||||
dvr_path ./DVR_Record/[stream]/[2006].[01].[02].[15].[04].[05].mp4;
|
||||
dvr_plan segment;
|
||||
dvr_duration 30;
|
||||
dvr_apply /^live\/.*/;
|
||||
dvr_wait_keyframe on;
|
||||
time_jitter full;
|
||||
}
|
||||
}
|
||||
1
trunk/livestream_site
Submodule
1
trunk/livestream_site
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 348b347bd7e1cf9d59acff0ba1bdf5390f0751d7
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
SRS Analytics Script
|
||||
This demonstrates another Python process that can run alongside SRS.
|
||||
"""
|
||||
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from threading import Thread
|
||||
from datetime import datetime
|
||||
|
||||
class AnalyticsHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
"""Handle GET requests for analytics data"""
|
||||
if self.path == '/stats':
|
||||
# Return sample analytics data
|
||||
stats = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'connections': 42,
|
||||
'streams': 5,
|
||||
'bandwidth': '1.2 Mbps',
|
||||
'uptime': '2h 30m'
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(stats, indent=2).encode())
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Override to use our logger"""
|
||||
pass
|
||||
|
||||
class SRSAnalytics:
|
||||
def __init__(self, port=8888):
|
||||
self.port = port
|
||||
self.running = True
|
||||
self.server = None
|
||||
self.server_thread = None
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[%(asctime)s] [Python Analytics] %(levelname)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGTERM, self.signal_handler)
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
|
||||
def signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals from SRS"""
|
||||
self.logger.info(f"Received signal {signum}, shutting down gracefully...")
|
||||
self.running = False
|
||||
if self.server:
|
||||
self.server.shutdown()
|
||||
|
||||
def start_http_server(self):
|
||||
"""Start the HTTP analytics server"""
|
||||
try:
|
||||
self.server = HTTPServer(('localhost', self.port), AnalyticsHandler)
|
||||
self.logger.info(f"Analytics server started on http://localhost:{self.port}")
|
||||
self.server.serve_forever()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error starting HTTP server: {e}")
|
||||
|
||||
def run(self):
|
||||
"""Main analytics loop"""
|
||||
self.logger.info(f"SRS Analytics started on port {self.port}")
|
||||
|
||||
try:
|
||||
# Start HTTP server in a separate thread
|
||||
self.server_thread = Thread(target=self.start_http_server)
|
||||
self.server_thread.daemon = True
|
||||
self.server_thread.start()
|
||||
|
||||
# Main analytics loop
|
||||
while self.running:
|
||||
# Simulate analytics work
|
||||
self.logger.debug("Processing analytics data...")
|
||||
|
||||
# You can add your analytics logic here:
|
||||
# - Collect stream metrics
|
||||
# - Process viewer statistics
|
||||
# - Generate reports
|
||||
# - Store data to database
|
||||
|
||||
time.sleep(10) # Process every 10 seconds
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in analytics loop: {e}")
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup before shutdown"""
|
||||
self.logger.info("Cleaning up Analytics server...")
|
||||
if self.server:
|
||||
self.server.shutdown()
|
||||
self.logger.info("Analytics server stopped")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='SRS Analytics Server')
|
||||
parser.add_argument('--port', type=int, default=8888, help='HTTP server port')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
analytics = SRSAnalytics(args.port)
|
||||
analytics.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
616
trunk/python/avatar_module.py
Normal file
616
trunk/python/avatar_module.py
Normal file
|
|
@ -0,0 +1,616 @@
|
|||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
class AvatarGenerator:
|
||||
DEFAULT_COLORS = ["#000000", "#8f1414", "#e50e0e", "#f3450f", "#fcac03"]
|
||||
DEFAULT_SIZE = 80
|
||||
DEFAULT_VARIANT = 'marble'
|
||||
VALID_VARIANTS = {'marble', 'beam', 'pixel', 'sunset', 'ring', 'bauhaus'}
|
||||
|
||||
def __init__(self, default_colors: Optional[List[str]] = None, default_size: Optional[int] = None, default_variant: Optional[str] = None):
|
||||
self.colors = default_colors if default_colors else self.DEFAULT_COLORS
|
||||
self.size = default_size if default_size else self.DEFAULT_SIZE
|
||||
self.variant = default_variant if default_variant else self.DEFAULT_VARIANT
|
||||
|
||||
self.AVATAR_GENERATORS = {
|
||||
'marble': self._generate_marble_avatar,
|
||||
'beam': self._generate_beam_avatar,
|
||||
'pixel': self._generate_pixel_avatar,
|
||||
'sunset': self._generate_sunset_avatar,
|
||||
'ring': self._generate_ring_avatar,
|
||||
'bauhaus': self._generate_bauhaus_avatar,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _hash_code(name: str) -> int:
|
||||
"""Generate hash from name string - exact port of hashCode function"""
|
||||
try:
|
||||
if not name or not isinstance(name, str):
|
||||
# Consider raising a TypeError or returning a default hash for invalid input
|
||||
return 12345 # Fallback for invalid input
|
||||
hash_val = 0
|
||||
for char in name:
|
||||
hash_val = (hash_val << 5) - hash_val + ord(char)
|
||||
hash_val &= hash_val # Convert to 32bit integer
|
||||
return abs(hash_val)
|
||||
except Exception:
|
||||
return 12345 # fallback hash value
|
||||
|
||||
@staticmethod
|
||||
def _get_modulus(num: int, max_val: int) -> int:
|
||||
"""Get modulus"""
|
||||
return num % max_val
|
||||
|
||||
@staticmethod
|
||||
def _get_digit(number: int, ntn: int) -> int:
|
||||
"""Get digit at position"""
|
||||
try:
|
||||
if not isinstance(number, int) or not isinstance(ntn, int):
|
||||
return 0 # Fallback for invalid input type
|
||||
return int((number / (10 ** ntn)) % 10)
|
||||
except (ZeroDivisionError, ValueError, OverflowError):
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _get_boolean(number: int, ntn: int) -> bool:
|
||||
"""Get boolean from digit"""
|
||||
return not ((AvatarGenerator._get_digit(number, ntn)) % 2)
|
||||
|
||||
@staticmethod
|
||||
def _get_angle(x: float, y: float) -> float:
|
||||
"""Get angle from coordinates"""
|
||||
return math.atan2(y, x) * 180 / math.pi
|
||||
|
||||
@staticmethod
|
||||
def _get_unit(number: int, range_val: int, index: int = None) -> int:
|
||||
"""Get unit value with optional negative"""
|
||||
try:
|
||||
if not isinstance(number, int) or not isinstance(range_val, int) or range_val == 0:
|
||||
return 0 # Fallback for invalid input
|
||||
value = number % range_val
|
||||
if index and ((AvatarGenerator._get_digit(number, index) % 2) == 0):
|
||||
return -value
|
||||
return value
|
||||
except (ZeroDivisionError, ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _get_random_color(number: int, colors: List[str], range_val: int) -> str:
|
||||
"""Get random color from palette"""
|
||||
try:
|
||||
if not colors or range_val == 0:
|
||||
return AvatarGenerator.DEFAULT_COLORS[0] # Fallback
|
||||
return colors[number % range_val]
|
||||
except (IndexError, TypeError, ZeroDivisionError):
|
||||
return AvatarGenerator.DEFAULT_COLORS[0]
|
||||
|
||||
@staticmethod
|
||||
def _get_contrast(hexcolor: str) -> str:
|
||||
"""Get contrasting color (black or white)"""
|
||||
try:
|
||||
if not hexcolor or not isinstance(hexcolor, str):
|
||||
return '#000000' # Fallback
|
||||
|
||||
# Remove leading # if present
|
||||
if hexcolor.startswith('#'):
|
||||
hexcolor = hexcolor[1:]
|
||||
|
||||
# Clean up any remaining quotes or whitespace
|
||||
hexcolor = hexcolor.strip().replace('"', '').replace("'", "")
|
||||
|
||||
# Ensure we have exactly 6 hex characters
|
||||
if len(hexcolor) != 6:
|
||||
return '#000000' # Fallback
|
||||
|
||||
# Validate hex characters
|
||||
try:
|
||||
int(hexcolor, 16)
|
||||
except ValueError:
|
||||
return '#000000' # Fallback
|
||||
|
||||
# Convert to RGB
|
||||
r = int(hexcolor[0:2], 16)
|
||||
g = int(hexcolor[2:4], 16)
|
||||
b = int(hexcolor[4:6], 16)
|
||||
|
||||
# Get YIQ ratio
|
||||
yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
|
||||
|
||||
# Check contrast
|
||||
return '#000000' if yiq >= 128 else '#FFFFFF'
|
||||
except Exception:
|
||||
return '#000000'
|
||||
|
||||
@staticmethod
|
||||
def normalize_colors(colors_param: Optional[str]) -> List[str]:
|
||||
"""Normalize color palette from query parameter"""
|
||||
try:
|
||||
if not colors_param:
|
||||
return AvatarGenerator.DEFAULT_COLORS
|
||||
|
||||
# Clean up URL encoding
|
||||
cleaned = colors_param.replace('%22', '"').replace('%20', ' ').replace('%5B', '[').replace('%5D', ']').strip()
|
||||
|
||||
# Extract colors from different formats
|
||||
color_list = AvatarGenerator._extract_color_strings(cleaned)
|
||||
|
||||
# Validate and normalize colors
|
||||
return AvatarGenerator._validate_color_list(color_list)
|
||||
|
||||
except Exception:
|
||||
return AvatarGenerator.DEFAULT_COLORS
|
||||
|
||||
@staticmethod
|
||||
def _extract_color_strings(cleaned_colors: str) -> List[str]:
|
||||
"""Extract color strings from cleaned input"""
|
||||
# Handle array format: ["#xxxxxx", "#xxxxxx", ...]
|
||||
if cleaned_colors.startswith('[') and cleaned_colors.endswith(']'):
|
||||
inner_content = cleaned_colors[1:-1].strip()
|
||||
if not inner_content:
|
||||
return []
|
||||
return [part.strip().replace('"', '').replace("'", '') for part in inner_content.split(',')]
|
||||
|
||||
# Handle comma-separated format: #xxxxxx, #xxxxxx, ...
|
||||
return [part.strip().replace('"', '').replace("'", '') for part in cleaned_colors.split(',')]
|
||||
|
||||
@staticmethod
|
||||
def _validate_color_list(color_list: List[str]) -> List[str]:
|
||||
"""Validate and normalize hex colors"""
|
||||
normalized_colors = []
|
||||
|
||||
for color in color_list:
|
||||
color = color.strip()
|
||||
if not color:
|
||||
continue
|
||||
if not color.startswith('#'):
|
||||
color = '#' + color
|
||||
if AvatarGenerator._is_valid_hex_color(color):
|
||||
normalized_colors.append(color)
|
||||
|
||||
return normalized_colors if normalized_colors else AvatarGenerator.DEFAULT_COLORS
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_hex_color(color: str) -> bool:
|
||||
"""Check if string is valid hex color"""
|
||||
return len(color) == 7 and all(c in '0123456789ABCDEFabcdef' for c in color[1:])
|
||||
|
||||
def _generate_marble_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
|
||||
"""Generate marble variant avatar"""
|
||||
ELEMENTS = 3
|
||||
SIZE = 80 # Intrinsic size of the SVG design
|
||||
|
||||
def generate_colors_marble(name: str, current_colors: List[str]):
|
||||
num_from_name = self._hash_code(name)
|
||||
range_val = len(current_colors) if current_colors else len(self.DEFAULT_COLORS)
|
||||
|
||||
elements_properties = []
|
||||
for i in range(ELEMENTS):
|
||||
elements_properties.append({
|
||||
'color': self._get_random_color(num_from_name + i, current_colors, range_val),
|
||||
'translateX': self._get_unit(num_from_name * (i + 1), SIZE // 10, 1),
|
||||
'translateY': self._get_unit(num_from_name * (i + 1), SIZE // 10, 2),
|
||||
'scale': 1.2 + self._get_unit(num_from_name * (i + 1), SIZE // 20) / 10,
|
||||
'rotate': self._get_unit(num_from_name * (i + 1), 360, 1)
|
||||
})
|
||||
|
||||
return elements_properties
|
||||
|
||||
properties = generate_colors_marble(name, colors)
|
||||
mask_id = f"mask_{self._hash_code(name)}"
|
||||
filter_id = f"filter_{self._hash_code(name)}"
|
||||
|
||||
# Use the passed 'size' for width/height, internal 'SIZE' for viewBox
|
||||
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
|
||||
|
||||
if title:
|
||||
svg += f'<title>{name}</title>'
|
||||
|
||||
svg += f'''
|
||||
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
|
||||
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
|
||||
</mask>
|
||||
<g mask="url(#{mask_id})">
|
||||
<rect width="{SIZE}" height="{SIZE}" fill="{properties[0]['color']}" />
|
||||
<path
|
||||
filter="url(#{filter_id})"
|
||||
d="M32.414 59.35L50.376 70.5H72.5v-71H33.728L26.5 13.381l19.057 27.08L32.414 59.35z"
|
||||
fill="{properties[1]['color']}"
|
||||
transform="translate({properties[1]['translateX']} {properties[1]['translateY']}) rotate({properties[1]['rotate']} {SIZE // 2} {SIZE // 2}) scale({properties[1]['scale']})"
|
||||
/>
|
||||
<path
|
||||
filter="url(#{filter_id})"
|
||||
style="mix-blend-mode: overlay;"
|
||||
d="M22.216 24L0 46.75l14.108 38.129L78 86l-3.081-59.276-22.378 4.005 12.972 20.186-23.35 27.395L22.215 24z"
|
||||
fill="{properties[2]['color']}"
|
||||
transform="translate({properties[2]['translateX']} {properties[2]['translateY']}) rotate({properties[2]['rotate']} {SIZE // 2} {SIZE // 2}) scale({properties[2]['scale']})"
|
||||
/>
|
||||
</g>
|
||||
<defs>
|
||||
<filter
|
||||
id="{filter_id}"
|
||||
filterUnits="userSpaceOnUse"
|
||||
colorInterpolationFilters="sRGB"
|
||||
>
|
||||
<feFlood floodOpacity="0" result="BackgroundImageFix" />
|
||||
<feBlend in="SourceGraphic" in2="BackgroundImageFix" result="shape" />
|
||||
<feGaussianBlur stdDeviation="7" result="effect1_foregroundBlur" />
|
||||
</filter>
|
||||
</defs>
|
||||
</svg>'''
|
||||
|
||||
return svg
|
||||
|
||||
def _generate_beam_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
|
||||
"""Generate beam variant avatar"""
|
||||
SIZE = 36 # Intrinsic size of the SVG design
|
||||
|
||||
def generate_data_beam(name: str, current_colors: List[str]):
|
||||
num_from_name = self._hash_code(name)
|
||||
range_val = len(current_colors) if current_colors else len(self.DEFAULT_COLORS)
|
||||
wrapper_color = self._get_random_color(num_from_name, current_colors, range_val)
|
||||
pre_translate_x = self._get_unit(num_from_name, 10, 1)
|
||||
wrapper_translate_x = pre_translate_x + SIZE // 9 if pre_translate_x < 5 else pre_translate_x
|
||||
pre_translate_y = self._get_unit(num_from_name, 10, 2)
|
||||
wrapper_translate_y = pre_translate_y + SIZE // 9 if pre_translate_y < 5 else pre_translate_y
|
||||
|
||||
data = {
|
||||
'wrapperColor': wrapper_color,
|
||||
'faceColor': self._get_contrast(wrapper_color),
|
||||
'backgroundColor': self._get_random_color(num_from_name + 13, current_colors, range_val),
|
||||
'wrapperTranslateX': wrapper_translate_x,
|
||||
'wrapperTranslateY': wrapper_translate_y,
|
||||
'wrapperRotate': self._get_unit(num_from_name, 360),
|
||||
'wrapperScale': 1 + self._get_unit(num_from_name, SIZE // 12) / 10,
|
||||
'isMouthOpen': self._get_boolean(num_from_name, 2),
|
||||
'isCircle': self._get_boolean(num_from_name, 1),
|
||||
'eyeSpread': self._get_unit(num_from_name, 5),
|
||||
'mouthSpread': self._get_unit(num_from_name, 3),
|
||||
'faceRotate': self._get_unit(num_from_name, 10, 3),
|
||||
'faceTranslateX': wrapper_translate_x // 2 if wrapper_translate_x > SIZE // 6 else self._get_unit(num_from_name, 8, 1),
|
||||
'faceTranslateY': wrapper_translate_y // 2 if wrapper_translate_y > SIZE // 6 else self._get_unit(num_from_name, 7, 2),
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
data = generate_data_beam(name, colors)
|
||||
mask_id = f"mask_{self._hash_code(name)}"
|
||||
|
||||
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
|
||||
|
||||
if title:
|
||||
svg += f'<title>{name}</title>'
|
||||
|
||||
rx_value = SIZE if data['isCircle'] else SIZE // 6
|
||||
|
||||
svg += f'''
|
||||
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
|
||||
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
|
||||
</mask>
|
||||
<g mask="url(#{mask_id})">
|
||||
<rect width="{SIZE}" height="{SIZE}" fill="{data['backgroundColor']}" />
|
||||
<rect
|
||||
x="0"
|
||||
y="0"
|
||||
width="{SIZE}"
|
||||
height="{SIZE}"
|
||||
transform="translate({data['wrapperTranslateX']} {data['wrapperTranslateY']}) rotate({data['wrapperRotate']} {SIZE // 2} {SIZE // 2}) scale({data['wrapperScale']})"
|
||||
fill="{data['wrapperColor']}"
|
||||
rx="{rx_value}"
|
||||
/>
|
||||
<g transform="translate({data['faceTranslateX']} {data['faceTranslateY']}) rotate({data['faceRotate']} {SIZE // 2} {SIZE // 2})">'''
|
||||
|
||||
if data['isMouthOpen']:
|
||||
svg += f'''
|
||||
<path
|
||||
d="M15 {19 + data['mouthSpread']}c2 1 4 1 6 0"
|
||||
stroke="{data['faceColor']}"
|
||||
fill="none"
|
||||
strokeLinecap="round"
|
||||
/>'''
|
||||
else:
|
||||
svg += f'''
|
||||
<path
|
||||
d="M13,{19 + data['mouthSpread']} a1,0.75 0 0,0 10,0"
|
||||
fill="{data['faceColor']}"
|
||||
/>'''
|
||||
|
||||
svg += f'''
|
||||
<rect
|
||||
x="{14 - data['eyeSpread']}"
|
||||
y="14"
|
||||
width="1.5"
|
||||
height="2"
|
||||
rx="1"
|
||||
stroke="none"
|
||||
fill="{data['faceColor']}"
|
||||
/>
|
||||
<rect
|
||||
x="{20 + data['eyeSpread']}"
|
||||
y="14"
|
||||
width="1.5"
|
||||
height="2"
|
||||
rx="1"
|
||||
stroke="none"
|
||||
fill="{data['faceColor']}"
|
||||
/>
|
||||
</g>
|
||||
</g>
|
||||
</svg>'''
|
||||
|
||||
return svg
|
||||
|
||||
def _generate_pixel_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
|
||||
"""Generate pixel variant avatar"""
|
||||
ELEMENTS = 64 # 8x8 grid
|
||||
SIZE = 80 # Intrinsic SVG size
|
||||
|
||||
def generate_colors_pixel(name: str, current_colors: List[str]):
|
||||
num_from_name = self._hash_code(name)
|
||||
# Use current_colors if provided, otherwise fallback to instance default, then class default
|
||||
final_colors = current_colors if current_colors else self.colors
|
||||
range_val = len(final_colors)
|
||||
|
||||
color_list = []
|
||||
for i in range(ELEMENTS):
|
||||
color_list.append(self._get_random_color(num_from_name + i, final_colors, range_val))
|
||||
return color_list
|
||||
|
||||
pixel_colors = generate_colors_pixel(name, colors)
|
||||
mask_id = f"mask_{self._hash_code(name)}"
|
||||
|
||||
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
|
||||
|
||||
if title:
|
||||
svg += f'<title>{name}</title>'
|
||||
|
||||
svg += f'''
|
||||
<mask id="{mask_id}" mask-type="alpha" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
|
||||
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
|
||||
</mask>
|
||||
<g mask="url(#{mask_id})">'''
|
||||
|
||||
# Generate 8x8 grid of pixels
|
||||
idx = 0
|
||||
pixel_size = SIZE // 8 # Each pixel is 10x10 if SIZE is 80
|
||||
for row in range(8):
|
||||
for col in range(8):
|
||||
svg += f'''
|
||||
<rect x="{col * pixel_size}" y="{row * pixel_size}" width="{pixel_size}" height="{pixel_size}" fill="{pixel_colors[idx]}" />'''
|
||||
idx +=1
|
||||
|
||||
svg += '''
|
||||
</g>
|
||||
</svg>'''
|
||||
|
||||
return svg
|
||||
|
||||
def _generate_sunset_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
|
||||
"""Generate sunset variant avatar"""
|
||||
ELEMENTS = 4 # For 4 gradient stops
|
||||
SIZE = 80 # Intrinsic SVG size
|
||||
|
||||
def generate_colors_sunset(name: str, current_colors: List[str]):
|
||||
num_from_name = self._hash_code(name)
|
||||
# Use current_colors if provided, otherwise fallback to instance default, then class default
|
||||
final_colors = current_colors if current_colors else self.colors
|
||||
# Ensure we have at least 4 colors for sunset, repeat if necessary
|
||||
if len(final_colors) < ELEMENTS:
|
||||
final_colors = (final_colors * (ELEMENTS // len(final_colors) + 1))[:ELEMENTS]
|
||||
|
||||
range_val = len(final_colors)
|
||||
|
||||
color_list = []
|
||||
for i in range(ELEMENTS):
|
||||
color_list.append(self._get_random_color(num_from_name + i, final_colors, range_val))
|
||||
return color_list
|
||||
|
||||
sunset_colors = generate_colors_sunset(name, colors)
|
||||
# name_without_space = name.replace(' ', '').replace('-', '').replace('_', '') # Not used
|
||||
hash_val = abs(self._hash_code(name))
|
||||
mask_id = f"mask_{hash_val}"
|
||||
gradient_id_1 = f"gradient_paint0_linear_{hash_val}"
|
||||
gradient_id_2 = f"gradient_paint1_linear_{hash_val}"
|
||||
|
||||
rx_value = '' if square else f'rx="{SIZE * 2}"'
|
||||
|
||||
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
|
||||
|
||||
if title:
|
||||
svg += f'<title>{name}</title>'
|
||||
|
||||
svg += f'''
|
||||
<defs>
|
||||
<linearGradient
|
||||
id="{gradient_id_1}"
|
||||
x1="{SIZE // 2}"
|
||||
y1="0"
|
||||
x2="{SIZE // 2}"
|
||||
y2="{SIZE // 2}"
|
||||
gradientUnits="userSpaceOnUse"
|
||||
>
|
||||
<stop stop-color="{sunset_colors[0]}" />
|
||||
<stop offset="1" stop-color="{sunset_colors[1]}" />
|
||||
</linearGradient>
|
||||
<linearGradient
|
||||
id="{gradient_id_2}"
|
||||
x1="{SIZE // 2}"
|
||||
y1="{SIZE // 2}"
|
||||
x2="{SIZE // 2}"
|
||||
y2="{SIZE}"
|
||||
gradientUnits="userSpaceOnUse"
|
||||
>
|
||||
<stop stop-color="{sunset_colors[2]}" />
|
||||
<stop offset="1" stop-color="{sunset_colors[3]}" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
|
||||
<rect width="{SIZE}" height="{SIZE}" {rx_value} fill="#FFFFFF" />
|
||||
</mask>
|
||||
<g mask="url(#{mask_id})">
|
||||
<path fill="url(#{gradient_id_1})" d="M0 0h{SIZE}v{SIZE//2}H0z" />
|
||||
<path fill="url(#{gradient_id_2})" d="M0 {SIZE//2}h{SIZE}v{SIZE//2}H0z" />
|
||||
</g>
|
||||
</svg>'''
|
||||
|
||||
return svg
|
||||
|
||||
def _generate_ring_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
|
||||
"""Generate ring variant avatar"""
|
||||
SIZE = 90 # Intrinsic SVG size
|
||||
COLORS_NEEDED = 9 # Number of colors used in the SVG paths
|
||||
|
||||
def generate_colors_ring(name: str, current_colors: List[str]):
|
||||
num_from_name = self._hash_code(name)
|
||||
# Use current_colors if provided, otherwise fallback to instance default, then class default
|
||||
final_colors = current_colors if current_colors else self.colors
|
||||
# Ensure we have enough colors, repeat if necessary
|
||||
if len(final_colors) < COLORS_NEEDED:
|
||||
final_colors = (final_colors * (COLORS_NEEDED // len(final_colors) + 1))[:COLORS_NEEDED]
|
||||
|
||||
range_val = len(final_colors)
|
||||
|
||||
ring_palette = []
|
||||
for i in range(COLORS_NEEDED):
|
||||
ring_palette.append(self._get_random_color(num_from_name + i, final_colors, range_val))
|
||||
return ring_palette
|
||||
|
||||
ring_colors = generate_colors_ring(name, colors)
|
||||
mask_id = f"mask_{self._hash_code(name)}"
|
||||
|
||||
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
|
||||
|
||||
if title:
|
||||
svg += f'<title>{name}</title>'
|
||||
|
||||
svg += f'''
|
||||
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
|
||||
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
|
||||
</mask>
|
||||
<g mask="url(#{mask_id})">
|
||||
<path d="M0 0h{SIZE}v{SIZE//2}H0z" fill="{ring_colors[0]}" />
|
||||
<path d="M0 {SIZE//2}h{SIZE}v{SIZE//2}H0z" fill="{ring_colors[1]}" />
|
||||
<path d="M{SIZE-7} {SIZE//2}a{SIZE//2-2} {SIZE//2-2} 0 00-{SIZE-14} 0h{SIZE-14}z" fill="{ring_colors[2]}" />
|
||||
<path d="M{SIZE-7} {SIZE//2}a{SIZE//2-2} {SIZE//2-2} 0 01-{SIZE-14} 0h{SIZE-14}z" fill="{ring_colors[3]}" />
|
||||
<path d="M{SIZE-13} {SIZE//2}a{SIZE//2-8} {SIZE//2-8} 0 10-{SIZE-26} 0h{SIZE-26}z" fill="{ring_colors[4]}" />
|
||||
<path d="M{SIZE-13} {SIZE//2}a{SIZE//2-8} {SIZE//2-8} 0 11-{SIZE-26} 0h{SIZE-26}z" fill="{ring_colors[5]}" />
|
||||
<path d="M{SIZE-19} {SIZE//2}a{SIZE//2-14} {SIZE//2-14} 0 00-{SIZE-38} 0h{SIZE-38}z" fill="{ring_colors[6]}" />
|
||||
<path d="M{SIZE-19} {SIZE//2}a{SIZE//2-14} {SIZE//2-14} 0 01-{SIZE-38} 0h{SIZE-38}z" fill="{ring_colors[7]}" />
|
||||
<circle cx="{SIZE//2}" cy="{SIZE//2}" r="{SIZE//2-22}" fill="{ring_colors[8]}" />
|
||||
</g>
|
||||
</svg>'''
|
||||
|
||||
return svg
|
||||
|
||||
def _generate_bauhaus_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str:
|
||||
"""Generate bauhaus variant avatar"""
|
||||
ELEMENTS = 4 # Number of geometric elements
|
||||
SIZE = 80 # Intrinsic SVG size
|
||||
|
||||
def generate_colors_bauhaus(name: str, current_colors: List[str]):
|
||||
num_from_name = self._hash_code(name)
|
||||
# Use current_colors if provided, otherwise fallback to instance default, then class default
|
||||
final_colors = current_colors if current_colors else self.colors
|
||||
range_val = len(final_colors)
|
||||
|
||||
properties = []
|
||||
for i in range(ELEMENTS):
|
||||
properties.append({
|
||||
'color': self._get_random_color(num_from_name + i, final_colors, range_val),
|
||||
'translateX': self._get_unit(num_from_name * (i + 1), SIZE // 2 - (i + 17), 1),
|
||||
'translateY': self._get_unit(num_from_name * (i + 1), SIZE // 2 - (i + 17), 2),
|
||||
'rotate': self._get_unit(num_from_name * (i + 1), 360),
|
||||
'isSquare': self._get_boolean(num_from_name, 2)
|
||||
})
|
||||
return properties
|
||||
|
||||
properties = generate_colors_bauhaus(name, colors)
|
||||
mask_id = f"mask_{self._hash_code(name)}"
|
||||
|
||||
svg = f'''<svg viewBox="0 0 {SIZE} {SIZE}" fill="none" role="img" xmlns="http://www.w3.org/2000/svg" width="{size}" height="{size}">'''
|
||||
|
||||
if title:
|
||||
svg += f'<title>{name}</title>'
|
||||
|
||||
svg += f'''
|
||||
<mask id="{mask_id}" maskUnits="userSpaceOnUse" x="0" y="0" width="{SIZE}" height="{SIZE}">
|
||||
<rect width="{SIZE}" height="{SIZE}" {"" if square else f'rx="{SIZE * 2}"'} fill="#FFFFFF" />
|
||||
</mask>
|
||||
<g mask="url(#{mask_id})">
|
||||
<rect width="{SIZE}" height="{SIZE}" fill="{properties[0]['color']}" />
|
||||
<rect
|
||||
x="{(SIZE - 60) // 2}"
|
||||
y="{(SIZE - 20) // 2}"
|
||||
width="{SIZE}"
|
||||
height="{SIZE if properties[1]['isSquare'] else SIZE // 8}"
|
||||
fill="{properties[1]['color']}"
|
||||
transform="translate({properties[1]['translateX']} {properties[1]['translateY']}) rotate({properties[1]['rotate']} {SIZE // 2} {SIZE // 2})"
|
||||
/>
|
||||
<circle
|
||||
cx="{SIZE // 2}"
|
||||
cy="{SIZE // 2}"
|
||||
fill="{properties[2]['color']}"
|
||||
r="{SIZE // 5}"
|
||||
transform="translate({properties[2]['translateX']} {properties[2]['translateY']})"
|
||||
/>
|
||||
<line
|
||||
x1="0"
|
||||
y1="{SIZE // 2}"
|
||||
x2="{SIZE}"
|
||||
y2="{SIZE // 2}"
|
||||
strokeWidth="2"
|
||||
stroke="{properties[3]['color']}"
|
||||
transform="translate({properties[3]['translateX']} {properties[3]['translateY']}) rotate({properties[3]['rotate']} {SIZE // 2} {SIZE // 2})"
|
||||
/>
|
||||
</g>
|
||||
</svg>'''
|
||||
|
||||
return svg
|
||||
|
||||
def generate_avatar(self,
|
||||
name: str,
|
||||
variant: Optional[str] = None,
|
||||
colors: Optional[List[str]] = None, # Allow passing colors as list of strings
|
||||
colors_param: Optional[str] = None, # Allow passing colors as a string parameter
|
||||
size: Optional[int] = None,
|
||||
square: bool = False,
|
||||
title: bool = True) -> str:
|
||||
"""
|
||||
Generate an SVG avatar.
|
||||
|
||||
Args:
|
||||
name: The name to generate the avatar for.
|
||||
variant: The avatar variant (e.g., 'marble', 'beam'). Uses instance default if None.
|
||||
colors: A list of hex color strings. Overrides colors_param if provided.
|
||||
colors_param: A string representation of colors (e.g., "['#FF0000', '#00FF00']" or "#FF0000,#00FF00").
|
||||
Used if 'colors' list is not provided.
|
||||
size: The desired size of the avatar in pixels. Uses instance default if None.
|
||||
square: If True, generates a square avatar. Otherwise, a circular one.
|
||||
title: If True, includes a <title> tag with the name in the SVG.
|
||||
|
||||
Returns:
|
||||
An SVG string representing the avatar.
|
||||
"""
|
||||
current_variant = variant if variant and variant in self.VALID_VARIANTS else self.variant
|
||||
current_size = size if size else self.size
|
||||
|
||||
# Determine colors: prioritize direct list, then string param, then instance default
|
||||
if colors:
|
||||
# Validate and normalize if a list is directly passed
|
||||
final_colors = self._validate_color_list(colors)
|
||||
elif colors_param:
|
||||
final_colors = self.normalize_colors(colors_param)
|
||||
else:
|
||||
final_colors = self.colors # Use instance default colors
|
||||
|
||||
if not final_colors: # Ensure there's always a fallback
|
||||
final_colors = self.DEFAULT_COLORS
|
||||
|
||||
|
||||
generator_func = self.AVATAR_GENERATORS.get(current_variant)
|
||||
|
||||
if generator_func:
|
||||
return generator_func(name, final_colors, current_size, square, title)
|
||||
else:
|
||||
# Fallback to default variant if the chosen one is somehow invalid after checks
|
||||
default_gen_func = self.AVATAR_GENERATORS.get(self.DEFAULT_VARIANT)
|
||||
return default_gen_func(name, final_colors, current_size, square, title)
|
||||
890
trunk/python/database.py
Normal file
890
trunk/python/database.py
Normal file
|
|
@ -0,0 +1,890 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Database module for JWT authentication system
|
||||
Provides SQLite database initialization and management for user authentication
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from contextlib import contextmanager
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
# Import SRS logger
|
||||
from srs_logger import get_logger
|
||||
|
||||
class Database:
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, db_path: str = "./objs/srs_database.db"):
|
||||
"""Singleton pattern to ensure only one database instance"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, db_path: str = "./objs/srs_database.db"):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self.db_path = db_path
|
||||
self.logger = get_logger()
|
||||
self._connection_lock = threading.Lock()
|
||||
|
||||
# Initialize database
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize database and create tables if they don't exist"""
|
||||
try:
|
||||
# Create objs directory if it doesn't exist
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
self.logger.info(f"Created database directory: {db_dir}")
|
||||
|
||||
# Check if database exists
|
||||
db_exists = os.path.exists(self.db_path)
|
||||
if not db_exists:
|
||||
self.logger.warn(f"Database file not found, creating new database: {self.db_path}")
|
||||
else:
|
||||
self.logger.info(f"Using existing database: {self.db_path}")
|
||||
|
||||
# Create tables
|
||||
self._create_tables()
|
||||
|
||||
if not db_exists:
|
||||
self.logger.info("Database initialized successfully with all tables created")
|
||||
else:
|
||||
self.logger.info("Database connection established and tables verified")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to initialize database: {e}")
|
||||
raise
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create all required tables"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create users table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE,
|
||||
hashed_password TEXT NOT NULL,
|
||||
salt TEXT NOT NULL,
|
||||
user_group TEXT NOT NULL DEFAULT '["user"]',
|
||||
last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
is_activated BOOLEAN NOT NULL DEFAULT 1
|
||||
)
|
||||
''')
|
||||
|
||||
# Create auth_session table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS auth_session (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
hashed_authkey TEXT NOT NULL,
|
||||
salt TEXT NOT NULL,
|
||||
expire_time TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# Create refresh_session table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS refresh_session (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
hashed_refreshkey TEXT NOT NULL,
|
||||
salt TEXT NOT NULL,
|
||||
expire_time TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# Create streams table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS streams (
|
||||
stream_id TEXT PRIMARY KEY,
|
||||
stream_code TEXT UNIQUE NOT NULL,
|
||||
streamer_id TEXT NOT NULL,
|
||||
stream_title TEXT NOT NULL,
|
||||
stream_description TEXT,
|
||||
stream_tags TEXT DEFAULT '[]',
|
||||
stream_visibility TEXT NOT NULL DEFAULT 'public',
|
||||
quality_info TEXT,
|
||||
active_time TIMESTAMP,
|
||||
stream_status TEXT NOT NULL DEFAULT 'planned',
|
||||
FOREIGN KEY (streamer_id) REFERENCES users (user_id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# Create system_stats table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS system_stats (
|
||||
timestamp TIMESTAMP PRIMARY KEY,
|
||||
srs_uptime REAL NOT NULL,
|
||||
srs_cpu_percent REAL NOT NULL,
|
||||
srs_memory_percent REAL NOT NULL,
|
||||
srs_recv_KBps REAL NOT NULL,
|
||||
srs_send_KBps REAL NOT NULL,
|
||||
disk_read_KBps REAL NOT NULL,
|
||||
disk_write_KBps REAL NOT NULL,
|
||||
os_uptime REAL NOT NULL,
|
||||
os_cpu_percent REAL NOT NULL,
|
||||
os_memory_percent REAL NOT NULL
|
||||
)
|
||||
''')
|
||||
|
||||
# Create indexes for better performance
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_user_id ON auth_session(user_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_expire ON auth_session(expire_time)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_user_id ON refresh_session(user_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_expire ON refresh_session(expire_time)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_system_stats_timestamp ON system_stats(timestamp)')
|
||||
|
||||
conn.commit()
|
||||
self.logger.debug("All database tables and indexes created/verified successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to create database tables: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Get a database connection with proper error handling and cleanup"""
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path, timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row # Enable column access by name
|
||||
conn.execute('PRAGMA foreign_keys = ON') # Enable foreign key constraints
|
||||
# Configure datetime adapter for Python 3.12+ compatibility
|
||||
sqlite3.register_adapter(datetime, lambda dt: dt.isoformat())
|
||||
sqlite3.register_converter("TIMESTAMP", lambda b: datetime.fromisoformat(b.decode()))
|
||||
yield conn
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
self.logger.error(f"Database connection error: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
# ============================================================================
|
||||
# Users management functions
|
||||
# ============================================================================
|
||||
|
||||
def cleanup_expired_sessions(self):
|
||||
"""Remove expired auth and refresh sessions, and sessions for deactivated users"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
current_time = datetime.now()
|
||||
|
||||
# Clean up expired auth sessions
|
||||
cursor.execute('''
|
||||
DELETE FROM auth_session
|
||||
WHERE expire_time < ?
|
||||
''', (current_time,))
|
||||
auth_expired_deleted = cursor.rowcount
|
||||
|
||||
# Clean up expired refresh sessions
|
||||
cursor.execute('''
|
||||
DELETE FROM refresh_session
|
||||
WHERE expire_time < ?
|
||||
''', (current_time,))
|
||||
refresh_expired_deleted = cursor.rowcount
|
||||
|
||||
# Clean up sessions for deactivated users
|
||||
cursor.execute('''
|
||||
DELETE FROM auth_session
|
||||
WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0)
|
||||
''')
|
||||
auth_deactivated_deleted = cursor.rowcount
|
||||
|
||||
cursor.execute('''
|
||||
DELETE FROM refresh_session
|
||||
WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0)
|
||||
''')
|
||||
refresh_deactivated_deleted = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
|
||||
total_auth_deleted = auth_expired_deleted + auth_deactivated_deleted
|
||||
total_refresh_deleted = refresh_expired_deleted + refresh_deactivated_deleted
|
||||
|
||||
if total_auth_deleted > 0 or total_refresh_deleted > 0:
|
||||
self.logger.info(f"Cleaned up sessions: {total_auth_deleted} auth ({auth_expired_deleted} expired, {auth_deactivated_deleted} deactivated), {total_refresh_deleted} refresh ({refresh_expired_deleted} expired, {refresh_deactivated_deleted} deactivated)")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to cleanup expired sessions: {e}")
|
||||
|
||||
def get_user(self, username_or_email: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by username or email"""
|
||||
if not username_or_email:
|
||||
self.logger.warn("get_user called with None or empty value")
|
||||
return None
|
||||
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT user_id, username, name, email, hashed_password, salt,
|
||||
user_group, last_active, is_activated
|
||||
FROM users WHERE username = ? OR email = ?
|
||||
''', (username_or_email, username_or_email))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_data = dict(row)
|
||||
user_data['user_group'] = json.loads(user_data['user_group'])
|
||||
return user_data
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get user by username/email '{username_or_email}': {e}")
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by user_id"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT user_id, username, name, email, hashed_password, salt,
|
||||
user_group, last_active, is_activated
|
||||
FROM users WHERE user_id = ?
|
||||
''', (user_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_data = dict(row)
|
||||
user_data['user_group'] = json.loads(user_data['user_group'])
|
||||
return user_data
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get user by user_id {user_id}: {e}")
|
||||
return None
|
||||
|
||||
def create_user(self, username: str, name: str, email: Optional[str],
|
||||
hashed_password: str, salt: str, user_group: List[str] = None,
|
||||
is_activated: bool = True) -> Optional[Dict[str, Any]]:
|
||||
"""Create a new user and return user data"""
|
||||
try:
|
||||
# Validate username format
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||
self.logger.warn(f"Invalid username format: '{username}'. Only a-z, A-Z, 0-9, '_', '-' are allowed")
|
||||
return None
|
||||
|
||||
# Validate email format if provided
|
||||
if email and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email):
|
||||
self.logger.warn(f"Invalid email format: '{email}'")
|
||||
return None
|
||||
|
||||
# Ensure user_group is not empty
|
||||
if user_group is None or len(user_group) == 0:
|
||||
user_group = ["user"]
|
||||
|
||||
# Generate unique user_id (UUID collisions are extremely rare, but we'll still check)
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
user_group_json = json.dumps(user_group)
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO users (user_id, username, name, email, hashed_password, salt, user_group, is_activated)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (user_id, username, name, email, hashed_password, salt, user_group_json, is_activated))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self.logger.info(f"Created new user: {username} (ID: {user_id}, activated: {is_activated})")
|
||||
|
||||
# Return the created user data directly to avoid deadlock
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'name': name,
|
||||
'email': email,
|
||||
'hashed_password': hashed_password,
|
||||
'salt': salt,
|
||||
'user_group': user_group, # Already parsed list
|
||||
'last_active': datetime.now().isoformat(),
|
||||
'is_activated': is_activated
|
||||
}
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
if 'username' in str(e):
|
||||
self.logger.warn(f"Username '{username}' already exists")
|
||||
elif 'email' in str(e):
|
||||
self.logger.warn(f"Email '{email}' already exists")
|
||||
else:
|
||||
self.logger.warn(f"User creation failed due to constraint: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to create user '{username}': {e}")
|
||||
return None
|
||||
|
||||
def update_user_last_active(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Update user's last active timestamp and return updated user data"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
current_time = datetime.now()
|
||||
cursor.execute('''
|
||||
UPDATE users
|
||||
SET last_active = ?
|
||||
WHERE user_id = ?
|
||||
''', (current_time, user_id))
|
||||
conn.commit()
|
||||
|
||||
self.logger.debug(f"Updated last_active for user_id: {user_id}")
|
||||
# Get the updated user data in the same connection to avoid deadlock
|
||||
cursor.execute('''
|
||||
SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
''', (user_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_data = dict(row)
|
||||
user_data['user_group'] = json.loads(user_data['user_group'])
|
||||
return user_data
|
||||
else:
|
||||
self.logger.warn(f"No user found with user_id: {user_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to update last_active for user_id {user_id}: {e}")
|
||||
return None
|
||||
|
||||
def update_user_activation(self, user_id: str, activation: bool) -> Optional[Dict[str, Any]]:
|
||||
"""Update user's activation status and cleanup sessions if deactivated, return updated user data"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE users
|
||||
SET is_activated = ?
|
||||
WHERE user_id = ?
|
||||
''', (activation, user_id))
|
||||
conn.commit()
|
||||
|
||||
self.logger.info(f"Updated activation status for user_id {user_id}: {activation}")
|
||||
|
||||
# Get the updated user data in the same connection to avoid deadlock
|
||||
cursor.execute('''
|
||||
SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
''', (user_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
# Clean up sessions after updating activation status (outside the transaction)
|
||||
if not activation: self.cleanup_expired_sessions()
|
||||
|
||||
user_data = dict(row)
|
||||
user_data['user_group'] = json.loads(user_data['user_group'])
|
||||
return user_data
|
||||
else:
|
||||
self.logger.warn(f"No user found with user_id: {user_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to update activation for user_id {user_id}: {e}")
|
||||
return None
|
||||
|
||||
def update_user_password(self, user_id: str, hashed_password: str, salt: str) -> Optional[Dict[str, Any]]:
|
||||
"""Update user's password and salt, return updated user data"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE users
|
||||
SET hashed_password = ?, salt = ?
|
||||
WHERE user_id = ?
|
||||
''', (hashed_password, salt, user_id))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount == 0:
|
||||
self.logger.warn(f"No user found with user_id: {user_id}")
|
||||
return None
|
||||
|
||||
self.logger.info(f"Updated password for user_id: {user_id}")
|
||||
|
||||
# Get the updated user data in the same connection to avoid deadlock
|
||||
cursor.execute('''
|
||||
SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
''', (user_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_data = dict(row)
|
||||
user_data['user_group'] = json.loads(user_data['user_group'])
|
||||
return user_data
|
||||
else:
|
||||
self.logger.warn(f"No user found with user_id: {user_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to update password for user_id {user_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_user(self, user_id: str) -> bool:
|
||||
"""Delete user and all associated data from all tables"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if user exists first
|
||||
cursor.execute('SELECT username FROM users WHERE user_id = ?', (user_id,))
|
||||
user_row = cursor.fetchone()
|
||||
if not user_row:
|
||||
self.logger.warn(f"No user found with user_id: {user_id}")
|
||||
return False
|
||||
|
||||
username = user_row[0]
|
||||
|
||||
# Delete auth sessions (foreign key constraints will handle this automatically, but we'll be explicit)
|
||||
cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,))
|
||||
auth_deleted = cursor.rowcount
|
||||
|
||||
# Delete refresh sessions
|
||||
cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,))
|
||||
refresh_deleted = cursor.rowcount
|
||||
|
||||
# Delete user
|
||||
cursor.execute('DELETE FROM users WHERE user_id = ?', (user_id,))
|
||||
user_deleted = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
|
||||
if user_deleted > 0:
|
||||
self.logger.info(f"Deleted user '{username}' (ID: {user_id}) and all associated data: {auth_deleted} auth sessions, {refresh_deleted} refresh sessions")
|
||||
return True
|
||||
else:
|
||||
self.logger.warn(f"Failed to delete user with user_id: {user_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to delete user with user_id {user_id}: {e}")
|
||||
return False
|
||||
|
||||
# ============================================================================
|
||||
# Session management functions
|
||||
# ============================================================================
|
||||
|
||||
def create_auth_session(self, user_id: str, hashed_authkey: str, salt: str,
|
||||
expire_time: datetime) -> Optional[Dict[str, Any]]:
|
||||
"""Create an auth session and return session data"""
|
||||
try:
|
||||
# Generate unique session_id (UUID collisions are extremely rare)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO auth_session (session_id, user_id, hashed_authkey, salt, expire_time)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (session_id, user_id, hashed_authkey, salt, expire_time))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self.logger.debug(f"Created auth session for user_id {user_id}, session_id: {session_id}")
|
||||
|
||||
# Return the created session data
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'user_id': user_id,
|
||||
'hashed_authkey': hashed_authkey,
|
||||
'salt': salt,
|
||||
'expire_time': expire_time.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to create auth session for user_id {user_id}: {e}")
|
||||
return None
|
||||
|
||||
def create_refresh_session(self, user_id: str, hashed_refreshkey: str, salt: str,
|
||||
expire_time: datetime) -> Optional[Dict[str, Any]]:
|
||||
"""Create a refresh session and return session data"""
|
||||
try:
|
||||
# Generate unique session_id (UUID collisions are extremely rare)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO refresh_session (session_id, user_id, hashed_refreshkey, salt, expire_time)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (session_id, user_id, hashed_refreshkey, salt, expire_time))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self.logger.debug(f"Created refresh session for user_id {user_id}, session_id: {session_id}")
|
||||
|
||||
# Return the created session data
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'user_id': user_id,
|
||||
'hashed_refreshkey': hashed_refreshkey,
|
||||
'salt': salt,
|
||||
'expire_time': expire_time.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to create refresh session for user_id {user_id}: {e}")
|
||||
return None
|
||||
|
||||
def get_auth_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get valid auth session by session_id"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT session_id, user_id, hashed_authkey, salt, expire_time
|
||||
FROM auth_session
|
||||
WHERE session_id = ? AND expire_time > ?
|
||||
''', (session_id, datetime.now()))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get auth session for session_id {session_id}: {e}")
|
||||
return None
|
||||
|
||||
def get_refresh_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get valid refresh session by session_id"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT session_id, user_id, hashed_refreshkey, salt, expire_time
|
||||
FROM refresh_session
|
||||
WHERE session_id = ? AND expire_time > ?
|
||||
''', (session_id, datetime.now()))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get refresh session for session_id {session_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_auth_session(self, session_id: str) -> bool:
|
||||
"""Delete an auth session"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM auth_session WHERE session_id = ?', (session_id,))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount > 0:
|
||||
self.logger.debug(f"Deleted auth session_id: {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to delete auth session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def delete_refresh_session(self, session_id: str) -> bool:
|
||||
"""Delete a refresh session"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM refresh_session WHERE session_id = ?', (session_id,))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount > 0:
|
||||
self.logger.debug(f"Deleted refresh session_id: {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to delete refresh session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def delete_user_sessions(self, user_id: str) -> bool:
|
||||
"""Delete all sessions for a user (logout)"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
# Delete auth sessions
|
||||
cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,))
|
||||
auth_deleted = cursor.rowcount
|
||||
|
||||
# Delete refresh sessions
|
||||
cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,))
|
||||
refresh_deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
self.logger.info(f"Deleted all sessions for user_id {user_id}: {auth_deleted} auth, {refresh_deleted} refresh")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to delete sessions for user_id {user_id}: {e}")
|
||||
return False
|
||||
|
||||
# ============================================================================
|
||||
# System statistics functions
|
||||
# ============================================================================
|
||||
|
||||
def insert_system_stats(self,
|
||||
srs_uptime: float, srs_cpu_percent: float,
|
||||
srs_memory_percent: float, srs_recv_KBps: float,
|
||||
srs_send_KBps: float, disk_read_KBps: float,
|
||||
disk_write_KBps: float, os_uptime: float,
|
||||
os_cpu_percent: float, os_memory_percent: float) -> bool:
|
||||
"""Insert system statistics into the database"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
timestamp = datetime.now()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO system_stats (timestamp,
|
||||
srs_uptime, srs_cpu_percent, srs_memory_percent,
|
||||
srs_recv_KBps, srs_send_KBps,
|
||||
disk_read_KBps, disk_write_KBps,
|
||||
os_uptime, os_cpu_percent, os_memory_percent)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (timestamp,
|
||||
srs_uptime, srs_cpu_percent, srs_memory_percent,
|
||||
srs_recv_KBps, srs_send_KBps,
|
||||
disk_read_KBps, disk_write_KBps,
|
||||
os_uptime, os_cpu_percent, os_memory_percent))
|
||||
|
||||
conn.commit()
|
||||
self.logger.debug(f"Inserted system stats at {timestamp.isoformat()}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to insert system stats: {e}")
|
||||
return False
|
||||
|
||||
def get_system_stats(self, time_delta: int) -> List[Dict[str, Any]]:
|
||||
"""Get system statistics from the database"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT * FROM system_stats WHERE timestamp >= ?
|
||||
''', (datetime.now() - timedelta(seconds=time_delta),))
|
||||
rows = cursor.fetchall()
|
||||
columns = [column[0] for column in cursor.description]
|
||||
return [dict(zip(columns, row)) for row in rows]
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get system stats: {e}")
|
||||
return []
|
||||
|
||||
def delete_expired_system_stats(self) -> bool:
|
||||
"""Delete expired system statistics older than 20 minutes"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
DELETE FROM system_stats WHERE timestamp < ?
|
||||
''', (datetime.now() - timedelta(minutes=20),))
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount > 0:
|
||||
self.logger.info(f"Deleted {cursor.rowcount} expired system stats")
|
||||
return True
|
||||
else:
|
||||
self.logger.debug("No expired system stats to delete")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to delete expired system stats: {e}")
|
||||
return False
|
||||
|
||||
# ============================================================================
|
||||
# Streams management functions
|
||||
# ============================================================================
|
||||
|
||||
def create_stream(self, streamer_id: str, stream_title: str, stream_description: Optional[str] = None, stream_tags: Optional[List[str]] = [], stream_visibility: str = 'public') -> Dict[str, Any]:
|
||||
"""Create a new stream and return stream data"""
|
||||
try:
|
||||
# Validate streamer_id exists
|
||||
if not self.get_user_by_id(streamer_id):
|
||||
self.logger.warn(f"Streamer with user_id {streamer_id} does not exist")
|
||||
return None
|
||||
|
||||
# Generate unique stream_id (UUID collisions are extremely rare)
|
||||
stream_id = str(uuid.uuid4())
|
||||
# Generate stream_code in form of xxx-xxxx (only include letters and numbers)
|
||||
def random_code():
|
||||
part1 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=3))
|
||||
part2 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
|
||||
return f"{part1}-{part2}"
|
||||
stream_code = random_code()
|
||||
|
||||
# Convert tags to JSON string
|
||||
stream_tags_json = json.dumps(stream_tags)
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO streams (stream_id, stream_code, streamer_id, stream_title,
|
||||
stream_description, stream_tags,
|
||||
stream_visibility)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (stream_id, stream_code, streamer_id, stream_title,
|
||||
stream_description, stream_tags_json,
|
||||
stream_visibility))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self.logger.info(f"Created new stream: {stream_title} (ID: {stream_id}, Streamer ID: {streamer_id})")
|
||||
|
||||
return {
|
||||
'stream_id': stream_id,
|
||||
'stream_code': stream_code,
|
||||
'streamer_id': streamer_id,
|
||||
'stream_title': stream_title,
|
||||
'stream_description': stream_description,
|
||||
'stream_tags': stream_tags,
|
||||
'stream_visibility': stream_visibility,
|
||||
'stream_status': 'planned', # Default status
|
||||
}
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self.logger.warn(f"Stream creation failed due to constraint: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to create stream: {e}")
|
||||
return None
|
||||
|
||||
def get_stream_by_vis(self, stream_visibility: str = 'public') -> Optional[List[Dict[str, Any]]]:
|
||||
"""Get streams by visibility"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT * FROM streams
|
||||
WHERE stream_visibility = ? AND stream_status = 'streaming'
|
||||
''', (stream_visibility,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
columns = [column[0] for column in cursor.description]
|
||||
return [dict(zip(columns, row)) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get streams by visibility '{stream_visibility}': {e}")
|
||||
return None
|
||||
|
||||
# ============================================================================
|
||||
# Default database statistics functions
|
||||
# ============================================================================
|
||||
|
||||
def get_database_stats(self) -> Dict[str, int]:
|
||||
"""Get database statistics"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {}
|
||||
|
||||
# Count users
|
||||
cursor.execute('SELECT COUNT(*) FROM users')
|
||||
stats['users'] = cursor.fetchone()[0]
|
||||
|
||||
# Count active auth sessions
|
||||
cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time > ?', (datetime.now(),))
|
||||
stats['active_auth_sessions'] = cursor.fetchone()[0]
|
||||
|
||||
# Count active refresh sessions
|
||||
cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time > ?', (datetime.now(),))
|
||||
stats['active_refresh_sessions'] = cursor.fetchone()[0]
|
||||
|
||||
# Count expired auth sessions
|
||||
cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time <= ?', (datetime.now(),))
|
||||
stats['expired_auth_sessions'] = cursor.fetchone()[0]
|
||||
|
||||
# Count expired refresh sessions
|
||||
cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time <= ?', (datetime.now(),))
|
||||
stats['expired_refresh_sessions'] = cursor.fetchone()[0]
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to get database statistics: {e}")
|
||||
return {}
|
||||
|
||||
# Global database instance
|
||||
_global_db: Optional[Database] = None
|
||||
|
||||
def get_database(db_path: str = "./objs/srs_database.db") -> Database:
|
||||
"""
|
||||
Get the global database instance
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
|
||||
Returns:
|
||||
Database instance
|
||||
"""
|
||||
global _global_db
|
||||
|
||||
if _global_db is None:
|
||||
# run_comprehensive_tests()
|
||||
_global_db = Database(db_path)
|
||||
|
||||
return _global_db
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Comprehensive test suite for database functionality
|
||||
import argparse
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from srs_logger import init_logger
|
||||
|
||||
logger = init_logger()
|
||||
|
||||
# Main execution
|
||||
parser = argparse.ArgumentParser(description="Database Module Test")
|
||||
parser.add_argument("--db-path", default="./objs/srs_database.db", help="Database file path")
|
||||
args = parser.parse_args()
|
||||
|
||||
db = get_database(args.db_path)
|
||||
|
||||
# Show database statistics
|
||||
stats = db.get_database_stats()
|
||||
logger.info(f"Database statistics: {stats}")
|
||||
|
||||
# Clean up expired sessions
|
||||
db.cleanup_expired_sessions()
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
SRS Python Addon - Simple HTTP Server
|
||||
This addon demonstrates how to create a simple HTTP server as an SRS addon.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
import logging
|
||||
import argparse
|
||||
import threading
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('srs_http_addon')
|
||||
|
||||
class SRSAddonHandler(BaseHTTPRequestHandler):
|
||||
"""Simple HTTP request handler for SRS addon."""
|
||||
|
||||
def do_GET(self):
|
||||
"""Handle GET requests."""
|
||||
if self.path == '/':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'text/html')
|
||||
self.end_headers()
|
||||
response = '''
|
||||
<html>
|
||||
<body>
|
||||
<h1>SRS Python Addon - HTTP Server</h1>
|
||||
<p>This is a simple HTTP server running as an SRS addon.</p>
|
||||
<p>Status: Running</p>
|
||||
<p>Time: {}</p>
|
||||
</body>
|
||||
</html>
|
||||
'''.format(time.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
self.wfile.write(response.encode())
|
||||
elif self.path == '/status':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
response = '{"status": "running", "time": "%s"}' % time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
self.wfile.write(response.encode())
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Override to use our logger."""
|
||||
logger.info("%s - %s" % (self.address_string(), format % args))
|
||||
|
||||
class SRSHTTPAddon:
|
||||
"""SRS HTTP Addon main class."""
|
||||
|
||||
def __init__(self, port=8888):
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.running = False
|
||||
self.server_thread = None
|
||||
|
||||
def start(self):
|
||||
"""Start the HTTP server."""
|
||||
try:
|
||||
self.server = HTTPServer(('', self.port), SRSAddonHandler)
|
||||
self.running = True
|
||||
|
||||
# Start server in a separate thread
|
||||
self.server_thread = threading.Thread(target=self._run_server)
|
||||
self.server_thread.daemon = True
|
||||
self.server_thread.start()
|
||||
|
||||
logger.info(f"SRS HTTP addon started on port {self.port}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start HTTP addon: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the HTTP server."""
|
||||
if self.server and self.running:
|
||||
self.running = False
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
if self.server_thread:
|
||||
self.server_thread.join(timeout=5)
|
||||
logger.info("SRS HTTP addon stopped")
|
||||
|
||||
def _run_server(self):
|
||||
"""Run the HTTP server."""
|
||||
try:
|
||||
self.server.serve_forever()
|
||||
except Exception as e:
|
||||
if self.running:
|
||||
logger.error(f"HTTP server error: {e}")
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals."""
|
||||
logger.info(f"Received signal {signum}, shutting down...")
|
||||
if 'addon' in globals():
|
||||
addon.stop()
|
||||
sys.exit(0)
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
parser = argparse.ArgumentParser(description='SRS HTTP Addon')
|
||||
parser.add_argument('--port', type=int, default=8888, help='HTTP server port')
|
||||
parser.add_argument('--verbose', action='store_true', help='Enable verbose logging')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Create and start the addon
|
||||
global addon
|
||||
addon = SRSHTTPAddon(args.port)
|
||||
|
||||
if addon.start():
|
||||
logger.info("SRS HTTP addon is running, press Ctrl+C to stop")
|
||||
try:
|
||||
while addon.running:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
addon.stop()
|
||||
else:
|
||||
logger.error("Failed to start SRS HTTP addon")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
845
trunk/python/httpbackend_server.py
Normal file
845
trunk/python/httpbackend_server.py
Normal file
|
|
@ -0,0 +1,845 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FastAPI Authentication API Server
|
||||
Provides secure authentication endpoints with token-based authentication
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, Cookie, Response, Query
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
import uvicorn
|
||||
|
||||
# Import our database and logger
|
||||
from database import get_database, Database
|
||||
from srs_logger import get_logger
|
||||
|
||||
# Import Function Class
|
||||
from security_module import AuthManager
|
||||
from system_stats_module import SystemStatsManager
|
||||
from avatar_module import AvatarGenerator
|
||||
|
||||
# ============================================================================
|
||||
# Pydantic Models for Request/Response
|
||||
# ============================================================================
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""用户注册请求模型"""
|
||||
username: str = Field(..., min_length=1, max_length=50,
|
||||
description="用户名,只能包含字母、数字、下划线和连字符")
|
||||
name: str = Field(..., min_length=1, max_length=100, description="用户真实姓名")
|
||||
email: Optional[EmailStr] = Field(None, description="用户邮箱(可选)")
|
||||
password: str = Field(..., min_length=6, max_length=128, description="用户密码")
|
||||
|
||||
@field_validator('username')
|
||||
def validate_username(cls, v):
|
||||
if not v.replace('_', '').replace('-', '').isalnum():
|
||||
raise ValueError('用户名只能包含字母、数字、下划线和连字符')
|
||||
return v
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""用户登录请求模型"""
|
||||
username_or_email: str = Field(..., min_length=1, description="用户名或邮箱")
|
||||
password: str = Field(..., min_length=1, description="用户密码")
|
||||
remember_me: Optional[bool] = Field(True, description="是否记住登录状态(默认为False)")
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""令牌刷新请求模型"""
|
||||
auth_key_session_id: Optional[str] = Field(None, description="认证令牌会话ID(可选,用于删除旧的认证会话)")
|
||||
|
||||
class UpdatePasswordRequest(BaseModel):
|
||||
"""更新密码请求模型"""
|
||||
original_password: str = Field(..., min_length=1, description="原密码")
|
||||
password: str = Field(..., min_length=6, max_length=128, description="新密码")
|
||||
|
||||
class UserIDRequest(BaseModel):
|
||||
"""用户ID请求模型"""
|
||||
user_id: str = Field(..., min_length=1, description="传入的用户ID")
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""认证响应模型"""
|
||||
success: bool = Field(..., description="操作是否成功")
|
||||
message: str = Field(..., description="响应消息")
|
||||
auth_key: Optional[str] = Field(None, description="认证令牌")
|
||||
auth_key_session_id: Optional[str] = Field(None, description="认证令牌会话ID")
|
||||
refresh_key: Optional[str] = Field(None, description="刷新令牌")
|
||||
refresh_key_session_id: Optional[str] = Field(None, description="刷新令牌会话ID")
|
||||
|
||||
class SimpleResponse(BaseModel):
|
||||
"""简单响应模型"""
|
||||
success: bool = Field(..., description="操作是否成功")
|
||||
message: str = Field(..., description="响应消息")
|
||||
|
||||
class SystemStatsResponse(BaseModel):
|
||||
"""系统统计响应模型"""
|
||||
system_stats: List[Dict[str, Any]] = Field(..., description="系统统计信息列表,每个元素包含时间戳和统计数据")
|
||||
|
||||
# ============================================================================
|
||||
# FastAPI Application Setup
|
||||
# ============================================================================
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用程序生命周期管理"""
|
||||
logger = get_logger()
|
||||
logger.info("Starting FastAPI Authentication Server...")
|
||||
|
||||
# 初始化数据库
|
||||
db = get_database()
|
||||
logger.info("Database initialized")
|
||||
|
||||
# 创建认证管理器
|
||||
auth_manager = AuthManager(db)
|
||||
system_stats_manager = SystemStatsManager(db)
|
||||
avatar_generator = AvatarGenerator()
|
||||
|
||||
app.state.auth_manager = auth_manager
|
||||
logger.info("Auth manager initialized")
|
||||
system_stats_manager.start_polling() # 启动系统统计轮询
|
||||
app.state.system_stats_manager = system_stats_manager
|
||||
logger.info("System stats manager initialized")
|
||||
app.state.avatar_generator = avatar_generator
|
||||
logger.info("Avatar generator initialized")
|
||||
app.state.db = db
|
||||
logger.info("Database connection established")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down FastAPI Authentication Server...")
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="认证API服务器",
|
||||
description="基于FastAPI的安全认证系统,提供用户注册、登录、令牌刷新和登出功能",
|
||||
version="1.0.0",
|
||||
docs_url="/api/docs",
|
||||
redoc_url="/api/redoc",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 在生产环境中应该限制允许的源
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 获取认证管理器
|
||||
def get_auth_manager() -> AuthManager:
|
||||
return app.state.auth_manager
|
||||
|
||||
def get_system_stats_manager() -> SystemStatsManager:
|
||||
return app.state.system_stats_manager
|
||||
|
||||
def get_avatar_generator() -> AvatarGenerator:
|
||||
return app.state.avatar_generator
|
||||
|
||||
def get_db() -> Database:
|
||||
return app.state.db
|
||||
|
||||
# HTTP Bearer认证方案
|
||||
security = HTTPBearer()
|
||||
|
||||
# 认证依赖项
|
||||
async def verify_auth_token(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)
|
||||
) -> Dict[str, Any]:
|
||||
"""验证Bearer令牌并返回认证会话信息"""
|
||||
try:
|
||||
# Bearer token格式: "auth_key:auth_key_session_id"
|
||||
token_parts = credentials.credentials.split(':')
|
||||
if len(token_parts) != 2:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的Bearer令牌格式,应为 'auth_key:auth_key_session_id'"
|
||||
)
|
||||
|
||||
auth_key, auth_key_session_id = token_parts
|
||||
|
||||
# 获取认证会话
|
||||
auth_manager.db.cleanup_expired_sessions() # 清理过期会话
|
||||
auth_session = auth_manager.db.get_auth_session(auth_key_session_id)
|
||||
if not auth_session:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="认证会话无效或已过期"
|
||||
)
|
||||
|
||||
# 验证认证令牌
|
||||
if not auth_manager.verify_token(auth_key, auth_session['salt'], auth_session['hashed_authkey']):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="认证令牌无效"
|
||||
)
|
||||
|
||||
# 获取用户信息以获取用户组
|
||||
user_data = auth_manager.db.get_user_by_id(auth_session['user_id'])
|
||||
auth_session['user_group'] = user_data.get('user_group', ['user'])
|
||||
|
||||
# 更新用户最后活跃时间
|
||||
auth_manager.db.update_user_last_active(auth_session['user_id'])
|
||||
|
||||
return auth_session
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Token verification error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="令牌验证失败"
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Cookie Configuration
|
||||
# ============================================================================
|
||||
|
||||
def set_auth_cookies(response: Response, tokens: Dict[str, str]) -> None:
|
||||
"""设置安全的认证cookie"""
|
||||
response.set_cookie(
|
||||
key="refresh_key",
|
||||
value=tokens['refresh_key'],
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
max_age=604800 # 7天
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="refresh_key_session_id",
|
||||
value=tokens['refresh_key_session_id'],
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
max_age=604800 # 7天
|
||||
)
|
||||
|
||||
def clear_auth_cookies(response: Response) -> None:
|
||||
"""清除认证cookie"""
|
||||
response.delete_cookie(key="refresh_key", httponly=True, secure=True, samesite="strict")
|
||||
response.delete_cookie(key="refresh_key_session_id", httponly=True, secure=True, samesite="strict")
|
||||
|
||||
# ============================================================================
|
||||
# Authentication API Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@app.post("/api/register", response_model=SimpleResponse,
|
||||
tags=["Authentication API Endpoints"], summary="用户注册", description="注册新用户账户")
|
||||
async def register(request: RegisterRequest, auth_manager: AuthManager = Depends(get_auth_manager)):
|
||||
"""
|
||||
用户注册端点
|
||||
|
||||
- **username**: 用户名(必需,只能包含字母、数字、下划线和连字符)
|
||||
- **name**: 用户真实姓名(必需)
|
||||
- **email**: 用户邮箱(可选)
|
||||
- **password**: 用户密码(必需,最少6个字符)
|
||||
|
||||
返回注册是否成功的布尔值
|
||||
"""
|
||||
try:
|
||||
success = auth_manager.register_user(
|
||||
username=request.username,
|
||||
name=request.name,
|
||||
email=request.email,
|
||||
password=request.password
|
||||
)
|
||||
|
||||
if success:
|
||||
return SimpleResponse(success=True, message="用户注册成功")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="用户注册失败,用户名或邮箱可能已存在"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Register endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.post("/api/login", response_model=AuthResponse,
|
||||
tags=["Authentication API Endpoints"], summary="用户登录", description="用户登录并获取认证令牌")
|
||||
async def login(request: LoginRequest, response: Response,
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)):
|
||||
"""
|
||||
用户登录端点
|
||||
|
||||
- **username_or_email**: 用户名或邮箱(必需)
|
||||
- **password**: 用户密码(必需)
|
||||
|
||||
成功登录后返回认证令牌和刷新令牌,并设置安全cookie
|
||||
"""
|
||||
try:
|
||||
# 认证用户
|
||||
user_data = auth_manager.authenticate_user(
|
||||
username_or_email=request.username_or_email,
|
||||
password=request.password
|
||||
)
|
||||
|
||||
if not user_data:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名/邮箱或密码错误"
|
||||
)
|
||||
|
||||
# 创建认证令牌
|
||||
tokens = auth_manager.create_auth_tokens(user_data['user_id'])
|
||||
if not tokens:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="创建认证令牌失败"
|
||||
)
|
||||
|
||||
# 设置安全cookie
|
||||
if request.remember_me:
|
||||
set_auth_cookies(response, tokens)
|
||||
|
||||
return AuthResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
auth_key=tokens['auth_key'],
|
||||
auth_key_session_id=tokens['auth_key_session_id'],
|
||||
refresh_key=tokens['refresh_key'],
|
||||
refresh_key_session_id=tokens['refresh_key_session_id']
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Login endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.post("/api/refresh", response_model=AuthResponse,
|
||||
tags=["Authentication API Endpoints"], summary="刷新令牌", description="使用刷新令牌获取新的认证令牌")
|
||||
async def refresh_tokens(request: RefreshRequest, response: Response,
|
||||
refresh_key: Optional[str] = Cookie(None),
|
||||
refresh_key_session_id: Optional[str] = Cookie(None),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)):
|
||||
"""
|
||||
令牌刷新端点
|
||||
|
||||
- **auth_key_session_id**: 认证令牌会话ID(可选,用于删除旧的认证会话)
|
||||
|
||||
refresh_key和refresh_key_session_id从httponly Cookie中自动获取
|
||||
验证刷新令牌后返回新的认证令牌和刷新令牌,并更新cookie
|
||||
"""
|
||||
try:
|
||||
# 检查Cookie中的刷新令牌信息
|
||||
if not refresh_key or not refresh_key_session_id:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="未找到刷新令牌,请重新登录"
|
||||
)
|
||||
|
||||
# 刷新令牌
|
||||
new_tokens = auth_manager.refresh_tokens(
|
||||
refresh_key_session_id=refresh_key_session_id,
|
||||
refresh_key=refresh_key,
|
||||
auth_key_session_id=request.auth_key_session_id
|
||||
)
|
||||
|
||||
if not new_tokens:
|
||||
# 清除cookie
|
||||
clear_auth_cookies(response)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="刷新令牌无效或已过期"
|
||||
)
|
||||
|
||||
# 设置新的安全cookie
|
||||
set_auth_cookies(response, new_tokens)
|
||||
|
||||
return AuthResponse(
|
||||
success=True,
|
||||
message="令牌刷新成功",
|
||||
auth_key=new_tokens['auth_key'],
|
||||
auth_key_session_id=new_tokens['auth_key_session_id'],
|
||||
refresh_key=new_tokens['refresh_key'],
|
||||
refresh_key_session_id=new_tokens['refresh_key_session_id']
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Refresh endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.post("/api/logout", response_model=SimpleResponse,
|
||||
tags=["Authentication API Endpoints"], summary="登出", description="登出当前会话")
|
||||
async def logout(response: Response,
|
||||
refresh_key_session_id: Optional[str] = Cookie(None),
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)):
|
||||
"""
|
||||
登出端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
refresh_key_session_id从httponly Cookie中自动获取
|
||||
|
||||
登出指定的会话并清除相关cookie
|
||||
"""
|
||||
try:
|
||||
# 检查Cookie中的刷新令牌会话ID
|
||||
if not refresh_key_session_id:
|
||||
# 清除cookie(如果有的话)
|
||||
clear_auth_cookies(response)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="未找到刷新令牌会话ID"
|
||||
)
|
||||
|
||||
success = auth_manager.logout_session(
|
||||
refresh_key_session_id=refresh_key_session_id,
|
||||
auth_key_session_id=auth_session['session_id']
|
||||
)
|
||||
|
||||
# 清除cookie
|
||||
clear_auth_cookies(response)
|
||||
|
||||
if success:
|
||||
return SimpleResponse(success=True, message="登出成功")
|
||||
else:
|
||||
return SimpleResponse(success=False, message="会话不存在或已过期")
|
||||
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Logout endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.post("/api/logout-all", response_model=SimpleResponse,
|
||||
tags=["Authentication API Endpoints"], summary="登出所有会话", description="登出用户的所有会话")
|
||||
async def logout_all(response: Response,
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)):
|
||||
"""
|
||||
登出所有会话端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
|
||||
登出用户的所有会话并清除相关cookie
|
||||
"""
|
||||
try:
|
||||
success = auth_manager.logout_all_sessions(auth_session['user_id'])
|
||||
|
||||
# 清除cookie
|
||||
clear_auth_cookies(response)
|
||||
|
||||
if success:
|
||||
return SimpleResponse(success=True, message="所有会话已登出")
|
||||
else:
|
||||
return SimpleResponse(success=False, message="会话不存在或已过期")
|
||||
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Logout all endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.get("/api/profile", response_model=Dict[str, Any],
|
||||
tags=["Authentication API Endpoints"], summary="获取用户信息", description="获取当前认证用户的基本信息")
|
||||
async def get_profile(
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)
|
||||
):
|
||||
"""
|
||||
获取用户信息端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
|
||||
返回当前认证用户的基本信息
|
||||
"""
|
||||
try:
|
||||
user_data = auth_manager.db.get_user_by_id(auth_session['user_id'])
|
||||
if not user_data:
|
||||
raise HTTPException(status_code=404, detail="用户未找到")
|
||||
|
||||
return {
|
||||
"user_id": user_data['user_id'],
|
||||
"username": user_data['username'],
|
||||
"name": user_data['name'],
|
||||
"email": user_data.get('email', None),
|
||||
'user_group': user_data.get('user_group', ['user']),
|
||||
"is_activated": user_data.get('is_activated', False),
|
||||
"last_active": user_data.get('last_active', None)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Get profile endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.put("/api/update-password", response_model=SimpleResponse,
|
||||
tags=["Authentication API Endpoints"], summary="更新密码", description="更新当前用户的密码")
|
||||
async def update_password(
|
||||
request: UpdatePasswordRequest,
|
||||
response: Response,
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)
|
||||
):
|
||||
"""
|
||||
更新密码端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
|
||||
- **original_password**: 原密码(必需)
|
||||
- **password**: 新密码(必需,最少6个字符)
|
||||
|
||||
验证原密码后更新为新密码,并自动登出所有会话
|
||||
"""
|
||||
try:
|
||||
user_id = auth_session['user_id']
|
||||
|
||||
# 调用AuthManager处理密码更新逻辑
|
||||
result = auth_manager.update_user_password_with_verification(
|
||||
user_id=user_id,
|
||||
original_password=request.original_password,
|
||||
new_password=request.password
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
# 如果密码更新成功且包含logout_all标志,清除cookie
|
||||
if result.get("logout_all", False):
|
||||
clear_auth_cookies(response)
|
||||
|
||||
return SimpleResponse(success=True, message=result["message"])
|
||||
else:
|
||||
# 根据错误类型返回相应的HTTP状态码
|
||||
if "用户未找到" in result["message"]:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
elif "原密码错误" in result["message"]:
|
||||
raise HTTPException(status_code=401, detail=result["message"])
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Update password endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
@app.delete("/api/delete-user", response_model=SimpleResponse,
|
||||
tags=["Authentication API Endpoints"], summary="删除自己的账户", description="删除当前登录用户的账户")
|
||||
async def delete_own_account(
|
||||
response: Response,
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)
|
||||
):
|
||||
"""
|
||||
删除自己账户端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
只能删除当前登录用户本人账户,删除成功后将清除所有认证cookie
|
||||
"""
|
||||
try:
|
||||
user_id = auth_session['user_id']
|
||||
# 调用AuthManager处理自己账户删除逻辑
|
||||
result = auth_manager.delete_own_account(user_id)
|
||||
|
||||
if result["success"]:
|
||||
# 删除成功后清除认证cookie
|
||||
clear_auth_cookies(response)
|
||||
return SimpleResponse(success=True, message=result["message"])
|
||||
else:
|
||||
# 根据错误类型返回相应的HTTP状态码
|
||||
if "未找到" in result["message"]:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=result["message"])
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Delete own account endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@app.delete("/api/admin/delete-user", response_model=SimpleResponse,
|
||||
tags=["Admin API Endpoints"], summary="管理员删除用户", description="管理员删除指定用户账户")
|
||||
async def admin_delete_user(
|
||||
request: UserIDRequest,
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
auth_manager: AuthManager = Depends(get_auth_manager)
|
||||
):
|
||||
"""
|
||||
管理员删除用户端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
需要管理员权限(manager或admin用户组)
|
||||
- **user_id**: 要删除的用户ID(必需)
|
||||
"""
|
||||
try:
|
||||
admin_user_id = auth_session['user_id']
|
||||
target_user_id = request.user_id
|
||||
|
||||
# 调用AuthManager处理管理员删除用户逻辑
|
||||
result = auth_manager.admin_delete_user(
|
||||
admin_user_id=admin_user_id,
|
||||
admin_user_groups=auth_session.get('user_group', ["user"]),
|
||||
target_user_id=target_user_id
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
return SimpleResponse(success=True, message=result["message"])
|
||||
else:
|
||||
# 根据错误类型返回相应的HTTP状态码
|
||||
if "未找到" in result["message"]:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
elif "权限不足" in result["message"]:
|
||||
raise HTTPException(status_code=403, detail=result["message"])
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_manager.logger.exception(f"Admin delete user endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
# ============================================================================
|
||||
# Avatar Generation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/avatar/{variant}",
|
||||
tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像")
|
||||
async def generate_avatar_variant_only(
|
||||
variant: str,
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
name: Optional[str] = Query(default="John Doe", description="Name to generate avatar for"),
|
||||
colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"),
|
||||
size: int = Query(default=80, description="Avatar size"),
|
||||
square: bool = Query(default=False, description="Square avatar"),
|
||||
title: bool = Query(default=False, description="Include title element"),
|
||||
|
||||
avatar_generator: AvatarGenerator = Depends(get_avatar_generator),
|
||||
):
|
||||
"""Generate avatar with variant only"""
|
||||
try:
|
||||
if variant not in avatar_generator.VALID_VARIANTS:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}")
|
||||
|
||||
# Validate and sanitize inputs
|
||||
if not name or not isinstance(name, str):
|
||||
name = auth_session['user_id']
|
||||
|
||||
if not isinstance(size, int) or size < 10 or size > 1000:
|
||||
size = avatar_generator.DEFAULT_SIZE
|
||||
|
||||
color_palette = avatar_generator.normalize_colors(colors)
|
||||
generator = avatar_generator.AVATAR_GENERATORS[variant]
|
||||
svg_content = generator(name, color_palette, size, square, title)
|
||||
|
||||
return Response(
|
||||
content=svg_content,
|
||||
media_type="image/svg+xml",
|
||||
headers={
|
||||
"Cache-Control": "s-max-age=1, stale-while-revalidate",
|
||||
"Content-Type": "image/svg+xml"
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}")
|
||||
|
||||
@app.get("/avatar/{variant}/{size}",
|
||||
tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像")
|
||||
async def generate_avatar_variant_size(
|
||||
variant: str,
|
||||
size: int,
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
name: Optional[str] = Query(default="Clara Barton", description="Name to generate avatar for"),
|
||||
colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"),
|
||||
square: bool = Query(default=False, description="Square avatar"),
|
||||
title: bool = Query(default=False, description="Include title element"),
|
||||
avatar_generator: AvatarGenerator = Depends(get_avatar_generator)
|
||||
):
|
||||
"""Generate avatar with variant and size"""
|
||||
try:
|
||||
if variant not in avatar_generator.VALID_VARIANTS:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}")
|
||||
|
||||
# Validate and sanitize inputs
|
||||
if not name or not isinstance(name, str):
|
||||
name = auth_session['user_id']
|
||||
|
||||
if not isinstance(size, int) or size < 10 or size > 1000:
|
||||
size = avatar_generator.DEFAULT_SIZE
|
||||
|
||||
color_palette = avatar_generator.normalize_colors(colors)
|
||||
generator = avatar_generator.AVATAR_GENERATORS[variant]
|
||||
svg_content = generator(name, color_palette, size, square, title)
|
||||
|
||||
return Response(
|
||||
content=svg_content,
|
||||
media_type="image/svg+xml",
|
||||
headers={
|
||||
"Cache-Control": "s-max-age=1, stale-while-revalidate",
|
||||
"Content-Type": "image/svg+xml"
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}")
|
||||
|
||||
@app.get("/avatar/{variant}/{size}/{name}",
|
||||
tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像")
|
||||
async def generate_avatar_full_path(
|
||||
variant: str,
|
||||
size: int,
|
||||
name: str,
|
||||
colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"),
|
||||
square: bool = Query(default=False, description="Square avatar"),
|
||||
title: bool = Query(default=False, description="Include title element"),
|
||||
avatar_generator: AvatarGenerator = Depends(get_avatar_generator),
|
||||
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
):
|
||||
"""Generate avatar with variant, size, and name in path"""
|
||||
try:
|
||||
if variant not in avatar_generator.VALID_VARIANTS:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}")
|
||||
|
||||
# Validate and sanitize inputs
|
||||
if not name or not isinstance(name, str):
|
||||
name = auth_session['user_id']
|
||||
|
||||
if not isinstance(size, int) or size < 10 or size > 1000:
|
||||
size = avatar_generator.DEFAULT_SIZE
|
||||
|
||||
color_palette = avatar_generator.normalize_colors(colors)
|
||||
generator = avatar_generator.AVATAR_GENERATORS[variant]
|
||||
svg_content = generator(name, color_palette, size, square, title)
|
||||
|
||||
return Response(
|
||||
content=svg_content,
|
||||
media_type="image/svg+xml",
|
||||
headers={
|
||||
"Cache-Control": "s-max-age=1, stale-while-revalidate",
|
||||
"Content-Type": "image/svg+xml"
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}")
|
||||
|
||||
# ============================================================================
|
||||
# System Statistics API Endpoints
|
||||
# ============================================================================
|
||||
@app.get("/api/system-stats", response_model=SystemStatsResponse,
|
||||
tags=["System Statistics API Endpoints"], summary="获取系统统计信息", description="获取服务器的系统统计信息")
|
||||
async def get_system_stats(
|
||||
auth_session: Dict[str, Any] = Depends(verify_auth_token),
|
||||
time_delta: Optional[int] = Query(default=5, description="时间范围(秒),默认为5秒"),
|
||||
time_interval: Optional[int] = Query(default=1, description="时间间隔(秒),默认为1秒"),
|
||||
system_stats_manager: SystemStatsManager = Depends(get_system_stats_manager)
|
||||
):
|
||||
"""
|
||||
获取系统统计信息端点
|
||||
|
||||
需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id
|
||||
- **time_delta**: 时间范围(秒),默认为5秒
|
||||
返回服务器的系统统计信息
|
||||
"""
|
||||
try:
|
||||
stats = system_stats_manager.get_recent_stats(
|
||||
auth_session.get('user_group', ['user']),
|
||||
time_delta=time_delta,
|
||||
time_interval=time_interval
|
||||
)
|
||||
if stats["success"]:
|
||||
return stats
|
||||
else:
|
||||
if "未找到" in stats["message"]:
|
||||
raise HTTPException(status_code=404, detail=stats["message"])
|
||||
elif "权限不足" in stats["message"]:
|
||||
raise HTTPException(status_code=403, detail=stats["message"])
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=stats["message"])
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
system_stats_manager.logger.exception(f"Get system stats endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
# ============================================================================
|
||||
# Additional Utility API Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/api/health", response_model=Dict[str, Any],
|
||||
tags=["Additional Utility API Endpoints"], summary="健康检查", description="检查API服务器状态")
|
||||
async def health_check(db: Database = Depends(get_db)):
|
||||
"""健康检查端点"""
|
||||
try:
|
||||
stats = db.get_database_stats()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"database_stats": stats
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"服务不可用: {str(e)}")
|
||||
|
||||
@app.post("/api/clean-cookies", response_model=SimpleResponse,
|
||||
tags=["Additional Utility API Endpoints"], summary="清除认证Cookie", description="清除认证相关的Cookie")
|
||||
async def clean_cookies(response: Response):
|
||||
"""清除认证Cookie端点"""
|
||||
try:
|
||||
clear_auth_cookies(response)
|
||||
return SimpleResponse(success=True, message="认证Cookie已清除")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"清除Cookie失败: {str(e)}")
|
||||
|
||||
# ============================================================================
|
||||
# Static Files Configuration
|
||||
# ============================================================================
|
||||
# 在所有API路由定义之后挂载静态文件,确保API路由有更高优先级
|
||||
|
||||
# Frontend路径配置
|
||||
frontend_path = r"./livestream_site" # 前端静态文件目录
|
||||
app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
|
||||
|
||||
# ============================================================================
|
||||
# Main Entry Point
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger()
|
||||
logger.info("Starting FastAPI Authentication Server on port 8000...")
|
||||
|
||||
# 检查是否在PyInstaller打包环境中运行
|
||||
import sys
|
||||
if getattr(sys, 'frozen', False):
|
||||
# 在打包环境中,直接传递app对象
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
reload=False,
|
||||
log_level="info"
|
||||
)
|
||||
else:
|
||||
# 在开发环境中,使用字符串导入(支持热重载)
|
||||
uvicorn.run(
|
||||
"httpbackend_server:app",
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
reload=True,
|
||||
reload_delay=5.0,
|
||||
log_level="info"
|
||||
)
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
SRS Python Monitor Script
|
||||
This is an example Python script that runs alongside SRS server.
|
||||
It demonstrates how Python processes can be managed by SRS.
|
||||
"""
|
||||
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
class SRSMonitor:
|
||||
def __init__(self, config_file=None, verbose=False):
|
||||
self.running = True
|
||||
self.config_file = config_file
|
||||
self.verbose = verbose
|
||||
|
||||
# Set up logging
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format='[%(asctime)s] [Python Monitor] %(levelname)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGTERM, self.signal_handler)
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
|
||||
def signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals from SRS"""
|
||||
self.logger.info(f"Received signal {signum}, shutting down gracefully...")
|
||||
self.running = False
|
||||
|
||||
def run(self):
|
||||
"""Main monitoring loop"""
|
||||
self.logger.info("SRS Python Monitor started")
|
||||
if self.config_file:
|
||||
self.logger.info(f"Using config file: {self.config_file}")
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# Simulate monitoring work
|
||||
self.logger.debug(f"Monitor heartbeat at {datetime.now()}")
|
||||
|
||||
# You can add your monitoring logic here:
|
||||
# - Check stream status
|
||||
# - Monitor server health
|
||||
# - Send alerts
|
||||
# - Log analytics
|
||||
|
||||
time.sleep(5) # Check every 5 seconds
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monitor loop: {e}")
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup before shutdown"""
|
||||
self.logger.info("Cleaning up Python Monitor...")
|
||||
# Add any cleanup logic here
|
||||
self.logger.info("Python Monitor stopped")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='SRS Python Monitor')
|
||||
parser.add_argument('--config', help='Configuration file path')
|
||||
parser.add_argument('--verbose', action='store_true', help='Enable verbose logging')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
monitor = SRSMonitor(args.config, args.verbose)
|
||||
monitor.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -1,7 +1,70 @@
|
|||
# Python packages required for SRS Python addons
|
||||
requests>=2.25.0
|
||||
psutil>=5.8.0
|
||||
pyyaml>=5.4.0
|
||||
websockets>=10.0
|
||||
aiohttp>=3.8.0
|
||||
numpy>=1.20.0
|
||||
aiofiles==23.2.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.12.13
|
||||
aiosignal==1.3.2
|
||||
altgraph==0.17.4
|
||||
annotated-types==0.7.0
|
||||
anyio==3.7.1
|
||||
attrs==25.3.0
|
||||
bcrypt==4.3.0
|
||||
blinker==1.9.0
|
||||
certifi==2025.4.26
|
||||
cffi==1.17.1
|
||||
charset-normalizer==3.4.2
|
||||
click==8.2.1
|
||||
colorama==0.4.6
|
||||
cryptography==45.0.3
|
||||
dnspython==2.7.0
|
||||
ecdsa==0.19.1
|
||||
email-validator==2.1.0
|
||||
fastapi==0.104.1
|
||||
Flask==3.1.1
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
frozenlist==1.7.0
|
||||
greenlet==3.2.3
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httptools==0.6.4
|
||||
httpx==0.25.2
|
||||
idna==3.10
|
||||
iniconfig==2.1.0
|
||||
itsdangerous==2.2.0
|
||||
Jinja2==3.1.2
|
||||
jwt==1.3.1
|
||||
MarkupSafe==3.0.2
|
||||
multidict==6.4.4
|
||||
packaging==25.0
|
||||
passlib==1.7.4
|
||||
pefile==2023.2.7
|
||||
pluggy==1.6.0
|
||||
propcache==0.3.2
|
||||
pyasn1==0.6.1
|
||||
pycparser==2.22
|
||||
pydantic==2.5.0
|
||||
pydantic_core==2.14.1
|
||||
pyinstaller==6.14.1
|
||||
pyinstaller-hooks-contrib==2025.5
|
||||
PyJWT==2.8.0
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
python-dotenv==1.0.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.6
|
||||
pywin32-ctypes==0.2.3
|
||||
PyYAML==6.0.2
|
||||
requests==2.32.4
|
||||
rsa==4.9.1
|
||||
setuptools==78.1.1
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
SQLAlchemy==2.0.41
|
||||
starlette==0.27.0
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.14.0
|
||||
urllib3==2.4.0
|
||||
uvicorn==0.24.0
|
||||
watchfiles==1.0.5
|
||||
websockets==15.0.1
|
||||
Werkzeug==3.1.3
|
||||
wheel==0.45.1
|
||||
yarl==1.20.1
|
||||
|
|
|
|||
317
trunk/python/security_module.py
Normal file
317
trunk/python/security_module.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
import hashlib
|
||||
import secrets
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from database import Database
|
||||
from srs_logger import get_logger
|
||||
|
||||
# ============================================================================
|
||||
# Security and Utility Functions
|
||||
# ============================================================================
|
||||
|
||||
class AuthManager:
|
||||
"""认证管理器"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
self.logger = get_logger()
|
||||
|
||||
# Token 配置
|
||||
self.AUTH_TOKEN_EXPIRE_HOURS = 1 # 认证令牌1小时过期
|
||||
self.REFRESH_TOKEN_EXPIRE_DAYS = 7 # 刷新令牌7天过期
|
||||
self.TOKEN_LENGTH = 128 # 令牌长度
|
||||
|
||||
def generate_secure_token(self) -> str:
|
||||
"""生成安全的随机令牌"""
|
||||
return secrets.token_urlsafe(self.TOKEN_LENGTH)
|
||||
|
||||
def generate_salt(self) -> str:
|
||||
"""生成盐值"""
|
||||
return secrets.token_hex(16)
|
||||
|
||||
def hash_password(self, password: str, salt: str) -> str:
|
||||
"""对密码进行加盐哈希"""
|
||||
return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex()
|
||||
|
||||
def verify_password(self, password: str, salt: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return hmac.compare_digest(
|
||||
self.hash_password(password, salt),
|
||||
hashed_password
|
||||
)
|
||||
|
||||
def hash_token(self, token: str, salt: str) -> str:
|
||||
"""对令牌进行加盐哈希"""
|
||||
return hashlib.pbkdf2_hmac('sha256', token.encode(), salt.encode(), 100000).hex()
|
||||
|
||||
def verify_token(self, token: str, salt: str, hashed_token: str) -> bool:
|
||||
"""验证令牌"""
|
||||
return hmac.compare_digest(
|
||||
self.hash_token(token, salt),
|
||||
hashed_token
|
||||
)
|
||||
|
||||
def register_user(self, username: str, name: str, email: Optional[str], password: str) -> bool:
|
||||
"""注册新用户"""
|
||||
try:
|
||||
# 生成密码盐值和哈希
|
||||
salt = self.generate_salt()
|
||||
hashed_password = self.hash_password(password, salt)
|
||||
|
||||
# 创建用户
|
||||
user_data = self.db.create_user(
|
||||
username=username,
|
||||
name=name,
|
||||
email=email,
|
||||
hashed_password=hashed_password,
|
||||
salt=salt,
|
||||
user_group=["user"],
|
||||
is_activated=True
|
||||
)
|
||||
|
||||
if user_data:
|
||||
self.logger.info(f"User registered successfully: {username}")
|
||||
return True
|
||||
else:
|
||||
self.logger.warn(f"Failed to register user: {username}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error during user registration: {e}")
|
||||
return False
|
||||
|
||||
def authenticate_user(self, username_or_email: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""认证用户并返回用户数据"""
|
||||
try:
|
||||
# 获取用户信息
|
||||
user_data = self.db.get_user(username_or_email)
|
||||
if not user_data:
|
||||
self.logger.warn(f"User not found: {username_or_email}")
|
||||
return None
|
||||
|
||||
# 检查用户是否激活
|
||||
if not user_data.get('is_activated', False):
|
||||
self.logger.warn(f"User account deactivated: {username_or_email}")
|
||||
return None
|
||||
|
||||
# 验证密码
|
||||
if self.verify_password(password, user_data['salt'], user_data['hashed_password']):
|
||||
self.logger.info(f"User authenticated successfully: {username_or_email}")
|
||||
return user_data
|
||||
else:
|
||||
self.logger.warn(f"Invalid password for user: {username_or_email}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error during user authentication: {e}")
|
||||
return None
|
||||
|
||||
def create_auth_tokens(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""为用户创建认证和刷新令牌"""
|
||||
try:
|
||||
# 清理过期会话
|
||||
self.db.cleanup_expired_sessions()
|
||||
|
||||
# 生成认证令牌
|
||||
auth_key = self.generate_secure_token()
|
||||
auth_salt = self.generate_salt()
|
||||
hashed_auth_key = self.hash_token(auth_key, auth_salt)
|
||||
auth_expire_time = datetime.now() + timedelta(hours=self.AUTH_TOKEN_EXPIRE_HOURS)
|
||||
|
||||
# 生成刷新令牌
|
||||
refresh_key = self.generate_secure_token()
|
||||
refresh_salt = self.generate_salt()
|
||||
hashed_refresh_key = self.hash_token(refresh_key, refresh_salt)
|
||||
refresh_expire_time = datetime.now() + timedelta(days=self.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
# 创建认证会话
|
||||
auth_session = self.db.create_auth_session(
|
||||
user_id=user_id,
|
||||
hashed_authkey=hashed_auth_key,
|
||||
salt=auth_salt,
|
||||
expire_time=auth_expire_time
|
||||
)
|
||||
|
||||
if not auth_session:
|
||||
self.logger.error(f"Failed to create auth session for user: {user_id}")
|
||||
return None
|
||||
|
||||
# 创建刷新会话
|
||||
refresh_session = self.db.create_refresh_session(
|
||||
user_id=user_id,
|
||||
hashed_refreshkey=hashed_refresh_key,
|
||||
salt=refresh_salt,
|
||||
expire_time=refresh_expire_time
|
||||
)
|
||||
|
||||
if not refresh_session:
|
||||
self.logger.error(f"Failed to create refresh session for user: {user_id}")
|
||||
# 清理已创建的认证会话
|
||||
self.db.delete_auth_session(auth_session['session_id'])
|
||||
return None
|
||||
|
||||
# 更新用户最后活跃时间
|
||||
self.db.update_user_last_active(user_id)
|
||||
|
||||
return {
|
||||
'auth_key': auth_key,
|
||||
'auth_key_session_id': auth_session['session_id'],
|
||||
'refresh_key': refresh_key,
|
||||
'refresh_key_session_id': refresh_session['session_id']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error creating auth tokens for user: {user_id}")
|
||||
return None
|
||||
|
||||
def refresh_tokens(self, refresh_key_session_id: str, refresh_key: str,
|
||||
auth_key_session_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""刷新认证令牌"""
|
||||
try:
|
||||
# 清理过期会话
|
||||
self.db.cleanup_expired_sessions()
|
||||
|
||||
# 验证刷新会话
|
||||
refresh_session = self.db.get_refresh_session(refresh_key_session_id)
|
||||
if not refresh_session:
|
||||
self.logger.warn(f"Invalid or expired refresh session: {refresh_key_session_id}")
|
||||
return None
|
||||
|
||||
# 验证刷新令牌
|
||||
if not self.verify_token(refresh_key, refresh_session['salt'], refresh_session['hashed_refreshkey']):
|
||||
self.logger.warn(f"Invalid refresh token for session: {refresh_key_session_id}")
|
||||
return None
|
||||
|
||||
user_id = refresh_session['user_id']
|
||||
|
||||
# 删除旧的认证会话(如果提供了session_id)
|
||||
if auth_key_session_id:
|
||||
self.db.delete_auth_session(auth_key_session_id)
|
||||
|
||||
# 删除旧的刷新会话
|
||||
self.db.delete_refresh_session(refresh_key_session_id)
|
||||
|
||||
# 创建新的令牌
|
||||
new_tokens = self.create_auth_tokens(user_id)
|
||||
if new_tokens:
|
||||
self.logger.info(f"Tokens refreshed successfully for user: {user_id}")
|
||||
return new_tokens
|
||||
else:
|
||||
self.logger.error(f"Failed to create new tokens during refresh for user: {user_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error during token refresh: {e}")
|
||||
return None
|
||||
|
||||
def logout_session(self, refresh_key_session_id: str, auth_key_session_id: str = None) -> bool:
|
||||
"""登出单个会话,通过删除指定的refresh session和对应的auth session"""
|
||||
try:
|
||||
# 删除指定的刷新会话
|
||||
refresh_success = self.db.delete_refresh_session(refresh_key_session_id)
|
||||
if refresh_success:
|
||||
self.logger.info(f"Refresh session logged out successfully: {refresh_key_session_id}")
|
||||
|
||||
auth_success = self.db.delete_auth_session(auth_key_session_id)
|
||||
if auth_success:
|
||||
self.logger.info(f"Auth session logged out successfully: {auth_key_session_id}")
|
||||
else:
|
||||
self.logger.warn(f"Failed to delete auth session: {auth_key_session_id}")
|
||||
|
||||
return refresh_success and auth_success
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error during session logout: {e}")
|
||||
return False
|
||||
|
||||
def logout_all_sessions(self, user_id: str) -> bool:
|
||||
"""登出用户的所有会话,通过refresh session确定用户"""
|
||||
try:
|
||||
# 删除用户的所有会话
|
||||
success = self.db.delete_user_sessions(user_id)
|
||||
if success:
|
||||
self.logger.info(f"All sessions logged out successfully for user: {user_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error during logout all sessions: {e}")
|
||||
return False
|
||||
|
||||
def update_user_password_with_verification(self, user_id: str, original_password: str, new_password: str) -> Dict[str, Any]:
|
||||
"""验证原密码后更新用户密码,并登出所有会话"""
|
||||
try:
|
||||
# 获取用户数据
|
||||
user_data = self.db.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return {"success": False, "message": "用户未找到"}
|
||||
|
||||
# 验证原密码
|
||||
if not self.verify_password(original_password, user_data['salt'], user_data['hashed_password']):
|
||||
return {"success": False, "message": "原密码错误"}
|
||||
|
||||
# 生成新的盐值和密码哈希
|
||||
new_salt = self.generate_salt()
|
||||
new_hashed_password = self.hash_password(new_password, new_salt)
|
||||
|
||||
# 更新密码
|
||||
updated_user = self.db.update_user_password(user_id, new_hashed_password, new_salt)
|
||||
if not updated_user:
|
||||
return {"success": False, "message": "密码更新失败"}
|
||||
|
||||
# 密码更新成功后,登出用户的所有会话(强制重新登录)
|
||||
logout_success = self.db.delete_user_sessions(user_id)
|
||||
if logout_success:
|
||||
self.logger.info(f"All sessions logged out after password update for user: {user_data['username']}")
|
||||
else:
|
||||
self.logger.warn(f"Failed to logout all sessions after password update for user: {user_data['username']}")
|
||||
|
||||
self.logger.info(f"Password updated successfully for user: {user_data['username']}")
|
||||
return {"success": True, "message": "密码更新成功,请重新登录", "logout_all": True}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error updating user password: {e}")
|
||||
return {"success": False, "message": "服务器内部错误"}
|
||||
|
||||
def delete_own_account(self, user_id: str) -> Dict[str, Any]:
|
||||
"""删除用户自己的账户"""
|
||||
try:
|
||||
success = self.db.delete_user(user_id)
|
||||
if not success:
|
||||
return {"success": False, "message": "账户删除失败"}
|
||||
|
||||
self.logger.info(f"User-ID: '{user_id}' deleted his/her own account")
|
||||
return {"success": True, "message": "您的账户已成功删除"}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error deleting own account: {e}")
|
||||
return {"success": False, "message": "服务器内部错误"}
|
||||
|
||||
def admin_delete_user(self, admin_user_id: str, admin_user_groups: list, target_user_id: str, ) -> Dict[str, Any]:
|
||||
"""管理员删除指定用户账户"""
|
||||
try:
|
||||
# 检查管理员权限
|
||||
if not any(group in admin_user_groups for group in ['manager', 'admin']):
|
||||
return {
|
||||
"success": False,
|
||||
"message": "权限不足,只有管理员可以删除其他用户账户"
|
||||
}
|
||||
|
||||
# 获取目标用户数据
|
||||
target_user = self.db.get_user_by_id(target_user_id)
|
||||
if not target_user:
|
||||
return {"success": False, "message": "目标删除的用户未找到"}
|
||||
|
||||
# 执行删除操作
|
||||
success = self.db.delete_user(target_user_id)
|
||||
if not success:
|
||||
return {"success": False, "message": "用户删除失败"}
|
||||
|
||||
self.logger.info(f"User-ID '{target_user_id}' deleted by admin-ID: '{admin_user_id}'")
|
||||
return {"success": True, "message": f"用户 '{target_user_id}' 已成功删除"}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error in admin delete user: {e}")
|
||||
return {"success": False, "message": "服务器内部错误"}
|
||||
405
trunk/python/srs_logger.py
Normal file
405
trunk/python/srs_logger.py
Normal file
|
|
@ -0,0 +1,405 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SRS Python Logger Module
|
||||
Provides synchronized logging with SRS main process.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
import argparse
|
||||
|
||||
class SRSLogFormatter(logging.Formatter):
|
||||
"""
|
||||
Custom formatter that matches SRS log format with color support:
|
||||
[2025-05-30 21:54:18.835][WARN][3448][python_analytics] message
|
||||
"""
|
||||
|
||||
# ANSI color codes
|
||||
COLORS = {
|
||||
'RESET': '\033[0m',
|
||||
'RED': '\033[31m',
|
||||
'YELLOW': '\033[33m',
|
||||
'WHITE': '\033[37m',
|
||||
'CYAN': '\033[36m',
|
||||
'GREEN': '\033[32m'
|
||||
}
|
||||
|
||||
def __init__(self, use_colors=True):
|
||||
super().__init__()
|
||||
self.pid = os.getpid()
|
||||
self.session_id = self._generate_session_id()
|
||||
self.use_colors = use_colors and self._supports_color()
|
||||
|
||||
def _supports_color(self) -> bool:
|
||||
"""Check if the terminal supports color output"""
|
||||
import sys
|
||||
|
||||
# Check if we're in a terminal
|
||||
if not hasattr(sys.stdout, 'isatty') or not sys.stdout.isatty():
|
||||
return False
|
||||
|
||||
# Check for Windows terminal support
|
||||
if os.name == 'nt':
|
||||
try:
|
||||
import colorama
|
||||
colorama.init()
|
||||
return True
|
||||
except ImportError:
|
||||
# Try to enable ANSI escape sequence support on Windows
|
||||
try:
|
||||
import ctypes
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# Unix-like systems usually support colors
|
||||
return True
|
||||
|
||||
def _generate_session_id(self) -> str:
|
||||
"""Generate a meaningful session ID based on the calling script"""
|
||||
import inspect
|
||||
|
||||
# Try to get the main script name
|
||||
main_module = sys.modules.get('__main__')
|
||||
if main_module and hasattr(main_module, '__file__'):
|
||||
script_path = main_module.__file__
|
||||
if script_path:
|
||||
script_name = os.path.splitext(os.path.basename(script_path))[0]
|
||||
return f"python_{script_name}"
|
||||
|
||||
# Fallback to inspect the call stack to find the calling script
|
||||
try:
|
||||
for frame_info in inspect.stack():
|
||||
filename = frame_info.filename
|
||||
if filename and not filename.endswith('srs_logger.py') and not filename.endswith('logging/__init__.py'):
|
||||
script_name = os.path.splitext(os.path.basename(filename))[0]
|
||||
if script_name != '<string>' and script_name != '<stdin>':
|
||||
return f"python_{script_name}"
|
||||
except:
|
||||
pass
|
||||
|
||||
# Final fallback
|
||||
return "python_unknown"
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record to match SRS format with color support"""
|
||||
# Get current timestamp with milliseconds
|
||||
now = datetime.now()
|
||||
timestamp = now.strftime("%Y-%m-%d %H:%M:%S.") + f"{now.microsecond // 1000:03d}"
|
||||
|
||||
# Map Python log levels to SRS log levels
|
||||
level_mapping = {
|
||||
'DEBUG': 'TRACE',
|
||||
'INFO': 'INFO',
|
||||
'WARNING': 'WARN',
|
||||
'ERROR': 'ERROR',
|
||||
'CRITICAL': 'ERROR'
|
||||
}
|
||||
|
||||
srs_level = level_mapping.get(record.levelname, 'INFO')
|
||||
|
||||
# Format: [timestamp][level][pid][session] message
|
||||
formatted_msg = f"[{timestamp}][{srs_level}][{self.pid}][{self.session_id}] {record.getMessage()}"
|
||||
|
||||
if record.exc_info:
|
||||
formatted_msg += f"\n{self.formatException(record.exc_info)}"
|
||||
|
||||
# Apply colors if supported and enabled
|
||||
if self.use_colors:
|
||||
if srs_level == 'ERROR':
|
||||
formatted_msg = f"{self.COLORS['RED']}{formatted_msg}{self.COLORS['RESET']}"
|
||||
elif srs_level == 'WARN':
|
||||
formatted_msg = f"{self.COLORS['YELLOW']}{formatted_msg}{self.COLORS['RESET']}"
|
||||
# INFO, TRACE levels remain uncolored (default terminal color)
|
||||
|
||||
return formatted_msg
|
||||
|
||||
class SRSConfigParser:
|
||||
"""Parser for SRS configuration files"""
|
||||
|
||||
@staticmethod
|
||||
def parse_config_file(config_path: str) -> Dict[str, Any]:
|
||||
"""Parse SRS configuration file and extract logging settings"""
|
||||
config = {
|
||||
'log_tank': 'all',
|
||||
'log_file': './objs/srs.log',
|
||||
'log_level': 'info'
|
||||
}
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
return config
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Parse srs_log_tank
|
||||
match = re.search(r'srs_log_tank\s+([^;]+);', content)
|
||||
if match:
|
||||
config['log_tank'] = match.group(1).strip()
|
||||
|
||||
# Parse srs_log_file
|
||||
match = re.search(r'srs_log_file\s+([^;]+);', content)
|
||||
if match:
|
||||
config['log_file'] = match.group(1).strip()
|
||||
|
||||
# Parse srs_log_level_v2
|
||||
match = re.search(r'srs_log_level_v2\s+([^;]+);', content)
|
||||
if match:
|
||||
config['log_level'] = match.group(1).strip().lower()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse config file {config_path}: {e}", file=sys.stderr)
|
||||
|
||||
return config
|
||||
|
||||
class SRSLogger:
|
||||
"""
|
||||
SRS-compatible Python logger that synchronizes with SRS main process logging
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, config_path: Optional[str] = None):
|
||||
"""Singleton pattern to ensure only one logger instance"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self.config_path = config_path
|
||||
self.logger = None
|
||||
self.config = {}
|
||||
self._setup_logger()
|
||||
|
||||
def _setup_logger(self):
|
||||
"""Setup the logger based on SRS configuration"""
|
||||
try:
|
||||
# Parse SRS config if provided
|
||||
if self.config_path and os.path.exists(self.config_path):
|
||||
self.config = SRSConfigParser.parse_config_file(self.config_path)
|
||||
else:
|
||||
# Default configuration
|
||||
self.config = {
|
||||
'log_tank': 'all',
|
||||
'log_file': './objs/srs.log',
|
||||
'log_level': 'info'
|
||||
}
|
||||
|
||||
# Create logger
|
||||
self.logger = logging.getLogger('srs_python')
|
||||
self.logger.setLevel(self._get_python_log_level(self.config['log_level']))
|
||||
|
||||
# Clear existing handlers
|
||||
self.logger.handlers.clear()
|
||||
|
||||
# Setup handlers based on log_tank setting
|
||||
log_tank = self.config['log_tank'].lower()
|
||||
|
||||
if log_tank in ['all', 'console']:
|
||||
# Console handler with colors
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_formatter = SRSLogFormatter(use_colors=True)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
if log_tank in ['all', 'file']:
|
||||
# File handler without colors
|
||||
log_file = self.config['log_file']
|
||||
# Create directory if it doesn't exist
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||
file_formatter = SRSLogFormatter(use_colors=False)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
# Prevent propagation to root logger
|
||||
self.logger.propagate = False
|
||||
|
||||
self.info(f"SRS Python logger initialized with config: {self.config}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to setup SRS logger: {e}", file=sys.stderr)
|
||||
# Setup a basic console logger as fallback
|
||||
self._setup_fallback_logger()
|
||||
|
||||
def _setup_fallback_logger(self):
|
||||
"""Setup a basic fallback logger if main setup fails"""
|
||||
self.logger = logging.getLogger('srs_python_fallback')
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
formatter = logging.Formatter('%(asctime)s [ERROR] [Python] %(message)s')
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
self.logger.addHandler(console_handler)
|
||||
self.logger.propagate = False
|
||||
|
||||
def _get_python_log_level(self, srs_level: str) -> int:
|
||||
"""Convert SRS log level to Python log level"""
|
||||
level_mapping = {
|
||||
'trace': logging.DEBUG,
|
||||
'debug': logging.DEBUG,
|
||||
'info': logging.INFO,
|
||||
'warn': logging.WARNING,
|
||||
'warning': logging.WARNING,
|
||||
'error': logging.ERROR
|
||||
}
|
||||
return level_mapping.get(srs_level.lower(), logging.INFO)
|
||||
|
||||
def trace(self, message: str, *args, **kwargs):
|
||||
"""Log trace message (mapped to DEBUG in Python)"""
|
||||
if self.logger:
|
||||
self.logger.debug(message, *args, **kwargs)
|
||||
|
||||
def debug(self, message: str, *args, **kwargs):
|
||||
"""Log debug message"""
|
||||
if self.logger:
|
||||
self.logger.debug(message, *args, **kwargs)
|
||||
|
||||
def info(self, message: str, *args, **kwargs):
|
||||
"""Log info message"""
|
||||
if self.logger:
|
||||
self.logger.info(message, *args, **kwargs)
|
||||
|
||||
def warn(self, message: str, *args, **kwargs):
|
||||
"""Log warning message"""
|
||||
if self.logger:
|
||||
self.logger.warning(message, *args, **kwargs)
|
||||
|
||||
def warning(self, message: str, *args, **kwargs):
|
||||
"""Log warning message"""
|
||||
if self.logger:
|
||||
self.logger.warning(message, *args, **kwargs)
|
||||
|
||||
def error(self, message: str, *args, **kwargs):
|
||||
"""Log error message"""
|
||||
if self.logger:
|
||||
self.logger.error(message, *args, **kwargs)
|
||||
|
||||
def critical(self, message: str, *args, **kwargs):
|
||||
"""Log critical message (mapped to ERROR in SRS)"""
|
||||
if self.logger:
|
||||
self.logger.critical(message, *args, **kwargs)
|
||||
|
||||
def exception(self, message: str, *args, **kwargs):
|
||||
"""Log exception with traceback"""
|
||||
if self.logger:
|
||||
self.logger.exception(message, *args, **kwargs)
|
||||
|
||||
# Global logger instance
|
||||
_global_logger: Optional[SRSLogger] = None
|
||||
|
||||
def get_logger(config_path: Optional[str] = None) -> SRSLogger:
|
||||
"""
|
||||
Get the global SRS logger instance
|
||||
|
||||
Args:
|
||||
config_path: Path to SRS configuration file
|
||||
|
||||
Returns:
|
||||
SRSLogger instance
|
||||
"""
|
||||
global _global_logger
|
||||
|
||||
if _global_logger is None:
|
||||
_global_logger = SRSLogger(config_path)
|
||||
|
||||
return _global_logger
|
||||
|
||||
def init_logger(config_path: Optional[str] = None) -> SRSLogger:
|
||||
"""
|
||||
Initialize the SRS logger with configuration
|
||||
|
||||
Args:
|
||||
config_path: Path to SRS configuration file
|
||||
|
||||
Returns:
|
||||
SRSLogger instance
|
||||
"""
|
||||
return get_logger(config_path)
|
||||
|
||||
# Convenience functions for direct logging
|
||||
def trace(message: str, *args, **kwargs):
|
||||
"""Log trace message"""
|
||||
get_logger().trace(message, *args, **kwargs)
|
||||
|
||||
def debug(message: str, *args, **kwargs):
|
||||
"""Log debug message"""
|
||||
get_logger().debug(message, *args, **kwargs)
|
||||
|
||||
def info(message: str, *args, **kwargs):
|
||||
"""Log info message"""
|
||||
get_logger().info(message, *args, **kwargs)
|
||||
|
||||
def warn(message: str, *args, **kwargs):
|
||||
"""Log warning message"""
|
||||
get_logger().warn(message, *args, **kwargs)
|
||||
|
||||
def warning(message: str, *args, **kwargs):
|
||||
"""Log warning message"""
|
||||
get_logger().warning(message, *args, **kwargs)
|
||||
|
||||
def error(message: str, *args, **kwargs):
|
||||
"""Log error message"""
|
||||
get_logger().error(message, *args, **kwargs)
|
||||
|
||||
def critical(message: str, *args, **kwargs):
|
||||
"""Log critical message"""
|
||||
get_logger().critical(message, *args, **kwargs)
|
||||
|
||||
def exception(message: str, *args, **kwargs):
|
||||
"""Log exception with traceback"""
|
||||
get_logger().exception(message, *args, **kwargs)
|
||||
|
||||
def log_error_and_exit(message: str, exit_code: int = 1):
|
||||
"""
|
||||
Log an error message and exit the process
|
||||
Used when Python process encounters fatal errors
|
||||
"""
|
||||
error(f"FATAL: {message}")
|
||||
sys.exit(exit_code)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the logger
|
||||
parser = argparse.ArgumentParser(description="SRS Python Logger Test")
|
||||
parser.add_argument("--config", help="Path to SRS config file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize logger
|
||||
logger = init_logger(args.config)
|
||||
|
||||
# Test different log levels
|
||||
logger.trace("This is a trace message")
|
||||
logger.debug("This is a debug message")
|
||||
logger.info("SRS Python logger test started")
|
||||
logger.warn("This is a warning message")
|
||||
logger.error("This is an error message")
|
||||
|
||||
# Test exception logging
|
||||
try:
|
||||
raise ValueError("Test exception")
|
||||
except Exception:
|
||||
logger.exception("Caught test exception")
|
||||
|
||||
logger.info("SRS Python logger test completed")
|
||||
146
trunk/python/srs_logger_example.py
Normal file
146
trunk/python/srs_logger_example.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Example usage of SRS Logger module
|
||||
This file demonstrates how to import and use the SRS logger in other Python files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# Import the SRS logger module
|
||||
from srs_logger import get_logger, init_logger, log_error_and_exit
|
||||
import srs_logger
|
||||
|
||||
def example_basic_usage():
|
||||
"""Basic usage example"""
|
||||
logger = get_logger()
|
||||
|
||||
logger.info("Starting basic usage example")
|
||||
logger.debug("This is a debug message")
|
||||
logger.warn("This is a warning message")
|
||||
logger.info("Basic usage example completed")
|
||||
|
||||
def example_with_config():
|
||||
"""Example with config file parsing"""
|
||||
# This would typically be passed via --config argument
|
||||
config_path = "./conf/console.conf"
|
||||
|
||||
logger = init_logger(config_path)
|
||||
logger.info(f"Logger initialized with config file: {config_path}")
|
||||
|
||||
# Demonstrate various log levels
|
||||
logger.trace("Trace level message")
|
||||
logger.debug("Debug level message")
|
||||
logger.info("Info level message")
|
||||
logger.warning("Warning level message")
|
||||
logger.error("Error level message")
|
||||
|
||||
def example_error_handling():
|
||||
"""Example of error handling and logging"""
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
# Simulate some work that might fail
|
||||
logger.info("Starting risky operation")
|
||||
|
||||
# Simulate an error condition
|
||||
if True: # This would be your actual error condition
|
||||
raise RuntimeError("Simulated error in Python process")
|
||||
|
||||
logger.info("Risky operation completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Log the error with full traceback
|
||||
logger.exception(f"Error in risky operation: {e}")
|
||||
|
||||
# For fatal errors, you can use log_error_and_exit
|
||||
# log_error_and_exit(f"Fatal error occurred: {e}")
|
||||
|
||||
# Or just log and continue
|
||||
logger.error(f"Continuing after error: {e}")
|
||||
|
||||
def example_convenience_functions():
|
||||
"""Example using convenience functions"""
|
||||
# These functions use the global logger instance
|
||||
srs_logger.info("Using convenience function for info")
|
||||
srs_logger.warn("Using convenience function for warning")
|
||||
srs_logger.error("Using convenience function for error")
|
||||
|
||||
def analytics_simulation():
|
||||
"""Simulate analytics processing with logging"""
|
||||
logger = get_logger()
|
||||
|
||||
logger.info("Analytics process started")
|
||||
|
||||
try:
|
||||
# Simulate analytics work
|
||||
for i in range(5):
|
||||
logger.debug(f"Processing analytics batch {i+1}/5")
|
||||
time.sleep(0.5) # Simulate processing time
|
||||
|
||||
if i == 3: # Simulate a warning condition
|
||||
logger.warn(f"High CPU usage detected during batch {i+1}")
|
||||
|
||||
logger.info("Analytics processing completed successfully")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warn("Analytics process interrupted by user")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f"Analytics process failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""Main function demonstrating different usage patterns"""
|
||||
parser = argparse.ArgumentParser(description="SRS Logger Usage Examples")
|
||||
parser.add_argument("--config", help="Path to SRS config file")
|
||||
parser.add_argument("--port", type=int, default=8888, help="Port number (example parameter)")
|
||||
parser.add_argument("--example", choices=['basic', 'config', 'error', 'convenience', 'analytics'],
|
||||
default='analytics', help="Example to run")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Initialize logger with config if provided
|
||||
if args.config:
|
||||
logger = init_logger(args.config)
|
||||
logger.info(f"SRS Logger example started with config: {args.config}")
|
||||
else:
|
||||
logger = get_logger()
|
||||
logger.info("SRS Logger example started without config file")
|
||||
|
||||
logger.info(f"Example parameters: port={args.port}, example={args.example}")
|
||||
|
||||
# Run the selected example
|
||||
if args.example == 'basic':
|
||||
example_basic_usage()
|
||||
elif args.example == 'config':
|
||||
example_with_config()
|
||||
elif args.example == 'error':
|
||||
example_error_handling()
|
||||
elif args.example == 'convenience':
|
||||
example_convenience_functions()
|
||||
elif args.example == 'analytics':
|
||||
success = analytics_simulation()
|
||||
if not success:
|
||||
log_error_and_exit("Analytics simulation failed")
|
||||
|
||||
logger.info("SRS Logger example completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Use the logger for any unhandled exceptions
|
||||
if 'logger' in locals():
|
||||
logger.exception(f"Unhandled exception in main: {e}")
|
||||
else:
|
||||
print(f"FATAL ERROR: {e}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
402
trunk/python/system_stats_module.py
Normal file
402
trunk/python/system_stats_module.py
Normal file
|
|
@ -0,0 +1,402 @@
|
|||
from database import Database
|
||||
from srs_logger import get_logger
|
||||
import requests
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# ============================================================================
|
||||
# System Statistics Management Module
|
||||
# ============================================================================
|
||||
|
||||
class SystemStatsManager:
|
||||
"""系统统计管理器"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
self.logger = get_logger()
|
||||
self.api_url = "http://python_stats:wMePq3ahpoLRzgsVg7BY9eE82uuJHT0YukD2ZE1JfMY2RjP4e6QnUaKg3V9x5s9M@localhost:1985/api/v1/summaries"
|
||||
self.previous_data: Optional[Dict[str, Any]] = None
|
||||
self.previous_timestamp: Optional[float] = None
|
||||
self.is_running = False
|
||||
self.poll_thread: Optional[threading.Thread] = None
|
||||
self.cleanup_thread: Optional[threading.Thread] = None
|
||||
|
||||
# 启动定期清理过期数据的线程
|
||||
self._start_cleanup_thread()
|
||||
|
||||
def _fetch_system_stats(self):
|
||||
"""获取系统统计数据并插入数据库"""
|
||||
max_retries = 5
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# 发送HTTP请求获取SRS统计数据
|
||||
response = requests.get(self.api_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 检查返回状态
|
||||
if data.get('code') != 0:
|
||||
self.logger.error(f"SRS API returned error code: {data.get('code')}")
|
||||
retry_count += 1
|
||||
time.sleep(2 ** retry_count) # 指数退避
|
||||
continue
|
||||
|
||||
# 提取有用的信息
|
||||
stats_data = self._extract_stats_data(data)
|
||||
if stats_data:
|
||||
# 插入数据库
|
||||
success = self.db.insert_system_stats(**stats_data)
|
||||
if success:
|
||||
self.logger.debug("Successfully inserted system stats")
|
||||
return True
|
||||
else:
|
||||
self.logger.error("Failed to insert system stats to database")
|
||||
retry_count += 1
|
||||
else:
|
||||
self.logger.error("Failed to extract stats data")
|
||||
retry_count += 1
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self.logger.error(f"HTTP request failed: {e}")
|
||||
retry_count += 1
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Unexpected error in fetch_system_stats: {e}")
|
||||
retry_count += 1
|
||||
|
||||
if retry_count < max_retries:
|
||||
time.sleep(2 ** retry_count) # 指数退避重试
|
||||
|
||||
self.logger.error(f"Failed to fetch system stats after {max_retries} retries")
|
||||
return False
|
||||
|
||||
def _extract_stats_data(self, api_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""从API返回数据中提取有用的统计信息"""
|
||||
try:
|
||||
data_section = api_data.get('data', {})
|
||||
self_section = data_section.get('self', {})
|
||||
system_section = data_section.get('system', {})
|
||||
|
||||
current_timestamp = data_section.get('now_ms', 0) / 1000.0 # 转换为秒
|
||||
|
||||
# SRS相关数据
|
||||
srs_uptime = self_section.get('srs_uptime', 0.0)
|
||||
srs_cpu_percent = self_section.get('cpu_percent', 0.0)
|
||||
srs_memory_percent = self_section.get('mem_percent', 0.0)
|
||||
|
||||
# 网络流量数据(需要计算KBps)
|
||||
srs_recv_bytes = system_section.get('srs_recv_bytes', 0)
|
||||
srs_send_bytes = system_section.get('srs_send_bytes', 0)
|
||||
srs_sample_time = system_section.get('srs_sample_time', 0) / 1000.0 # 转换为秒
|
||||
|
||||
# 计算KBps(需要与上一次的数据比较)
|
||||
srs_recv_KBps = 0.0
|
||||
srs_send_KBps = 0.0
|
||||
|
||||
if (self.previous_data and
|
||||
self.previous_timestamp and
|
||||
current_timestamp > self.previous_timestamp):
|
||||
|
||||
time_diff = current_timestamp - self.previous_timestamp
|
||||
prev_recv = self.previous_data.get('srs_recv_bytes', 0)
|
||||
prev_send = self.previous_data.get('srs_send_bytes', 0)
|
||||
|
||||
if time_diff > 0:
|
||||
# 计算字节差值并转换为KBps
|
||||
srs_recv_KBps = max(0, (srs_recv_bytes - prev_recv) / time_diff / 1024)
|
||||
srs_send_KBps = max(0, (srs_send_bytes - prev_send) / time_diff / 1024)
|
||||
|
||||
# print(srs_recv_KBps, srs_send_KBps) # DEBUG
|
||||
|
||||
# 磁盘IO数据
|
||||
disk_read_KBps = system_section.get('disk_read_KBps', 0.0)
|
||||
disk_write_KBps = system_section.get('disk_write_KBps', 0.0)
|
||||
|
||||
# 操作系统相关数据
|
||||
os_uptime = system_section.get('uptime', 0.0)
|
||||
os_cpu_percent = system_section.get('cpu_percent', 0.0)
|
||||
os_memory_percent = system_section.get('mem_ram_percent', 0.0)
|
||||
|
||||
# 更新上一次的数据
|
||||
self.previous_data = {
|
||||
'srs_recv_bytes': srs_recv_bytes,
|
||||
'srs_send_bytes': srs_send_bytes,
|
||||
'srs_sample_time': srs_sample_time
|
||||
}
|
||||
self.previous_timestamp = current_timestamp
|
||||
|
||||
return {
|
||||
'srs_uptime': srs_uptime,
|
||||
'srs_cpu_percent': srs_cpu_percent,
|
||||
'srs_memory_percent': srs_memory_percent,
|
||||
'srs_recv_KBps': srs_recv_KBps,
|
||||
'srs_send_KBps': srs_send_KBps,
|
||||
'disk_read_KBps': disk_read_KBps,
|
||||
'disk_write_KBps': disk_write_KBps,
|
||||
'os_uptime': os_uptime,
|
||||
'os_cpu_percent': os_cpu_percent,
|
||||
'os_memory_percent': os_memory_percent
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Failed to extract stats data: {e}")
|
||||
return None
|
||||
|
||||
def start_polling(self, interval: int = 3):
|
||||
"""启动长轮询统计数据收集"""
|
||||
if self.is_running:
|
||||
self.logger.warn("Polling is already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.poll_thread = threading.Thread(target=self._polling_loop, args=(interval,))
|
||||
self.poll_thread.daemon = True
|
||||
self.poll_thread.start()
|
||||
self.logger.info(f"Started system stats polling with {interval}s interval")
|
||||
|
||||
def stop_polling(self):
|
||||
"""停止长轮询"""
|
||||
if not self.is_running:
|
||||
self.logger.warn("Polling is not running")
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
if self.poll_thread:
|
||||
self.poll_thread.join(timeout=5)
|
||||
self.logger.info("Stopped system stats polling")
|
||||
|
||||
def _polling_loop(self, interval: int):
|
||||
"""轮询循环"""
|
||||
while self.is_running:
|
||||
try:
|
||||
self._fetch_system_stats()
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error in polling loop: {e}")
|
||||
|
||||
# 等待指定间隔,但允许提前停止
|
||||
for _ in range(interval):
|
||||
if not self.is_running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
def _start_cleanup_thread(self):
|
||||
"""启动清理过期数据的线程"""
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop)
|
||||
self.cleanup_thread.daemon = True
|
||||
self.cleanup_thread.start()
|
||||
self.logger.info("Started system stats cleanup thread (5 minute interval)")
|
||||
|
||||
def _cleanup_loop(self):
|
||||
"""定期清理过期数据的循环"""
|
||||
while True:
|
||||
try:
|
||||
# 执行清理
|
||||
success = self.db.delete_expired_system_stats()
|
||||
if success:
|
||||
self.logger.debug("Successfully cleaned up expired system stats")
|
||||
else:
|
||||
self.logger.warn("Failed to clean up expired system stats")
|
||||
|
||||
time.sleep(300)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error in cleanup loop: {e}")
|
||||
# 发生错误后等待一分钟再继续
|
||||
time.sleep(60)
|
||||
|
||||
def _process_stats_data(self, raw_stats: list, time_delta: int, time_interval: int) -> list:
|
||||
"""
|
||||
处理原始统计数据,按指定时间间隔进行插值和过滤
|
||||
|
||||
Args:
|
||||
raw_stats: 原始统计数据列表
|
||||
time_delta: 时间范围(秒)
|
||||
time_interval: 时间间隔(秒)
|
||||
|
||||
Returns:
|
||||
处理后的统计数据列表
|
||||
"""
|
||||
try:
|
||||
if not raw_stats:
|
||||
return []
|
||||
|
||||
# 参数验证
|
||||
if time_interval <= 0:
|
||||
self.logger.warn("Invalid time_interval, using 1 second")
|
||||
time_interval = 1
|
||||
if time_delta <= 0:
|
||||
self.logger.warn("Invalid time_delta, using 60 seconds")
|
||||
time_delta = 60
|
||||
|
||||
# 转换时间戳并排序
|
||||
for stat in raw_stats:
|
||||
if isinstance(stat['timestamp'], str):
|
||||
stat['timestamp'] = datetime.fromisoformat(stat['timestamp'])
|
||||
elif not isinstance(stat['timestamp'], datetime):
|
||||
# 如果是其他格式,尝试解析
|
||||
stat['timestamp'] = datetime.fromisoformat(str(stat['timestamp']))
|
||||
|
||||
# 按时间戳排序
|
||||
raw_stats.sort(key=lambda x: x['timestamp'])
|
||||
|
||||
# 确定时间范围
|
||||
now = datetime.now().replace(microsecond=0) # 取整到秒
|
||||
start_time = now - timedelta(seconds=time_delta)
|
||||
|
||||
# 过滤掉超出时间范围的数据
|
||||
filtered_stats = [stat for stat in raw_stats if stat['timestamp'] >= start_time]
|
||||
|
||||
if not filtered_stats:
|
||||
self.logger.warn("No data points in specified time range")
|
||||
return []
|
||||
|
||||
# 生成目标时间点列表(从最近一次记录向前)
|
||||
target_times = []
|
||||
current_time = now
|
||||
while current_time >= start_time:
|
||||
target_times.append(current_time)
|
||||
current_time -= timedelta(seconds=time_interval)
|
||||
|
||||
# 反转列表,使其按时间顺序排列
|
||||
target_times.reverse()
|
||||
|
||||
# 对每个目标时间点进行插值
|
||||
processed_stats = []
|
||||
numeric_fields = [
|
||||
'srs_uptime', 'srs_cpu_percent', 'srs_memory_percent',
|
||||
'srs_recv_KBps', 'srs_send_KBps', 'disk_read_KBps',
|
||||
'disk_write_KBps', 'os_uptime', 'os_cpu_percent', 'os_memory_percent'
|
||||
]
|
||||
|
||||
for target_time in target_times:
|
||||
interpolated_data = self._interpolate_data_point(filtered_stats, target_time, numeric_fields)
|
||||
if interpolated_data:
|
||||
processed_stats.append(interpolated_data)
|
||||
|
||||
self.logger.debug(f"Processed {len(raw_stats)} raw points into {len(processed_stats)} interpolated points")
|
||||
return processed_stats
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error processing stats data: {e}")
|
||||
return raw_stats # 如果处理失败,返回原始数据
|
||||
|
||||
def _interpolate_data_point(self, raw_stats: list, target_time: datetime, numeric_fields: list) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
为指定时间点插值数据
|
||||
|
||||
Args:
|
||||
raw_stats: 原始统计数据列表(已排序)
|
||||
target_time: 目标时间点
|
||||
numeric_fields: 需要插值的数值字段列表
|
||||
|
||||
Returns:
|
||||
插值后的数据点
|
||||
"""
|
||||
try:
|
||||
# 查找最接近的前后两个数据点
|
||||
before_point = None
|
||||
after_point = None
|
||||
|
||||
for i, stat in enumerate(raw_stats):
|
||||
stat_time = stat['timestamp']
|
||||
|
||||
if stat_time <= target_time:
|
||||
before_point = stat
|
||||
elif stat_time > target_time:
|
||||
after_point = stat
|
||||
break
|
||||
|
||||
# 如果没有找到合适的数据点,使用最近的点
|
||||
if before_point is None and after_point is None:
|
||||
return None
|
||||
elif before_point is None:
|
||||
# 只有后面的点,直接使用
|
||||
result = after_point.copy()
|
||||
result['timestamp'] = target_time.replace(microsecond=0)
|
||||
return result
|
||||
elif after_point is None:
|
||||
# 只有前面的点,直接使用
|
||||
result = before_point.copy()
|
||||
result['timestamp'] = target_time.replace(microsecond=0)
|
||||
return result
|
||||
|
||||
# 计算时间差和权重
|
||||
before_time = before_point['timestamp']
|
||||
after_time = after_point['timestamp']
|
||||
|
||||
# 如果时间点完全匹配,直接返回
|
||||
if before_time == target_time:
|
||||
return before_point.copy()
|
||||
if after_time == target_time:
|
||||
return after_point.copy()
|
||||
|
||||
# 计算插值权重
|
||||
total_diff = (after_time - before_time).total_seconds()
|
||||
if total_diff == 0:
|
||||
# 两个点时间相同,使用前一个点
|
||||
result = before_point.copy()
|
||||
result['timestamp'] = target_time.replace(microsecond=0)
|
||||
return result
|
||||
|
||||
target_diff = (target_time - before_time).total_seconds()
|
||||
weight = target_diff / total_diff
|
||||
|
||||
# 执行线性插值
|
||||
result = {'timestamp': target_time.replace(microsecond=0)} # 确保时间戳取整到秒
|
||||
|
||||
for field in numeric_fields:
|
||||
if field in before_point and field in after_point:
|
||||
before_val = float(before_point[field]) if before_point[field] is not None else 0.0
|
||||
after_val = float(after_point[field]) if after_point[field] is not None else 0.0
|
||||
|
||||
# 线性插值
|
||||
interpolated_val = before_val + (after_val - before_val) * weight
|
||||
result[field] = round(interpolated_val, 4) # 保留4位小数
|
||||
else:
|
||||
# 如果字段不存在,使用前一个点的值
|
||||
result[field] = before_point.get(field, 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error interpolating data point: {e}")
|
||||
return None
|
||||
|
||||
def get_recent_stats(self, user_group: list, time_delta: int = 5, time_interval: int = 1):
|
||||
"""获取最近的统计数据"""
|
||||
try:
|
||||
if not any(group in user_group for group in ['streamer', 'manager', 'admin']):
|
||||
return {"success": False, "message": "权限不足,无法获取统计数据"}
|
||||
|
||||
# 获取原始数据,多获取一些以便插值
|
||||
raw_stats = self.db.get_system_stats(time_delta + 5) # 多获取60秒数据用于插值
|
||||
if not raw_stats:
|
||||
return {"success": False, "message": "没有找到相关的统计数据"}
|
||||
|
||||
# 处理数据,按指定间隔进行插值和过滤
|
||||
processed_stats = self._process_stats_data(raw_stats, time_delta, time_interval)
|
||||
|
||||
# 转换时间戳为字符串格式,便于JSON序列化
|
||||
for stat in processed_stats:
|
||||
if isinstance(stat['timestamp'], datetime):
|
||||
stat['timestamp'] = stat['timestamp'].isoformat()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"system_stats": processed_stats,
|
||||
"metadata": {
|
||||
"time_delta": time_delta,
|
||||
"time_interval": time_interval,
|
||||
"data_points": len(processed_stats),
|
||||
"original_points": len(raw_stats)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error in get_recent_stats: {e}")
|
||||
return {"success": False, "message": "服务器内部错误"}
|
||||
0
trunk/python/test_stats_processing.py
Normal file
0
trunk/python/test_stats_processing.py
Normal file
|
|
@ -1,52 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys, shlex, time, subprocess
|
||||
|
||||
cmd = "./python.subprocess 0 80000"
|
||||
args = shlex.split(str(cmd))
|
||||
print "cmd: %s, args: %s"%(cmd, args)
|
||||
|
||||
# the communicate will read all data and wait for sub process to quit.
|
||||
def use_communicate(args, fout, ferr):
|
||||
process = subprocess.Popen(args, stdout=fout, stderr=ferr)
|
||||
(stdout_str, stderr_str) = process.communicate()
|
||||
return (stdout_str, stderr_str)
|
||||
|
||||
# if use subprocess.PIPE, the pipe will full about 50KB data,
|
||||
# and sub process will blocked, then timeout will kill it.
|
||||
def use_poll(args, fout, ferr, timeout):
|
||||
(stdout_str, stderr_str) = (None, None)
|
||||
process = subprocess.Popen(args, stdout=fout, stderr=ferr)
|
||||
starttime = time.time()
|
||||
while True:
|
||||
process.poll()
|
||||
if process.returncode is not None:
|
||||
(stdout_str, stderr_str) = process.communicate()
|
||||
break
|
||||
if timeout > 0 and time.time() - starttime >= timeout:
|
||||
print "timeout, kill process. timeout=%s"%(timeout)
|
||||
process.kill()
|
||||
break
|
||||
time.sleep(1)
|
||||
process.wait()
|
||||
return (stdout_str, stderr_str)
|
||||
|
||||
# stdout/stderr can be fd, fileobject, subprocess.PIPE, None
|
||||
fnull = open("/dev/null", "rw")
|
||||
fout = fnull.fileno()#subprocess.PIPE#fnull#fnull.fileno()
|
||||
ferr = fnull.fileno()#subprocess.PIPE#fnull#fnull.fileno()
|
||||
print "fout=%s, ferr=%s"%(fout, ferr)
|
||||
#(stdout_str, stderr_str) = use_communicate(args, fout, ferr)
|
||||
(stdout_str, stderr_str) = use_poll(args, fout, ferr, 10)
|
||||
|
||||
def print_result(stdout_str, stderr_str):
|
||||
if stdout_str is None:
|
||||
stdout_str = ""
|
||||
if stderr_str is None:
|
||||
stderr_str = ""
|
||||
print "terminated, size of stdout=%s, stderr=%s"%(len(stdout_str), len(stderr_str))
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
print_result(stdout_str, stderr_str)
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
/**
|
||||
# always to print to stdout and stderr.
|
||||
g++ python.subprocess.cpp -o python.subprocess
|
||||
*/
|
||||
int main(int argc, char** argv) {
|
||||
if (argc <= 2) {
|
||||
printf("Usage: <%s> <interval_ms> <max_loop>\n"
|
||||
" %s 50 100000\n", argv[0], argv[0]);
|
||||
exit(-1);
|
||||
return -1;
|
||||
}
|
||||
|
||||
int interval_ms = ::atoi(argv[1]);
|
||||
int max_loop = ::atoi(argv[2]);
|
||||
printf("always to print to stdout and stderr.\n");
|
||||
printf("interval: %d ms\n", interval_ms);
|
||||
printf("max_loop: %d\n", max_loop);
|
||||
|
||||
for (int i = 0; i < max_loop; i++) {
|
||||
fprintf(stdout, "always to print to stdout and stderr. interval=%dms, max=%d, current=%d\n", interval_ms, max_loop, i);
|
||||
fprintf(stderr, "always to print to stdout and stderr. interval=%dms, max=%d, current=%d\n", interval_ms, max_loop, i);
|
||||
if (interval_ms > 0) {
|
||||
usleep(interval_ms * 1000);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user