diff --git a/trunk/conf/console.conf b/trunk/conf/console.conf index 7a4eb3108..20ca385f9 100644 --- a/trunk/conf/console.conf +++ b/trunk/conf/console.conf @@ -1,37 +1,186 @@ -# no-daemon and write log to console config for srs. -# @see full.conf for detail config. +# docker config for srs. +# @see full.conf for detail config explanation. +############################################################################################# +# Last modified: 2025-05-10 Jason Yang +# Contributor: SRS Team, Jason Yang, Jasper, HyperKNF, Hayden, NPL ITP Team +############################################################################################# -listen 1935; -max_connections 1000; -daemon off; -srs_log_tank all; +############################################################################################# +# Global sections +############################################################################################# +ff_log_dir ./objs; +ff_log_level info; + +srs_log_tank all; +srs_log_file ./objs/srs.log; +# TRACE, DEBUG, INFO, WARN, ERROR +srs_log_level_v2 info; + +max_connections 1000; +daemon off; +utc_time off; + +############################################################################################# +# RTMP sections +############################################################################################# +listen 1935; +chunk_size 60000; + +############################################################################################# +# HTTP sections +############################################################################################# http_api { - enabled on; - listen 1985; + enabled on; + listen 1985; + crossdomain on; + auth { + enabled on; + username python_stats; + password wMePq3ahpoLRzgsVg7BY9eE82uuJHT0YukD2ZE1JfMY2RjP4e6QnUaKg3V9x5s9M; + } + https { + enabled off; + listen 1990; + key ./conf/server.key; + cert ./conf/server.crt; + } } http_server { - enabled on; - listen 8080; + enabled on; + listen 8080; + dir ./objs/nginx/html; + crossdomain on; + https { + enabled off; + listen 8088; + key ./conf/server.key; + cert ./conf/server.crt; + } } + +############################################################################################# +# SRT server section +############################################################################################# +srt_server { + # whether SRT server is enabled. + enabled off; + listen 10080; + + maxbw 1000000000; + mss 1500; + + connect_timeout 4000; + peer_idle_timeout 8000; + + default_app live; + peerlatency 0; + recvlatency 0; + latency 0; + + tsbpdmode off; + tlpktdrop off; + sendbuf 2000000; + recvbuf 2000000; +} + +############################################################################################# +# WebRTC server section +############################################################################################# rtc_server { - enabled on; - listen 8000; # UDP port - # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate - candidate $CANDIDATE; + enabled on; + listen 8000; + + protocol udp; + tcp { + enabled off; + listen 8000; + } + candidate $CANDIDATE; + ip_family ipv4; + api_as_candidates on; + resolve_api_domain on; + + ecdsa on; + encrypt on; } + +############################################################################################# +# PYTHON_ADDONS sections +############################################################################################# +python_addons { + # Enable or disable Python addon management + enabled on; + + # Python addon definitions + # Each addon block defines a Python script to run + addon { + script "./python/httpbackend_server.py"; + } +} + +############################################################################################# +# VHOST sections +############################################################################################# vhost __defaultVhost__ { + enabled on; + hls { - enabled on; + enabled on; } + srt { + # Whether enable SRT on this vhost. + enabled on; + srt_to_rtmp on; + } + http_remux { - enabled on; - mount [vhost]/[app]/[stream].flv; + enabled on; + mount [vhost]/[app]/[stream].flv; } + rtc { - enabled on; + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc - rtmp_to_rtc on; + rtmp_to_rtc on; + keep_bframe on; + # opus_bitrate 48000; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp - rtc_to_rtmp on; + rtc_to_rtmp on; + pli_for_rtmp 6.0; + aac_bitrate 48000; } -} + + chunk_size 128; + tcp_nodelay on; + min_latency on; + play { + gop_cache off; + queue_length 10; + mw_latency 100; + mw_msgs 0; + } + publish { + mr off; + mr_latency 300; + firstpkt_timeout 20000; + normal_timeout 5000; + } + + security { + enabled off; + allow play all; + allow publish all; + } + + dvr { + enabled on; + dvr_path ./DVR_Record/[stream]/[2006].[01].[02].[15].[04].[05].mp4; + dvr_plan segment; + dvr_duration 30; + dvr_apply /^live\/.*/; + dvr_wait_keyframe on; + time_jitter full; + } +} \ No newline at end of file diff --git a/trunk/livestream_site b/trunk/livestream_site new file mode 160000 index 000000000..348b347bd --- /dev/null +++ b/trunk/livestream_site @@ -0,0 +1 @@ +Subproject commit 348b347bd7e1cf9d59acff0ba1bdf5390f0751d7 diff --git a/trunk/python/analytics.py b/trunk/python/analytics.py deleted file mode 100644 index 7b36439b9..000000000 --- a/trunk/python/analytics.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python3 -""" -SRS Analytics Script -This demonstrates another Python process that can run alongside SRS. -""" - -import time -import signal -import sys -import argparse -import logging -import json -from http.server import HTTPServer, BaseHTTPRequestHandler -from threading import Thread -from datetime import datetime - -class AnalyticsHandler(BaseHTTPRequestHandler): - def do_GET(self): - """Handle GET requests for analytics data""" - if self.path == '/stats': - # Return sample analytics data - stats = { - 'timestamp': datetime.now().isoformat(), - 'connections': 42, - 'streams': 5, - 'bandwidth': '1.2 Mbps', - 'uptime': '2h 30m' - } - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps(stats, indent=2).encode()) - else: - self.send_response(404) - self.end_headers() - - def log_message(self, format, *args): - """Override to use our logger""" - pass - -class SRSAnalytics: - def __init__(self, port=8888): - self.port = port - self.running = True - self.server = None - self.server_thread = None - - # Set up logging - logging.basicConfig( - level=logging.INFO, - format='[%(asctime)s] [Python Analytics] %(levelname)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - self.logger = logging.getLogger(__name__) - - # Set up signal handlers - signal.signal(signal.SIGTERM, self.signal_handler) - signal.signal(signal.SIGINT, self.signal_handler) - - def signal_handler(self, signum, frame): - """Handle shutdown signals from SRS""" - self.logger.info(f"Received signal {signum}, shutting down gracefully...") - self.running = False - if self.server: - self.server.shutdown() - - def start_http_server(self): - """Start the HTTP analytics server""" - try: - self.server = HTTPServer(('localhost', self.port), AnalyticsHandler) - self.logger.info(f"Analytics server started on http://localhost:{self.port}") - self.server.serve_forever() - except Exception as e: - self.logger.error(f"Error starting HTTP server: {e}") - - def run(self): - """Main analytics loop""" - self.logger.info(f"SRS Analytics started on port {self.port}") - - try: - # Start HTTP server in a separate thread - self.server_thread = Thread(target=self.start_http_server) - self.server_thread.daemon = True - self.server_thread.start() - - # Main analytics loop - while self.running: - # Simulate analytics work - self.logger.debug("Processing analytics data...") - - # You can add your analytics logic here: - # - Collect stream metrics - # - Process viewer statistics - # - Generate reports - # - Store data to database - - time.sleep(10) # Process every 10 seconds - - except Exception as e: - self.logger.error(f"Error in analytics loop: {e}") - finally: - self.cleanup() - - def cleanup(self): - """Cleanup before shutdown""" - self.logger.info("Cleaning up Analytics server...") - if self.server: - self.server.shutdown() - self.logger.info("Analytics server stopped") - -def main(): - parser = argparse.ArgumentParser(description='SRS Analytics Server') - parser.add_argument('--port', type=int, default=8888, help='HTTP server port') - - args = parser.parse_args() - - analytics = SRSAnalytics(args.port) - analytics.run() - -if __name__ == '__main__': - main() diff --git a/trunk/python/avatar_module.py b/trunk/python/avatar_module.py new file mode 100644 index 000000000..fb97a8c14 --- /dev/null +++ b/trunk/python/avatar_module.py @@ -0,0 +1,616 @@ +import math +from typing import Optional, List + +class AvatarGenerator: + DEFAULT_COLORS = ["#000000", "#8f1414", "#e50e0e", "#f3450f", "#fcac03"] + DEFAULT_SIZE = 80 + DEFAULT_VARIANT = 'marble' + VALID_VARIANTS = {'marble', 'beam', 'pixel', 'sunset', 'ring', 'bauhaus'} + + def __init__(self, default_colors: Optional[List[str]] = None, default_size: Optional[int] = None, default_variant: Optional[str] = None): + self.colors = default_colors if default_colors else self.DEFAULT_COLORS + self.size = default_size if default_size else self.DEFAULT_SIZE + self.variant = default_variant if default_variant else self.DEFAULT_VARIANT + + self.AVATAR_GENERATORS = { + 'marble': self._generate_marble_avatar, + 'beam': self._generate_beam_avatar, + 'pixel': self._generate_pixel_avatar, + 'sunset': self._generate_sunset_avatar, + 'ring': self._generate_ring_avatar, + 'bauhaus': self._generate_bauhaus_avatar, + } + + @staticmethod + def _hash_code(name: str) -> int: + """Generate hash from name string - exact port of hashCode function""" + try: + if not name or not isinstance(name, str): + # Consider raising a TypeError or returning a default hash for invalid input + return 12345 # Fallback for invalid input + hash_val = 0 + for char in name: + hash_val = (hash_val << 5) - hash_val + ord(char) + hash_val &= hash_val # Convert to 32bit integer + return abs(hash_val) + except Exception: + return 12345 # fallback hash value + + @staticmethod + def _get_modulus(num: int, max_val: int) -> int: + """Get modulus""" + return num % max_val + + @staticmethod + def _get_digit(number: int, ntn: int) -> int: + """Get digit at position""" + try: + if not isinstance(number, int) or not isinstance(ntn, int): + return 0 # Fallback for invalid input type + return int((number / (10 ** ntn)) % 10) + except (ZeroDivisionError, ValueError, OverflowError): + return 0 + + @staticmethod + def _get_boolean(number: int, ntn: int) -> bool: + """Get boolean from digit""" + return not ((AvatarGenerator._get_digit(number, ntn)) % 2) + + @staticmethod + def _get_angle(x: float, y: float) -> float: + """Get angle from coordinates""" + return math.atan2(y, x) * 180 / math.pi + + @staticmethod + def _get_unit(number: int, range_val: int, index: int = None) -> int: + """Get unit value with optional negative""" + try: + if not isinstance(number, int) or not isinstance(range_val, int) or range_val == 0: + return 0 # Fallback for invalid input + value = number % range_val + if index and ((AvatarGenerator._get_digit(number, index) % 2) == 0): + return -value + return value + except (ZeroDivisionError, ValueError, TypeError): + return 0 + + @staticmethod + def _get_random_color(number: int, colors: List[str], range_val: int) -> str: + """Get random color from palette""" + try: + if not colors or range_val == 0: + return AvatarGenerator.DEFAULT_COLORS[0] # Fallback + return colors[number % range_val] + except (IndexError, TypeError, ZeroDivisionError): + return AvatarGenerator.DEFAULT_COLORS[0] + + @staticmethod + def _get_contrast(hexcolor: str) -> str: + """Get contrasting color (black or white)""" + try: + if not hexcolor or not isinstance(hexcolor, str): + return '#000000' # Fallback + + # Remove leading # if present + if hexcolor.startswith('#'): + hexcolor = hexcolor[1:] + + # Clean up any remaining quotes or whitespace + hexcolor = hexcolor.strip().replace('"', '').replace("'", "") + + # Ensure we have exactly 6 hex characters + if len(hexcolor) != 6: + return '#000000' # Fallback + + # Validate hex characters + try: + int(hexcolor, 16) + except ValueError: + return '#000000' # Fallback + + # Convert to RGB + r = int(hexcolor[0:2], 16) + g = int(hexcolor[2:4], 16) + b = int(hexcolor[4:6], 16) + + # Get YIQ ratio + yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000 + + # Check contrast + return '#000000' if yiq >= 128 else '#FFFFFF' + except Exception: + return '#000000' + + @staticmethod + def normalize_colors(colors_param: Optional[str]) -> List[str]: + """Normalize color palette from query parameter""" + try: + if not colors_param: + return AvatarGenerator.DEFAULT_COLORS + + # Clean up URL encoding + cleaned = colors_param.replace('%22', '"').replace('%20', ' ').replace('%5B', '[').replace('%5D', ']').strip() + + # Extract colors from different formats + color_list = AvatarGenerator._extract_color_strings(cleaned) + + # Validate and normalize colors + return AvatarGenerator._validate_color_list(color_list) + + except Exception: + return AvatarGenerator.DEFAULT_COLORS + + @staticmethod + def _extract_color_strings(cleaned_colors: str) -> List[str]: + """Extract color strings from cleaned input""" + # Handle array format: ["#xxxxxx", "#xxxxxx", ...] + if cleaned_colors.startswith('[') and cleaned_colors.endswith(']'): + inner_content = cleaned_colors[1:-1].strip() + if not inner_content: + return [] + return [part.strip().replace('"', '').replace("'", '') for part in inner_content.split(',')] + + # Handle comma-separated format: #xxxxxx, #xxxxxx, ... + return [part.strip().replace('"', '').replace("'", '') for part in cleaned_colors.split(',')] + + @staticmethod + def _validate_color_list(color_list: List[str]) -> List[str]: + """Validate and normalize hex colors""" + normalized_colors = [] + + for color in color_list: + color = color.strip() + if not color: + continue + if not color.startswith('#'): + color = '#' + color + if AvatarGenerator._is_valid_hex_color(color): + normalized_colors.append(color) + + return normalized_colors if normalized_colors else AvatarGenerator.DEFAULT_COLORS + + @staticmethod + def _is_valid_hex_color(color: str) -> bool: + """Check if string is valid hex color""" + return len(color) == 7 and all(c in '0123456789ABCDEFabcdef' for c in color[1:]) + + def _generate_marble_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str: + """Generate marble variant avatar""" + ELEMENTS = 3 + SIZE = 80 # Intrinsic size of the SVG design + + def generate_colors_marble(name: str, current_colors: List[str]): + num_from_name = self._hash_code(name) + range_val = len(current_colors) if current_colors else len(self.DEFAULT_COLORS) + + elements_properties = [] + for i in range(ELEMENTS): + elements_properties.append({ + 'color': self._get_random_color(num_from_name + i, current_colors, range_val), + 'translateX': self._get_unit(num_from_name * (i + 1), SIZE // 10, 1), + 'translateY': self._get_unit(num_from_name * (i + 1), SIZE // 10, 2), + 'scale': 1.2 + self._get_unit(num_from_name * (i + 1), SIZE // 20) / 10, + 'rotate': self._get_unit(num_from_name * (i + 1), 360, 1) + }) + + return elements_properties + + properties = generate_colors_marble(name, colors) + mask_id = f"mask_{self._hash_code(name)}" + filter_id = f"filter_{self._hash_code(name)}" + + # Use the passed 'size' for width/height, internal 'SIZE' for viewBox + svg = f'''''' + + if title: + svg += f'{name}' + + svg += f''' + + + + + + + + + + + + + + + +''' + + return svg + + def _generate_beam_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str: + """Generate beam variant avatar""" + SIZE = 36 # Intrinsic size of the SVG design + + def generate_data_beam(name: str, current_colors: List[str]): + num_from_name = self._hash_code(name) + range_val = len(current_colors) if current_colors else len(self.DEFAULT_COLORS) + wrapper_color = self._get_random_color(num_from_name, current_colors, range_val) + pre_translate_x = self._get_unit(num_from_name, 10, 1) + wrapper_translate_x = pre_translate_x + SIZE // 9 if pre_translate_x < 5 else pre_translate_x + pre_translate_y = self._get_unit(num_from_name, 10, 2) + wrapper_translate_y = pre_translate_y + SIZE // 9 if pre_translate_y < 5 else pre_translate_y + + data = { + 'wrapperColor': wrapper_color, + 'faceColor': self._get_contrast(wrapper_color), + 'backgroundColor': self._get_random_color(num_from_name + 13, current_colors, range_val), + 'wrapperTranslateX': wrapper_translate_x, + 'wrapperTranslateY': wrapper_translate_y, + 'wrapperRotate': self._get_unit(num_from_name, 360), + 'wrapperScale': 1 + self._get_unit(num_from_name, SIZE // 12) / 10, + 'isMouthOpen': self._get_boolean(num_from_name, 2), + 'isCircle': self._get_boolean(num_from_name, 1), + 'eyeSpread': self._get_unit(num_from_name, 5), + 'mouthSpread': self._get_unit(num_from_name, 3), + 'faceRotate': self._get_unit(num_from_name, 10, 3), + 'faceTranslateX': wrapper_translate_x // 2 if wrapper_translate_x > SIZE // 6 else self._get_unit(num_from_name, 8, 1), + 'faceTranslateY': wrapper_translate_y // 2 if wrapper_translate_y > SIZE // 6 else self._get_unit(num_from_name, 7, 2), + } + + return data + + data = generate_data_beam(name, colors) + mask_id = f"mask_{self._hash_code(name)}" + + svg = f'''''' + + if title: + svg += f'{name}' + + rx_value = SIZE if data['isCircle'] else SIZE // 6 + + svg += f''' + + + + + + + ''' + + if data['isMouthOpen']: + svg += f''' + ''' + else: + svg += f''' + ''' + + svg += f''' + + + + +''' + + return svg + + def _generate_pixel_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str: + """Generate pixel variant avatar""" + ELEMENTS = 64 # 8x8 grid + SIZE = 80 # Intrinsic SVG size + + def generate_colors_pixel(name: str, current_colors: List[str]): + num_from_name = self._hash_code(name) + # Use current_colors if provided, otherwise fallback to instance default, then class default + final_colors = current_colors if current_colors else self.colors + range_val = len(final_colors) + + color_list = [] + for i in range(ELEMENTS): + color_list.append(self._get_random_color(num_from_name + i, final_colors, range_val)) + return color_list + + pixel_colors = generate_colors_pixel(name, colors) + mask_id = f"mask_{self._hash_code(name)}" + + svg = f'''''' + + if title: + svg += f'{name}' + + svg += f''' + + + + ''' + + # Generate 8x8 grid of pixels + idx = 0 + pixel_size = SIZE // 8 # Each pixel is 10x10 if SIZE is 80 + for row in range(8): + for col in range(8): + svg += f''' + ''' + idx +=1 + + svg += ''' + +''' + + return svg + + def _generate_sunset_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str: + """Generate sunset variant avatar""" + ELEMENTS = 4 # For 4 gradient stops + SIZE = 80 # Intrinsic SVG size + + def generate_colors_sunset(name: str, current_colors: List[str]): + num_from_name = self._hash_code(name) + # Use current_colors if provided, otherwise fallback to instance default, then class default + final_colors = current_colors if current_colors else self.colors + # Ensure we have at least 4 colors for sunset, repeat if necessary + if len(final_colors) < ELEMENTS: + final_colors = (final_colors * (ELEMENTS // len(final_colors) + 1))[:ELEMENTS] + + range_val = len(final_colors) + + color_list = [] + for i in range(ELEMENTS): + color_list.append(self._get_random_color(num_from_name + i, final_colors, range_val)) + return color_list + + sunset_colors = generate_colors_sunset(name, colors) + # name_without_space = name.replace(' ', '').replace('-', '').replace('_', '') # Not used + hash_val = abs(self._hash_code(name)) + mask_id = f"mask_{hash_val}" + gradient_id_1 = f"gradient_paint0_linear_{hash_val}" + gradient_id_2 = f"gradient_paint1_linear_{hash_val}" + + rx_value = '' if square else f'rx="{SIZE * 2}"' + + svg = f'''''' + + if title: + svg += f'{name}' + + svg += f''' + + + + + + + + + + + + + + + + + +''' + + return svg + + def _generate_ring_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str: + """Generate ring variant avatar""" + SIZE = 90 # Intrinsic SVG size + COLORS_NEEDED = 9 # Number of colors used in the SVG paths + + def generate_colors_ring(name: str, current_colors: List[str]): + num_from_name = self._hash_code(name) + # Use current_colors if provided, otherwise fallback to instance default, then class default + final_colors = current_colors if current_colors else self.colors + # Ensure we have enough colors, repeat if necessary + if len(final_colors) < COLORS_NEEDED: + final_colors = (final_colors * (COLORS_NEEDED // len(final_colors) + 1))[:COLORS_NEEDED] + + range_val = len(final_colors) + + ring_palette = [] + for i in range(COLORS_NEEDED): + ring_palette.append(self._get_random_color(num_from_name + i, final_colors, range_val)) + return ring_palette + + ring_colors = generate_colors_ring(name, colors) + mask_id = f"mask_{self._hash_code(name)}" + + svg = f'''''' + + if title: + svg += f'{name}' + + svg += f''' + + + + + + + + + + + + + + +''' + + return svg + + def _generate_bauhaus_avatar(self, name: str, colors: List[str], size: int, square: bool, title: bool) -> str: + """Generate bauhaus variant avatar""" + ELEMENTS = 4 # Number of geometric elements + SIZE = 80 # Intrinsic SVG size + + def generate_colors_bauhaus(name: str, current_colors: List[str]): + num_from_name = self._hash_code(name) + # Use current_colors if provided, otherwise fallback to instance default, then class default + final_colors = current_colors if current_colors else self.colors + range_val = len(final_colors) + + properties = [] + for i in range(ELEMENTS): + properties.append({ + 'color': self._get_random_color(num_from_name + i, final_colors, range_val), + 'translateX': self._get_unit(num_from_name * (i + 1), SIZE // 2 - (i + 17), 1), + 'translateY': self._get_unit(num_from_name * (i + 1), SIZE // 2 - (i + 17), 2), + 'rotate': self._get_unit(num_from_name * (i + 1), 360), + 'isSquare': self._get_boolean(num_from_name, 2) + }) + return properties + + properties = generate_colors_bauhaus(name, colors) + mask_id = f"mask_{self._hash_code(name)}" + + svg = f'''''' + + if title: + svg += f'{name}' + + svg += f''' + + + + + + + + + +''' + + return svg + + def generate_avatar(self, + name: str, + variant: Optional[str] = None, + colors: Optional[List[str]] = None, # Allow passing colors as list of strings + colors_param: Optional[str] = None, # Allow passing colors as a string parameter + size: Optional[int] = None, + square: bool = False, + title: bool = True) -> str: + """ + Generate an SVG avatar. + + Args: + name: The name to generate the avatar for. + variant: The avatar variant (e.g., 'marble', 'beam'). Uses instance default if None. + colors: A list of hex color strings. Overrides colors_param if provided. + colors_param: A string representation of colors (e.g., "['#FF0000', '#00FF00']" or "#FF0000,#00FF00"). + Used if 'colors' list is not provided. + size: The desired size of the avatar in pixels. Uses instance default if None. + square: If True, generates a square avatar. Otherwise, a circular one. + title: If True, includes a tag with the name in the SVG. + + Returns: + An SVG string representing the avatar. + """ + current_variant = variant if variant and variant in self.VALID_VARIANTS else self.variant + current_size = size if size else self.size + + # Determine colors: prioritize direct list, then string param, then instance default + if colors: + # Validate and normalize if a list is directly passed + final_colors = self._validate_color_list(colors) + elif colors_param: + final_colors = self.normalize_colors(colors_param) + else: + final_colors = self.colors # Use instance default colors + + if not final_colors: # Ensure there's always a fallback + final_colors = self.DEFAULT_COLORS + + + generator_func = self.AVATAR_GENERATORS.get(current_variant) + + if generator_func: + return generator_func(name, final_colors, current_size, square, title) + else: + # Fallback to default variant if the chosen one is somehow invalid after checks + default_gen_func = self.AVATAR_GENERATORS.get(self.DEFAULT_VARIANT) + return default_gen_func(name, final_colors, current_size, square, title) diff --git a/trunk/python/database.py b/trunk/python/database.py new file mode 100644 index 000000000..d8b6f79b6 --- /dev/null +++ b/trunk/python/database.py @@ -0,0 +1,890 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Database module for JWT authentication system +Provides SQLite database initialization and management for user authentication +""" + +import sqlite3 +import os +import json +import re +import threading +import uuid +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any, Union +from contextlib import contextmanager + +import random +import string + +# Import SRS logger +from srs_logger import get_logger + +class Database: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, db_path: str = "./objs/srs_database.db"): + """Singleton pattern to ensure only one database instance""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, db_path: str = "./objs/srs_database.db"): + if self._initialized: + return + + self._initialized = True + self.db_path = db_path + self.logger = get_logger() + self._connection_lock = threading.Lock() + + # Initialize database + self._init_database() + + def _init_database(self): + """Initialize database and create tables if they don't exist""" + try: + # Create objs directory if it doesn't exist + db_dir = os.path.dirname(self.db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + self.logger.info(f"Created database directory: {db_dir}") + + # Check if database exists + db_exists = os.path.exists(self.db_path) + if not db_exists: + self.logger.warn(f"Database file not found, creating new database: {self.db_path}") + else: + self.logger.info(f"Using existing database: {self.db_path}") + + # Create tables + self._create_tables() + + if not db_exists: + self.logger.info("Database initialized successfully with all tables created") + else: + self.logger.info("Database connection established and tables verified") + + except Exception as e: + self.logger.exception(f"Failed to initialize database: {e}") + raise + + def _create_tables(self): + """Create all required tables""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + + # Create users table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + user_id TEXT PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + name TEXT NOT NULL, + email TEXT UNIQUE, + hashed_password TEXT NOT NULL, + salt TEXT NOT NULL, + user_group TEXT NOT NULL DEFAULT '["user"]', + last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_activated BOOLEAN NOT NULL DEFAULT 1 + ) + ''') + + # Create auth_session table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS auth_session ( + session_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + hashed_authkey TEXT NOT NULL, + salt TEXT NOT NULL, + expire_time TIMESTAMP NOT NULL, + FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE + ) + ''') + + # Create refresh_session table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS refresh_session ( + session_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + hashed_refreshkey TEXT NOT NULL, + salt TEXT NOT NULL, + expire_time TIMESTAMP NOT NULL, + FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE + ) + ''') + + # Create streams table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS streams ( + stream_id TEXT PRIMARY KEY, + stream_code TEXT UNIQUE NOT NULL, + streamer_id TEXT NOT NULL, + stream_title TEXT NOT NULL, + stream_description TEXT, + stream_tags TEXT DEFAULT '[]', + stream_visibility TEXT NOT NULL DEFAULT 'public', + quality_info TEXT, + active_time TIMESTAMP, + stream_status TEXT NOT NULL DEFAULT 'planned', + FOREIGN KEY (streamer_id) REFERENCES users (user_id) ON DELETE CASCADE + ) + ''') + + # Create system_stats table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS system_stats ( + timestamp TIMESTAMP PRIMARY KEY, + srs_uptime REAL NOT NULL, + srs_cpu_percent REAL NOT NULL, + srs_memory_percent REAL NOT NULL, + srs_recv_KBps REAL NOT NULL, + srs_send_KBps REAL NOT NULL, + disk_read_KBps REAL NOT NULL, + disk_write_KBps REAL NOT NULL, + os_uptime REAL NOT NULL, + os_cpu_percent REAL NOT NULL, + os_memory_percent REAL NOT NULL + ) + ''') + + # Create indexes for better performance + cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_user_id ON auth_session(user_id)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_expire ON auth_session(expire_time)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_user_id ON refresh_session(user_id)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_expire ON refresh_session(expire_time)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_system_stats_timestamp ON system_stats(timestamp)') + + conn.commit() + self.logger.debug("All database tables and indexes created/verified successfully") + + except Exception as e: + self.logger.exception(f"Failed to create database tables: {e}") + raise + + @contextmanager + def get_connection(self): + """Get a database connection with proper error handling and cleanup""" + conn = None + try: + conn = sqlite3.connect(self.db_path, timeout=30.0) + conn.row_factory = sqlite3.Row # Enable column access by name + conn.execute('PRAGMA foreign_keys = ON') # Enable foreign key constraints + # Configure datetime adapter for Python 3.12+ compatibility + sqlite3.register_adapter(datetime, lambda dt: dt.isoformat()) + sqlite3.register_converter("TIMESTAMP", lambda b: datetime.fromisoformat(b.decode())) + yield conn + except Exception as e: + if conn: + conn.rollback() + self.logger.error(f"Database connection error: {e}") + raise + finally: + if conn: + conn.close() + + # ============================================================================ + # Users management functions + # ============================================================================ + + def cleanup_expired_sessions(self): + """Remove expired auth and refresh sessions, and sessions for deactivated users""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + current_time = datetime.now() + + # Clean up expired auth sessions + cursor.execute(''' + DELETE FROM auth_session + WHERE expire_time < ? + ''', (current_time,)) + auth_expired_deleted = cursor.rowcount + + # Clean up expired refresh sessions + cursor.execute(''' + DELETE FROM refresh_session + WHERE expire_time < ? + ''', (current_time,)) + refresh_expired_deleted = cursor.rowcount + + # Clean up sessions for deactivated users + cursor.execute(''' + DELETE FROM auth_session + WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0) + ''') + auth_deactivated_deleted = cursor.rowcount + + cursor.execute(''' + DELETE FROM refresh_session + WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0) + ''') + refresh_deactivated_deleted = cursor.rowcount + + conn.commit() + + total_auth_deleted = auth_expired_deleted + auth_deactivated_deleted + total_refresh_deleted = refresh_expired_deleted + refresh_deactivated_deleted + + if total_auth_deleted > 0 or total_refresh_deleted > 0: + self.logger.info(f"Cleaned up sessions: {total_auth_deleted} auth ({auth_expired_deleted} expired, {auth_deactivated_deleted} deactivated), {total_refresh_deleted} refresh ({refresh_expired_deleted} expired, {refresh_deactivated_deleted} deactivated)") + + except Exception as e: + self.logger.exception(f"Failed to cleanup expired sessions: {e}") + + def get_user(self, username_or_email: str) -> Optional[Dict[str, Any]]: + """Get user by username or email""" + if not username_or_email: + self.logger.warn("get_user called with None or empty value") + return None + + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT user_id, username, name, email, hashed_password, salt, + user_group, last_active, is_activated + FROM users WHERE username = ? OR email = ? + ''', (username_or_email, username_or_email)) + + row = cursor.fetchone() + if row: + user_data = dict(row) + user_data['user_group'] = json.loads(user_data['user_group']) + return user_data + else: + return None + + except Exception as e: + self.logger.exception(f"Failed to get user by username/email '{username_or_email}': {e}") + return None + + def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Get user by user_id""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT user_id, username, name, email, hashed_password, salt, + user_group, last_active, is_activated + FROM users WHERE user_id = ? + ''', (user_id,)) + + row = cursor.fetchone() + if row: + user_data = dict(row) + user_data['user_group'] = json.loads(user_data['user_group']) + return user_data + else: + return None + + except Exception as e: + self.logger.exception(f"Failed to get user by user_id {user_id}: {e}") + return None + + def create_user(self, username: str, name: str, email: Optional[str], + hashed_password: str, salt: str, user_group: List[str] = None, + is_activated: bool = True) -> Optional[Dict[str, Any]]: + """Create a new user and return user data""" + try: + # Validate username format + if not re.match(r'^[a-zA-Z0-9_-]+$', username): + self.logger.warn(f"Invalid username format: '{username}'. Only a-z, A-Z, 0-9, '_', '-' are allowed") + return None + + # Validate email format if provided + if email and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): + self.logger.warn(f"Invalid email format: '{email}'") + return None + + # Ensure user_group is not empty + if user_group is None or len(user_group) == 0: + user_group = ["user"] + + # Generate unique user_id (UUID collisions are extremely rare, but we'll still check) + user_id = str(uuid.uuid4()) + + user_group_json = json.dumps(user_group) + + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO users (user_id, username, name, email, hashed_password, salt, user_group, is_activated) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ''', (user_id, username, name, email, hashed_password, salt, user_group_json, is_activated)) + + conn.commit() + + self.logger.info(f"Created new user: {username} (ID: {user_id}, activated: {is_activated})") + + # Return the created user data directly to avoid deadlock + return { + 'user_id': user_id, + 'username': username, + 'name': name, + 'email': email, + 'hashed_password': hashed_password, + 'salt': salt, + 'user_group': user_group, # Already parsed list + 'last_active': datetime.now().isoformat(), + 'is_activated': is_activated + } + + except sqlite3.IntegrityError as e: + if 'username' in str(e): + self.logger.warn(f"Username '{username}' already exists") + elif 'email' in str(e): + self.logger.warn(f"Email '{email}' already exists") + else: + self.logger.warn(f"User creation failed due to constraint: {e}") + return None + except Exception as e: + self.logger.exception(f"Failed to create user '{username}': {e}") + return None + + def update_user_last_active(self, user_id: str) -> Optional[Dict[str, Any]]: + """Update user's last active timestamp and return updated user data""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + current_time = datetime.now() + cursor.execute(''' + UPDATE users + SET last_active = ? + WHERE user_id = ? + ''', (current_time, user_id)) + conn.commit() + + self.logger.debug(f"Updated last_active for user_id: {user_id}") + # Get the updated user data in the same connection to avoid deadlock + cursor.execute(''' + SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated + FROM users + WHERE user_id = ? + ''', (user_id,)) + + row = cursor.fetchone() + if row: + user_data = dict(row) + user_data['user_group'] = json.loads(user_data['user_group']) + return user_data + else: + self.logger.warn(f"No user found with user_id: {user_id}") + return None + + except Exception as e: + self.logger.exception(f"Failed to update last_active for user_id {user_id}: {e}") + return None + + def update_user_activation(self, user_id: str, activation: bool) -> Optional[Dict[str, Any]]: + """Update user's activation status and cleanup sessions if deactivated, return updated user data""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + UPDATE users + SET is_activated = ? + WHERE user_id = ? + ''', (activation, user_id)) + conn.commit() + + self.logger.info(f"Updated activation status for user_id {user_id}: {activation}") + + # Get the updated user data in the same connection to avoid deadlock + cursor.execute(''' + SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated + FROM users + WHERE user_id = ? + ''', (user_id,)) + + row = cursor.fetchone() + if row: + # Clean up sessions after updating activation status (outside the transaction) + if not activation: self.cleanup_expired_sessions() + + user_data = dict(row) + user_data['user_group'] = json.loads(user_data['user_group']) + return user_data + else: + self.logger.warn(f"No user found with user_id: {user_id}") + return None + + except Exception as e: + self.logger.exception(f"Failed to update activation for user_id {user_id}: {e}") + return None + + def update_user_password(self, user_id: str, hashed_password: str, salt: str) -> Optional[Dict[str, Any]]: + """Update user's password and salt, return updated user data""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + UPDATE users + SET hashed_password = ?, salt = ? + WHERE user_id = ? + ''', (hashed_password, salt, user_id)) + conn.commit() + + if cursor.rowcount == 0: + self.logger.warn(f"No user found with user_id: {user_id}") + return None + + self.logger.info(f"Updated password for user_id: {user_id}") + + # Get the updated user data in the same connection to avoid deadlock + cursor.execute(''' + SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated + FROM users + WHERE user_id = ? + ''', (user_id,)) + + row = cursor.fetchone() + if row: + user_data = dict(row) + user_data['user_group'] = json.loads(user_data['user_group']) + return user_data + else: + self.logger.warn(f"No user found with user_id: {user_id}") + return None + + except Exception as e: + self.logger.exception(f"Failed to update password for user_id {user_id}: {e}") + return None + + def delete_user(self, user_id: str) -> bool: + """Delete user and all associated data from all tables""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + + # Check if user exists first + cursor.execute('SELECT username FROM users WHERE user_id = ?', (user_id,)) + user_row = cursor.fetchone() + if not user_row: + self.logger.warn(f"No user found with user_id: {user_id}") + return False + + username = user_row[0] + + # Delete auth sessions (foreign key constraints will handle this automatically, but we'll be explicit) + cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,)) + auth_deleted = cursor.rowcount + + # Delete refresh sessions + cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,)) + refresh_deleted = cursor.rowcount + + # Delete user + cursor.execute('DELETE FROM users WHERE user_id = ?', (user_id,)) + user_deleted = cursor.rowcount + + conn.commit() + + if user_deleted > 0: + self.logger.info(f"Deleted user '{username}' (ID: {user_id}) and all associated data: {auth_deleted} auth sessions, {refresh_deleted} refresh sessions") + return True + else: + self.logger.warn(f"Failed to delete user with user_id: {user_id}") + return False + + except Exception as e: + self.logger.exception(f"Failed to delete user with user_id {user_id}: {e}") + return False + + # ============================================================================ + # Session management functions + # ============================================================================ + + def create_auth_session(self, user_id: str, hashed_authkey: str, salt: str, + expire_time: datetime) -> Optional[Dict[str, Any]]: + """Create an auth session and return session data""" + try: + # Generate unique session_id (UUID collisions are extremely rare) + session_id = str(uuid.uuid4()) + + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO auth_session (session_id, user_id, hashed_authkey, salt, expire_time) + VALUES (?, ?, ?, ?, ?) + ''', (session_id, user_id, hashed_authkey, salt, expire_time)) + + conn.commit() + + self.logger.debug(f"Created auth session for user_id {user_id}, session_id: {session_id}") + + # Return the created session data + return { + 'session_id': session_id, + 'user_id': user_id, + 'hashed_authkey': hashed_authkey, + 'salt': salt, + 'expire_time': expire_time.isoformat() + } + + except Exception as e: + self.logger.exception(f"Failed to create auth session for user_id {user_id}: {e}") + return None + + def create_refresh_session(self, user_id: str, hashed_refreshkey: str, salt: str, + expire_time: datetime) -> Optional[Dict[str, Any]]: + """Create a refresh session and return session data""" + try: + # Generate unique session_id (UUID collisions are extremely rare) + session_id = str(uuid.uuid4()) + + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO refresh_session (session_id, user_id, hashed_refreshkey, salt, expire_time) + VALUES (?, ?, ?, ?, ?) + ''', (session_id, user_id, hashed_refreshkey, salt, expire_time)) + + conn.commit() + + self.logger.debug(f"Created refresh session for user_id {user_id}, session_id: {session_id}") + + # Return the created session data + return { + 'session_id': session_id, + 'user_id': user_id, + 'hashed_refreshkey': hashed_refreshkey, + 'salt': salt, + 'expire_time': expire_time.isoformat() + } + + except Exception as e: + self.logger.exception(f"Failed to create refresh session for user_id {user_id}: {e}") + return None + + def get_auth_session(self, session_id: str) -> Optional[Dict[str, Any]]: + """Get valid auth session by session_id""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT session_id, user_id, hashed_authkey, salt, expire_time + FROM auth_session + WHERE session_id = ? AND expire_time > ? + ''', (session_id, datetime.now())) + + row = cursor.fetchone() + if row: + return dict(row) + return None + + except Exception as e: + self.logger.exception(f"Failed to get auth session for session_id {session_id}: {e}") + return None + + def get_refresh_session(self, session_id: str) -> Optional[Dict[str, Any]]: + """Get valid refresh session by session_id""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT session_id, user_id, hashed_refreshkey, salt, expire_time + FROM refresh_session + WHERE session_id = ? AND expire_time > ? + ''', (session_id, datetime.now())) + + row = cursor.fetchone() + if row: + return dict(row) + return None + + except Exception as e: + self.logger.exception(f"Failed to get refresh session for session_id {session_id}: {e}") + return None + + def delete_auth_session(self, session_id: str) -> bool: + """Delete an auth session""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM auth_session WHERE session_id = ?', (session_id,)) + conn.commit() + + if cursor.rowcount > 0: + self.logger.debug(f"Deleted auth session_id: {session_id}") + return True + return False + + except Exception as e: + self.logger.exception(f"Failed to delete auth session {session_id}: {e}") + return False + + def delete_refresh_session(self, session_id: str) -> bool: + """Delete a refresh session""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM refresh_session WHERE session_id = ?', (session_id,)) + conn.commit() + + if cursor.rowcount > 0: + self.logger.debug(f"Deleted refresh session_id: {session_id}") + return True + return False + + except Exception as e: + self.logger.exception(f"Failed to delete refresh session {session_id}: {e}") + return False + + def delete_user_sessions(self, user_id: str) -> bool: + """Delete all sessions for a user (logout)""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + # Delete auth sessions + cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,)) + auth_deleted = cursor.rowcount + + # Delete refresh sessions + cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,)) + refresh_deleted = cursor.rowcount + conn.commit() + + self.logger.info(f"Deleted all sessions for user_id {user_id}: {auth_deleted} auth, {refresh_deleted} refresh") + return True + + except Exception as e: + self.logger.exception(f"Failed to delete sessions for user_id {user_id}: {e}") + return False + + # ============================================================================ + # System statistics functions + # ============================================================================ + + def insert_system_stats(self, + srs_uptime: float, srs_cpu_percent: float, + srs_memory_percent: float, srs_recv_KBps: float, + srs_send_KBps: float, disk_read_KBps: float, + disk_write_KBps: float, os_uptime: float, + os_cpu_percent: float, os_memory_percent: float) -> bool: + """Insert system statistics into the database""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + timestamp = datetime.now() + + cursor.execute(''' + INSERT INTO system_stats (timestamp, + srs_uptime, srs_cpu_percent, srs_memory_percent, + srs_recv_KBps, srs_send_KBps, + disk_read_KBps, disk_write_KBps, + os_uptime, os_cpu_percent, os_memory_percent) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', (timestamp, + srs_uptime, srs_cpu_percent, srs_memory_percent, + srs_recv_KBps, srs_send_KBps, + disk_read_KBps, disk_write_KBps, + os_uptime, os_cpu_percent, os_memory_percent)) + + conn.commit() + self.logger.debug(f"Inserted system stats at {timestamp.isoformat()}") + return True + except Exception as e: + self.logger.exception(f"Failed to insert system stats: {e}") + return False + + def get_system_stats(self, time_delta: int) -> List[Dict[str, Any]]: + """Get system statistics from the database""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT * FROM system_stats WHERE timestamp >= ? + ''', (datetime.now() - timedelta(seconds=time_delta),)) + rows = cursor.fetchall() + columns = [column[0] for column in cursor.description] + return [dict(zip(columns, row)) for row in rows] + except Exception as e: + self.logger.exception(f"Failed to get system stats: {e}") + return [] + + def delete_expired_system_stats(self) -> bool: + """Delete expired system statistics older than 20 minutes""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + DELETE FROM system_stats WHERE timestamp < ? + ''', (datetime.now() - timedelta(minutes=20),)) + conn.commit() + + if cursor.rowcount > 0: + self.logger.info(f"Deleted {cursor.rowcount} expired system stats") + return True + else: + self.logger.debug("No expired system stats to delete") + return False + + except Exception as e: + self.logger.exception(f"Failed to delete expired system stats: {e}") + return False + + # ============================================================================ + # Streams management functions + # ============================================================================ + + def create_stream(self, streamer_id: str, stream_title: str, stream_description: Optional[str] = None, stream_tags: Optional[List[str]] = [], stream_visibility: str = 'public') -> Dict[str, Any]: + """Create a new stream and return stream data""" + try: + # Validate streamer_id exists + if not self.get_user_by_id(streamer_id): + self.logger.warn(f"Streamer with user_id {streamer_id} does not exist") + return None + + # Generate unique stream_id (UUID collisions are extremely rare) + stream_id = str(uuid.uuid4()) + # Generate stream_code in form of xxx-xxxx (only include letters and numbers) + def random_code(): + part1 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=3)) + part2 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)) + return f"{part1}-{part2}" + stream_code = random_code() + + # Convert tags to JSON string + stream_tags_json = json.dumps(stream_tags) + + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO streams (stream_id, stream_code, streamer_id, stream_title, + stream_description, stream_tags, + stream_visibility) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', (stream_id, stream_code, streamer_id, stream_title, + stream_description, stream_tags_json, + stream_visibility)) + + conn.commit() + + self.logger.info(f"Created new stream: {stream_title} (ID: {stream_id}, Streamer ID: {streamer_id})") + + return { + 'stream_id': stream_id, + 'stream_code': stream_code, + 'streamer_id': streamer_id, + 'stream_title': stream_title, + 'stream_description': stream_description, + 'stream_tags': stream_tags, + 'stream_visibility': stream_visibility, + 'stream_status': 'planned', # Default status + } + + except sqlite3.IntegrityError as e: + self.logger.warn(f"Stream creation failed due to constraint: {e}") + return None + except Exception as e: + self.logger.exception(f"Failed to create stream: {e}") + return None + + def get_stream_by_vis(self, stream_visibility: str = 'public') -> Optional[List[Dict[str, Any]]]: + """Get streams by visibility""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT * FROM streams + WHERE stream_visibility = ? AND stream_status = 'streaming' + ''', (stream_visibility,)) + + rows = cursor.fetchall() + columns = [column[0] for column in cursor.description] + return [dict(zip(columns, row)) for row in rows] + + except Exception as e: + self.logger.exception(f"Failed to get streams by visibility '{stream_visibility}': {e}") + return None + + # ============================================================================ + # Default database statistics functions + # ============================================================================ + + def get_database_stats(self) -> Dict[str, int]: + """Get database statistics""" + try: + with self.get_connection() as conn: + cursor = conn.cursor() + + stats = {} + + # Count users + cursor.execute('SELECT COUNT(*) FROM users') + stats['users'] = cursor.fetchone()[0] + + # Count active auth sessions + cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time > ?', (datetime.now(),)) + stats['active_auth_sessions'] = cursor.fetchone()[0] + + # Count active refresh sessions + cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time > ?', (datetime.now(),)) + stats['active_refresh_sessions'] = cursor.fetchone()[0] + + # Count expired auth sessions + cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time <= ?', (datetime.now(),)) + stats['expired_auth_sessions'] = cursor.fetchone()[0] + + # Count expired refresh sessions + cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time <= ?', (datetime.now(),)) + stats['expired_refresh_sessions'] = cursor.fetchone()[0] + + return stats + + except Exception as e: + self.logger.exception(f"Failed to get database statistics: {e}") + return {} + +# Global database instance +_global_db: Optional[Database] = None + +def get_database(db_path: str = "./objs/srs_database.db") -> Database: + """ + Get the global database instance + + Args: + db_path: Path to SQLite database file + + Returns: + Database instance + """ + global _global_db + + if _global_db is None: + # run_comprehensive_tests() + _global_db = Database(db_path) + + return _global_db + +if __name__ == "__main__": + # Comprehensive test suite for database functionality + import argparse + import uuid + from datetime import datetime, timedelta + from srs_logger import init_logger + + logger = init_logger() + + # Main execution + parser = argparse.ArgumentParser(description="Database Module Test") + parser.add_argument("--db-path", default="./objs/srs_database.db", help="Database file path") + args = parser.parse_args() + + db = get_database(args.db_path) + + # Show database statistics + stats = db.get_database_stats() + logger.info(f"Database statistics: {stats}") + + # Clean up expired sessions + db.cleanup_expired_sessions() \ No newline at end of file diff --git a/trunk/python/http_addon.py b/trunk/python/http_addon.py deleted file mode 100644 index b57c18818..000000000 --- a/trunk/python/http_addon.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -SRS Python Addon - Simple HTTP Server -This addon demonstrates how to create a simple HTTP server as an SRS addon. -""" - -import sys -import signal -import time -import logging -import argparse -import threading -from http.server import HTTPServer, BaseHTTPRequestHandler - -# Set up logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger('srs_http_addon') - -class SRSAddonHandler(BaseHTTPRequestHandler): - """Simple HTTP request handler for SRS addon.""" - - def do_GET(self): - """Handle GET requests.""" - if self.path == '/': - self.send_response(200) - self.send_header('Content-type', 'text/html') - self.end_headers() - response = ''' - <html> - <body> - <h1>SRS Python Addon - HTTP Server</h1> - <p>This is a simple HTTP server running as an SRS addon.</p> - <p>Status: Running</p> - <p>Time: {}</p> - </body> - </html> - '''.format(time.strftime('%Y-%m-%d %H:%M:%S')) - self.wfile.write(response.encode()) - elif self.path == '/status': - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - response = '{"status": "running", "time": "%s"}' % time.strftime('%Y-%m-%d %H:%M:%S') - self.wfile.write(response.encode()) - else: - self.send_response(404) - self.end_headers() - - def log_message(self, format, *args): - """Override to use our logger.""" - logger.info("%s - %s" % (self.address_string(), format % args)) - -class SRSHTTPAddon: - """SRS HTTP Addon main class.""" - - def __init__(self, port=8888): - self.port = port - self.server = None - self.running = False - self.server_thread = None - - def start(self): - """Start the HTTP server.""" - try: - self.server = HTTPServer(('', self.port), SRSAddonHandler) - self.running = True - - # Start server in a separate thread - self.server_thread = threading.Thread(target=self._run_server) - self.server_thread.daemon = True - self.server_thread.start() - - logger.info(f"SRS HTTP addon started on port {self.port}") - return True - except Exception as e: - logger.error(f"Failed to start HTTP addon: {e}") - return False - - def stop(self): - """Stop the HTTP server.""" - if self.server and self.running: - self.running = False - self.server.shutdown() - self.server.server_close() - if self.server_thread: - self.server_thread.join(timeout=5) - logger.info("SRS HTTP addon stopped") - - def _run_server(self): - """Run the HTTP server.""" - try: - self.server.serve_forever() - except Exception as e: - if self.running: - logger.error(f"HTTP server error: {e}") - -def signal_handler(signum, frame): - """Handle termination signals.""" - logger.info(f"Received signal {signum}, shutting down...") - if 'addon' in globals(): - addon.stop() - sys.exit(0) - -def main(): - """Main function.""" - parser = argparse.ArgumentParser(description='SRS HTTP Addon') - parser.add_argument('--port', type=int, default=8888, help='HTTP server port') - parser.add_argument('--verbose', action='store_true', help='Enable verbose logging') - - args = parser.parse_args() - - if args.verbose: - logger.setLevel(logging.DEBUG) - - # Set up signal handlers - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - - # Create and start the addon - global addon - addon = SRSHTTPAddon(args.port) - - if addon.start(): - logger.info("SRS HTTP addon is running, press Ctrl+C to stop") - try: - while addon.running: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Interrupted by user") - finally: - addon.stop() - else: - logger.error("Failed to start SRS HTTP addon") - sys.exit(1) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/trunk/python/httpbackend_server.py b/trunk/python/httpbackend_server.py new file mode 100644 index 000000000..27849e318 --- /dev/null +++ b/trunk/python/httpbackend_server.py @@ -0,0 +1,845 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +FastAPI Authentication API Server +Provides secure authentication endpoints with token-based authentication +""" +from datetime import datetime +from typing import Optional, Dict, Any, List +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, Depends, Cookie, Response, Query +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel, EmailStr, Field, field_validator +import uvicorn + +# Import our database and logger +from database import get_database, Database +from srs_logger import get_logger + +# Import Function Class +from security_module import AuthManager +from system_stats_module import SystemStatsManager +from avatar_module import AvatarGenerator + +# ============================================================================ +# Pydantic Models for Request/Response +# ============================================================================ + +class RegisterRequest(BaseModel): + """用户注册请求模型""" + username: str = Field(..., min_length=1, max_length=50, + description="用户名,只能包含字母、数字、下划线和连字符") + name: str = Field(..., min_length=1, max_length=100, description="用户真实姓名") + email: Optional[EmailStr] = Field(None, description="用户邮箱(可选)") + password: str = Field(..., min_length=6, max_length=128, description="用户密码") + + @field_validator('username') + def validate_username(cls, v): + if not v.replace('_', '').replace('-', '').isalnum(): + raise ValueError('用户名只能包含字母、数字、下划线和连字符') + return v + +class LoginRequest(BaseModel): + """用户登录请求模型""" + username_or_email: str = Field(..., min_length=1, description="用户名或邮箱") + password: str = Field(..., min_length=1, description="用户密码") + remember_me: Optional[bool] = Field(True, description="是否记住登录状态(默认为False)") + +class RefreshRequest(BaseModel): + """令牌刷新请求模型""" + auth_key_session_id: Optional[str] = Field(None, description="认证令牌会话ID(可选,用于删除旧的认证会话)") + +class UpdatePasswordRequest(BaseModel): + """更新密码请求模型""" + original_password: str = Field(..., min_length=1, description="原密码") + password: str = Field(..., min_length=6, max_length=128, description="新密码") + +class UserIDRequest(BaseModel): + """用户ID请求模型""" + user_id: str = Field(..., min_length=1, description="传入的用户ID") + +class AuthResponse(BaseModel): + """认证响应模型""" + success: bool = Field(..., description="操作是否成功") + message: str = Field(..., description="响应消息") + auth_key: Optional[str] = Field(None, description="认证令牌") + auth_key_session_id: Optional[str] = Field(None, description="认证令牌会话ID") + refresh_key: Optional[str] = Field(None, description="刷新令牌") + refresh_key_session_id: Optional[str] = Field(None, description="刷新令牌会话ID") + +class SimpleResponse(BaseModel): + """简单响应模型""" + success: bool = Field(..., description="操作是否成功") + message: str = Field(..., description="响应消息") + +class SystemStatsResponse(BaseModel): + """系统统计响应模型""" + system_stats: List[Dict[str, Any]] = Field(..., description="系统统计信息列表,每个元素包含时间戳和统计数据") + +# ============================================================================ +# FastAPI Application Setup +# ============================================================================ + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用程序生命周期管理""" + logger = get_logger() + logger.info("Starting FastAPI Authentication Server...") + + # 初始化数据库 + db = get_database() + logger.info("Database initialized") + + # 创建认证管理器 + auth_manager = AuthManager(db) + system_stats_manager = SystemStatsManager(db) + avatar_generator = AvatarGenerator() + + app.state.auth_manager = auth_manager + logger.info("Auth manager initialized") + system_stats_manager.start_polling() # 启动系统统计轮询 + app.state.system_stats_manager = system_stats_manager + logger.info("System stats manager initialized") + app.state.avatar_generator = avatar_generator + logger.info("Avatar generator initialized") + app.state.db = db + logger.info("Database connection established") + + yield + + logger.info("Shutting down FastAPI Authentication Server...") + +# 创建FastAPI应用 +app = FastAPI( + title="认证API服务器", + description="基于FastAPI的安全认证系统,提供用户注册、登录、令牌刷新和登出功能", + version="1.0.0", + docs_url="/api/docs", + redoc_url="/api/redoc", + lifespan=lifespan +) + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 在生产环境中应该限制允许的源 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 获取认证管理器 +def get_auth_manager() -> AuthManager: + return app.state.auth_manager + +def get_system_stats_manager() -> SystemStatsManager: + return app.state.system_stats_manager + +def get_avatar_generator() -> AvatarGenerator: + return app.state.avatar_generator + +def get_db() -> Database: + return app.state.db + +# HTTP Bearer认证方案 +security = HTTPBearer() + +# 认证依赖项 +async def verify_auth_token( + credentials: HTTPAuthorizationCredentials = Depends(security), + auth_manager: AuthManager = Depends(get_auth_manager) +) -> Dict[str, Any]: + """验证Bearer令牌并返回认证会话信息""" + try: + # Bearer token格式: "auth_key:auth_key_session_id" + token_parts = credentials.credentials.split(':') + if len(token_parts) != 2: + raise HTTPException( + status_code=401, + detail="无效的Bearer令牌格式,应为 'auth_key:auth_key_session_id'" + ) + + auth_key, auth_key_session_id = token_parts + + # 获取认证会话 + auth_manager.db.cleanup_expired_sessions() # 清理过期会话 + auth_session = auth_manager.db.get_auth_session(auth_key_session_id) + if not auth_session: + raise HTTPException( + status_code=401, + detail="认证会话无效或已过期" + ) + + # 验证认证令牌 + if not auth_manager.verify_token(auth_key, auth_session['salt'], auth_session['hashed_authkey']): + raise HTTPException( + status_code=401, + detail="认证令牌无效" + ) + + # 获取用户信息以获取用户组 + user_data = auth_manager.db.get_user_by_id(auth_session['user_id']) + auth_session['user_group'] = user_data.get('user_group', ['user']) + + # 更新用户最后活跃时间 + auth_manager.db.update_user_last_active(auth_session['user_id']) + + return auth_session + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Token verification error: {e}") + raise HTTPException( + status_code=401, + detail="令牌验证失败" + ) + +# ============================================================================ +# Cookie Configuration +# ============================================================================ + +def set_auth_cookies(response: Response, tokens: Dict[str, str]) -> None: + """设置安全的认证cookie""" + response.set_cookie( + key="refresh_key", + value=tokens['refresh_key'], + httponly=True, + secure=False, + samesite="strict", + max_age=604800 # 7天 + ) + + response.set_cookie( + key="refresh_key_session_id", + value=tokens['refresh_key_session_id'], + httponly=True, + secure=False, + samesite="strict", + max_age=604800 # 7天 + ) + +def clear_auth_cookies(response: Response) -> None: + """清除认证cookie""" + response.delete_cookie(key="refresh_key", httponly=True, secure=True, samesite="strict") + response.delete_cookie(key="refresh_key_session_id", httponly=True, secure=True, samesite="strict") + +# ============================================================================ +# Authentication API Endpoints +# ============================================================================ + +@app.post("/api/register", response_model=SimpleResponse, + tags=["Authentication API Endpoints"], summary="用户注册", description="注册新用户账户") +async def register(request: RegisterRequest, auth_manager: AuthManager = Depends(get_auth_manager)): + """ + 用户注册端点 + + - **username**: 用户名(必需,只能包含字母、数字、下划线和连字符) + - **name**: 用户真实姓名(必需) + - **email**: 用户邮箱(可选) + - **password**: 用户密码(必需,最少6个字符) + + 返回注册是否成功的布尔值 + """ + try: + success = auth_manager.register_user( + username=request.username, + name=request.name, + email=request.email, + password=request.password + ) + + if success: + return SimpleResponse(success=True, message="用户注册成功") + else: + raise HTTPException( + status_code=400, + detail="用户注册失败,用户名或邮箱可能已存在" + ) + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Register endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.post("/api/login", response_model=AuthResponse, + tags=["Authentication API Endpoints"], summary="用户登录", description="用户登录并获取认证令牌") +async def login(request: LoginRequest, response: Response, + auth_manager: AuthManager = Depends(get_auth_manager)): + """ + 用户登录端点 + + - **username_or_email**: 用户名或邮箱(必需) + - **password**: 用户密码(必需) + + 成功登录后返回认证令牌和刷新令牌,并设置安全cookie + """ + try: + # 认证用户 + user_data = auth_manager.authenticate_user( + username_or_email=request.username_or_email, + password=request.password + ) + + if not user_data: + raise HTTPException( + status_code=401, + detail="用户名/邮箱或密码错误" + ) + + # 创建认证令牌 + tokens = auth_manager.create_auth_tokens(user_data['user_id']) + if not tokens: + raise HTTPException( + status_code=500, + detail="创建认证令牌失败" + ) + + # 设置安全cookie + if request.remember_me: + set_auth_cookies(response, tokens) + + return AuthResponse( + success=True, + message="登录成功", + auth_key=tokens['auth_key'], + auth_key_session_id=tokens['auth_key_session_id'], + refresh_key=tokens['refresh_key'], + refresh_key_session_id=tokens['refresh_key_session_id'] + ) + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Login endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.post("/api/refresh", response_model=AuthResponse, + tags=["Authentication API Endpoints"], summary="刷新令牌", description="使用刷新令牌获取新的认证令牌") +async def refresh_tokens(request: RefreshRequest, response: Response, + refresh_key: Optional[str] = Cookie(None), + refresh_key_session_id: Optional[str] = Cookie(None), + auth_manager: AuthManager = Depends(get_auth_manager)): + """ + 令牌刷新端点 + + - **auth_key_session_id**: 认证令牌会话ID(可选,用于删除旧的认证会话) + + refresh_key和refresh_key_session_id从httponly Cookie中自动获取 + 验证刷新令牌后返回新的认证令牌和刷新令牌,并更新cookie + """ + try: + # 检查Cookie中的刷新令牌信息 + if not refresh_key or not refresh_key_session_id: + raise HTTPException( + status_code=401, + detail="未找到刷新令牌,请重新登录" + ) + + # 刷新令牌 + new_tokens = auth_manager.refresh_tokens( + refresh_key_session_id=refresh_key_session_id, + refresh_key=refresh_key, + auth_key_session_id=request.auth_key_session_id + ) + + if not new_tokens: + # 清除cookie + clear_auth_cookies(response) + raise HTTPException( + status_code=401, + detail="刷新令牌无效或已过期" + ) + + # 设置新的安全cookie + set_auth_cookies(response, new_tokens) + + return AuthResponse( + success=True, + message="令牌刷新成功", + auth_key=new_tokens['auth_key'], + auth_key_session_id=new_tokens['auth_key_session_id'], + refresh_key=new_tokens['refresh_key'], + refresh_key_session_id=new_tokens['refresh_key_session_id'] + ) + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Refresh endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.post("/api/logout", response_model=SimpleResponse, + tags=["Authentication API Endpoints"], summary="登出", description="登出当前会话") +async def logout(response: Response, + refresh_key_session_id: Optional[str] = Cookie(None), + auth_session: Dict[str, Any] = Depends(verify_auth_token), + auth_manager: AuthManager = Depends(get_auth_manager)): + """ + 登出端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + refresh_key_session_id从httponly Cookie中自动获取 + + 登出指定的会话并清除相关cookie + """ + try: + # 检查Cookie中的刷新令牌会话ID + if not refresh_key_session_id: + # 清除cookie(如果有的话) + clear_auth_cookies(response) + raise HTTPException( + status_code=401, + detail="未找到刷新令牌会话ID" + ) + + success = auth_manager.logout_session( + refresh_key_session_id=refresh_key_session_id, + auth_key_session_id=auth_session['session_id'] + ) + + # 清除cookie + clear_auth_cookies(response) + + if success: + return SimpleResponse(success=True, message="登出成功") + else: + return SimpleResponse(success=False, message="会话不存在或已过期") + + except Exception as e: + auth_manager.logger.exception(f"Logout endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.post("/api/logout-all", response_model=SimpleResponse, + tags=["Authentication API Endpoints"], summary="登出所有会话", description="登出用户的所有会话") +async def logout_all(response: Response, + auth_session: Dict[str, Any] = Depends(verify_auth_token), + auth_manager: AuthManager = Depends(get_auth_manager)): + """ + 登出所有会话端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + + 登出用户的所有会话并清除相关cookie + """ + try: + success = auth_manager.logout_all_sessions(auth_session['user_id']) + + # 清除cookie + clear_auth_cookies(response) + + if success: + return SimpleResponse(success=True, message="所有会话已登出") + else: + return SimpleResponse(success=False, message="会话不存在或已过期") + + except Exception as e: + auth_manager.logger.exception(f"Logout all endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.get("/api/profile", response_model=Dict[str, Any], + tags=["Authentication API Endpoints"], summary="获取用户信息", description="获取当前认证用户的基本信息") +async def get_profile( + auth_session: Dict[str, Any] = Depends(verify_auth_token), + auth_manager: AuthManager = Depends(get_auth_manager) +): + """ + 获取用户信息端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + + 返回当前认证用户的基本信息 + """ + try: + user_data = auth_manager.db.get_user_by_id(auth_session['user_id']) + if not user_data: + raise HTTPException(status_code=404, detail="用户未找到") + + return { + "user_id": user_data['user_id'], + "username": user_data['username'], + "name": user_data['name'], + "email": user_data.get('email', None), + 'user_group': user_data.get('user_group', ['user']), + "is_activated": user_data.get('is_activated', False), + "last_active": user_data.get('last_active', None) + } + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Get profile endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.put("/api/update-password", response_model=SimpleResponse, + tags=["Authentication API Endpoints"], summary="更新密码", description="更新当前用户的密码") +async def update_password( + request: UpdatePasswordRequest, + response: Response, + auth_session: Dict[str, Any] = Depends(verify_auth_token), + auth_manager: AuthManager = Depends(get_auth_manager) +): + """ + 更新密码端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + + - **original_password**: 原密码(必需) + - **password**: 新密码(必需,最少6个字符) + + 验证原密码后更新为新密码,并自动登出所有会话 + """ + try: + user_id = auth_session['user_id'] + + # 调用AuthManager处理密码更新逻辑 + result = auth_manager.update_user_password_with_verification( + user_id=user_id, + original_password=request.original_password, + new_password=request.password + ) + + if result["success"]: + # 如果密码更新成功且包含logout_all标志,清除cookie + if result.get("logout_all", False): + clear_auth_cookies(response) + + return SimpleResponse(success=True, message=result["message"]) + else: + # 根据错误类型返回相应的HTTP状态码 + if "用户未找到" in result["message"]: + raise HTTPException(status_code=404, detail=result["message"]) + elif "原密码错误" in result["message"]: + raise HTTPException(status_code=401, detail=result["message"]) + else: + raise HTTPException(status_code=500, detail=result["message"]) + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Update password endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + +@app.delete("/api/delete-user", response_model=SimpleResponse, + tags=["Authentication API Endpoints"], summary="删除自己的账户", description="删除当前登录用户的账户") +async def delete_own_account( + response: Response, + auth_session: Dict[str, Any] = Depends(verify_auth_token), + auth_manager: AuthManager = Depends(get_auth_manager) +): + """ + 删除自己账户端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + 只能删除当前登录用户本人账户,删除成功后将清除所有认证cookie + """ + try: + user_id = auth_session['user_id'] + # 调用AuthManager处理自己账户删除逻辑 + result = auth_manager.delete_own_account(user_id) + + if result["success"]: + # 删除成功后清除认证cookie + clear_auth_cookies(response) + return SimpleResponse(success=True, message=result["message"]) + else: + # 根据错误类型返回相应的HTTP状态码 + if "未找到" in result["message"]: + raise HTTPException(status_code=404, detail=result["message"]) + else: + raise HTTPException(status_code=500, detail=result["message"]) + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Delete own account endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@app.delete("/api/admin/delete-user", response_model=SimpleResponse, + tags=["Admin API Endpoints"], summary="管理员删除用户", description="管理员删除指定用户账户") +async def admin_delete_user( + request: UserIDRequest, + auth_session: Dict[str, Any] = Depends(verify_auth_token), + auth_manager: AuthManager = Depends(get_auth_manager) +): + """ + 管理员删除用户端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + 需要管理员权限(manager或admin用户组) + - **user_id**: 要删除的用户ID(必需) + """ + try: + admin_user_id = auth_session['user_id'] + target_user_id = request.user_id + + # 调用AuthManager处理管理员删除用户逻辑 + result = auth_manager.admin_delete_user( + admin_user_id=admin_user_id, + admin_user_groups=auth_session.get('user_group', ["user"]), + target_user_id=target_user_id + ) + + if result["success"]: + return SimpleResponse(success=True, message=result["message"]) + else: + # 根据错误类型返回相应的HTTP状态码 + if "未找到" in result["message"]: + raise HTTPException(status_code=404, detail=result["message"]) + elif "权限不足" in result["message"]: + raise HTTPException(status_code=403, detail=result["message"]) + else: + raise HTTPException(status_code=400, detail=result["message"]) + + except HTTPException: + raise + except Exception as e: + auth_manager.logger.exception(f"Admin delete user endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + +# ============================================================================ +# Avatar Generation Endpoint +# ============================================================================ + +@app.get("/avatar/{variant}", + tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像") +async def generate_avatar_variant_only( + variant: str, + auth_session: Dict[str, Any] = Depends(verify_auth_token), + name: Optional[str] = Query(default="John Doe", description="Name to generate avatar for"), + colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"), + size: int = Query(default=80, description="Avatar size"), + square: bool = Query(default=False, description="Square avatar"), + title: bool = Query(default=False, description="Include title element"), + + avatar_generator: AvatarGenerator = Depends(get_avatar_generator), +): + """Generate avatar with variant only""" + try: + if variant not in avatar_generator.VALID_VARIANTS: + raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}") + + # Validate and sanitize inputs + if not name or not isinstance(name, str): + name = auth_session['user_id'] + + if not isinstance(size, int) or size < 10 or size > 1000: + size = avatar_generator.DEFAULT_SIZE + + color_palette = avatar_generator.normalize_colors(colors) + generator = avatar_generator.AVATAR_GENERATORS[variant] + svg_content = generator(name, color_palette, size, square, title) + + return Response( + content=svg_content, + media_type="image/svg+xml", + headers={ + "Cache-Control": "s-max-age=1, stale-while-revalidate", + "Content-Type": "image/svg+xml" + } + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}") + +@app.get("/avatar/{variant}/{size}", + tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像") +async def generate_avatar_variant_size( + variant: str, + size: int, + auth_session: Dict[str, Any] = Depends(verify_auth_token), + name: Optional[str] = Query(default="Clara Barton", description="Name to generate avatar for"), + colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"), + square: bool = Query(default=False, description="Square avatar"), + title: bool = Query(default=False, description="Include title element"), + avatar_generator: AvatarGenerator = Depends(get_avatar_generator) +): + """Generate avatar with variant and size""" + try: + if variant not in avatar_generator.VALID_VARIANTS: + raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}") + + # Validate and sanitize inputs + if not name or not isinstance(name, str): + name = auth_session['user_id'] + + if not isinstance(size, int) or size < 10 or size > 1000: + size = avatar_generator.DEFAULT_SIZE + + color_palette = avatar_generator.normalize_colors(colors) + generator = avatar_generator.AVATAR_GENERATORS[variant] + svg_content = generator(name, color_palette, size, square, title) + + return Response( + content=svg_content, + media_type="image/svg+xml", + headers={ + "Cache-Control": "s-max-age=1, stale-while-revalidate", + "Content-Type": "image/svg+xml" + } + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}") + +@app.get("/avatar/{variant}/{size}/{name}", + tags=["Avatar API Endpoints"], summary="生成头像", description="根据变体生成头像") +async def generate_avatar_full_path( + variant: str, + size: int, + name: str, + colors: Optional[str] = Query(default=None, description="Comma-separated hex colors"), + square: bool = Query(default=False, description="Square avatar"), + title: bool = Query(default=False, description="Include title element"), + avatar_generator: AvatarGenerator = Depends(get_avatar_generator), + + auth_session: Dict[str, Any] = Depends(verify_auth_token), +): + """Generate avatar with variant, size, and name in path""" + try: + if variant not in avatar_generator.VALID_VARIANTS: + raise HTTPException(status_code=400, detail=f"Invalid variant. Must be one of: {', '.join(avatar_generator.VALID_VARIANTS)}") + + # Validate and sanitize inputs + if not name or not isinstance(name, str): + name = auth_session['user_id'] + + if not isinstance(size, int) or size < 10 or size > 1000: + size = avatar_generator.DEFAULT_SIZE + + color_palette = avatar_generator.normalize_colors(colors) + generator = avatar_generator.AVATAR_GENERATORS[variant] + svg_content = generator(name, color_palette, size, square, title) + + return Response( + content=svg_content, + media_type="image/svg+xml", + headers={ + "Cache-Control": "s-max-age=1, stale-while-revalidate", + "Content-Type": "image/svg+xml" + } + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Avatar generation failed: {str(e)}") + +# ============================================================================ +# System Statistics API Endpoints +# ============================================================================ +@app.get("/api/system-stats", response_model=SystemStatsResponse, + tags=["System Statistics API Endpoints"], summary="获取系统统计信息", description="获取服务器的系统统计信息") +async def get_system_stats( + auth_session: Dict[str, Any] = Depends(verify_auth_token), + time_delta: Optional[int] = Query(default=5, description="时间范围(秒),默认为5秒"), + time_interval: Optional[int] = Query(default=1, description="时间间隔(秒),默认为1秒"), + system_stats_manager: SystemStatsManager = Depends(get_system_stats_manager) +): + """ + 获取系统统计信息端点 + + 需要在Authorization头中提供Bearer令牌,格式为: Bearer auth_key:auth_key_session_id + - **time_delta**: 时间范围(秒),默认为5秒 + 返回服务器的系统统计信息 + """ + try: + stats = system_stats_manager.get_recent_stats( + auth_session.get('user_group', ['user']), + time_delta=time_delta, + time_interval=time_interval + ) + if stats["success"]: + return stats + else: + if "未找到" in stats["message"]: + raise HTTPException(status_code=404, detail=stats["message"]) + elif "权限不足" in stats["message"]: + raise HTTPException(status_code=403, detail=stats["message"]) + else: + raise HTTPException(status_code=500, detail=stats["message"]) + + except HTTPException: + raise + except Exception as e: + system_stats_manager.logger.exception(f"Get system stats endpoint error: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + +# ============================================================================ +# Additional Utility API Endpoints +# ============================================================================ + +@app.get("/api/health", response_model=Dict[str, Any], + tags=["Additional Utility API Endpoints"], summary="健康检查", description="检查API服务器状态") +async def health_check(db: Database = Depends(get_db)): + """健康检查端点""" + try: + stats = db.get_database_stats() + return { + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "database_stats": stats + } + except Exception as e: + raise HTTPException(status_code=503, detail=f"服务不可用: {str(e)}") + +@app.post("/api/clean-cookies", response_model=SimpleResponse, + tags=["Additional Utility API Endpoints"], summary="清除认证Cookie", description="清除认证相关的Cookie") +async def clean_cookies(response: Response): + """清除认证Cookie端点""" + try: + clear_auth_cookies(response) + return SimpleResponse(success=True, message="认证Cookie已清除") + except Exception as e: + raise HTTPException(status_code=500, detail=f"清除Cookie失败: {str(e)}") + +# ============================================================================ +# Static Files Configuration +# ============================================================================ +# 在所有API路由定义之后挂载静态文件,确保API路由有更高优先级 + +# Frontend路径配置 +frontend_path = r"./livestream_site" # 前端静态文件目录 +app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend") + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +if __name__ == "__main__": + logger = get_logger() + logger.info("Starting FastAPI Authentication Server on port 8000...") + + # 检查是否在PyInstaller打包环境中运行 + import sys + if getattr(sys, 'frozen', False): + # 在打包环境中,直接传递app对象 + uvicorn.run( + app, + host="127.0.0.1", + port=8000, + reload=False, + log_level="info" + ) + else: + # 在开发环境中,使用字符串导入(支持热重载) + uvicorn.run( + "httpbackend_server:app", + host="127.0.0.1", + port=8000, + reload=True, + reload_delay=5.0, + log_level="info" + ) \ No newline at end of file diff --git a/trunk/python/monitor.py b/trunk/python/monitor.py deleted file mode 100644 index e718f53ff..000000000 --- a/trunk/python/monitor.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -""" -SRS Python Monitor Script -This is an example Python script that runs alongside SRS server. -It demonstrates how Python processes can be managed by SRS. -""" - -import time -import signal -import sys -import argparse -import logging -from datetime import datetime - -class SRSMonitor: - def __init__(self, config_file=None, verbose=False): - self.running = True - self.config_file = config_file - self.verbose = verbose - - # Set up logging - level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig( - level=level, - format='[%(asctime)s] [Python Monitor] %(levelname)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - self.logger = logging.getLogger(__name__) - - # Set up signal handlers - signal.signal(signal.SIGTERM, self.signal_handler) - signal.signal(signal.SIGINT, self.signal_handler) - - def signal_handler(self, signum, frame): - """Handle shutdown signals from SRS""" - self.logger.info(f"Received signal {signum}, shutting down gracefully...") - self.running = False - - def run(self): - """Main monitoring loop""" - self.logger.info("SRS Python Monitor started") - if self.config_file: - self.logger.info(f"Using config file: {self.config_file}") - - try: - while self.running: - # Simulate monitoring work - self.logger.debug(f"Monitor heartbeat at {datetime.now()}") - - # You can add your monitoring logic here: - # - Check stream status - # - Monitor server health - # - Send alerts - # - Log analytics - - time.sleep(5) # Check every 5 seconds - - except Exception as e: - self.logger.error(f"Error in monitor loop: {e}") - finally: - self.cleanup() - - def cleanup(self): - """Cleanup before shutdown""" - self.logger.info("Cleaning up Python Monitor...") - # Add any cleanup logic here - self.logger.info("Python Monitor stopped") - -def main(): - parser = argparse.ArgumentParser(description='SRS Python Monitor') - parser.add_argument('--config', help='Configuration file path') - parser.add_argument('--verbose', action='store_true', help='Enable verbose logging') - - args = parser.parse_args() - - monitor = SRSMonitor(args.config, args.verbose) - monitor.run() - -if __name__ == '__main__': - main() diff --git a/trunk/python/requirements.txt b/trunk/python/requirements.txt index 2438804f1..4bf721606 100644 --- a/trunk/python/requirements.txt +++ b/trunk/python/requirements.txt @@ -1,7 +1,70 @@ -# Python packages required for SRS Python addons -requests>=2.25.0 -psutil>=5.8.0 -pyyaml>=5.4.0 -websockets>=10.0 -aiohttp>=3.8.0 -numpy>=1.20.0 +aiofiles==23.2.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiosignal==1.3.2 +altgraph==0.17.4 +annotated-types==0.7.0 +anyio==3.7.1 +attrs==25.3.0 +bcrypt==4.3.0 +blinker==1.9.0 +certifi==2025.4.26 +cffi==1.17.1 +charset-normalizer==3.4.2 +click==8.2.1 +colorama==0.4.6 +cryptography==45.0.3 +dnspython==2.7.0 +ecdsa==0.19.1 +email-validator==2.1.0 +fastapi==0.104.1 +Flask==3.1.1 +Flask-SQLAlchemy==3.1.1 +frozenlist==1.7.0 +greenlet==3.2.3 +h11==0.16.0 +httpcore==1.0.9 +httptools==0.6.4 +httpx==0.25.2 +idna==3.10 +iniconfig==2.1.0 +itsdangerous==2.2.0 +Jinja2==3.1.2 +jwt==1.3.1 +MarkupSafe==3.0.2 +multidict==6.4.4 +packaging==25.0 +passlib==1.7.4 +pefile==2023.2.7 +pluggy==1.6.0 +propcache==0.3.2 +pyasn1==0.6.1 +pycparser==2.22 +pydantic==2.5.0 +pydantic_core==2.14.1 +pyinstaller==6.14.1 +pyinstaller-hooks-contrib==2025.5 +PyJWT==2.8.0 +pytest==7.4.3 +pytest-asyncio==0.21.1 +python-dotenv==1.0.0 +python-jose==3.3.0 +python-multipart==0.0.6 +pywin32-ctypes==0.2.3 +PyYAML==6.0.2 +requests==2.32.4 +rsa==4.9.1 +setuptools==78.1.1 +six==1.17.0 +sniffio==1.3.1 +SQLAlchemy==2.0.41 +starlette==0.27.0 +typing-inspection==0.4.1 +typing_extensions==4.14.0 +urllib3==2.4.0 +uvicorn==0.24.0 +watchfiles==1.0.5 +websockets==15.0.1 +Werkzeug==3.1.3 +wheel==0.45.1 +yarl==1.20.1 diff --git a/trunk/python/security_module.py b/trunk/python/security_module.py new file mode 100644 index 000000000..7f5005d6a --- /dev/null +++ b/trunk/python/security_module.py @@ -0,0 +1,317 @@ +import hashlib +import secrets +import hmac +from datetime import datetime, timedelta +from typing import Optional, Dict, Any + +from database import Database +from srs_logger import get_logger + +# ============================================================================ +# Security and Utility Functions +# ============================================================================ + +class AuthManager: + """认证管理器""" + + def __init__(self, db: Database): + self.db = db + self.logger = get_logger() + + # Token 配置 + self.AUTH_TOKEN_EXPIRE_HOURS = 1 # 认证令牌1小时过期 + self.REFRESH_TOKEN_EXPIRE_DAYS = 7 # 刷新令牌7天过期 + self.TOKEN_LENGTH = 128 # 令牌长度 + + def generate_secure_token(self) -> str: + """生成安全的随机令牌""" + return secrets.token_urlsafe(self.TOKEN_LENGTH) + + def generate_salt(self) -> str: + """生成盐值""" + return secrets.token_hex(16) + + def hash_password(self, password: str, salt: str) -> str: + """对密码进行加盐哈希""" + return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex() + + def verify_password(self, password: str, salt: str, hashed_password: str) -> bool: + """验证密码""" + return hmac.compare_digest( + self.hash_password(password, salt), + hashed_password + ) + + def hash_token(self, token: str, salt: str) -> str: + """对令牌进行加盐哈希""" + return hashlib.pbkdf2_hmac('sha256', token.encode(), salt.encode(), 100000).hex() + + def verify_token(self, token: str, salt: str, hashed_token: str) -> bool: + """验证令牌""" + return hmac.compare_digest( + self.hash_token(token, salt), + hashed_token + ) + + def register_user(self, username: str, name: str, email: Optional[str], password: str) -> bool: + """注册新用户""" + try: + # 生成密码盐值和哈希 + salt = self.generate_salt() + hashed_password = self.hash_password(password, salt) + + # 创建用户 + user_data = self.db.create_user( + username=username, + name=name, + email=email, + hashed_password=hashed_password, + salt=salt, + user_group=["user"], + is_activated=True + ) + + if user_data: + self.logger.info(f"User registered successfully: {username}") + return True + else: + self.logger.warn(f"Failed to register user: {username}") + return False + + except Exception as e: + self.logger.exception(f"Error during user registration: {e}") + return False + + def authenticate_user(self, username_or_email: str, password: str) -> Optional[Dict[str, Any]]: + """认证用户并返回用户数据""" + try: + # 获取用户信息 + user_data = self.db.get_user(username_or_email) + if not user_data: + self.logger.warn(f"User not found: {username_or_email}") + return None + + # 检查用户是否激活 + if not user_data.get('is_activated', False): + self.logger.warn(f"User account deactivated: {username_or_email}") + return None + + # 验证密码 + if self.verify_password(password, user_data['salt'], user_data['hashed_password']): + self.logger.info(f"User authenticated successfully: {username_or_email}") + return user_data + else: + self.logger.warn(f"Invalid password for user: {username_or_email}") + return None + + except Exception as e: + self.logger.exception(f"Error during user authentication: {e}") + return None + + def create_auth_tokens(self, user_id: str) -> Optional[Dict[str, Any]]: + """为用户创建认证和刷新令牌""" + try: + # 清理过期会话 + self.db.cleanup_expired_sessions() + + # 生成认证令牌 + auth_key = self.generate_secure_token() + auth_salt = self.generate_salt() + hashed_auth_key = self.hash_token(auth_key, auth_salt) + auth_expire_time = datetime.now() + timedelta(hours=self.AUTH_TOKEN_EXPIRE_HOURS) + + # 生成刷新令牌 + refresh_key = self.generate_secure_token() + refresh_salt = self.generate_salt() + hashed_refresh_key = self.hash_token(refresh_key, refresh_salt) + refresh_expire_time = datetime.now() + timedelta(days=self.REFRESH_TOKEN_EXPIRE_DAYS) + + # 创建认证会话 + auth_session = self.db.create_auth_session( + user_id=user_id, + hashed_authkey=hashed_auth_key, + salt=auth_salt, + expire_time=auth_expire_time + ) + + if not auth_session: + self.logger.error(f"Failed to create auth session for user: {user_id}") + return None + + # 创建刷新会话 + refresh_session = self.db.create_refresh_session( + user_id=user_id, + hashed_refreshkey=hashed_refresh_key, + salt=refresh_salt, + expire_time=refresh_expire_time + ) + + if not refresh_session: + self.logger.error(f"Failed to create refresh session for user: {user_id}") + # 清理已创建的认证会话 + self.db.delete_auth_session(auth_session['session_id']) + return None + + # 更新用户最后活跃时间 + self.db.update_user_last_active(user_id) + + return { + 'auth_key': auth_key, + 'auth_key_session_id': auth_session['session_id'], + 'refresh_key': refresh_key, + 'refresh_key_session_id': refresh_session['session_id'] + } + + except Exception as e: + self.logger.exception(f"Error creating auth tokens for user: {user_id}") + return None + + def refresh_tokens(self, refresh_key_session_id: str, refresh_key: str, + auth_key_session_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """刷新认证令牌""" + try: + # 清理过期会话 + self.db.cleanup_expired_sessions() + + # 验证刷新会话 + refresh_session = self.db.get_refresh_session(refresh_key_session_id) + if not refresh_session: + self.logger.warn(f"Invalid or expired refresh session: {refresh_key_session_id}") + return None + + # 验证刷新令牌 + if not self.verify_token(refresh_key, refresh_session['salt'], refresh_session['hashed_refreshkey']): + self.logger.warn(f"Invalid refresh token for session: {refresh_key_session_id}") + return None + + user_id = refresh_session['user_id'] + + # 删除旧的认证会话(如果提供了session_id) + if auth_key_session_id: + self.db.delete_auth_session(auth_key_session_id) + + # 删除旧的刷新会话 + self.db.delete_refresh_session(refresh_key_session_id) + + # 创建新的令牌 + new_tokens = self.create_auth_tokens(user_id) + if new_tokens: + self.logger.info(f"Tokens refreshed successfully for user: {user_id}") + return new_tokens + else: + self.logger.error(f"Failed to create new tokens during refresh for user: {user_id}") + return None + + except Exception as e: + self.logger.exception(f"Error during token refresh: {e}") + return None + + def logout_session(self, refresh_key_session_id: str, auth_key_session_id: str = None) -> bool: + """登出单个会话,通过删除指定的refresh session和对应的auth session""" + try: + # 删除指定的刷新会话 + refresh_success = self.db.delete_refresh_session(refresh_key_session_id) + if refresh_success: + self.logger.info(f"Refresh session logged out successfully: {refresh_key_session_id}") + + auth_success = self.db.delete_auth_session(auth_key_session_id) + if auth_success: + self.logger.info(f"Auth session logged out successfully: {auth_key_session_id}") + else: + self.logger.warn(f"Failed to delete auth session: {auth_key_session_id}") + + return refresh_success and auth_success + + except Exception as e: + self.logger.exception(f"Error during session logout: {e}") + return False + + def logout_all_sessions(self, user_id: str) -> bool: + """登出用户的所有会话,通过refresh session确定用户""" + try: + # 删除用户的所有会话 + success = self.db.delete_user_sessions(user_id) + if success: + self.logger.info(f"All sessions logged out successfully for user: {user_id}") + + return success + + except Exception as e: + self.logger.exception(f"Error during logout all sessions: {e}") + return False + + def update_user_password_with_verification(self, user_id: str, original_password: str, new_password: str) -> Dict[str, Any]: + """验证原密码后更新用户密码,并登出所有会话""" + try: + # 获取用户数据 + user_data = self.db.get_user_by_id(user_id) + if not user_data: + return {"success": False, "message": "用户未找到"} + + # 验证原密码 + if not self.verify_password(original_password, user_data['salt'], user_data['hashed_password']): + return {"success": False, "message": "原密码错误"} + + # 生成新的盐值和密码哈希 + new_salt = self.generate_salt() + new_hashed_password = self.hash_password(new_password, new_salt) + + # 更新密码 + updated_user = self.db.update_user_password(user_id, new_hashed_password, new_salt) + if not updated_user: + return {"success": False, "message": "密码更新失败"} + + # 密码更新成功后,登出用户的所有会话(强制重新登录) + logout_success = self.db.delete_user_sessions(user_id) + if logout_success: + self.logger.info(f"All sessions logged out after password update for user: {user_data['username']}") + else: + self.logger.warn(f"Failed to logout all sessions after password update for user: {user_data['username']}") + + self.logger.info(f"Password updated successfully for user: {user_data['username']}") + return {"success": True, "message": "密码更新成功,请重新登录", "logout_all": True} + + except Exception as e: + self.logger.exception(f"Error updating user password: {e}") + return {"success": False, "message": "服务器内部错误"} + + def delete_own_account(self, user_id: str) -> Dict[str, Any]: + """删除用户自己的账户""" + try: + success = self.db.delete_user(user_id) + if not success: + return {"success": False, "message": "账户删除失败"} + + self.logger.info(f"User-ID: '{user_id}' deleted his/her own account") + return {"success": True, "message": "您的账户已成功删除"} + + except Exception as e: + self.logger.exception(f"Error deleting own account: {e}") + return {"success": False, "message": "服务器内部错误"} + + def admin_delete_user(self, admin_user_id: str, admin_user_groups: list, target_user_id: str, ) -> Dict[str, Any]: + """管理员删除指定用户账户""" + try: + # 检查管理员权限 + if not any(group in admin_user_groups for group in ['manager', 'admin']): + return { + "success": False, + "message": "权限不足,只有管理员可以删除其他用户账户" + } + + # 获取目标用户数据 + target_user = self.db.get_user_by_id(target_user_id) + if not target_user: + return {"success": False, "message": "目标删除的用户未找到"} + + # 执行删除操作 + success = self.db.delete_user(target_user_id) + if not success: + return {"success": False, "message": "用户删除失败"} + + self.logger.info(f"User-ID '{target_user_id}' deleted by admin-ID: '{admin_user_id}'") + return {"success": True, "message": f"用户 '{target_user_id}' 已成功删除"} + + except Exception as e: + self.logger.exception(f"Error in admin delete user: {e}") + return {"success": False, "message": "服务器内部错误"} \ No newline at end of file diff --git a/trunk/python/srs_logger.py b/trunk/python/srs_logger.py new file mode 100644 index 000000000..0a780cbcb --- /dev/null +++ b/trunk/python/srs_logger.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +SRS Python Logger Module +Provides synchronized logging with SRS main process. +""" + +import logging +import os +import sys +import threading +import time +import re +from datetime import datetime +from typing import Optional, Dict, Any +import argparse + +class SRSLogFormatter(logging.Formatter): + """ + Custom formatter that matches SRS log format with color support: + [2025-05-30 21:54:18.835][WARN][3448][python_analytics] message + """ + + # ANSI color codes + COLORS = { + 'RESET': '\033[0m', + 'RED': '\033[31m', + 'YELLOW': '\033[33m', + 'WHITE': '\033[37m', + 'CYAN': '\033[36m', + 'GREEN': '\033[32m' + } + + def __init__(self, use_colors=True): + super().__init__() + self.pid = os.getpid() + self.session_id = self._generate_session_id() + self.use_colors = use_colors and self._supports_color() + + def _supports_color(self) -> bool: + """Check if the terminal supports color output""" + import sys + + # Check if we're in a terminal + if not hasattr(sys.stdout, 'isatty') or not sys.stdout.isatty(): + return False + + # Check for Windows terminal support + if os.name == 'nt': + try: + import colorama + colorama.init() + return True + except ImportError: + # Try to enable ANSI escape sequence support on Windows + try: + import ctypes + kernel32 = ctypes.windll.kernel32 + kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7) + return True + except: + return False + + # Unix-like systems usually support colors + return True + + def _generate_session_id(self) -> str: + """Generate a meaningful session ID based on the calling script""" + import inspect + + # Try to get the main script name + main_module = sys.modules.get('__main__') + if main_module and hasattr(main_module, '__file__'): + script_path = main_module.__file__ + if script_path: + script_name = os.path.splitext(os.path.basename(script_path))[0] + return f"python_{script_name}" + + # Fallback to inspect the call stack to find the calling script + try: + for frame_info in inspect.stack(): + filename = frame_info.filename + if filename and not filename.endswith('srs_logger.py') and not filename.endswith('logging/__init__.py'): + script_name = os.path.splitext(os.path.basename(filename))[0] + if script_name != '<string>' and script_name != '<stdin>': + return f"python_{script_name}" + except: + pass + + # Final fallback + return "python_unknown" + + def format(self, record: logging.LogRecord) -> str: + """Format log record to match SRS format with color support""" + # Get current timestamp with milliseconds + now = datetime.now() + timestamp = now.strftime("%Y-%m-%d %H:%M:%S.") + f"{now.microsecond // 1000:03d}" + + # Map Python log levels to SRS log levels + level_mapping = { + 'DEBUG': 'TRACE', + 'INFO': 'INFO', + 'WARNING': 'WARN', + 'ERROR': 'ERROR', + 'CRITICAL': 'ERROR' + } + + srs_level = level_mapping.get(record.levelname, 'INFO') + + # Format: [timestamp][level][pid][session] message + formatted_msg = f"[{timestamp}][{srs_level}][{self.pid}][{self.session_id}] {record.getMessage()}" + + if record.exc_info: + formatted_msg += f"\n{self.formatException(record.exc_info)}" + + # Apply colors if supported and enabled + if self.use_colors: + if srs_level == 'ERROR': + formatted_msg = f"{self.COLORS['RED']}{formatted_msg}{self.COLORS['RESET']}" + elif srs_level == 'WARN': + formatted_msg = f"{self.COLORS['YELLOW']}{formatted_msg}{self.COLORS['RESET']}" + # INFO, TRACE levels remain uncolored (default terminal color) + + return formatted_msg + +class SRSConfigParser: + """Parser for SRS configuration files""" + + @staticmethod + def parse_config_file(config_path: str) -> Dict[str, Any]: + """Parse SRS configuration file and extract logging settings""" + config = { + 'log_tank': 'all', + 'log_file': './objs/srs.log', + 'log_level': 'info' + } + + if not os.path.exists(config_path): + return config + + try: + with open(config_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Parse srs_log_tank + match = re.search(r'srs_log_tank\s+([^;]+);', content) + if match: + config['log_tank'] = match.group(1).strip() + + # Parse srs_log_file + match = re.search(r'srs_log_file\s+([^;]+);', content) + if match: + config['log_file'] = match.group(1).strip() + + # Parse srs_log_level_v2 + match = re.search(r'srs_log_level_v2\s+([^;]+);', content) + if match: + config['log_level'] = match.group(1).strip().lower() + + except Exception as e: + print(f"Warning: Failed to parse config file {config_path}: {e}", file=sys.stderr) + + return config + +class SRSLogger: + """ + SRS-compatible Python logger that synchronizes with SRS main process logging + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, config_path: Optional[str] = None): + """Singleton pattern to ensure only one logger instance""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config_path: Optional[str] = None): + if self._initialized: + return + + self._initialized = True + self.config_path = config_path + self.logger = None + self.config = {} + self._setup_logger() + + def _setup_logger(self): + """Setup the logger based on SRS configuration""" + try: + # Parse SRS config if provided + if self.config_path and os.path.exists(self.config_path): + self.config = SRSConfigParser.parse_config_file(self.config_path) + else: + # Default configuration + self.config = { + 'log_tank': 'all', + 'log_file': './objs/srs.log', + 'log_level': 'info' + } + + # Create logger + self.logger = logging.getLogger('srs_python') + self.logger.setLevel(self._get_python_log_level(self.config['log_level'])) + + # Clear existing handlers + self.logger.handlers.clear() + + # Setup handlers based on log_tank setting + log_tank = self.config['log_tank'].lower() + + if log_tank in ['all', 'console']: + # Console handler with colors + console_handler = logging.StreamHandler(sys.stdout) + console_formatter = SRSLogFormatter(use_colors=True) + console_handler.setFormatter(console_formatter) + self.logger.addHandler(console_handler) + + if log_tank in ['all', 'file']: + # File handler without colors + log_file = self.config['log_file'] + # Create directory if it doesn't exist + log_dir = os.path.dirname(log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_formatter = SRSLogFormatter(use_colors=False) + file_handler.setFormatter(file_formatter) + self.logger.addHandler(file_handler) + + # Prevent propagation to root logger + self.logger.propagate = False + + self.info(f"SRS Python logger initialized with config: {self.config}") + + except Exception as e: + print(f"ERROR: Failed to setup SRS logger: {e}", file=sys.stderr) + # Setup a basic console logger as fallback + self._setup_fallback_logger() + + def _setup_fallback_logger(self): + """Setup a basic fallback logger if main setup fails""" + self.logger = logging.getLogger('srs_python_fallback') + self.logger.setLevel(logging.INFO) + + console_handler = logging.StreamHandler(sys.stderr) + formatter = logging.Formatter('%(asctime)s [ERROR] [Python] %(message)s') + console_handler.setFormatter(formatter) + + self.logger.addHandler(console_handler) + self.logger.propagate = False + + def _get_python_log_level(self, srs_level: str) -> int: + """Convert SRS log level to Python log level""" + level_mapping = { + 'trace': logging.DEBUG, + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warn': logging.WARNING, + 'warning': logging.WARNING, + 'error': logging.ERROR + } + return level_mapping.get(srs_level.lower(), logging.INFO) + + def trace(self, message: str, *args, **kwargs): + """Log trace message (mapped to DEBUG in Python)""" + if self.logger: + self.logger.debug(message, *args, **kwargs) + + def debug(self, message: str, *args, **kwargs): + """Log debug message""" + if self.logger: + self.logger.debug(message, *args, **kwargs) + + def info(self, message: str, *args, **kwargs): + """Log info message""" + if self.logger: + self.logger.info(message, *args, **kwargs) + + def warn(self, message: str, *args, **kwargs): + """Log warning message""" + if self.logger: + self.logger.warning(message, *args, **kwargs) + + def warning(self, message: str, *args, **kwargs): + """Log warning message""" + if self.logger: + self.logger.warning(message, *args, **kwargs) + + def error(self, message: str, *args, **kwargs): + """Log error message""" + if self.logger: + self.logger.error(message, *args, **kwargs) + + def critical(self, message: str, *args, **kwargs): + """Log critical message (mapped to ERROR in SRS)""" + if self.logger: + self.logger.critical(message, *args, **kwargs) + + def exception(self, message: str, *args, **kwargs): + """Log exception with traceback""" + if self.logger: + self.logger.exception(message, *args, **kwargs) + +# Global logger instance +_global_logger: Optional[SRSLogger] = None + +def get_logger(config_path: Optional[str] = None) -> SRSLogger: + """ + Get the global SRS logger instance + + Args: + config_path: Path to SRS configuration file + + Returns: + SRSLogger instance + """ + global _global_logger + + if _global_logger is None: + _global_logger = SRSLogger(config_path) + + return _global_logger + +def init_logger(config_path: Optional[str] = None) -> SRSLogger: + """ + Initialize the SRS logger with configuration + + Args: + config_path: Path to SRS configuration file + + Returns: + SRSLogger instance + """ + return get_logger(config_path) + +# Convenience functions for direct logging +def trace(message: str, *args, **kwargs): + """Log trace message""" + get_logger().trace(message, *args, **kwargs) + +def debug(message: str, *args, **kwargs): + """Log debug message""" + get_logger().debug(message, *args, **kwargs) + +def info(message: str, *args, **kwargs): + """Log info message""" + get_logger().info(message, *args, **kwargs) + +def warn(message: str, *args, **kwargs): + """Log warning message""" + get_logger().warn(message, *args, **kwargs) + +def warning(message: str, *args, **kwargs): + """Log warning message""" + get_logger().warning(message, *args, **kwargs) + +def error(message: str, *args, **kwargs): + """Log error message""" + get_logger().error(message, *args, **kwargs) + +def critical(message: str, *args, **kwargs): + """Log critical message""" + get_logger().critical(message, *args, **kwargs) + +def exception(message: str, *args, **kwargs): + """Log exception with traceback""" + get_logger().exception(message, *args, **kwargs) + +def log_error_and_exit(message: str, exit_code: int = 1): + """ + Log an error message and exit the process + Used when Python process encounters fatal errors + """ + error(f"FATAL: {message}") + sys.exit(exit_code) + +if __name__ == "__main__": + # Test the logger + parser = argparse.ArgumentParser(description="SRS Python Logger Test") + parser.add_argument("--config", help="Path to SRS config file") + args = parser.parse_args() + + # Initialize logger + logger = init_logger(args.config) + + # Test different log levels + logger.trace("This is a trace message") + logger.debug("This is a debug message") + logger.info("SRS Python logger test started") + logger.warn("This is a warning message") + logger.error("This is an error message") + + # Test exception logging + try: + raise ValueError("Test exception") + except Exception: + logger.exception("Caught test exception") + + logger.info("SRS Python logger test completed") diff --git a/trunk/python/srs_logger_example.py b/trunk/python/srs_logger_example.py new file mode 100644 index 000000000..2564216ce --- /dev/null +++ b/trunk/python/srs_logger_example.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Example usage of SRS Logger module +This file demonstrates how to import and use the SRS logger in other Python files. +""" + +import sys +import argparse +import time +import traceback + +# Import the SRS logger module +from srs_logger import get_logger, init_logger, log_error_and_exit +import srs_logger + +def example_basic_usage(): + """Basic usage example""" + logger = get_logger() + + logger.info("Starting basic usage example") + logger.debug("This is a debug message") + logger.warn("This is a warning message") + logger.info("Basic usage example completed") + +def example_with_config(): + """Example with config file parsing""" + # This would typically be passed via --config argument + config_path = "./conf/console.conf" + + logger = init_logger(config_path) + logger.info(f"Logger initialized with config file: {config_path}") + + # Demonstrate various log levels + logger.trace("Trace level message") + logger.debug("Debug level message") + logger.info("Info level message") + logger.warning("Warning level message") + logger.error("Error level message") + +def example_error_handling(): + """Example of error handling and logging""" + logger = get_logger() + + try: + # Simulate some work that might fail + logger.info("Starting risky operation") + + # Simulate an error condition + if True: # This would be your actual error condition + raise RuntimeError("Simulated error in Python process") + + logger.info("Risky operation completed successfully") + + except Exception as e: + # Log the error with full traceback + logger.exception(f"Error in risky operation: {e}") + + # For fatal errors, you can use log_error_and_exit + # log_error_and_exit(f"Fatal error occurred: {e}") + + # Or just log and continue + logger.error(f"Continuing after error: {e}") + +def example_convenience_functions(): + """Example using convenience functions""" + # These functions use the global logger instance + srs_logger.info("Using convenience function for info") + srs_logger.warn("Using convenience function for warning") + srs_logger.error("Using convenience function for error") + +def analytics_simulation(): + """Simulate analytics processing with logging""" + logger = get_logger() + + logger.info("Analytics process started") + + try: + # Simulate analytics work + for i in range(5): + logger.debug(f"Processing analytics batch {i+1}/5") + time.sleep(0.5) # Simulate processing time + + if i == 3: # Simulate a warning condition + logger.warn(f"High CPU usage detected during batch {i+1}") + + logger.info("Analytics processing completed successfully") + + except KeyboardInterrupt: + logger.warn("Analytics process interrupted by user") + return False + except Exception as e: + logger.exception(f"Analytics process failed: {e}") + return False + + return True + +def main(): + """Main function demonstrating different usage patterns""" + parser = argparse.ArgumentParser(description="SRS Logger Usage Examples") + parser.add_argument("--config", help="Path to SRS config file") + parser.add_argument("--port", type=int, default=8888, help="Port number (example parameter)") + parser.add_argument("--example", choices=['basic', 'config', 'error', 'convenience', 'analytics'], + default='analytics', help="Example to run") + + args = parser.parse_args() + + try: + # Initialize logger with config if provided + if args.config: + logger = init_logger(args.config) + logger.info(f"SRS Logger example started with config: {args.config}") + else: + logger = get_logger() + logger.info("SRS Logger example started without config file") + + logger.info(f"Example parameters: port={args.port}, example={args.example}") + + # Run the selected example + if args.example == 'basic': + example_basic_usage() + elif args.example == 'config': + example_with_config() + elif args.example == 'error': + example_error_handling() + elif args.example == 'convenience': + example_convenience_functions() + elif args.example == 'analytics': + success = analytics_simulation() + if not success: + log_error_and_exit("Analytics simulation failed") + + logger.info("SRS Logger example completed successfully") + + except Exception as e: + # Use the logger for any unhandled exceptions + if 'logger' in locals(): + logger.exception(f"Unhandled exception in main: {e}") + else: + print(f"FATAL ERROR: {e}", file=sys.stderr) + traceback.print_exc() + + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/trunk/python/system_stats_module.py b/trunk/python/system_stats_module.py new file mode 100644 index 000000000..15699d7d4 --- /dev/null +++ b/trunk/python/system_stats_module.py @@ -0,0 +1,402 @@ +from database import Database +from srs_logger import get_logger +import requests +import time +import threading +from datetime import datetime, timedelta +from typing import Optional, Dict, Any +from datetime import datetime, timedelta + +# ============================================================================ +# System Statistics Management Module +# ============================================================================ + +class SystemStatsManager: + """系统统计管理器""" + + def __init__(self, db: Database): + self.db = db + self.logger = get_logger() + self.api_url = "http://python_stats:wMePq3ahpoLRzgsVg7BY9eE82uuJHT0YukD2ZE1JfMY2RjP4e6QnUaKg3V9x5s9M@localhost:1985/api/v1/summaries" + self.previous_data: Optional[Dict[str, Any]] = None + self.previous_timestamp: Optional[float] = None + self.is_running = False + self.poll_thread: Optional[threading.Thread] = None + self.cleanup_thread: Optional[threading.Thread] = None + + # 启动定期清理过期数据的线程 + self._start_cleanup_thread() + + def _fetch_system_stats(self): + """获取系统统计数据并插入数据库""" + max_retries = 5 + retry_count = 0 + + while retry_count < max_retries: + try: + # 发送HTTP请求获取SRS统计数据 + response = requests.get(self.api_url, timeout=10) + response.raise_for_status() + + data = response.json() + + # 检查返回状态 + if data.get('code') != 0: + self.logger.error(f"SRS API returned error code: {data.get('code')}") + retry_count += 1 + time.sleep(2 ** retry_count) # 指数退避 + continue + + # 提取有用的信息 + stats_data = self._extract_stats_data(data) + if stats_data: + # 插入数据库 + success = self.db.insert_system_stats(**stats_data) + if success: + self.logger.debug("Successfully inserted system stats") + return True + else: + self.logger.error("Failed to insert system stats to database") + retry_count += 1 + else: + self.logger.error("Failed to extract stats data") + retry_count += 1 + + except requests.exceptions.RequestException as e: + self.logger.error(f"HTTP request failed: {e}") + retry_count += 1 + except Exception as e: + self.logger.exception(f"Unexpected error in fetch_system_stats: {e}") + retry_count += 1 + + if retry_count < max_retries: + time.sleep(2 ** retry_count) # 指数退避重试 + + self.logger.error(f"Failed to fetch system stats after {max_retries} retries") + return False + + def _extract_stats_data(self, api_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """从API返回数据中提取有用的统计信息""" + try: + data_section = api_data.get('data', {}) + self_section = data_section.get('self', {}) + system_section = data_section.get('system', {}) + + current_timestamp = data_section.get('now_ms', 0) / 1000.0 # 转换为秒 + + # SRS相关数据 + srs_uptime = self_section.get('srs_uptime', 0.0) + srs_cpu_percent = self_section.get('cpu_percent', 0.0) + srs_memory_percent = self_section.get('mem_percent', 0.0) + + # 网络流量数据(需要计算KBps) + srs_recv_bytes = system_section.get('srs_recv_bytes', 0) + srs_send_bytes = system_section.get('srs_send_bytes', 0) + srs_sample_time = system_section.get('srs_sample_time', 0) / 1000.0 # 转换为秒 + + # 计算KBps(需要与上一次的数据比较) + srs_recv_KBps = 0.0 + srs_send_KBps = 0.0 + + if (self.previous_data and + self.previous_timestamp and + current_timestamp > self.previous_timestamp): + + time_diff = current_timestamp - self.previous_timestamp + prev_recv = self.previous_data.get('srs_recv_bytes', 0) + prev_send = self.previous_data.get('srs_send_bytes', 0) + + if time_diff > 0: + # 计算字节差值并转换为KBps + srs_recv_KBps = max(0, (srs_recv_bytes - prev_recv) / time_diff / 1024) + srs_send_KBps = max(0, (srs_send_bytes - prev_send) / time_diff / 1024) + + # print(srs_recv_KBps, srs_send_KBps) # DEBUG + + # 磁盘IO数据 + disk_read_KBps = system_section.get('disk_read_KBps', 0.0) + disk_write_KBps = system_section.get('disk_write_KBps', 0.0) + + # 操作系统相关数据 + os_uptime = system_section.get('uptime', 0.0) + os_cpu_percent = system_section.get('cpu_percent', 0.0) + os_memory_percent = system_section.get('mem_ram_percent', 0.0) + + # 更新上一次的数据 + self.previous_data = { + 'srs_recv_bytes': srs_recv_bytes, + 'srs_send_bytes': srs_send_bytes, + 'srs_sample_time': srs_sample_time + } + self.previous_timestamp = current_timestamp + + return { + 'srs_uptime': srs_uptime, + 'srs_cpu_percent': srs_cpu_percent, + 'srs_memory_percent': srs_memory_percent, + 'srs_recv_KBps': srs_recv_KBps, + 'srs_send_KBps': srs_send_KBps, + 'disk_read_KBps': disk_read_KBps, + 'disk_write_KBps': disk_write_KBps, + 'os_uptime': os_uptime, + 'os_cpu_percent': os_cpu_percent, + 'os_memory_percent': os_memory_percent + } + + except Exception as e: + self.logger.exception(f"Failed to extract stats data: {e}") + return None + + def start_polling(self, interval: int = 3): + """启动长轮询统计数据收集""" + if self.is_running: + self.logger.warn("Polling is already running") + return + + self.is_running = True + self.poll_thread = threading.Thread(target=self._polling_loop, args=(interval,)) + self.poll_thread.daemon = True + self.poll_thread.start() + self.logger.info(f"Started system stats polling with {interval}s interval") + + def stop_polling(self): + """停止长轮询""" + if not self.is_running: + self.logger.warn("Polling is not running") + return + + self.is_running = False + if self.poll_thread: + self.poll_thread.join(timeout=5) + self.logger.info("Stopped system stats polling") + + def _polling_loop(self, interval: int): + """轮询循环""" + while self.is_running: + try: + self._fetch_system_stats() + except Exception as e: + self.logger.exception(f"Error in polling loop: {e}") + + # 等待指定间隔,但允许提前停止 + for _ in range(interval): + if not self.is_running: + break + time.sleep(1) + + def _start_cleanup_thread(self): + """启动清理过期数据的线程""" + self.cleanup_thread = threading.Thread(target=self._cleanup_loop) + self.cleanup_thread.daemon = True + self.cleanup_thread.start() + self.logger.info("Started system stats cleanup thread (5 minute interval)") + + def _cleanup_loop(self): + """定期清理过期数据的循环""" + while True: + try: + # 执行清理 + success = self.db.delete_expired_system_stats() + if success: + self.logger.debug("Successfully cleaned up expired system stats") + else: + self.logger.warn("Failed to clean up expired system stats") + + time.sleep(300) + + except Exception as e: + self.logger.exception(f"Error in cleanup loop: {e}") + # 发生错误后等待一分钟再继续 + time.sleep(60) + + def _process_stats_data(self, raw_stats: list, time_delta: int, time_interval: int) -> list: + """ + 处理原始统计数据,按指定时间间隔进行插值和过滤 + + Args: + raw_stats: 原始统计数据列表 + time_delta: 时间范围(秒) + time_interval: 时间间隔(秒) + + Returns: + 处理后的统计数据列表 + """ + try: + if not raw_stats: + return [] + + # 参数验证 + if time_interval <= 0: + self.logger.warn("Invalid time_interval, using 1 second") + time_interval = 1 + if time_delta <= 0: + self.logger.warn("Invalid time_delta, using 60 seconds") + time_delta = 60 + + # 转换时间戳并排序 + for stat in raw_stats: + if isinstance(stat['timestamp'], str): + stat['timestamp'] = datetime.fromisoformat(stat['timestamp']) + elif not isinstance(stat['timestamp'], datetime): + # 如果是其他格式,尝试解析 + stat['timestamp'] = datetime.fromisoformat(str(stat['timestamp'])) + + # 按时间戳排序 + raw_stats.sort(key=lambda x: x['timestamp']) + + # 确定时间范围 + now = datetime.now().replace(microsecond=0) # 取整到秒 + start_time = now - timedelta(seconds=time_delta) + + # 过滤掉超出时间范围的数据 + filtered_stats = [stat for stat in raw_stats if stat['timestamp'] >= start_time] + + if not filtered_stats: + self.logger.warn("No data points in specified time range") + return [] + + # 生成目标时间点列表(从最近一次记录向前) + target_times = [] + current_time = now + while current_time >= start_time: + target_times.append(current_time) + current_time -= timedelta(seconds=time_interval) + + # 反转列表,使其按时间顺序排列 + target_times.reverse() + + # 对每个目标时间点进行插值 + processed_stats = [] + numeric_fields = [ + 'srs_uptime', 'srs_cpu_percent', 'srs_memory_percent', + 'srs_recv_KBps', 'srs_send_KBps', 'disk_read_KBps', + 'disk_write_KBps', 'os_uptime', 'os_cpu_percent', 'os_memory_percent' + ] + + for target_time in target_times: + interpolated_data = self._interpolate_data_point(filtered_stats, target_time, numeric_fields) + if interpolated_data: + processed_stats.append(interpolated_data) + + self.logger.debug(f"Processed {len(raw_stats)} raw points into {len(processed_stats)} interpolated points") + return processed_stats + + except Exception as e: + self.logger.exception(f"Error processing stats data: {e}") + return raw_stats # 如果处理失败,返回原始数据 + + def _interpolate_data_point(self, raw_stats: list, target_time: datetime, numeric_fields: list) -> Optional[Dict[str, Any]]: + """ + 为指定时间点插值数据 + + Args: + raw_stats: 原始统计数据列表(已排序) + target_time: 目标时间点 + numeric_fields: 需要插值的数值字段列表 + + Returns: + 插值后的数据点 + """ + try: + # 查找最接近的前后两个数据点 + before_point = None + after_point = None + + for i, stat in enumerate(raw_stats): + stat_time = stat['timestamp'] + + if stat_time <= target_time: + before_point = stat + elif stat_time > target_time: + after_point = stat + break + + # 如果没有找到合适的数据点,使用最近的点 + if before_point is None and after_point is None: + return None + elif before_point is None: + # 只有后面的点,直接使用 + result = after_point.copy() + result['timestamp'] = target_time.replace(microsecond=0) + return result + elif after_point is None: + # 只有前面的点,直接使用 + result = before_point.copy() + result['timestamp'] = target_time.replace(microsecond=0) + return result + + # 计算时间差和权重 + before_time = before_point['timestamp'] + after_time = after_point['timestamp'] + + # 如果时间点完全匹配,直接返回 + if before_time == target_time: + return before_point.copy() + if after_time == target_time: + return after_point.copy() + + # 计算插值权重 + total_diff = (after_time - before_time).total_seconds() + if total_diff == 0: + # 两个点时间相同,使用前一个点 + result = before_point.copy() + result['timestamp'] = target_time.replace(microsecond=0) + return result + + target_diff = (target_time - before_time).total_seconds() + weight = target_diff / total_diff + + # 执行线性插值 + result = {'timestamp': target_time.replace(microsecond=0)} # 确保时间戳取整到秒 + + for field in numeric_fields: + if field in before_point and field in after_point: + before_val = float(before_point[field]) if before_point[field] is not None else 0.0 + after_val = float(after_point[field]) if after_point[field] is not None else 0.0 + + # 线性插值 + interpolated_val = before_val + (after_val - before_val) * weight + result[field] = round(interpolated_val, 4) # 保留4位小数 + else: + # 如果字段不存在,使用前一个点的值 + result[field] = before_point.get(field, 0.0) + + return result + + except Exception as e: + self.logger.exception(f"Error interpolating data point: {e}") + return None + + def get_recent_stats(self, user_group: list, time_delta: int = 5, time_interval: int = 1): + """获取最近的统计数据""" + try: + if not any(group in user_group for group in ['streamer', 'manager', 'admin']): + return {"success": False, "message": "权限不足,无法获取统计数据"} + + # 获取原始数据,多获取一些以便插值 + raw_stats = self.db.get_system_stats(time_delta + 5) # 多获取60秒数据用于插值 + if not raw_stats: + return {"success": False, "message": "没有找到相关的统计数据"} + + # 处理数据,按指定间隔进行插值和过滤 + processed_stats = self._process_stats_data(raw_stats, time_delta, time_interval) + + # 转换时间戳为字符串格式,便于JSON序列化 + for stat in processed_stats: + if isinstance(stat['timestamp'], datetime): + stat['timestamp'] = stat['timestamp'].isoformat() + + return { + "success": True, + "system_stats": processed_stats, + "metadata": { + "time_delta": time_delta, + "time_interval": time_interval, + "data_points": len(processed_stats), + "original_points": len(raw_stats) + } + } + + except Exception as e: + self.logger.exception(f"Error in get_recent_stats: {e}") + return {"success": False, "message": "服务器内部错误"} \ No newline at end of file diff --git a/trunk/python/test_stats_processing.py b/trunk/python/test_stats_processing.py new file mode 100644 index 000000000..e69de29bb diff --git a/trunk/research/python-subprocess/main.py b/trunk/research/python-subprocess/main.py deleted file mode 100755 index 81df4270b..000000000 --- a/trunk/research/python-subprocess/main.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -import sys, shlex, time, subprocess - -cmd = "./python.subprocess 0 80000" -args = shlex.split(str(cmd)) -print "cmd: %s, args: %s"%(cmd, args) - -# the communicate will read all data and wait for sub process to quit. -def use_communicate(args, fout, ferr): - process = subprocess.Popen(args, stdout=fout, stderr=ferr) - (stdout_str, stderr_str) = process.communicate() - return (stdout_str, stderr_str) - -# if use subprocess.PIPE, the pipe will full about 50KB data, -# and sub process will blocked, then timeout will kill it. -def use_poll(args, fout, ferr, timeout): - (stdout_str, stderr_str) = (None, None) - process = subprocess.Popen(args, stdout=fout, stderr=ferr) - starttime = time.time() - while True: - process.poll() - if process.returncode is not None: - (stdout_str, stderr_str) = process.communicate() - break - if timeout > 0 and time.time() - starttime >= timeout: - print "timeout, kill process. timeout=%s"%(timeout) - process.kill() - break - time.sleep(1) - process.wait() - return (stdout_str, stderr_str) - -# stdout/stderr can be fd, fileobject, subprocess.PIPE, None -fnull = open("/dev/null", "rw") -fout = fnull.fileno()#subprocess.PIPE#fnull#fnull.fileno() -ferr = fnull.fileno()#subprocess.PIPE#fnull#fnull.fileno() -print "fout=%s, ferr=%s"%(fout, ferr) -#(stdout_str, stderr_str) = use_communicate(args, fout, ferr) -(stdout_str, stderr_str) = use_poll(args, fout, ferr, 10) - -def print_result(stdout_str, stderr_str): - if stdout_str is None: - stdout_str = "" - if stderr_str is None: - stderr_str = "" - print "terminated, size of stdout=%s, stderr=%s"%(len(stdout_str), len(stderr_str)) - while True: - time.sleep(1) - -print_result(stdout_str, stderr_str) diff --git a/trunk/research/python-subprocess/python.subprocess.cpp b/trunk/research/python-subprocess/python.subprocess.cpp deleted file mode 100644 index 6b409b028..000000000 --- a/trunk/research/python-subprocess/python.subprocess.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include <stdio.h> -#include <unistd.h> -#include <stdlib.h> - -/** -# always to print to stdout and stderr. -g++ python.subprocess.cpp -o python.subprocess -*/ -int main(int argc, char** argv) { - if (argc <= 2) { - printf("Usage: <%s> <interval_ms> <max_loop>\n" - " %s 50 100000\n", argv[0], argv[0]); - exit(-1); - return -1; - } - - int interval_ms = ::atoi(argv[1]); - int max_loop = ::atoi(argv[2]); - printf("always to print to stdout and stderr.\n"); - printf("interval: %d ms\n", interval_ms); - printf("max_loop: %d\n", max_loop); - - for (int i = 0; i < max_loop; i++) { - fprintf(stdout, "always to print to stdout and stderr. interval=%dms, max=%d, current=%d\n", interval_ms, max_loop, i); - fprintf(stderr, "always to print to stdout and stderr. interval=%dms, max=%d, current=%d\n", interval_ms, max_loop, i); - if (interval_ms > 0) { - usleep(interval_ms * 1000); - } - } - - return 0; -}