#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Database module for JWT authentication system Provides SQLite database initialization and management for user authentication """ import sqlite3 import os import json import re import threading import uuid from datetime import datetime, timedelta from typing import Optional, List, Dict, Any, Union from contextlib import contextmanager import random import string # Import SRS logger from srs_logger import get_logger class Database: _instance = None _lock = threading.Lock() def __new__(cls, db_path: str = "./objs/srs_database.db"): """Singleton pattern to ensure only one database instance""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self, db_path: str = "./objs/srs_database.db"): if self._initialized: return self._initialized = True self.db_path = db_path self.logger = get_logger() self._connection_lock = threading.Lock() # Initialize database self._init_database() def _init_database(self): """Initialize database and create tables if they don't exist""" try: # Create objs directory if it doesn't exist db_dir = os.path.dirname(self.db_path) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir, exist_ok=True) self.logger.info(f"Created database directory: {db_dir}") # Check if database exists db_exists = os.path.exists(self.db_path) if not db_exists: self.logger.warn(f"Database file not found, creating new database: {self.db_path}") else: self.logger.info(f"Using existing database: {self.db_path}") # Create tables self._create_tables() if not db_exists: self.logger.info("Database initialized successfully with all tables created") else: self.logger.info("Database connection established and tables verified") except Exception as e: self.logger.exception(f"Failed to initialize database: {e}") raise def _create_tables(self): """Create all required tables""" try: with self.get_connection() as conn: cursor = conn.cursor() # Create users table cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( user_id TEXT PRIMARY KEY, username TEXT UNIQUE NOT NULL, name TEXT NOT NULL, email TEXT UNIQUE, hashed_password TEXT NOT NULL, salt TEXT NOT NULL, user_group TEXT NOT NULL DEFAULT '["user"]', last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_activated BOOLEAN NOT NULL DEFAULT 1 ) ''') # Create auth_session table cursor.execute(''' CREATE TABLE IF NOT EXISTS auth_session ( session_id TEXT PRIMARY KEY, user_id TEXT NOT NULL, hashed_authkey TEXT NOT NULL, salt TEXT NOT NULL, expire_time TIMESTAMP NOT NULL, FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE ) ''') # Create refresh_session table cursor.execute(''' CREATE TABLE IF NOT EXISTS refresh_session ( session_id TEXT PRIMARY KEY, user_id TEXT NOT NULL, hashed_refreshkey TEXT NOT NULL, salt TEXT NOT NULL, expire_time TIMESTAMP NOT NULL, FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE ) ''') # Create streams table cursor.execute(''' CREATE TABLE IF NOT EXISTS streams ( stream_id TEXT PRIMARY KEY, stream_code TEXT UNIQUE NOT NULL, streamer_id TEXT NOT NULL, stream_title TEXT NOT NULL, stream_description TEXT, stream_tags TEXT DEFAULT '[]', stream_visibility TEXT NOT NULL DEFAULT 'public', quality_info TEXT, active_time TIMESTAMP, stream_status TEXT NOT NULL DEFAULT 'planned', FOREIGN KEY (streamer_id) REFERENCES users (user_id) ON DELETE CASCADE ) ''') # Create system_stats table cursor.execute(''' CREATE TABLE IF NOT EXISTS system_stats ( timestamp TIMESTAMP PRIMARY KEY, srs_uptime REAL NOT NULL, srs_cpu_percent REAL NOT NULL, srs_memory_percent REAL NOT NULL, srs_recv_KBps REAL NOT NULL, srs_send_KBps REAL NOT NULL, disk_read_KBps REAL NOT NULL, disk_write_KBps REAL NOT NULL, os_uptime REAL NOT NULL, os_cpu_percent REAL NOT NULL, os_memory_percent REAL NOT NULL ) ''') # Create indexes for better performance cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_user_id ON auth_session(user_id)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_session_expire ON auth_session(expire_time)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_user_id ON refresh_session(user_id)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_refresh_session_expire ON refresh_session(expire_time)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_system_stats_timestamp ON system_stats(timestamp)') conn.commit() self.logger.debug("All database tables and indexes created/verified successfully") except Exception as e: self.logger.exception(f"Failed to create database tables: {e}") raise @contextmanager def get_connection(self): """Get a database connection with proper error handling and cleanup""" conn = None try: conn = sqlite3.connect(self.db_path, timeout=30.0) conn.row_factory = sqlite3.Row # Enable column access by name conn.execute('PRAGMA foreign_keys = ON') # Enable foreign key constraints # Configure datetime adapter for Python 3.12+ compatibility sqlite3.register_adapter(datetime, lambda dt: dt.isoformat()) sqlite3.register_converter("TIMESTAMP", lambda b: datetime.fromisoformat(b.decode())) yield conn except Exception as e: if conn: conn.rollback() self.logger.error(f"Database connection error: {e}") raise finally: if conn: conn.close() # ============================================================================ # Users management functions # ============================================================================ def cleanup_expired_sessions(self): """Remove expired auth and refresh sessions, and sessions for deactivated users""" try: with self.get_connection() as conn: cursor = conn.cursor() current_time = datetime.now() # Clean up expired auth sessions cursor.execute(''' DELETE FROM auth_session WHERE expire_time < ? ''', (current_time,)) auth_expired_deleted = cursor.rowcount # Clean up expired refresh sessions cursor.execute(''' DELETE FROM refresh_session WHERE expire_time < ? ''', (current_time,)) refresh_expired_deleted = cursor.rowcount # Clean up sessions for deactivated users cursor.execute(''' DELETE FROM auth_session WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0) ''') auth_deactivated_deleted = cursor.rowcount cursor.execute(''' DELETE FROM refresh_session WHERE user_id IN (SELECT user_id FROM users WHERE is_activated = 0) ''') refresh_deactivated_deleted = cursor.rowcount conn.commit() total_auth_deleted = auth_expired_deleted + auth_deactivated_deleted total_refresh_deleted = refresh_expired_deleted + refresh_deactivated_deleted if total_auth_deleted > 0 or total_refresh_deleted > 0: self.logger.info(f"Cleaned up sessions: {total_auth_deleted} auth ({auth_expired_deleted} expired, {auth_deactivated_deleted} deactivated), {total_refresh_deleted} refresh ({refresh_expired_deleted} expired, {refresh_deactivated_deleted} deactivated)") except Exception as e: self.logger.exception(f"Failed to cleanup expired sessions: {e}") def get_user(self, username_or_email: str) -> Optional[Dict[str, Any]]: """Get user by username or email""" if not username_or_email: self.logger.warn("get_user called with None or empty value") return None try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated FROM users WHERE username = ? OR email = ? ''', (username_or_email, username_or_email)) row = cursor.fetchone() if row: user_data = dict(row) user_data['user_group'] = json.loads(user_data['user_group']) return user_data else: return None except Exception as e: self.logger.exception(f"Failed to get user by username/email '{username_or_email}': {e}") return None def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: """Get user by user_id""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated FROM users WHERE user_id = ? ''', (user_id,)) row = cursor.fetchone() if row: user_data = dict(row) user_data['user_group'] = json.loads(user_data['user_group']) return user_data else: return None except Exception as e: self.logger.exception(f"Failed to get user by user_id {user_id}: {e}") return None def create_user(self, username: str, name: str, email: Optional[str], hashed_password: str, salt: str, user_group: List[str] = None, is_activated: bool = True) -> Optional[Dict[str, Any]]: """Create a new user and return user data""" try: # Validate username format if not re.match(r'^[a-zA-Z0-9_-]+$', username): self.logger.warn(f"Invalid username format: '{username}'. Only a-z, A-Z, 0-9, '_', '-' are allowed") return None # Validate email format if provided if email and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): self.logger.warn(f"Invalid email format: '{email}'") return None # Ensure user_group is not empty if user_group is None or len(user_group) == 0: user_group = ["user"] # Generate unique user_id (UUID collisions are extremely rare, but we'll still check) user_id = str(uuid.uuid4()) user_group_json = json.dumps(user_group) with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO users (user_id, username, name, email, hashed_password, salt, user_group, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', (user_id, username, name, email, hashed_password, salt, user_group_json, is_activated)) conn.commit() self.logger.info(f"Created new user: {username} (ID: {user_id}, activated: {is_activated})") # Return the created user data directly to avoid deadlock return { 'user_id': user_id, 'username': username, 'name': name, 'email': email, 'hashed_password': hashed_password, 'salt': salt, 'user_group': user_group, # Already parsed list 'last_active': datetime.now().isoformat(), 'is_activated': is_activated } except sqlite3.IntegrityError as e: if 'username' in str(e): self.logger.warn(f"Username '{username}' already exists") elif 'email' in str(e): self.logger.warn(f"Email '{email}' already exists") else: self.logger.warn(f"User creation failed due to constraint: {e}") return None except Exception as e: self.logger.exception(f"Failed to create user '{username}': {e}") return None def update_user_last_active(self, user_id: str) -> Optional[Dict[str, Any]]: """Update user's last active timestamp and return updated user data""" try: with self.get_connection() as conn: cursor = conn.cursor() current_time = datetime.now() cursor.execute(''' UPDATE users SET last_active = ? WHERE user_id = ? ''', (current_time, user_id)) conn.commit() self.logger.debug(f"Updated last_active for user_id: {user_id}") # Get the updated user data in the same connection to avoid deadlock cursor.execute(''' SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated FROM users WHERE user_id = ? ''', (user_id,)) row = cursor.fetchone() if row: user_data = dict(row) user_data['user_group'] = json.loads(user_data['user_group']) return user_data else: self.logger.warn(f"No user found with user_id: {user_id}") return None except Exception as e: self.logger.exception(f"Failed to update last_active for user_id {user_id}: {e}") return None def update_user_activation(self, user_id: str, activation: bool) -> Optional[Dict[str, Any]]: """Update user's activation status and cleanup sessions if deactivated, return updated user data""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' UPDATE users SET is_activated = ? WHERE user_id = ? ''', (activation, user_id)) conn.commit() self.logger.info(f"Updated activation status for user_id {user_id}: {activation}") # Get the updated user data in the same connection to avoid deadlock cursor.execute(''' SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated FROM users WHERE user_id = ? ''', (user_id,)) row = cursor.fetchone() if row: # Clean up sessions after updating activation status (outside the transaction) if not activation: self.cleanup_expired_sessions() user_data = dict(row) user_data['user_group'] = json.loads(user_data['user_group']) return user_data else: self.logger.warn(f"No user found with user_id: {user_id}") return None except Exception as e: self.logger.exception(f"Failed to update activation for user_id {user_id}: {e}") return None def update_user_password(self, user_id: str, hashed_password: str, salt: str) -> Optional[Dict[str, Any]]: """Update user's password and salt, return updated user data""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' UPDATE users SET hashed_password = ?, salt = ? WHERE user_id = ? ''', (hashed_password, salt, user_id)) conn.commit() if cursor.rowcount == 0: self.logger.warn(f"No user found with user_id: {user_id}") return None self.logger.info(f"Updated password for user_id: {user_id}") # Get the updated user data in the same connection to avoid deadlock cursor.execute(''' SELECT user_id, username, name, email, hashed_password, salt, user_group, last_active, is_activated FROM users WHERE user_id = ? ''', (user_id,)) row = cursor.fetchone() if row: user_data = dict(row) user_data['user_group'] = json.loads(user_data['user_group']) return user_data else: self.logger.warn(f"No user found with user_id: {user_id}") return None except Exception as e: self.logger.exception(f"Failed to update password for user_id {user_id}: {e}") return None def delete_user(self, user_id: str) -> bool: """Delete user and all associated data from all tables""" try: with self.get_connection() as conn: cursor = conn.cursor() # Check if user exists first cursor.execute('SELECT username FROM users WHERE user_id = ?', (user_id,)) user_row = cursor.fetchone() if not user_row: self.logger.warn(f"No user found with user_id: {user_id}") return False username = user_row[0] # Delete auth sessions (foreign key constraints will handle this automatically, but we'll be explicit) cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,)) auth_deleted = cursor.rowcount # Delete refresh sessions cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,)) refresh_deleted = cursor.rowcount # Delete user cursor.execute('DELETE FROM users WHERE user_id = ?', (user_id,)) user_deleted = cursor.rowcount conn.commit() if user_deleted > 0: self.logger.info(f"Deleted user '{username}' (ID: {user_id}) and all associated data: {auth_deleted} auth sessions, {refresh_deleted} refresh sessions") return True else: self.logger.warn(f"Failed to delete user with user_id: {user_id}") return False except Exception as e: self.logger.exception(f"Failed to delete user with user_id {user_id}: {e}") return False # ============================================================================ # Session management functions # ============================================================================ def create_auth_session(self, user_id: str, hashed_authkey: str, salt: str, expire_time: datetime) -> Optional[Dict[str, Any]]: """Create an auth session and return session data""" try: # Generate unique session_id (UUID collisions are extremely rare) session_id = str(uuid.uuid4()) with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO auth_session (session_id, user_id, hashed_authkey, salt, expire_time) VALUES (?, ?, ?, ?, ?) ''', (session_id, user_id, hashed_authkey, salt, expire_time)) conn.commit() self.logger.debug(f"Created auth session for user_id {user_id}, session_id: {session_id}") # Return the created session data return { 'session_id': session_id, 'user_id': user_id, 'hashed_authkey': hashed_authkey, 'salt': salt, 'expire_time': expire_time.isoformat() } except Exception as e: self.logger.exception(f"Failed to create auth session for user_id {user_id}: {e}") return None def create_refresh_session(self, user_id: str, hashed_refreshkey: str, salt: str, expire_time: datetime) -> Optional[Dict[str, Any]]: """Create a refresh session and return session data""" try: # Generate unique session_id (UUID collisions are extremely rare) session_id = str(uuid.uuid4()) with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO refresh_session (session_id, user_id, hashed_refreshkey, salt, expire_time) VALUES (?, ?, ?, ?, ?) ''', (session_id, user_id, hashed_refreshkey, salt, expire_time)) conn.commit() self.logger.debug(f"Created refresh session for user_id {user_id}, session_id: {session_id}") # Return the created session data return { 'session_id': session_id, 'user_id': user_id, 'hashed_refreshkey': hashed_refreshkey, 'salt': salt, 'expire_time': expire_time.isoformat() } except Exception as e: self.logger.exception(f"Failed to create refresh session for user_id {user_id}: {e}") return None def get_auth_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get valid auth session by session_id""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT session_id, user_id, hashed_authkey, salt, expire_time FROM auth_session WHERE session_id = ? AND expire_time > ? ''', (session_id, datetime.now())) row = cursor.fetchone() if row: return dict(row) return None except Exception as e: self.logger.exception(f"Failed to get auth session for session_id {session_id}: {e}") return None def get_refresh_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get valid refresh session by session_id""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT session_id, user_id, hashed_refreshkey, salt, expire_time FROM refresh_session WHERE session_id = ? AND expire_time > ? ''', (session_id, datetime.now())) row = cursor.fetchone() if row: return dict(row) return None except Exception as e: self.logger.exception(f"Failed to get refresh session for session_id {session_id}: {e}") return None def delete_auth_session(self, session_id: str) -> bool: """Delete an auth session""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute('DELETE FROM auth_session WHERE session_id = ?', (session_id,)) conn.commit() if cursor.rowcount > 0: self.logger.debug(f"Deleted auth session_id: {session_id}") return True return False except Exception as e: self.logger.exception(f"Failed to delete auth session {session_id}: {e}") return False def delete_refresh_session(self, session_id: str) -> bool: """Delete a refresh session""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute('DELETE FROM refresh_session WHERE session_id = ?', (session_id,)) conn.commit() if cursor.rowcount > 0: self.logger.debug(f"Deleted refresh session_id: {session_id}") return True return False except Exception as e: self.logger.exception(f"Failed to delete refresh session {session_id}: {e}") return False def delete_user_sessions(self, user_id: str) -> bool: """Delete all sessions for a user (logout)""" try: with self.get_connection() as conn: cursor = conn.cursor() # Delete auth sessions cursor.execute('DELETE FROM auth_session WHERE user_id = ?', (user_id,)) auth_deleted = cursor.rowcount # Delete refresh sessions cursor.execute('DELETE FROM refresh_session WHERE user_id = ?', (user_id,)) refresh_deleted = cursor.rowcount conn.commit() self.logger.info(f"Deleted all sessions for user_id {user_id}: {auth_deleted} auth, {refresh_deleted} refresh") return True except Exception as e: self.logger.exception(f"Failed to delete sessions for user_id {user_id}: {e}") return False # ============================================================================ # System statistics functions # ============================================================================ def insert_system_stats(self, srs_uptime: float, srs_cpu_percent: float, srs_memory_percent: float, srs_recv_KBps: float, srs_send_KBps: float, disk_read_KBps: float, disk_write_KBps: float, os_uptime: float, os_cpu_percent: float, os_memory_percent: float) -> bool: """Insert system statistics into the database""" try: with self.get_connection() as conn: cursor = conn.cursor() timestamp = datetime.now() cursor.execute(''' INSERT INTO system_stats (timestamp, srs_uptime, srs_cpu_percent, srs_memory_percent, srs_recv_KBps, srs_send_KBps, disk_read_KBps, disk_write_KBps, os_uptime, os_cpu_percent, os_memory_percent) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', (timestamp, srs_uptime, srs_cpu_percent, srs_memory_percent, srs_recv_KBps, srs_send_KBps, disk_read_KBps, disk_write_KBps, os_uptime, os_cpu_percent, os_memory_percent)) conn.commit() self.logger.debug(f"Inserted system stats at {timestamp.isoformat()}") return True except Exception as e: self.logger.exception(f"Failed to insert system stats: {e}") return False def get_system_stats(self, time_delta: int) -> List[Dict[str, Any]]: """Get system statistics from the database""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT * FROM system_stats WHERE timestamp >= ? ''', (datetime.now() - timedelta(seconds=time_delta),)) rows = cursor.fetchall() columns = [column[0] for column in cursor.description] return [dict(zip(columns, row)) for row in rows] except Exception as e: self.logger.exception(f"Failed to get system stats: {e}") return [] def delete_expired_system_stats(self) -> bool: """Delete expired system statistics older than 20 minutes""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' DELETE FROM system_stats WHERE timestamp < ? ''', (datetime.now() - timedelta(minutes=20),)) conn.commit() if cursor.rowcount > 0: self.logger.info(f"Deleted {cursor.rowcount} expired system stats") return True else: self.logger.debug("No expired system stats to delete") return False except Exception as e: self.logger.exception(f"Failed to delete expired system stats: {e}") return False # ============================================================================ # Streams management functions # ============================================================================ def create_stream(self, streamer_id: str, stream_title: str, stream_description: Optional[str] = None, stream_tags: Optional[List[str]] = [], stream_visibility: str = 'public') -> Dict[str, Any]: """Create a new stream and return stream data""" try: # Validate streamer_id exists if not self.get_user_by_id(streamer_id): self.logger.warn(f"Streamer with user_id {streamer_id} does not exist") return None # Generate unique stream_id (UUID collisions are extremely rare) stream_id = str(uuid.uuid4()) # Generate stream_code in form of xxx-xxxx (only include letters and numbers) def random_code(): part1 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=3)) part2 = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)) return f"{part1}-{part2}" stream_code = random_code() # Convert tags to JSON string stream_tags_json = json.dumps(stream_tags) with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO streams (stream_id, stream_code, streamer_id, stream_title, stream_description, stream_tags, stream_visibility) VALUES (?, ?, ?, ?, ?, ?, ?) ''', (stream_id, stream_code, streamer_id, stream_title, stream_description, stream_tags_json, stream_visibility)) conn.commit() self.logger.info(f"Created new stream: {stream_title} (ID: {stream_id}, Streamer ID: {streamer_id})") return { 'stream_id': stream_id, 'stream_code': stream_code, 'streamer_id': streamer_id, 'stream_title': stream_title, 'stream_description': stream_description, 'stream_tags': stream_tags, 'stream_visibility': stream_visibility, 'stream_status': 'planned', # Default status } except sqlite3.IntegrityError as e: self.logger.warn(f"Stream creation failed due to constraint: {e}") return None except Exception as e: self.logger.exception(f"Failed to create stream: {e}") return None def get_stream_by_vis(self, stream_visibility: str = 'public') -> Optional[List[Dict[str, Any]]]: """Get streams by visibility""" try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT * FROM streams WHERE stream_visibility = ? AND stream_status = 'streaming' ''', (stream_visibility,)) rows = cursor.fetchall() columns = [column[0] for column in cursor.description] return [dict(zip(columns, row)) for row in rows] except Exception as e: self.logger.exception(f"Failed to get streams by visibility '{stream_visibility}': {e}") return None # ============================================================================ # Default database statistics functions # ============================================================================ def get_database_stats(self) -> Dict[str, int]: """Get database statistics""" try: with self.get_connection() as conn: cursor = conn.cursor() stats = {} # Count users cursor.execute('SELECT COUNT(*) FROM users') stats['users'] = cursor.fetchone()[0] # Count active auth sessions cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time > ?', (datetime.now(),)) stats['active_auth_sessions'] = cursor.fetchone()[0] # Count active refresh sessions cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time > ?', (datetime.now(),)) stats['active_refresh_sessions'] = cursor.fetchone()[0] # Count expired auth sessions cursor.execute('SELECT COUNT(*) FROM auth_session WHERE expire_time <= ?', (datetime.now(),)) stats['expired_auth_sessions'] = cursor.fetchone()[0] # Count expired refresh sessions cursor.execute('SELECT COUNT(*) FROM refresh_session WHERE expire_time <= ?', (datetime.now(),)) stats['expired_refresh_sessions'] = cursor.fetchone()[0] return stats except Exception as e: self.logger.exception(f"Failed to get database statistics: {e}") return {} # Global database instance _global_db: Optional[Database] = None def get_database(db_path: str = "./objs/srs_database.db") -> Database: """ Get the global database instance Args: db_path: Path to SQLite database file Returns: Database instance """ global _global_db if _global_db is None: # run_comprehensive_tests() _global_db = Database(db_path) return _global_db if __name__ == "__main__": # Comprehensive test suite for database functionality import argparse import uuid from datetime import datetime, timedelta from srs_logger import init_logger logger = init_logger() # Main execution parser = argparse.ArgumentParser(description="Database Module Test") parser.add_argument("--db-path", default="./objs/srs_database.db", help="Database file path") args = parser.parse_args() db = get_database(args.db_path) # Show database statistics stats = db.get_database_stats() logger.info(f"Database statistics: {stats}") # Clean up expired sessions db.cleanup_expired_sessions()