diff --git a/.gitignore b/.gitignore index 702c855..74eef7d 100644 --- a/.gitignore +++ b/.gitignore @@ -61,4 +61,7 @@ cython_debug/ *.sqlite3 # Config -config.yaml \ No newline at end of file +config.yaml + +# SQLite +*.db \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index e456af3..073496d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,11 +3,17 @@ FROM python:3.13-slim ENV PYTHONUNBUFFERED=1 \ PYTHONDONTWRITEBYTECODE=1 +# Install SQLite +RUN apt-get update && apt-get install -y sqlite3 + WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir --upgrade pip \ && pip install --no-cache-dir -r requirements.txt +# Create database directory and set permissions +RUN mkdir -p /app/data && chown -R www-data:www-data /app/data + COPY . . RUN chmod +x /app/entrypoint.sh diff --git a/db.py b/db.py new file mode 100644 index 0000000..0816c17 --- /dev/null +++ b/db.py @@ -0,0 +1,134 @@ +import aiosqlite +import os +import asyncio +from pathlib import Path +from datetime import datetime, timezone +from collections import defaultdict + +class TokenDatabase: + def __init__(self, db_path: str = "token_counts.db"): + self.db_path = db_path + self._ensure_db_directory() + + def _ensure_db_directory(self): + """Ensure the directory for the database exists.""" + db_dir = Path(self.db_path).parent + if not db_dir.exists(): + db_dir.mkdir(parents=True, exist_ok=True) + + async def init_db(self): + """Initialize the database tables.""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(''' + CREATE TABLE IF NOT EXISTS token_counts ( + endpoint TEXT, + model TEXT, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + PRIMARY KEY(endpoint, model) + ) + ''') + await db.execute(''' + CREATE TABLE IF NOT EXISTS token_time_series ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + endpoint TEXT, + model TEXT, + input_tokens INTEGER, + output_tokens INTEGER, + total_tokens INTEGER, + timestamp INTEGER, -- Unix timestamp with approximate minute/hour precision + FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model) + ) + ''') + await db.commit() + + async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): + """Update token counts for a specific endpoint and model.""" + total_tokens = input_tokens + output_tokens + async with aiosqlite.connect(self.db_path) as db: + await db.execute(''' + INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(endpoint, model) DO UPDATE SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + total_tokens = total_tokens + ? + ''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens)) + await db.commit() + + async def add_time_series_entry(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): + """Add a time series entry with approximate timestamp.""" + total_tokens = input_tokens + output_tokens + # Use current minute/hour as approximate timestamp + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp()) + + async with aiosqlite.connect(self.db_path) as db: + await db.execute(''' + INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp) + VALUES (?, ?, ?, ?, ?, ?) + ''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)) + await db.commit() + + async def update_batched_counts(self, counts: dict): + """Update multiple token counts in a single transaction.""" + if not counts: + return + async with aiosqlite.connect(self.db_path) as db: + for endpoint, models in counts.items(): + for model, (input_tokens, output_tokens) in models.items(): + total_tokens = input_tokens + output_tokens + await db.execute(''' + INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(endpoint, model) DO UPDATE SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + total_tokens = total_tokens + ? + ''', (endpoint, model, input_tokens, output_tokens, total_tokens, + input_tokens, output_tokens, total_tokens)) + await db.commit() + + async def add_batched_time_series(self, entries: list): + """Add multiple time series entries in a single transaction.""" + async with aiosqlite.connect(self.db_path) as db: + for entry in entries: + await db.execute(''' + INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp) + VALUES (?, ?, ?, ?, ?, ?) + ''', (entry['endpoint'], entry['model'], entry['input_tokens'], + entry['output_tokens'], entry['total_tokens'], entry['timestamp'])) + await db.commit() + + async def load_token_counts(self): + """Load all token counts from database.""" + async with aiosqlite.connect(self.db_path) as db: + async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor: + async for row in cursor: + yield { + 'endpoint': row[0], + 'model': row[1], + 'input_tokens': row[2], + 'output_tokens': row[3], + 'total_tokens': row[4] + } + + async def get_latest_time_series(self, limit: int = 100): + """Get the latest time series entries.""" + async with aiosqlite.connect(self.db_path) as db: + async with db.execute(''' + SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp + FROM token_time_series + ORDER BY timestamp DESC + LIMIT ? + ''', (limit,)) as cursor: + async for row in cursor: + yield { + 'endpoint': row[0], + 'model': row[1], + 'input_tokens': row[2], + 'output_tokens': row[3], + 'total_tokens': row[4], + 'timestamp': row[5] + } diff --git a/entrypoint.sh b/entrypoint.sh index cee2f17..6682851 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,6 +1,13 @@ #!/usr/bin/env sh set -e +# Create database directory if it doesn't exist +mkdir -p /app/data +chown -R www-data:www-data /app/data + +# Set database path environment variable +export NOMYO_ROUTER_DB_PATH="/app/data/token_counts.db" + CONFIG_PATH_ARG="" SHOW_HELP=0 diff --git a/requirements.txt b/requirements.txt index e296839..a0c7ab6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,4 @@ typing-inspection==0.4.1 typing_extensions==4.14.1 uvicorn==0.38.0 yarl==1.20.1 +aiosqlite diff --git a/router.py b/router.py index dff8aa9..1b0af1c 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,8 @@ version: 0.4 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io +from datetime import datetime, timezone from pathlib import Path from typing import Dict, Set, List, Optional from urllib.parse import urlparse @@ -45,6 +46,18 @@ app_state = { "connector": None, } token_worker_task: asyncio.Task | None = None +flush_task: asyncio.Task | None = None + +# ------------------------------------------------------------------ +# Token Count Buffer (for write-behind pattern) +# ------------------------------------------------------------------ +# Structure: {endpoint: {model: (input_tokens, output_tokens)}} +token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(tuple)) +# Time series buffer with timestamp +time_series_buffer: list[dict[str, int | str]] = [] + +# Configuration for periodic flushing +FLUSH_INTERVAL = 10 # seconds # ------------------------------------------------------------- # 1. Configuration loader @@ -61,6 +74,9 @@ class Config(BaseSettings): api_keys: Dict[str, str] = Field(default_factory=dict) + # Database configuration + db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db")) + class Config: # Load from `config.yaml` first, then from env variables env_prefix = "NOMYO_ROUTER_" @@ -101,6 +117,8 @@ def _config_path_from_env() -> Path: return Path(candidate).expanduser() return Path("config.yaml") +from db import TokenDatabase + # Create the global config object – it will be overwritten on startup config = Config.from_yaml(_config_path_from_env()) @@ -130,6 +148,9 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict( usage_lock = asyncio.Lock() # protects access to usage_counts token_usage_lock = asyncio.Lock() +# Database instance +db: "TokenDatabase" = None + # ------------------------------------------------------------- # 4. Helperfunctions # ------------------------------------------------------------- @@ -181,7 +202,7 @@ def _format_connection_issue(url: str, error: Exception) -> str: ) return f"Error while contacting {url}: {error}" - + def is_ext_openai_endpoint(endpoint: str) -> bool: if "/v1" not in endpoint: return False @@ -199,9 +220,43 @@ def is_ext_openai_endpoint(endpoint: str) -> bool: async def token_worker() -> None: while True: endpoint, model, prompt, comp = await token_queue.get() + # Accumulate counts in memory buffer + token_buffer[endpoint][model] = ( + token_buffer[endpoint].get(model, (0, 0))[0] + prompt, + token_buffer[endpoint].get(model, (0, 0))[1] + comp + ) + + # Add to time series buffer with timestamp + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp()) + time_series_buffer.append({ + 'endpoint': endpoint, + 'model': model, + 'input_tokens': prompt, + 'output_tokens': comp, + 'total_tokens': prompt + comp, + 'timestamp': timestamp + }) + + # Update in-memory counts for immediate reporting async with token_usage_lock: token_usage_counts[endpoint][model] += (prompt + comp) - await publish_snapshot() + await publish_snapshot() + +async def flush_buffer() -> None: + """Periodically flush accumulated token counts to the database.""" + while True: + await asyncio.sleep(FLUSH_INTERVAL) + + # Flush token counts + if token_buffer: + await db.update_batched_counts(token_buffer) + token_buffer.clear() + + # Flush time series entries + if time_series_buffer: + await db.add_batched_time_series(time_series_buffer) + time_series_buffer.clear() class fetch: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: @@ -366,7 +421,7 @@ async def decrement_usage(endpoint: str, model: str) -> None: def iso8601_ns(): ns = time.time_ns() sec, ns_rem = divmod(ns, 1_000_000_000) - dt = datetime.datetime.fromtimestamp(sec, tz=datetime.timezone.utc) + dt = datetime.fromtimestamp(sec, tz=timezone.utc) return ( f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T" f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}." @@ -628,7 +683,6 @@ async def choose_endpoint(model: str) -> str: f"None of the configured endpoints ({', '.join(config.endpoints)}) " f"advertise the model '{model}'." ) - # 3️⃣ Among the candidates, find those that have the model *loaded* # (concurrently, but only for the filtered list) load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] @@ -1772,7 +1826,7 @@ async def usage_stream(request: Request): # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: - global config + global config, db # Load YAML config (or use defaults if not present) config_path = _config_path_from_env() config = Config.from_yaml(config_path) @@ -1787,7 +1841,21 @@ async def startup_event() -> None: f"No configuration file found at {config_path}. " "Falling back to default settings." ) - + + # Initialize database + db = TokenDatabase(config.db_path) + await db.init_db() + + # Load existing token counts from database + async for count_entry in db.load_token_counts(): + endpoint = count_entry['endpoint'] + model = count_entry['model'] + input_tokens = count_entry['input_tokens'] + output_tokens = count_entry['output_tokens'] + total_tokens = count_entry['total_tokens'] + + token_usage_counts[endpoint][model] = total_tokens + ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context) timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15) @@ -1796,10 +1864,13 @@ async def startup_event() -> None: app_state["connector"] = connector app_state["session"] = session token_worker_task = asyncio.create_task(token_worker()) + flush_task = asyncio.create_task(flush_buffer()) @app.on_event("shutdown") async def shutdown_event() -> None: await close_all_sse_queues() await app_state["session"].close() if token_worker_task is not None: - token_worker_task.cancel() \ No newline at end of file + token_worker_task.cancel() + if flush_task is not None: + flush_task.cancel() diff --git a/static/index.html b/static/index.html index 6f11ccb..b4f5466 100644 --- a/static/index.html +++ b/static/index.html @@ -442,12 +442,16 @@ const tokenValue = existingRow ? existingRow.querySelector(".token-usage")?.textContent ?? 0 : 0; + const digest = m.digest || ""; + const shortDigest = digest.length > 24 + ? `${digest.slice(0, 12)}...${digest.slice(-12)}` + : digest; return ` ${m.name} ${m.details.parameter_size} ${m.details.quantization_level} ${m.context_length} - ${m.digest} + ${shortDigest} ${tokenValue} `; })