494 lines
21 KiB
Python
494 lines
21 KiB
Python
import aiosqlite, asyncio, orjson
|
|
from typing import Optional
|
|
from pathlib import Path
|
|
from datetime import datetime, timezone
|
|
from collections import defaultdict
|
|
|
|
|
|
def get_db() -> "TokenDatabase":
|
|
"""Return the live TokenDatabase instance held by router.py.
|
|
|
|
Resolved lazily so submodules can access it without import cycles, and
|
|
so test patches of ``router.db`` flow through to all callers.
|
|
"""
|
|
import router # lazy to avoid module-load circular import
|
|
return router.db
|
|
|
|
|
|
class TokenDatabase:
|
|
def __init__(self, db_path: str = "token_counts.db"):
|
|
self.db_path = db_path
|
|
self._db: Optional[aiosqlite.Connection] = None
|
|
self._db_lock = asyncio.Lock()
|
|
self._operation_lock = asyncio.Lock()
|
|
|
|
def _ensure_db_directory(self):
|
|
"""Ensure the directory for the database exists."""
|
|
db_dir = Path(self.db_path).parent
|
|
if not db_dir.exists():
|
|
db_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
async def _get_connection(self) -> aiosqlite.Connection:
|
|
"""Return a persistent connection with WAL mode and FK enforcement enabled."""
|
|
if self._db is None:
|
|
async with self._db_lock:
|
|
if self._db is None:
|
|
self._ensure_db_directory()
|
|
self._db = await aiosqlite.connect(self.db_path)
|
|
# Enable WAL and foreign keys for reliability and integrity
|
|
await self._db.execute("PRAGMA journal_mode=WAL;")
|
|
await self._db.execute("PRAGMA foreign_keys = ON;")
|
|
await self._db.commit()
|
|
return self._db
|
|
|
|
async def close(self):
|
|
"""Close the persistent database connection, if open."""
|
|
if self._db is not None:
|
|
await self._db.close()
|
|
self._db = None
|
|
|
|
async def init_db(self):
|
|
"""Initialize the database tables."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
await db.execute('''
|
|
CREATE TABLE IF NOT EXISTS token_counts (
|
|
endpoint TEXT,
|
|
model TEXT,
|
|
input_tokens INTEGER DEFAULT 0,
|
|
output_tokens INTEGER DEFAULT 0,
|
|
total_tokens INTEGER DEFAULT 0,
|
|
PRIMARY KEY(endpoint, model)
|
|
)
|
|
''')
|
|
await db.execute('''
|
|
CREATE TABLE IF NOT EXISTS token_time_series (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
endpoint TEXT,
|
|
model TEXT,
|
|
input_tokens INTEGER,
|
|
output_tokens INTEGER,
|
|
total_tokens INTEGER,
|
|
timestamp INTEGER,
|
|
FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model)
|
|
)
|
|
''')
|
|
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
|
|
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)')
|
|
# Responses API state — the router owns conversation state for the
|
|
# /v1/responses family (store / previous_response_id) and tracks
|
|
# background-task status here so polling survives across workers.
|
|
await db.execute('''
|
|
CREATE TABLE IF NOT EXISTS stored_responses (
|
|
response_id TEXT PRIMARY KEY,
|
|
previous_response_id TEXT,
|
|
model TEXT,
|
|
status TEXT,
|
|
created_at INTEGER,
|
|
input_messages TEXT,
|
|
output_items TEXT,
|
|
usage TEXT,
|
|
instructions TEXT,
|
|
error TEXT
|
|
)
|
|
''')
|
|
await db.execute('CREATE INDEX IF NOT EXISTS idx_stored_responses_prev ON stored_responses(previous_response_id)')
|
|
await db.commit()
|
|
|
|
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
|
|
"""Update token counts for a specific endpoint and model."""
|
|
total_tokens = input_tokens + output_tokens
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
await db.execute('''
|
|
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT (endpoint, model) DO UPDATE SET
|
|
input_tokens = input_tokens + ?,
|
|
output_tokens = output_tokens + ?,
|
|
total_tokens = total_tokens + ?
|
|
''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens))
|
|
await db.commit()
|
|
|
|
async def add_time_series_entry(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
|
|
"""Add a time series entry with approximate timestamp."""
|
|
total_tokens = input_tokens + output_tokens
|
|
# Use current minute/hour as approximate timestamp in UTC
|
|
now = datetime.now(tz=timezone.utc)
|
|
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp())
|
|
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
await db.execute('''
|
|
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp))
|
|
await db.commit()
|
|
|
|
async def update_batched_counts(self, counts: dict):
|
|
"""Update multiple token counts in a single transaction."""
|
|
if not counts:
|
|
return
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
try:
|
|
await db.execute('BEGIN')
|
|
for endpoint, models in counts.items():
|
|
for model, (input_tokens, output_tokens) in models.items():
|
|
total_tokens = input_tokens + output_tokens
|
|
await db.execute('''
|
|
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT (endpoint, model) DO UPDATE SET
|
|
input_tokens = input_tokens + ?,
|
|
output_tokens = output_tokens + ?,
|
|
total_tokens = total_tokens + ?
|
|
''', (endpoint, model, input_tokens, output_tokens, total_tokens,
|
|
input_tokens, output_tokens, total_tokens))
|
|
await db.commit()
|
|
except Exception:
|
|
# Rollback on error to maintain consistency
|
|
try:
|
|
await db.execute('ROLLBACK')
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
async def add_batched_time_series(self, entries: list):
|
|
"""Add multiple time series entries in a single transaction."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
try:
|
|
await db.execute('BEGIN')
|
|
for entry in entries:
|
|
await db.execute('''
|
|
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
''', (entry['endpoint'], entry['model'], entry['input_tokens'],
|
|
entry['output_tokens'], entry['total_tokens'], entry['timestamp']))
|
|
await db.commit()
|
|
except Exception:
|
|
try:
|
|
await db.execute('ROLLBACK')
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
async def load_token_counts(self):
|
|
"""Load all token counts from database."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor:
|
|
async for row in cursor:
|
|
yield {
|
|
'endpoint': row[0],
|
|
'model': row[1],
|
|
'input_tokens': row[2],
|
|
'output_tokens': row[3],
|
|
'total_tokens': row[4]
|
|
}
|
|
|
|
async def get_latest_time_series(self, limit: int = 100):
|
|
"""Get the latest time series entries."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
async with db.execute('''
|
|
SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp
|
|
FROM token_time_series
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
''', (limit,)) as cursor:
|
|
async for row in cursor:
|
|
yield {
|
|
'endpoint': row[0],
|
|
'model': row[1],
|
|
'input_tokens': row[2],
|
|
'output_tokens': row[3],
|
|
'total_tokens': row[4],
|
|
'timestamp': row[5]
|
|
}
|
|
|
|
async def get_time_series_for_model(self, model: str, limit: int = 50000):
|
|
"""Get time series entries for a specific model, newest first.
|
|
|
|
Uses the (model, timestamp) composite index so the DB does the filtering
|
|
instead of returning all rows and discarding non-matching ones in Python.
|
|
"""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
async with db.execute('''
|
|
SELECT endpoint, input_tokens, output_tokens, total_tokens, timestamp
|
|
FROM token_time_series
|
|
WHERE model = ?
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
''', (model, limit)) as cursor:
|
|
async for row in cursor:
|
|
yield {
|
|
'endpoint': row[0],
|
|
'input_tokens': row[1],
|
|
'output_tokens': row[2],
|
|
'total_tokens': row[3],
|
|
'timestamp': row[4],
|
|
}
|
|
|
|
async def get_endpoint_distribution_for_model(self, model: str) -> dict:
|
|
"""Return total tokens per endpoint for a specific model as a plain dict.
|
|
|
|
Computed entirely in SQL so no Python-side aggregation is needed.
|
|
"""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
async with db.execute('''
|
|
SELECT endpoint, SUM(total_tokens)
|
|
FROM token_time_series
|
|
WHERE model = ?
|
|
GROUP BY endpoint
|
|
''', (model,)) as cursor:
|
|
rows = await cursor.fetchall()
|
|
return {row[0]: row[1] for row in rows}
|
|
|
|
async def get_token_counts_for_model(self, model):
|
|
"""Get token counts for a specific model, aggregated across all endpoints."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
async with db.execute('''
|
|
SELECT
|
|
'aggregated' as endpoint,
|
|
? as model,
|
|
SUM(input_tokens) as input_tokens,
|
|
SUM(output_tokens) as output_tokens,
|
|
SUM(total_tokens) as total_tokens
|
|
FROM token_counts
|
|
WHERE model = ?
|
|
''', (model, model)) as cursor:
|
|
row = await cursor.fetchone()
|
|
if row is not None:
|
|
return {
|
|
'endpoint': row[0],
|
|
'model': row[1],
|
|
'input_tokens': row[2],
|
|
'output_tokens': row[3],
|
|
'total_tokens': row[4]
|
|
}
|
|
return None
|
|
|
|
async def aggregate_time_series_older_than(self, days: int, trim_old: bool = False) -> int:
|
|
"""
|
|
Aggregate time_series entries older than 'days' days into daily aggregates by
|
|
endpoint, model and UTC date (YYYY-MM-DD). The results are stored in
|
|
token_time_series_daily with a UNIQUE constraint on (endpoint, model, date).
|
|
|
|
Returns the number of aggregated groups (distinct (endpoint, model, date) tuples)
|
|
that were created/updated.
|
|
"""
|
|
if not isinstance(days, int) or days <= 0:
|
|
days = 30
|
|
|
|
cutoff_ts = int(datetime.now(tz=timezone.utc).timestamp()) - (days * 86400)
|
|
|
|
db = await self._get_connection()
|
|
aggregated_count = 0
|
|
|
|
async with self._operation_lock:
|
|
# Ensure daily table exists
|
|
await db.execute('''
|
|
CREATE TABLE IF NOT EXISTS token_time_series_daily (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
endpoint TEXT,
|
|
model TEXT,
|
|
date TEXT,
|
|
input_tokens INTEGER DEFAULT 0,
|
|
output_tokens INTEGER DEFAULT 0,
|
|
total_tokens INTEGER DEFAULT 0,
|
|
UNIQUE(endpoint, model, date)
|
|
)
|
|
''')
|
|
await db.commit()
|
|
|
|
cursor = await db.execute('''
|
|
SELECT endpoint, model, date(timestamp, 'unixepoch') as day,
|
|
SUM(input_tokens) as in_sum,
|
|
SUM(output_tokens) as out_sum,
|
|
SUM(total_tokens) as tot_sum
|
|
FROM token_time_series
|
|
WHERE timestamp < ?
|
|
GROUP BY endpoint, model, day
|
|
''', (cutoff_ts,))
|
|
rows = await cursor.fetchall()
|
|
|
|
for row in rows:
|
|
endpoint, model, day, in_sum, out_sum, tot_sum = row
|
|
await db.execute('''
|
|
INSERT INTO token_time_series_daily (endpoint, model, date, input_tokens, output_tokens, total_tokens)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT (endpoint, model, date) DO UPDATE SET
|
|
input_tokens = input_tokens + ?,
|
|
output_tokens = output_tokens + ?,
|
|
total_tokens = total_tokens + ?
|
|
''', (endpoint, model, day, int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0),
|
|
int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0)))
|
|
aggregated_count += 1
|
|
|
|
# Trim old entries if requested
|
|
if trim_old:
|
|
await db.execute('DELETE FROM token_time_series WHERE timestamp < ?', (cutoff_ts,))
|
|
|
|
await db.commit()
|
|
|
|
return aggregated_count
|
|
|
|
# -----------------------------------------------------------------
|
|
# Responses API state (store / previous_response_id / background)
|
|
# -----------------------------------------------------------------
|
|
@staticmethod
|
|
def _row_to_response(row) -> dict:
|
|
"""Map a stored_responses row to a plain dict, decoding JSON columns."""
|
|
def _loads(val):
|
|
if val is None:
|
|
return None
|
|
try:
|
|
return orjson.loads(val)
|
|
except (orjson.JSONDecodeError, TypeError):
|
|
return None
|
|
return {
|
|
'response_id': row[0],
|
|
'previous_response_id': row[1],
|
|
'model': row[2],
|
|
'status': row[3],
|
|
'created_at': row[4],
|
|
'input_messages': _loads(row[5]),
|
|
'output_items': _loads(row[6]),
|
|
'usage': _loads(row[7]),
|
|
'instructions': row[8],
|
|
'error': _loads(row[9]),
|
|
}
|
|
|
|
async def store_response(
|
|
self,
|
|
response_id: str,
|
|
*,
|
|
previous_response_id: Optional[str],
|
|
model: str,
|
|
status: str,
|
|
created_at: int,
|
|
input_messages: list,
|
|
output_items: Optional[list] = None,
|
|
usage: Optional[dict] = None,
|
|
instructions: Optional[str] = None,
|
|
error: Optional[dict] = None,
|
|
):
|
|
"""Insert or replace a stored Responses-API response row."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
await db.execute('''
|
|
INSERT INTO stored_responses
|
|
(response_id, previous_response_id, model, status, created_at,
|
|
input_messages, output_items, usage, instructions, error)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT (response_id) DO UPDATE SET
|
|
previous_response_id = excluded.previous_response_id,
|
|
model = excluded.model,
|
|
status = excluded.status,
|
|
created_at = excluded.created_at,
|
|
input_messages = excluded.input_messages,
|
|
output_items = excluded.output_items,
|
|
usage = excluded.usage,
|
|
instructions = excluded.instructions,
|
|
error = excluded.error
|
|
''', (
|
|
response_id, previous_response_id, model, status, created_at,
|
|
orjson.dumps(input_messages).decode("utf-8"),
|
|
orjson.dumps(output_items).decode("utf-8") if output_items is not None else None,
|
|
orjson.dumps(usage).decode("utf-8") if usage is not None else None,
|
|
instructions,
|
|
orjson.dumps(error).decode("utf-8") if error is not None else None,
|
|
))
|
|
await db.commit()
|
|
|
|
async def update_response_status(
|
|
self,
|
|
response_id: str,
|
|
status: str,
|
|
*,
|
|
output_items: Optional[list] = None,
|
|
usage: Optional[dict] = None,
|
|
error: Optional[dict] = None,
|
|
):
|
|
"""Update the status (and optionally output/usage/error) of a stored response."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
await db.execute('''
|
|
UPDATE stored_responses
|
|
SET status = ?,
|
|
output_items = COALESCE(?, output_items),
|
|
usage = COALESCE(?, usage),
|
|
error = COALESCE(?, error)
|
|
WHERE response_id = ?
|
|
''', (
|
|
status,
|
|
orjson.dumps(output_items).decode("utf-8") if output_items is not None else None,
|
|
orjson.dumps(usage).decode("utf-8") if usage is not None else None,
|
|
orjson.dumps(error).decode("utf-8") if error is not None else None,
|
|
response_id,
|
|
))
|
|
await db.commit()
|
|
|
|
async def get_response(self, response_id: str) -> Optional[dict]:
|
|
"""Return a stored response as a dict, or None if not found."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
async with db.execute('''
|
|
SELECT response_id, previous_response_id, model, status, created_at,
|
|
input_messages, output_items, usage, instructions, error
|
|
FROM stored_responses WHERE response_id = ?
|
|
''', (response_id,)) as cursor:
|
|
row = await cursor.fetchone()
|
|
return self._row_to_response(row) if row is not None else None
|
|
|
|
async def delete_response(self, response_id: str) -> bool:
|
|
"""Delete a stored response. Returns True if a row was removed."""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
cursor = await db.execute(
|
|
'DELETE FROM stored_responses WHERE response_id = ?', (response_id,)
|
|
)
|
|
await db.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
async def get_response_chain(self, response_id: str, max_turns: int = 50) -> list:
|
|
"""Walk previous_response_id back to the root, returned oldest-first.
|
|
|
|
Bounded to ``max_turns`` so a pathological chain cannot stall a request.
|
|
Missing links terminate the walk gracefully.
|
|
"""
|
|
chain: list = []
|
|
seen: set = set()
|
|
current = response_id
|
|
while current and current not in seen and len(chain) < max_turns:
|
|
seen.add(current)
|
|
resp = await self.get_response(current)
|
|
if resp is None:
|
|
break
|
|
chain.append(resp)
|
|
current = resp.get('previous_response_id')
|
|
chain.reverse()
|
|
return chain
|
|
|
|
async def fail_orphaned_responses(self) -> int:
|
|
"""Mark non-terminal responses as failed (called on startup).
|
|
|
|
A background task lives in a worker's event loop; a process restart loses
|
|
it while the DB row stays ``queued``/``in_progress`` forever. Reconcile
|
|
those to ``failed`` so polling clients get a terminal state.
|
|
"""
|
|
db = await self._get_connection()
|
|
async with self._operation_lock:
|
|
cursor = await db.execute('''
|
|
UPDATE stored_responses
|
|
SET status = 'failed',
|
|
error = ?
|
|
WHERE status IN ('queued', 'in_progress')
|
|
''', (orjson.dumps({"message": "Response interrupted by server restart", "type": "server_error"}).decode("utf-8"),))
|
|
await db.commit()
|
|
return cursor.rowcount
|