refactor: use a persistent WAL-enabled connection with async locks
- Introduce a lazily initialized, shared aiosqlite connection stored in self._db and two asyncio locks (_db_lock, _operation_lock) for safe concurrent access - Ensure the database directory exists before connecting and enable WAL journaling and foreign keys on first connect - Add close method to gracefully close the persistent connection - Guard initialization and write operations with _operation_lock to ensure single-threaded schema setup - Switch to ON CONFLICT UPSERT for token_counts updates and initialize token_time_series table - Add typing for _db (Optional[aiosqlite.Connection]) and adjust imports accordingly addition: Frontend button with total stats aggregation task and feedback span element to keep user informed and a small database footprint
This commit is contained in:
parent
0ffb321154
commit
59a8ef3abb
3 changed files with 278 additions and 110 deletions
321
db.py
321
db.py
|
|
@ -1,4 +1,6 @@
|
|||
import aiosqlite, os, asyncio
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from collections import defaultdict
|
||||
|
|
@ -6,155 +8,260 @@ from collections import defaultdict
|
|||
class TokenDatabase:
|
||||
def __init__(self, db_path: str = "token_counts.db"):
|
||||
self.db_path = db_path
|
||||
self._ensure_db_directory()
|
||||
self._db: Optional[aiosqlite.Connection] = None
|
||||
self._db_lock = asyncio.Lock()
|
||||
self._operation_lock = asyncio.Lock()
|
||||
|
||||
def _ensure_db_directory(self):
|
||||
"""Ensure the directory for the database exists."""
|
||||
db_dir = Path(self.db_path).parent
|
||||
if not db_dir.exists():
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _get_connection(self):
|
||||
"""Return a connection with WAL mode enabled."""
|
||||
conn= await aiosqlite.connect(self.db_path)
|
||||
await conn.execute("PRAGMA journal_mode=WAL;")
|
||||
return conn
|
||||
|
||||
async def _get_connection(self) -> aiosqlite.Connection:
|
||||
"""Return a persistent connection with WAL mode and FK enforcement enabled."""
|
||||
if self._db is None:
|
||||
async with self._db_lock:
|
||||
if self._db is None:
|
||||
self._ensure_db_directory()
|
||||
self._db = await aiosqlite.connect(self.db_path)
|
||||
# Enable WAL and foreign keys for reliability and integrity
|
||||
await self._db.execute("PRAGMA journal_mode=WAL;")
|
||||
await self._db.execute("PRAGMA foreign_keys = ON;")
|
||||
await self._db.commit()
|
||||
return self._db
|
||||
|
||||
async def close(self):
|
||||
"""Close the persistent database connection, if open."""
|
||||
if self._db is not None:
|
||||
await self._db.close()
|
||||
self._db = None
|
||||
|
||||
async def init_db(self):
|
||||
"""Initialize the database tables."""
|
||||
db = await self._get_connection()
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_counts (
|
||||
endpoint TEXT,
|
||||
model TEXT,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
PRIMARY KEY(endpoint, model)
|
||||
)
|
||||
''')
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_time_series (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
endpoint TEXT,
|
||||
model TEXT,
|
||||
input_tokens INTEGER,
|
||||
output_tokens INTEGER,
|
||||
total_tokens INTEGER,
|
||||
timestamp INTEGER, -- Unix timestamp with approximate minute/hour precision
|
||||
FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model)
|
||||
)
|
||||
''')
|
||||
await db.commit()
|
||||
async with self._operation_lock:
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_counts (
|
||||
endpoint TEXT,
|
||||
model TEXT,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
PRIMARY KEY(endpoint, model)
|
||||
)
|
||||
''')
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_time_series (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
endpoint TEXT,
|
||||
model TEXT,
|
||||
input_tokens INTEGER,
|
||||
output_tokens INTEGER,
|
||||
total_tokens INTEGER,
|
||||
timestamp INTEGER,
|
||||
FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model)
|
||||
)
|
||||
''')
|
||||
await db.commit()
|
||||
|
||||
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
|
||||
"""Update token counts for a specific endpoint and model."""
|
||||
total_tokens = input_tokens + output_tokens
|
||||
db = await self._get_connection()
|
||||
await db.execute('''
|
||||
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(endpoint, model) DO UPDATE SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
total_tokens = total_tokens + ?
|
||||
''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens))
|
||||
await db.commit()
|
||||
async with self._operation_lock:
|
||||
await db.execute('''
|
||||
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint, model) DO UPDATE SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
total_tokens = total_tokens + ?
|
||||
''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens))
|
||||
await db.commit()
|
||||
|
||||
async def add_time_series_entry(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
|
||||
"""Add a time series entry with approximate timestamp."""
|
||||
total_tokens = input_tokens + output_tokens
|
||||
# Use current minute/hour as approximate timestamp
|
||||
# Use current minute/hour as approximate timestamp in UTC
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp())
|
||||
|
||||
db = await self._get_connection()
|
||||
await db.execute('''
|
||||
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp))
|
||||
await db.commit()
|
||||
async with self._operation_lock:
|
||||
await db.execute('''
|
||||
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp))
|
||||
await db.commit()
|
||||
|
||||
async def update_batched_counts(self, counts: dict):
|
||||
"""Update multiple token counts in a single transaction."""
|
||||
if not counts:
|
||||
return
|
||||
db = await self._get_connection()
|
||||
for endpoint, models in counts.items():
|
||||
for model, (input_tokens, output_tokens) in models.items():
|
||||
total_tokens = input_tokens + output_tokens
|
||||
await db.execute('''
|
||||
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(endpoint, model) DO UPDATE SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
total_tokens = total_tokens + ?
|
||||
''', (endpoint, model, input_tokens, output_tokens, total_tokens,
|
||||
input_tokens, output_tokens, total_tokens))
|
||||
await db.commit()
|
||||
async with self._operation_lock:
|
||||
try:
|
||||
await db.execute('BEGIN')
|
||||
for endpoint, models in counts.items():
|
||||
for model, (input_tokens, output_tokens) in models.items():
|
||||
total_tokens = input_tokens + output_tokens
|
||||
await db.execute('''
|
||||
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint, model) DO UPDATE SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
total_tokens = total_tokens + ?
|
||||
''', (endpoint, model, input_tokens, output_tokens, total_tokens,
|
||||
input_tokens, output_tokens, total_tokens))
|
||||
await db.commit()
|
||||
except Exception:
|
||||
# Rollback on error to maintain consistency
|
||||
try:
|
||||
await db.execute('ROLLBACK')
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
async def add_batched_time_series(self, entries: list):
|
||||
"""Add multiple time series entries in a single transaction."""
|
||||
db = await self._get_connection()
|
||||
for entry in entries:
|
||||
await db.execute('''
|
||||
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (entry['endpoint'], entry['model'], entry['input_tokens'],
|
||||
entry['output_tokens'], entry['total_tokens'], entry['timestamp']))
|
||||
await db.commit()
|
||||
async with self._operation_lock:
|
||||
try:
|
||||
await db.execute('BEGIN')
|
||||
for entry in entries:
|
||||
await db.execute('''
|
||||
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (entry['endpoint'], entry['model'], entry['input_tokens'],
|
||||
entry['output_tokens'], entry['total_tokens'], entry['timestamp']))
|
||||
await db.commit()
|
||||
except Exception:
|
||||
try:
|
||||
await db.execute('ROLLBACK')
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
async def load_token_counts(self):
|
||||
"""Load all token counts from database."""
|
||||
db = await self._get_connection()
|
||||
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor:
|
||||
async for row in cursor:
|
||||
yield {
|
||||
'endpoint': row[0],
|
||||
'model': row[1],
|
||||
'input_tokens': row[2],
|
||||
'output_tokens': row[3],
|
||||
'total_tokens': row[4]
|
||||
}
|
||||
async with self._operation_lock:
|
||||
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor:
|
||||
async for row in cursor:
|
||||
yield {
|
||||
'endpoint': row[0],
|
||||
'model': row[1],
|
||||
'input_tokens': row[2],
|
||||
'output_tokens': row[3],
|
||||
'total_tokens': row[4]
|
||||
}
|
||||
|
||||
async def get_latest_time_series(self, limit: int = 100):
|
||||
"""Get the latest time series entries."""
|
||||
db = await self._get_connection()
|
||||
async with db.execute('''
|
||||
SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp
|
||||
FROM token_time_series
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (limit,)) as cursor:
|
||||
async for row in cursor:
|
||||
yield {
|
||||
'endpoint': row[0],
|
||||
'model': row[1],
|
||||
'input_tokens': row[2],
|
||||
'output_tokens': row[3],
|
||||
'total_tokens': row[4],
|
||||
'timestamp': row[5]
|
||||
}
|
||||
async with self._operation_lock:
|
||||
async with db.execute('''
|
||||
SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp
|
||||
FROM token_time_series
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (limit,)) as cursor:
|
||||
async for row in cursor:
|
||||
yield {
|
||||
'endpoint': row[0],
|
||||
'model': row[1],
|
||||
'input_tokens': row[2],
|
||||
'output_tokens': row[3],
|
||||
'total_tokens': row[4],
|
||||
'timestamp': row[5]
|
||||
}
|
||||
|
||||
async def get_token_counts_for_model(self, model):
|
||||
"""Get token counts for a specific model, aggregated across all endpoints."""
|
||||
db = await self._get_connection()
|
||||
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor:
|
||||
total_input = 0
|
||||
total_output = 0
|
||||
total_tokens = 0
|
||||
async for row in cursor:
|
||||
total_input += row[2]
|
||||
total_output += row[3]
|
||||
total_tokens += row[4]
|
||||
|
||||
if total_input > 0 or total_output > 0:
|
||||
return {
|
||||
'endpoint': 'aggregated',
|
||||
'model': model,
|
||||
'input_tokens': total_input,
|
||||
'output_tokens': total_output,
|
||||
'total_tokens': total_tokens
|
||||
}
|
||||
async with self._operation_lock:
|
||||
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor:
|
||||
total_input = 0
|
||||
total_output = 0
|
||||
total_tokens = 0
|
||||
async for row in cursor:
|
||||
total_input += row[2]
|
||||
total_output += row[3]
|
||||
total_tokens += row[4]
|
||||
|
||||
if total_input > 0 or total_output > 0:
|
||||
return {
|
||||
'endpoint': 'aggregated',
|
||||
'model': model,
|
||||
'input_tokens': total_input,
|
||||
'output_tokens': total_output,
|
||||
'total_tokens': total_tokens
|
||||
}
|
||||
return None
|
||||
|
||||
async def aggregate_time_series_older_than(self, days: int, trim_old: bool = False) -> int:
|
||||
"""
|
||||
Aggregate time_series entries older than 'days' days into daily aggregates by
|
||||
endpoint, model and UTC date (YYYY-MM-DD). The results are stored in
|
||||
token_time_series_daily with a UNIQUE constraint on (endpoint, model, date).
|
||||
|
||||
Returns the number of aggregated groups (distinct (endpoint, model, date) tuples)
|
||||
that were created/updated.
|
||||
"""
|
||||
if not isinstance(days, int) or days <= 0:
|
||||
days = 30
|
||||
|
||||
cutoff_ts = int(datetime.now(tz=timezone.utc).timestamp()) - (days * 86400)
|
||||
|
||||
db = await self._get_connection()
|
||||
aggregated_count = 0
|
||||
|
||||
async with self._operation_lock:
|
||||
# Ensure daily table exists
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_time_series_daily (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
endpoint TEXT,
|
||||
model TEXT,
|
||||
date TEXT,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
UNIQUE(endpoint, model, date)
|
||||
)
|
||||
''')
|
||||
await db.commit()
|
||||
|
||||
cursor = await db.execute('''
|
||||
SELECT endpoint, model, date(timestamp, 'unixepoch') as day,
|
||||
SUM(input_tokens) as in_sum,
|
||||
SUM(output_tokens) as out_sum,
|
||||
SUM(total_tokens) as tot_sum
|
||||
FROM token_time_series
|
||||
WHERE timestamp < ?
|
||||
GROUP BY endpoint, model, day
|
||||
''', (cutoff_ts,))
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
endpoint, model, day, in_sum, out_sum, tot_sum = row
|
||||
await db.execute('''
|
||||
INSERT INTO token_time_series_daily (endpoint, model, date, input_tokens, output_tokens, total_tokens)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint, model, date) DO UPDATE SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
total_tokens = total_tokens + ?
|
||||
''', (endpoint, model, day, int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0),
|
||||
int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0)))
|
||||
aggregated_count += 1
|
||||
|
||||
# Trim old entries if requested
|
||||
if trim_old:
|
||||
await db.execute('DELETE FROM token_time_series WHERE timestamp < ?', (cutoff_ts,))
|
||||
|
||||
await db.commit()
|
||||
|
||||
return aggregated_count
|
||||
|
|
|
|||
46
router.py
46
router.py
|
|
@ -52,7 +52,7 @@ flush_task: asyncio.Task | None = None
|
|||
# Token Count Buffer (for write-behind pattern)
|
||||
# ------------------------------------------------------------------
|
||||
# Structure: {endpoint: {model: (input_tokens, output_tokens)}}
|
||||
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(tuple))
|
||||
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0)))
|
||||
# Time series buffer with timestamp
|
||||
time_series_buffer: list[dict[str, int | str]] = []
|
||||
|
||||
|
|
@ -258,6 +258,29 @@ async def flush_buffer() -> None:
|
|||
await db.add_batched_time_series(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
|
||||
async def flush_remaining_buffers() -> None:
|
||||
"""
|
||||
Flush any in-memory buffers to the database on shutdown.
|
||||
This is designed to be safely invoked during shutdown and should not raise.
|
||||
"""
|
||||
try:
|
||||
flushed_entries = 0
|
||||
if token_buffer:
|
||||
await db.update_batched_counts(token_buffer)
|
||||
flushed_entries += sum(len(v) for v in token_buffer.values())
|
||||
token_buffer.clear()
|
||||
if time_series_buffer:
|
||||
await db.add_batched_time_series(time_series_buffer)
|
||||
flushed_entries += len(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
if flushed_entries:
|
||||
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
|
||||
else:
|
||||
print("[shutdown] No in-memory entries to flush on shutdown.")
|
||||
except Exception as e:
|
||||
# Do not raise during shutdown – log and continue teardown
|
||||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
||||
|
||||
class fetch:
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -1179,6 +1202,26 @@ async def token_counts_proxy():
|
|||
})
|
||||
return {"total_tokens": total, "breakdown": breakdown}
|
||||
|
||||
@app.post("/api/aggregate_time_series_days")
|
||||
async def aggregate_time_series_days_proxy(request: Request):
|
||||
"""
|
||||
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if not body_bytes:
|
||||
days = 30
|
||||
trim_old = False
|
||||
else:
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
days = int(payload.get("days", 30))
|
||||
trim_old = bool(payload.get("trim_old", False))
|
||||
except Exception:
|
||||
days = 30
|
||||
trim_old = False
|
||||
aggregated = await db.aggregate_time_series_older_than(days, trim_old=trim_old)
|
||||
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
|
||||
|
||||
# 12. API route – Stats
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/stats")
|
||||
|
|
@ -1965,6 +2008,7 @@ async def startup_event() -> None:
|
|||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
await close_all_sse_queues()
|
||||
await flush_remaining_buffers()
|
||||
await app_state["session"].close()
|
||||
if token_worker_task is not None:
|
||||
token_worker_task.cancel()
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@
|
|||
/></a>
|
||||
<div class="header-row">
|
||||
<h1>Router Dashboard</h1>
|
||||
<button id="total-tokens-btn">Stats Total</button>
|
||||
<button id="total-tokens-btn">Stats Total</button><span id="aggregation-status" class="loading" style="margin-left:8px;"></span>
|
||||
</div>
|
||||
|
||||
<button onclick="toggleDarkMode()" id="dark-mode-button">
|
||||
|
|
@ -1008,6 +1008,23 @@ document.addEventListener('DOMContentLoaded', () => {
|
|||
const modal = document.getElementById('total-tokens-modal');
|
||||
const numberEl = document.getElementById('total-tokens-number');
|
||||
numberEl.textContent = data.total_tokens;
|
||||
document.getElementById('aggregation-status').textContent = 'Aggregating...';
|
||||
try {
|
||||
const aggResp = await fetch('/api/aggregate_time_series_days', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ days: 30 , trim_old: true})
|
||||
});
|
||||
if (aggResp.ok) {
|
||||
const aggData = await aggResp.json();
|
||||
const aggr = aggData.aggregated_groups ?? 0;
|
||||
document.getElementById('aggregation-status').textContent = `Aggregated ${aggr} groups`;
|
||||
} else {
|
||||
document.getElementById('aggregation-status').textContent = 'Aggregation failed';
|
||||
}
|
||||
} catch (err) {
|
||||
document.getElementById('aggregation-status').textContent = 'Aggregation error';
|
||||
}
|
||||
const chartCanvas = document.getElementById('total-tokens-chart');
|
||||
if (chartCanvas) {
|
||||
// Destroy existing chart if it exists
|
||||
|
|
@ -1066,7 +1083,7 @@ document.addEventListener('DOMContentLoaded', () => {
|
|||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: 'Token Distribution by Endpoint per Model'
|
||||
text: 'Token Distribution by Model per Endpoint'
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue