Merge pull request #15 from nomyo-ai:dev-v0.5.x

Dev-v0.5.x -> Main
This commit is contained in:
Alpha Nerd 2025-12-09 12:08:46 +01:00 committed by GitHub
commit 67edbb5f8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 1072 additions and 76 deletions

5
.gitignore vendored
View file

@ -61,4 +61,7 @@ cython_debug/
*.sqlite3
# Config
config.yaml
config.yaml
# SQLite
*.db*

View file

@ -3,11 +3,17 @@ FROM python:3.13-slim
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1
# Install SQLite
RUN apt-get update && apt-get install -y sqlite3
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir --upgrade pip \
&& pip install --no-cache-dir -r requirements.txt
# Create database directory and set permissions
RUN mkdir -p /app/data && chown -R www-data:www-data /app/data
COPY . .
RUN chmod +x /app/entrypoint.sh

267
db.py Normal file
View file

@ -0,0 +1,267 @@
import aiosqlite
import asyncio
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from collections import defaultdict
class TokenDatabase:
def __init__(self, db_path: str = "token_counts.db"):
self.db_path = db_path
self._db: Optional[aiosqlite.Connection] = None
self._db_lock = asyncio.Lock()
self._operation_lock = asyncio.Lock()
def _ensure_db_directory(self):
"""Ensure the directory for the database exists."""
db_dir = Path(self.db_path).parent
if not db_dir.exists():
db_dir.mkdir(parents=True, exist_ok=True)
async def _get_connection(self) -> aiosqlite.Connection:
"""Return a persistent connection with WAL mode and FK enforcement enabled."""
if self._db is None:
async with self._db_lock:
if self._db is None:
self._ensure_db_directory()
self._db = await aiosqlite.connect(self.db_path)
# Enable WAL and foreign keys for reliability and integrity
await self._db.execute("PRAGMA journal_mode=WAL;")
await self._db.execute("PRAGMA foreign_keys = ON;")
await self._db.commit()
return self._db
async def close(self):
"""Close the persistent database connection, if open."""
if self._db is not None:
await self._db.close()
self._db = None
async def init_db(self):
"""Initialize the database tables."""
db = await self._get_connection()
async with self._operation_lock:
await db.execute('''
CREATE TABLE IF NOT EXISTS token_counts (
endpoint TEXT,
model TEXT,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
PRIMARY KEY(endpoint, model)
)
''')
await db.execute('''
CREATE TABLE IF NOT EXISTS token_time_series (
id INTEGER PRIMARY KEY AUTOINCREMENT,
endpoint TEXT,
model TEXT,
input_tokens INTEGER,
output_tokens INTEGER,
total_tokens INTEGER,
timestamp INTEGER,
FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model)
)
''')
await db.commit()
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
"""Update token counts for a specific endpoint and model."""
total_tokens = input_tokens + output_tokens
db = await self._get_connection()
async with self._operation_lock:
await db.execute('''
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (endpoint, model) DO UPDATE SET
input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
total_tokens = total_tokens + ?
''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens))
await db.commit()
async def add_time_series_entry(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
"""Add a time series entry with approximate timestamp."""
total_tokens = input_tokens + output_tokens
# Use current minute/hour as approximate timestamp in UTC
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp())
db = await self._get_connection()
async with self._operation_lock:
await db.execute('''
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp))
await db.commit()
async def update_batched_counts(self, counts: dict):
"""Update multiple token counts in a single transaction."""
if not counts:
return
db = await self._get_connection()
async with self._operation_lock:
try:
await db.execute('BEGIN')
for endpoint, models in counts.items():
for model, (input_tokens, output_tokens) in models.items():
total_tokens = input_tokens + output_tokens
await db.execute('''
INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (endpoint, model) DO UPDATE SET
input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
total_tokens = total_tokens + ?
''', (endpoint, model, input_tokens, output_tokens, total_tokens,
input_tokens, output_tokens, total_tokens))
await db.commit()
except Exception:
# Rollback on error to maintain consistency
try:
await db.execute('ROLLBACK')
except Exception:
pass
raise
async def add_batched_time_series(self, entries: list):
"""Add multiple time series entries in a single transaction."""
db = await self._get_connection()
async with self._operation_lock:
try:
await db.execute('BEGIN')
for entry in entries:
await db.execute('''
INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
''', (entry['endpoint'], entry['model'], entry['input_tokens'],
entry['output_tokens'], entry['total_tokens'], entry['timestamp']))
await db.commit()
except Exception:
try:
await db.execute('ROLLBACK')
except Exception:
pass
raise
async def load_token_counts(self):
"""Load all token counts from database."""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor:
async for row in cursor:
yield {
'endpoint': row[0],
'model': row[1],
'input_tokens': row[2],
'output_tokens': row[3],
'total_tokens': row[4]
}
async def get_latest_time_series(self, limit: int = 100):
"""Get the latest time series entries."""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('''
SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp
FROM token_time_series
ORDER BY timestamp DESC
LIMIT ?
''', (limit,)) as cursor:
async for row in cursor:
yield {
'endpoint': row[0],
'model': row[1],
'input_tokens': row[2],
'output_tokens': row[3],
'total_tokens': row[4],
'timestamp': row[5]
}
async def get_token_counts_for_model(self, model):
"""Get token counts for a specific model, aggregated across all endpoints."""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor:
total_input = 0
total_output = 0
total_tokens = 0
async for row in cursor:
total_input += row[2]
total_output += row[3]
total_tokens += row[4]
if total_input > 0 or total_output > 0:
return {
'endpoint': 'aggregated',
'model': model,
'input_tokens': total_input,
'output_tokens': total_output,
'total_tokens': total_tokens
}
return None
async def aggregate_time_series_older_than(self, days: int, trim_old: bool = False) -> int:
"""
Aggregate time_series entries older than 'days' days into daily aggregates by
endpoint, model and UTC date (YYYY-MM-DD). The results are stored in
token_time_series_daily with a UNIQUE constraint on (endpoint, model, date).
Returns the number of aggregated groups (distinct (endpoint, model, date) tuples)
that were created/updated.
"""
if not isinstance(days, int) or days <= 0:
days = 30
cutoff_ts = int(datetime.now(tz=timezone.utc).timestamp()) - (days * 86400)
db = await self._get_connection()
aggregated_count = 0
async with self._operation_lock:
# Ensure daily table exists
await db.execute('''
CREATE TABLE IF NOT EXISTS token_time_series_daily (
id INTEGER PRIMARY KEY AUTOINCREMENT,
endpoint TEXT,
model TEXT,
date TEXT,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
UNIQUE(endpoint, model, date)
)
''')
await db.commit()
cursor = await db.execute('''
SELECT endpoint, model, date(timestamp, 'unixepoch') as day,
SUM(input_tokens) as in_sum,
SUM(output_tokens) as out_sum,
SUM(total_tokens) as tot_sum
FROM token_time_series
WHERE timestamp < ?
GROUP BY endpoint, model, day
''', (cutoff_ts,))
rows = await cursor.fetchall()
for row in rows:
endpoint, model, day, in_sum, out_sum, tot_sum = row
await db.execute('''
INSERT INTO token_time_series_daily (endpoint, model, date, input_tokens, output_tokens, total_tokens)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint, model, date) DO UPDATE SET
input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
total_tokens = total_tokens + ?
''', (endpoint, model, day, int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0),
int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0)))
aggregated_count += 1
# Trim old entries if requested
if trim_old:
await db.execute('DELETE FROM token_time_series WHERE timestamp < ?', (cutoff_ts,))
await db.commit()
return aggregated_count

