import aiosqlite, asyncio, orjson from typing import Optional from pathlib import Path from datetime import datetime, timezone from collections import defaultdict def get_db() -> "TokenDatabase": """Return the live TokenDatabase instance held by router.py. Resolved lazily so submodules can access it without import cycles, and so test patches of ``router.db`` flow through to all callers. """ import router # lazy to avoid module-load circular import return router.db class TokenDatabase: def __init__(self, db_path: str = "token_counts.db"): self.db_path = db_path self._db: Optional[aiosqlite.Connection] = None self._db_lock = asyncio.Lock() self._operation_lock = asyncio.Lock() def _ensure_db_directory(self): """Ensure the directory for the database exists.""" db_dir = Path(self.db_path).parent if not db_dir.exists(): db_dir.mkdir(parents=True, exist_ok=True) async def _get_connection(self) -> aiosqlite.Connection: """Return a persistent connection with WAL mode and FK enforcement enabled.""" if self._db is None: async with self._db_lock: if self._db is None: self._ensure_db_directory() self._db = await aiosqlite.connect(self.db_path) # Enable WAL and foreign keys for reliability and integrity await self._db.execute("PRAGMA journal_mode=WAL;") await self._db.execute("PRAGMA foreign_keys = ON;") await self._db.commit() return self._db async def close(self): """Close the persistent database connection, if open.""" if self._db is not None: await self._db.close() self._db = None async def init_db(self): """Initialize the database tables.""" db = await self._get_connection() async with self._operation_lock: await db.execute(''' CREATE TABLE IF NOT EXISTS token_counts ( endpoint TEXT, model TEXT, input_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0, total_tokens INTEGER DEFAULT 0, PRIMARY KEY(endpoint, model) ) ''') await db.execute(''' CREATE TABLE IF NOT EXISTS token_time_series ( id INTEGER PRIMARY KEY AUTOINCREMENT, endpoint TEXT, model TEXT, input_tokens INTEGER, output_tokens INTEGER, total_tokens INTEGER, timestamp INTEGER, FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model) ) ''') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)') # Responses API state — the router owns conversation state for the # /v1/responses family (store / previous_response_id) and tracks # background-task status here so polling survives across workers. await db.execute(''' CREATE TABLE IF NOT EXISTS stored_responses ( response_id TEXT PRIMARY KEY, previous_response_id TEXT, model TEXT, status TEXT, created_at INTEGER, input_messages TEXT, output_items TEXT, usage TEXT, instructions TEXT, error TEXT ) ''') await db.execute('CREATE INDEX IF NOT EXISTS idx_stored_responses_prev ON stored_responses(previous_response_id)') await db.commit() async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): """Update token counts for a specific endpoint and model.""" total_tokens = input_tokens + output_tokens db = await self._get_connection() async with self._operation_lock: await db.execute(''' INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens) VALUES (?, ?, ?, ?, ?) ON CONFLICT (endpoint, model) DO UPDATE SET input_tokens = input_tokens + ?, output_tokens = output_tokens + ?, total_tokens = total_tokens + ? ''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens)) await db.commit() async def add_time_series_entry(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): """Add a time series entry with approximate timestamp.""" total_tokens = input_tokens + output_tokens # Use current minute/hour as approximate timestamp in UTC now = datetime.now(tz=timezone.utc) timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp()) db = await self._get_connection() async with self._operation_lock: await db.execute(''' INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp) VALUES (?, ?, ?, ?, ?, ?) ''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)) await db.commit() async def update_batched_counts(self, counts: dict): """Update multiple token counts in a single transaction.""" if not counts: return db = await self._get_connection() async with self._operation_lock: try: await db.execute('BEGIN') for endpoint, models in counts.items(): for model, (input_tokens, output_tokens) in models.items(): total_tokens = input_tokens + output_tokens await db.execute(''' INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens) VALUES (?, ?, ?, ?, ?) ON CONFLICT (endpoint, model) DO UPDATE SET input_tokens = input_tokens + ?, output_tokens = output_tokens + ?, total_tokens = total_tokens + ? ''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens)) await db.commit() except Exception: # Rollback on error to maintain consistency try: await db.execute('ROLLBACK') except Exception: pass raise async def add_batched_time_series(self, entries: list): """Add multiple time series entries in a single transaction.""" db = await self._get_connection() async with self._operation_lock: try: await db.execute('BEGIN') for entry in entries: await db.execute(''' INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp) VALUES (?, ?, ?, ?, ?, ?) ''', (entry['endpoint'], entry['model'], entry['input_tokens'], entry['output_tokens'], entry['total_tokens'], entry['timestamp'])) await db.commit() except Exception: try: await db.execute('ROLLBACK') except Exception: pass raise async def load_token_counts(self): """Load all token counts from database.""" db = await self._get_connection() async with self._operation_lock: async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor: async for row in cursor: yield { 'endpoint': row[0], 'model': row[1], 'input_tokens': row[2], 'output_tokens': row[3], 'total_tokens': row[4] } async def get_latest_time_series(self, limit: int = 100): """Get the latest time series entries.""" db = await self._get_connection() async with self._operation_lock: async with db.execute(''' SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp FROM token_time_series ORDER BY timestamp DESC LIMIT ? ''', (limit,)) as cursor: async for row in cursor: yield { 'endpoint': row[0], 'model': row[1], 'input_tokens': row[2], 'output_tokens': row[3], 'total_tokens': row[4], 'timestamp': row[5] } async def get_time_series_for_model(self, model: str, limit: int = 50000): """Get time series entries for a specific model, newest first. Uses the (model, timestamp) composite index so the DB does the filtering instead of returning all rows and discarding non-matching ones in Python. """ db = await self._get_connection() async with self._operation_lock: async with db.execute(''' SELECT endpoint, input_tokens, output_tokens, total_tokens, timestamp FROM token_time_series WHERE model = ? ORDER BY timestamp DESC LIMIT ? ''', (model, limit)) as cursor: async for row in cursor: yield { 'endpoint': row[0], 'input_tokens': row[1], 'output_tokens': row[2], 'total_tokens': row[3], 'timestamp': row[4], } async def get_endpoint_distribution_for_model(self, model: str) -> dict: """Return total tokens per endpoint for a specific model as a plain dict. Computed entirely in SQL so no Python-side aggregation is needed. """ db = await self._get_connection() async with self._operation_lock: async with db.execute(''' SELECT endpoint, SUM(total_tokens) FROM token_time_series WHERE model = ? GROUP BY endpoint ''', (model,)) as cursor: rows = await cursor.fetchall() return {row[0]: row[1] for row in rows} async def get_token_counts_for_model(self, model): """Get token counts for a specific model, aggregated across all endpoints.""" db = await self._get_connection() async with self._operation_lock: async with db.execute(''' SELECT 'aggregated' as endpoint, ? as model, SUM(input_tokens) as input_tokens, SUM(output_tokens) as output_tokens, SUM(total_tokens) as total_tokens FROM token_counts WHERE model = ? ''', (model, model)) as cursor: row = await cursor.fetchone() if row is not None: return { 'endpoint': row[0], 'model': row[1], 'input_tokens': row[2], 'output_tokens': row[3], 'total_tokens': row[4] } return None async def aggregate_time_series_older_than(self, days: int, trim_old: bool = False) -> int: """ Aggregate time_series entries older than 'days' days into daily aggregates by endpoint, model and UTC date (YYYY-MM-DD). The results are stored in token_time_series_daily with a UNIQUE constraint on (endpoint, model, date). Returns the number of aggregated groups (distinct (endpoint, model, date) tuples) that were created/updated. """ if not isinstance(days, int) or days <= 0: days = 30 cutoff_ts = int(datetime.now(tz=timezone.utc).timestamp()) - (days * 86400) db = await self._get_connection() aggregated_count = 0 async with self._operation_lock: # Ensure daily table exists await db.execute(''' CREATE TABLE IF NOT EXISTS token_time_series_daily ( id INTEGER PRIMARY KEY AUTOINCREMENT, endpoint TEXT, model TEXT, date TEXT, input_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0, total_tokens INTEGER DEFAULT 0, UNIQUE(endpoint, model, date) ) ''') await db.commit() cursor = await db.execute(''' SELECT endpoint, model, date(timestamp, 'unixepoch') as day, SUM(input_tokens) as in_sum, SUM(output_tokens) as out_sum, SUM(total_tokens) as tot_sum FROM token_time_series WHERE timestamp < ? GROUP BY endpoint, model, day ''', (cutoff_ts,)) rows = await cursor.fetchall() for row in rows: endpoint, model, day, in_sum, out_sum, tot_sum = row await db.execute(''' INSERT INTO token_time_series_daily (endpoint, model, date, input_tokens, output_tokens, total_tokens) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT (endpoint, model, date) DO UPDATE SET input_tokens = input_tokens + ?, output_tokens = output_tokens + ?, total_tokens = total_tokens + ? ''', (endpoint, model, day, int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0), int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0))) aggregated_count += 1 # Trim old entries if requested if trim_old: await db.execute('DELETE FROM token_time_series WHERE timestamp < ?', (cutoff_ts,)) await db.commit() return aggregated_count # ----------------------------------------------------------------- # Responses API state (store / previous_response_id / background) # ----------------------------------------------------------------- @staticmethod def _row_to_response(row) -> dict: """Map a stored_responses row to a plain dict, decoding JSON columns.""" def _loads(val): if val is None: return None try: return orjson.loads(val) except (orjson.JSONDecodeError, TypeError): return None return { 'response_id': row[0], 'previous_response_id': row[1], 'model': row[2], 'status': row[3], 'created_at': row[4], 'input_messages': _loads(row[5]), 'output_items': _loads(row[6]), 'usage': _loads(row[7]), 'instructions': row[8], 'error': _loads(row[9]), } async def store_response( self, response_id: str, *, previous_response_id: Optional[str], model: str, status: str, created_at: int, input_messages: list, output_items: Optional[list] = None, usage: Optional[dict] = None, instructions: Optional[str] = None, error: Optional[dict] = None, ): """Insert or replace a stored Responses-API response row.""" db = await self._get_connection() async with self._operation_lock: await db.execute(''' INSERT INTO stored_responses (response_id, previous_response_id, model, status, created_at, input_messages, output_items, usage, instructions, error) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT (response_id) DO UPDATE SET previous_response_id = excluded.previous_response_id, model = excluded.model, status = excluded.status, created_at = excluded.created_at, input_messages = excluded.input_messages, output_items = excluded.output_items, usage = excluded.usage, instructions = excluded.instructions, error = excluded.error ''', ( response_id, previous_response_id, model, status, created_at, orjson.dumps(input_messages).decode("utf-8"), orjson.dumps(output_items).decode("utf-8") if output_items is not None else None, orjson.dumps(usage).decode("utf-8") if usage is not None else None, instructions, orjson.dumps(error).decode("utf-8") if error is not None else None, )) await db.commit() async def update_response_status( self, response_id: str, status: str, *, output_items: Optional[list] = None, usage: Optional[dict] = None, error: Optional[dict] = None, ): """Update the status (and optionally output/usage/error) of a stored response.""" db = await self._get_connection() async with self._operation_lock: await db.execute(''' UPDATE stored_responses SET status = ?, output_items = COALESCE(?, output_items), usage = COALESCE(?, usage), error = COALESCE(?, error) WHERE response_id = ? ''', ( status, orjson.dumps(output_items).decode("utf-8") if output_items is not None else None, orjson.dumps(usage).decode("utf-8") if usage is not None else None, orjson.dumps(error).decode("utf-8") if error is not None else None, response_id, )) await db.commit() async def get_response(self, response_id: str) -> Optional[dict]: """Return a stored response as a dict, or None if not found.""" db = await self._get_connection() async with self._operation_lock: async with db.execute(''' SELECT response_id, previous_response_id, model, status, created_at, input_messages, output_items, usage, instructions, error FROM stored_responses WHERE response_id = ? ''', (response_id,)) as cursor: row = await cursor.fetchone() return self._row_to_response(row) if row is not None else None async def delete_response(self, response_id: str) -> bool: """Delete a stored response. Returns True if a row was removed.""" db = await self._get_connection() async with self._operation_lock: cursor = await db.execute( 'DELETE FROM stored_responses WHERE response_id = ?', (response_id,) ) await db.commit() return cursor.rowcount > 0 async def get_response_chain(self, response_id: str, max_turns: int = 50) -> list: """Walk previous_response_id back to the root, returned oldest-first. Bounded to ``max_turns`` so a pathological chain cannot stall a request. Missing links terminate the walk gracefully. """ chain: list = [] seen: set = set() current = response_id while current and current not in seen and len(chain) < max_turns: seen.add(current) resp = await self.get_response(current) if resp is None: break chain.append(resp) current = resp.get('previous_response_id') chain.reverse() return chain async def fail_orphaned_responses(self) -> int: """Mark non-terminal responses as failed (called on startup). A background task lives in a worker's event loop; a process restart loses it while the DB row stays ``queued``/``in_progress`` forever. Reconcile those to ``failed`` so polling clients get a terminal state. """ db = await self._get_connection() async with self._operation_lock: cursor = await db.execute(''' UPDATE stored_responses SET status = 'failed', error = ? WHERE status IN ('queued', 'in_progress') ''', (orjson.dumps({"message": "Response interrupted by server restart", "type": "server_error"}).decode("utf-8"),)) await db.commit() return cursor.rowcount