7
entrypoint.sh Normal file → Executable file
View file

@ -1,6 +1,13 @@
#!/usr/bin/env sh
set -e
# Create database directory if it doesn't exist
mkdir -p /app/data
chown -R www-data:www-data /app/data
# Set database path environment variable
export NOMYO_ROUTER_DB_PATH="/app/data/token_counts.db"
CONFIG_PATH_ARG=""
SHOW_HELP=0

View file

@ -36,3 +36,4 @@ typing-inspection==0.4.1
typing_extensions==4.14.1
uvicorn==0.38.0
yarl==1.20.1
aiosqlite

255
router.py
View file

@ -2,11 +2,12 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
version: 0.4
version: 0.5
license: AGPL
"""
# -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Set, List, Optional
from urllib.parse import urlparse
@ -45,6 +46,18 @@ app_state = {
"connector": None,
}
token_worker_task: asyncio.Task | None = None
flush_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# Token Count Buffer (for write-behind pattern)
# ------------------------------------------------------------------
# Structure: {endpoint: {model: (input_tokens, output_tokens)}}
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0)))
# Time series buffer with timestamp
time_series_buffer: list[dict[str, int | str]] = []
# Configuration for periodic flushing
FLUSH_INTERVAL = 10 # seconds
# -------------------------------------------------------------
# 1. Configuration loader
@ -61,6 +74,9 @@ class Config(BaseSettings):
api_keys: Dict[str, str] = Field(default_factory=dict)
# Database configuration
db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db"))
class Config:
# Load from `config.yaml` first, then from env variables
env_prefix = "NOMYO_ROUTER_"
@ -101,6 +117,8 @@ def _config_path_from_env() -> Path:
return Path(candidate).expanduser()
return Path("config.yaml")
from db import TokenDatabase
# Create the global config object it will be overwritten on startup
config = Config.from_yaml(_config_path_from_env())
@ -130,6 +148,9 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(
usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
# Database instance
db: "TokenDatabase" = None
# -------------------------------------------------------------
# 4. Helperfunctions
# -------------------------------------------------------------
@ -181,7 +202,7 @@ def _format_connection_issue(url: str, error: Exception) -> str:
)
return f"Error while contacting {url}: {error}"
def is_ext_openai_endpoint(endpoint: str) -> bool:
if "/v1" not in endpoint:
return False
@ -199,9 +220,66 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
async def token_worker() -> None:
while True:
endpoint, model, prompt, comp = await token_queue.get()
# Accumulate counts in memory buffer
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
# Add to time series buffer with timestamp (UTC)
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
# Update in-memory counts for immediate reporting
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
await publish_snapshot()
await publish_snapshot()
async def flush_buffer() -> None:
"""Periodically flush accumulated token counts to the database."""
while True:
await asyncio.sleep(FLUSH_INTERVAL)
# Flush token counts
if token_buffer:
await db.update_batched_counts(token_buffer)
token_buffer.clear()
# Flush time series entries
if time_series_buffer:
await db.add_batched_time_series(time_series_buffer)
time_series_buffer.clear()
async def flush_remaining_buffers() -> None:
"""
Flush any in-memory buffers to the database on shutdown.
This is designed to be safely invoked during shutdown and should not raise.
"""
try:
flushed_entries = 0
if token_buffer:
await db.update_batched_counts(token_buffer)
flushed_entries += sum(len(v) for v in token_buffer.values())
token_buffer.clear()
if time_series_buffer:
await db.add_batched_time_series(time_series_buffer)
flushed_entries += len(time_series_buffer)
time_series_buffer.clear()
if flushed_entries:
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
else:
print("[shutdown] No in-memory entries to flush on shutdown.")
except Exception as e:
# Do not raise during shutdown log and continue teardown
print(f"[shutdown] Error flushing remaining buffers: {e}")
class fetch:
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
@ -366,7 +444,7 @@ async def decrement_usage(endpoint: str, model: str) -> None:
def iso8601_ns():
ns = time.time_ns()
sec, ns_rem = divmod(ns, 1_000_000_000)
dt = datetime.datetime.fromtimestamp(sec, tz=datetime.timezone.utc)
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
@ -628,7 +706,6 @@ async def choose_endpoint(model: str) -> str:
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
f"advertise the model '{model}'."
)
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
@ -743,7 +820,8 @@ async def proxy(request: Request):
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
@ -757,7 +835,8 @@ async def proxy(request: Request):
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
@ -859,7 +938,8 @@ async def chat_proxy(request: Request):
# `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
@ -873,7 +953,8 @@ async def chat_proxy(request: Request):
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
@ -1086,7 +1167,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
@ -1105,6 +1186,97 @@ async def show_proxy(request: Request, model: Optional[str] = None):
# 4. Return ShowResponse
return show
# -------------------------------------------------------------
@app.get("/api/token_counts")
async def token_counts_proxy():
breakdown = []
total = 0
async for entry in db.load_token_counts():
total += entry['total_tokens']
breakdown.append({
"endpoint": entry["endpoint"],
"model": entry["model"],
"input_tokens": entry["input_tokens"],
"output_tokens": entry["output_tokens"],
"total_tokens": entry["total_tokens"],
})
return {"total_tokens": total, "breakdown": breakdown}
@app.post("/api/aggregate_time_series_days")
async def aggregate_time_series_days_proxy(request: Request):
"""
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
"""
try:
body_bytes = await request.body()
if not body_bytes:
days = 30
trim_old = False
else:
payload = orjson.loads(body_bytes.decode("utf-8"))
days = int(payload.get("days", 30))
trim_old = bool(payload.get("trim_old", False))
except Exception:
days = 30
trim_old = False
aggregated = await db.aggregate_time_series_older_than(days, trim_old=trim_old)
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
# 12. API route Stats
# -------------------------------------------------------------
@app.post("/api/stats")
async def stats_proxy(request: Request, model: Optional[str] = None):
"""
Return token usage statistics for a specific model.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Get token counts from database
token_data = await db.get_token_counts_for_model(model)
if not token_data:
raise HTTPException(
status_code=404, detail="No token data found for this model"
)
# Get time series data for the last 30 days (43200 minutes = 30 days)
# Assuming entries are grouped by minute, 30 days = 43200 entries max
time_series = []
endpoint_totals = defaultdict(int) # Track tokens per endpoint
async for entry in db.get_latest_time_series(limit=50000):
if entry['model'] == model:
time_series.append({
'endpoint': entry['endpoint'],
'timestamp': entry['timestamp'],
'input_tokens': entry['input_tokens'],
'output_tokens': entry['output_tokens'],
'total_tokens': entry['total_tokens']
})
# Accumulate total tokens per endpoint
endpoint_totals[entry['endpoint']] += entry['total_tokens']
return {
'model': model,
'input_tokens': token_data['input_tokens'],
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
'time_series': time_series,
'endpoint_distribution': dict(endpoint_totals)
}
# -------------------------------------------------------------
# 12. API route Copy
# -------------------------------------------------------------
@ -1482,7 +1654,7 @@ async def openai_chat_completions_proxy(request: Request):
optional_params = {
"tools": tools,
"response_format": response_format,
"stream_options": stream_options,
"stream_options": stream_options or {"include_usage": True },
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"temperature": temperature,
@ -1524,13 +1696,23 @@ async def openai_chat_completions_proxy(request: Request):
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices[0].delta.content is not None:
yield f"data: {data}\n\n".encode("utf-8")
if chunk.choices:
if chunk.choices[0].delta.content is not None:
yield f"data: {data}\n\n".encode("utf-8")
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model+":latest"
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
yield b"data: [DONE]\n\n"
else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
@ -1591,7 +1773,7 @@ async def openai_completions_proxy(request: Request):
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"stream_options": stream_options or {"include_usage": True },
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
@ -1619,7 +1801,7 @@ async def openai_completions_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint])
# 3. Async generator that streams completions data and decrements the counter
async def stream_ocompletions_response():
async def stream_ocompletions_response(model=model):
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await oclient.completions.create(**params)
@ -1630,13 +1812,24 @@ async def openai_completions_proxy(request: Request):
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
yield f"data: {data}\n\n".encode("utf-8")
if chunk.choices:
if chunk.choices[0].finish_reason == None:
yield f"data: {data}\n\n".encode("utf-8")
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model+":latest"
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
# Final DONE event
yield b"data: [DONE]\n\n"
else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
@ -1772,7 +1965,7 @@ async def usage_stream(request: Request):
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None:
global config
global config, db
# Load YAML config (or use defaults if not present)
config_path = _config_path_from_env()
config = Config.from_yaml(config_path)
@ -1787,7 +1980,21 @@ async def startup_event() -> None:
f"No configuration file found at {config_path}. "
"Falling back to default settings."
)
# Initialize database
db = TokenDatabase(config.db_path)
await db.init_db()
# Load existing token counts from database
async for count_entry in db.load_token_counts():
endpoint = count_entry['endpoint']
model = count_entry['model']
input_tokens = count_entry['input_tokens']
output_tokens = count_entry['output_tokens']
total_tokens = count_entry['total_tokens']
token_usage_counts[endpoint][model] = total_tokens
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context)
timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15)
@ -1796,10 +2003,14 @@ async def startup_event() -> None:
app_state["connector"] = connector
app_state["session"] = session
token_worker_task = asyncio.create_task(token_worker())
flush_task = asyncio.create_task(flush_buffer())
@app.on_event("shutdown")
async def shutdown_event() -> None:
await close_all_sse_queues()
await flush_remaining_buffers()
await app_state["session"].close()
if token_worker_task is not None:
token_worker_task.cancel()
token_worker_task.cancel()
if flush_task is not None:
flush_task.cancel()

View file

@ -3,6 +3,7 @@
<head>
<meta charset="UTF-8" />
<title>NOMYO Router Dashboard</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<style>
body {
font-family: Arial, Helvetica, sans-serif;
@ -81,7 +82,7 @@
margin-left: 0.5rem;
margin-right: 0.5rem;
min-width: 30%;
font-size: 1rem;
transition: 0.3s;
}
@ -121,7 +122,8 @@
}
.copy-link,
.delete-link,
.show-link {
.show-link,
.stats-link {
font-size: 0.9em;
margin-left: 0.5em;
cursor: pointer;
@ -152,14 +154,14 @@
align-items: center;
justify-content: center;
}
.modal-content {
background: #fff;
padding: 1rem;
max-width: 90%;
max-height: 90%;
overflow: auto;
border-radius: 6px;
}
.modal-content {
background: #fff;
padding: 1rem;
width: 95%;
height: 95%;
overflow: auto;
border-radius: 6px;
}
.close-btn {
float: right;
cursor: pointer;
@ -210,13 +212,65 @@
order: 1;
}
}
/* ---------- Chart Timeframe Controls ---------- */
.timeframe-controls {
margin: 1rem 0;
}
.timeframe-controls button {
margin-right: 0.5rem;
padding: 0.25rem 0.5rem;
cursor: pointer;
background-color: #e0e0e0;
border: none;
border-radius: 4px;
}
.timeframe-controls button.active {
background-color: #0066cc;
color: white;
}
.chart-container {
position: relative;
height: 600px;
margin-top: 1rem;
}
.pie-chart-container {
position: relative;
height: 250px;
margin-top: 1rem;
max-width: 400px;
}
/* ---------- Stats Modal Layout ---------- */
.stats-content-wrapper {
display: flex;
flex-direction: row;
gap: 20px;
}
.main-stats-content {
flex: 1;
}
.endpoint-distribution-container {
flex: 0 0 auto;
width: 400px;
position: relative;
}
.endpoint-distribution-container h3 {
margin-top: 0;
}
.header-row {
display: flex;
align-items: center; /* vertically center the button with the headline */
gap: 1rem;
}
</style>
</head>
<body>
<a href="https://www.nomyo.ai" target="_blank"
><img src="./static/228394408.png" width="100px" height="100px"
/></a>
<h1>Router Dashboard</h1>
<div class="header-row">
<h1>Router Dashboard</h1>
<button id="total-tokens-btn">Stats Total</button><span id="aggregation-status" class="loading" style="margin-left:8px;"></span>
</div>
<button onclick="toggleDarkMode()" id="dark-mode-button">
🌗
@ -267,7 +321,7 @@
<th>Quant</th>
<th>Ctx</th>
<th>Digest</th>
<th>Token</th>
<th>Token</th>
</tr>
</thead>
<tbody id="ps-body">
@ -300,7 +354,159 @@
</div>
<script>
let psRows = new Map();
let psRows = new Map();
// Global placeholders for stats modal handling
let statsModal = null; // stats modal element
let statsChart = null; // Chart.js instance inside the modal
let rawTimeSeries = null; // raw timeseries data for the current model
let totalTokensChart = null; // Chart.js instance for total tokens modal
/* Integrated modal initialization and close handling into the main load block */
// Assign the stats modal element and attach the close handler once the DOM is ready
document.addEventListener('DOMContentLoaded', () => {
// Get the modal element (it now exists in the DOM)
statsModal = document.getElementById('stats-modal');
// Attach a single close handler (prevents multiple duplicate listeners)
if (statsModal) {
statsModal.addEventListener('click', (e) => {
if (e.target === statsModal || e.target.matches('.close-btn')) {
// Hide the modal
statsModal.style.display = 'none';
// Clean up the chart instance to avoid caching stale data
if (statsChart) {
statsChart.destroy();
statsChart = null;
}
// Remove the canvas element so a fresh one is created on next open
const oldCanvas = document.getElementById('time-series-chart');
if (oldCanvas) {
oldCanvas.remove();
}
// Reset stored timeseries data to avoid reuse of stale data
rawTimeSeries = null;
}
});
}
});
/* ---------- Global renderTimeSeriesChart ---------- */
function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
// Guard clause
if (!Array.isArray(timeSeriesData) || !timeSeriesData.length) {
chart.data.labels = [];
chart.data.datasets[0].data = [];
chart.data.datasets[1].data = [];
chart.update();
return;
}
/* ── 1⃣ Determine bucket interval based on timeframe ──────────────────── */
let intervalMs;
let timeFormat;
if (minutes <= 60) {
// 1 hour: 5-minute buckets
intervalMs = 5 * 60 * 1000;
timeFormat = { hour: '2-digit', minute: '2-digit' };
} else if (minutes <= 1440) {
// 1 day: 1-hour buckets
intervalMs = 60 * 60 * 1000;
timeFormat = { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' };
} else if (minutes <= 10080) {
// 7 days: 6-hour buckets
intervalMs = 6 * 60 * 60 * 1000;
timeFormat = { month: 'short', day: 'numeric', hour: '2-digit' };
} else {
// 30 days: 1-day buckets
intervalMs = 24 * 60 * 60 * 1000;
timeFormat = { month: 'short', day: 'numeric' };
}
/* ── 2⃣ Get current time in local timezone ──────────────────────────── */
const now = new Date();
const nowMs = now.getTime();
const cutoffMs = nowMs - minutes * 60 * 1000;
/* ── 3⃣ Build ordered bucket slots aligned to local time boundaries ───── */
const slots = [];
// Round cutoff down to nearest bucket interval in local time
const cutoffDate = new Date(cutoffMs);
let startDate = new Date(cutoffDate);
if (minutes <= 60) {
// Align to 5-minute boundaries
startDate.setMinutes(Math.floor(startDate.getMinutes() / 5) * 5, 0, 0);
} else if (minutes <= 1440) {
// Align to hour boundaries
startDate.setMinutes(0, 0, 0);
} else if (minutes <= 10080) {
// Align to 6-hour boundaries (00:00, 06:00, 12:00, 18:00)
startDate.setHours(Math.floor(startDate.getHours() / 6) * 6, 0, 0, 0);
} else {
// Align to day boundaries
startDate.setHours(0, 0, 0, 0);
}
let slotTime = startDate.getTime();
while (slotTime <= nowMs) {
slots.push(slotTime);
slotTime += intervalMs;
}
/* ── 4⃣ Aggregate raw rows into local time buckets ───────────────────── */
const bucketMap = {};
timeSeriesData.forEach(row => {
// Database stores UTC timestamps in seconds, convert to local time milliseconds
const utcTimestampMs = row.timestamp * 1000;
// Check if within our time window
if (utcTimestampMs < cutoffMs || utcTimestampMs > nowMs) return;
// Find which bucket this timestamp belongs to
let closestSlot = null;
let minDiff = Infinity;
for (const slot of slots) {
const diff = Math.abs(utcTimestampMs - slot);
if (diff < minDiff && diff < intervalMs) {
minDiff = diff;
closestSlot = slot;
}
}
if (closestSlot !== null) {
if (!bucketMap[closestSlot]) bucketMap[closestSlot] = { input: 0, output: 0 };
bucketMap[closestSlot].input += row.input_tokens || 0;
bucketMap[closestSlot].output += row.output_tokens || 0;
}
});
/* ── 5⃣ Build labels in local timezone ───────────────────────────────── */
const labels = slots.map(ts => {
const d = new Date(ts);
return d.toLocaleString(undefined, {
...timeFormat,
timeZoneName: 'short'
});
});
const inputData = slots.map(ts => (bucketMap[ts]?.input ?? 0));
const outputData = slots.map(ts => (bucketMap[ts]?.output ?? 0));
/* ── 6⃣ Push into the Chart.js instance ─────────────────────────────── */
chart.data.labels = labels;
chart.data.datasets[0].data = inputData;
chart.data.datasets[1].data = outputData;
chart.update();
}
/* ---------- Utility ---------- */
async function fetchJSON(url) {
const resp = await fetch(url);
@ -403,25 +609,26 @@
e.preventDefault();
const model = link.dataset.model;
const ok = confirm(
`Delete the model “${model}”? This cannot be undone.`,
`Delete the model "${model}"? This cannot be undone.`,
);
if (!ok) return;
try {
const resp = await fetch(
`/api/delete?model=${encodeURIComponent(model)}`,
{ method: "DELETE" },
);
if (!resp.ok)
throw new Error(
`Delete failed: ${resp.status}`,
if (ok) {
try {
const resp = await fetch(
`/api/delete?model=${encodeURIComponent(model)}`,
{ method: "DELETE" },
);
alert(
`Model “${model}” deleted successfully.`,
);
loadTags();
} catch (err) {
console.error(err);
alert(`Error deleting ${model}: ${err}`);
if (!resp.ok)
throw new Error(
`Delete failed: ${resp.status}`,
);
alert(
`Model "${model}" deleted successfully.`,
);
loadTags();
} catch (err) {
console.error(err);
alert(`Error deleting ${model}: ${err}`);
}
}
});
});
@ -442,18 +649,22 @@
const tokenValue = existingRow
? existingRow.querySelector(".token-usage")?.textContent ?? 0
: 0;
const digest = m.digest || "";
const shortDigest = digest.length > 24
? `${digest.slice(0, 12)}...${digest.slice(-12)}`
: digest;
return `<tr data-model="${m.name}">
<td class="model">${m.name}</td>
<td class="model">${m.name} <a href="#" class="stats-link" data-model="${m.name}">stats</a></td>
<td>${m.details.parameter_size}</td>
<td>${m.details.quantization_level}</td>
<td>${m.context_length}</td>
<td>${m.digest}</td>
<td>${shortDigest}</td>
<td class="token-usage">${tokenValue}</td>
</tr>`;
})
.join("");
psRows.clear();
document
psRows.clear();
document
.querySelectorAll("#ps-body tr[data-model]")
.forEach((row) => {
const model = row.dataset.model;
@ -470,9 +681,9 @@
return `hsl(${h}, 80%, 30%)`;
}
function hashString(str) {
let hash = 0;
let hash = 42;
for (let i = 0; i < str.length; i++) {
hash = (hash << 5) - hash + str.charCodeAt(i);
hash = ((hash << 5) + hash) + str.charCodeAt(i);
hash |= 0;
}
return Math.abs(hash);
@ -481,9 +692,7 @@
// Create the EventSource once and keep it around
const source = new EventSource("/api/usage-stream");
// -----------------------------------------------------------------
// Helper that receives the payload and renders the chart
// -----------------------------------------------------------------
const renderChart = (data) => {
const chart = document.getElementById("usage-chart");
const usage = data.usage_counts || {};
@ -514,9 +723,7 @@
chart.innerHTML = html;
};
// -----------------------------------------------------------------
// Event handlers
// -----------------------------------------------------------------
source.onmessage = (e) => {
try {
const payload = JSON.parse(e.data); // SSE sends plain text
@ -525,18 +732,9 @@
const tokens = payload.token_usage_counts || {};
psRows.forEach((row, model) => {
/* regular usage count optional if you want to keep it */
let total = 0;
for (const ep in usage) {
total += usage[ep][model] || 0;
}
const usageCell = row.querySelector(".usage");
if (usageCell) usageCell.textContent = total;
/* token usage */
let tokenTotal = 0;
for (const ep in tokens) {
tokenTotal += tokens[ep][model] || 0;
tokenTotal += tokens[ep][model] || 0;
}
const tokenCell = row.querySelector(".token-usage");
if (tokenCell) tokenCell.textContent = tokenTotal;
@ -548,7 +746,6 @@
source.onerror = (err) => {
console.error("SSE connection error. Retrying...", err);
// EventSource will automatically try to reconnect.
};
window.addEventListener("beforeunload", () => source.close());
}
@ -561,7 +758,7 @@
loadUsage();
setInterval(loadPS, 60_000);
setInterval(loadEndpoints, 300_000);
/* show logic */
document.body.addEventListener("click", async (e) => {
if (!e.target.matches(".show-link")) return;
@ -632,8 +829,295 @@
modal.style.display = "none";
}
});
/* stats logic */
document.body.addEventListener("click", async (e) => {
if (!e.target.matches(".stats-link")) return;
e.preventDefault();
const model = e.target.dataset.model;
try {
const resp = await fetch(
`/api/stats?model=${encodeURIComponent(model)}`,
{ method: "POST" },
);
if (!resp.ok)
throw new Error(`Status ${resp.status}`);
const data = await resp.json();
const content = document.getElementById("stats-content");
content.innerHTML = `
<div class="stats-content-wrapper">
<div class="main-stats-content">
<h3>Token Usage</h3>
<p>Input tokens: ${data.input_tokens}</p>
<p>Output tokens: ${data.output_tokens}</p>
<p>Total tokens: ${data.total_tokens}</p>
<h3>Usage Over Time</h3>
<div class="timeframe-controls">
<button class="timeframe-btn active" data-minutes="60">Last 1 hour</button>
<button class="timeframe-btn" data-minutes="1440">Last 1 day</button>
<button class="timeframe-btn" data-minutes="10080">Last 7 days</button>
<button class="timeframe-btn" data-minutes="43200">Last 30 days</button>
</div>
<div class="chart-container">
<canvas id="time-series-chart"></canvas>
</div>
</div>
<div class="endpoint-distribution-container">
<h3>Endpoint Distribution</h3>
<div class="pie-chart-container">
<canvas id="endpoint-pie-chart"></canvas>
</div>
</div>
</div>
`;
document.getElementById("stats-modal").style.display = "flex";
// Initialise the charts (time-series + pie chart)
initStatsChart(data.time_series, data.endpoint_distribution);
} catch (err) {
console.error(err);
alert(`Could not load model stats: ${err.message}`);
}
});
/* ---------- Helper to initialise or refresh the stats chart ---------- */
function initStatsChart(timeSeriesData, endpointDistribution) {
// Destroy any existing chart instance
if (statsChart) {
statsChart.destroy();
statsChart = null;
}
// Remove any existing canvas and create a fresh one
const oldCanvas = document.getElementById('time-series-chart');
if (oldCanvas) {
oldCanvas.remove();
}
const canvas = document.createElement('canvas');
canvas.id = 'time-series-chart';
document.querySelector('.chart-container').appendChild(canvas);
// Create a new Chart.js instance
const ctx = canvas.getContext('2d');
const chart = new Chart(ctx, {
type: 'bar',
data: {
labels: [],
datasets: [
{ label: 'Input Tokens', data: [], backgroundColor: '#4CAF50' },
{ label: 'Output Tokens', data: [], backgroundColor: '#2196F3' }
]
},
options: {
responsive: true,
maintainAspectRatio: false,
scales: {
x: { stacked: true },
y: { stacked: true }
},
plugins: {
legend: { position: 'top' },
title: { display: true, text: 'Token Usage Over Time' }
}
}
});
// Store the chart globally for later updates
statsChart = chart;
// Store the raw timeseries data globally
rawTimeSeries = timeSeriesData || [];
// Render the initial view (default to 60 minutes)
renderTimeSeriesChart(rawTimeSeries, statsChart, 60);
// Attach timeframe button handlers (once)
document.querySelectorAll('.timeframe-btn').forEach(button => {
button.addEventListener('click', function () {
// Update active button styling
document.querySelectorAll('.timeframe-btn').forEach(btn => btn.classList.remove('active'));
this.classList.add('active');
// Rerender chart with the selected timeframe
const minutes = parseInt(this.dataset.minutes);
renderTimeSeriesChart(rawTimeSeries, statsChart, minutes);
});
});
// Create endpoint distribution pie chart
if (endpointDistribution && Object.keys(endpointDistribution).length > 0) {
const pieCanvas = document.getElementById('endpoint-pie-chart');
const pieCtx = pieCanvas.getContext('2d');
const endpoints = Object.keys(endpointDistribution);
const tokenCounts = Object.values(endpointDistribution);
const colors = endpoints.map(ep => getColor(ep));
new Chart(pieCtx, {
type: 'pie',
data: {
labels: endpoints,
datasets: [{
data: tokenCounts,
backgroundColor: colors,
borderWidth: 1,
borderColor: '#fff'
}]
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: {
position: 'right',
labels: {
boxWidth: 12,
font: { size: 11 }
}
},
title: {
display: true,
text: 'Total Tokens per Endpoint'
},
tooltip: {
callbacks: {
label: function(context) {
const label = context.label || '';
const value = context.parsed || 0;
const total = context.dataset.data.reduce((a, b) => a + b, 0);
const percentage = ((value / total) * 100).toFixed(1);
return `${label}: ${value.toLocaleString()} tokens (${percentage}%)`;
}
}
}
}
}
});
}
}
});
</script>
<script>
document.addEventListener('DOMContentLoaded', () => {
const totalBtn = document.getElementById('total-tokens-btn');
if (totalBtn) {
totalBtn.addEventListener('click', async () => {
try {
const resp = await fetch('/api/token_counts');
const data = await resp.json();
const modal = document.getElementById('total-tokens-modal');
const numberEl = document.getElementById('total-tokens-number');
numberEl.textContent = data.total_tokens;
document.getElementById('aggregation-status').textContent = 'Aggregating...';
try {
const aggResp = await fetch('/api/aggregate_time_series_days', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ days: 30 , trim_old: true})
});
if (aggResp.ok) {
const aggData = await aggResp.json();
const aggr = aggData.aggregated_groups ?? 0;
document.getElementById('aggregation-status').textContent = `Aggregated ${aggr} groups`;
} else {
document.getElementById('aggregation-status').textContent = 'Aggregation failed';
}
} catch (err) {
document.getElementById('aggregation-status').textContent = 'Aggregation error';
}
const chartCanvas = document.getElementById('total-tokens-chart');
if (chartCanvas) {
// Destroy existing chart if it exists
if (totalTokensChart) {
totalTokensChart.destroy();
totalTokensChart = null;
}
const ctx = chartCanvas.getContext('2d');
const tokenCounts = data.breakdown || [];
/* NEW LOGIC: concentric rings per model */
const modelTotals = {};
const modelEndpointTotals = {};
tokenCounts.forEach(entry => {
const { model, endpoint, total_tokens } = entry;
modelTotals[model] = (modelTotals[model] || 0) + total_tokens;
if (!modelEndpointTotals[model]) modelEndpointTotals[model] = {};
modelEndpointTotals[model][endpoint] = (modelEndpointTotals[model][endpoint] || 0) + total_tokens;
});
const endpointsSet = new Set();
tokenCounts.forEach(entry => endpointsSet.add(entry.endpoint));
const endpoints = Array.from(endpointsSet);
const endpointColors = {};
endpoints.forEach(ep => {
endpointColors[ep] = getColor(ep);
});
const sortedModels = Object.keys(modelTotals).sort((a, b) => modelTotals[b] - modelTotals[a]);
const datasets = sortedModels.map(model => {
const data = endpoints.map(ep => (modelEndpointTotals[model][ep] || 0));
const backgroundColor = endpoints.map(ep => endpointColors[ep]);
return {
label: model,
data,
backgroundColor,
borderWidth: 1,
borderColor: '#fff'
};
});
totalTokensChart = new Chart(ctx, {
type: 'doughnut',
data: {
labels: endpoints,
datasets
},
options: {
responsive: true,
maintainAspectRatio: false,
cutout: '15%',
plugins: {
legend: {
position: 'right',
labels: {
boxWidth: 12,
font: { size: 11 }
}
},
title: {
display: true,
text: 'Token Distribution by Model per Endpoint'
},
tooltip: {
callbacks: {
label: function(context) {
const endpointName = context.chart.data.labels[context.dataIndex];
const modelName = context.dataset.label;
const value = context.parsed || 0;
const total = context.dataset.data.reduce((a, b) => a + b, 0);
const percentage = ((value / total) * 100).toFixed(1);
return `${modelName} - ${endpointName}: ${value.toLocaleString()} tokens (${percentage}%)`;
}
}
}
}
}
});
}
modal.style.display = 'flex';
} catch (err) {
console.error(err);
alert('Failed to load token counts');
}
});
}
const totalTokensModal = document.getElementById('total-tokens-modal');
if (totalTokensModal) {
totalTokensModal.addEventListener('click', (e) => {
if (e.target === totalTokensModal || e.target.matches('.close-btn')) {
totalTokensModal.style.display = 'none';
}
});
}
});
</script>
<div id="show-modal" class="modal">
<div class="modal-content">
@ -642,5 +1126,22 @@
<pre id="json-output"></pre>
</div>
</div>
</body>
<div id="stats-modal" class="modal">
<div class="modal-content">
<span class="close-btn">&times;</span>
<h2>Model Stats</h2>
<div id="stats-content">
<p>Loading stats...</p>
</div>
</div>
</div>
<div id="total-tokens-modal" class="modal">
<div class="modal-content">
<span class="close-btn">&times;</span>
<h2>Total Tokens</h2>
<p id="total-tokens-number"></p>
<canvas id="total-tokens-chart"></canvas>
</div>
</div>
</html>