diff --git a/.gitignore b/.gitignore index 702c855..100cc12 100644 --- a/.gitignore +++ b/.gitignore @@ -61,4 +61,7 @@ cython_debug/ *.sqlite3 # Config -config.yaml \ No newline at end of file +config.yaml + +# SQLite +*.db* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index e456af3..073496d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,11 +3,17 @@ FROM python:3.13-slim ENV PYTHONUNBUFFERED=1 \ PYTHONDONTWRITEBYTECODE=1 +# Install SQLite +RUN apt-get update && apt-get install -y sqlite3 + WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir --upgrade pip \ && pip install --no-cache-dir -r requirements.txt +# Create database directory and set permissions +RUN mkdir -p /app/data && chown -R www-data:www-data /app/data + COPY . . RUN chmod +x /app/entrypoint.sh diff --git a/db.py b/db.py new file mode 100644 index 0000000..aaea508 --- /dev/null +++ b/db.py @@ -0,0 +1,267 @@ +import aiosqlite +import asyncio +from typing import Optional +from pathlib import Path +from datetime import datetime, timezone +from collections import defaultdict + +class TokenDatabase: + def __init__(self, db_path: str = "token_counts.db"): + self.db_path = db_path + self._db: Optional[aiosqlite.Connection] = None + self._db_lock = asyncio.Lock() + self._operation_lock = asyncio.Lock() + + def _ensure_db_directory(self): + """Ensure the directory for the database exists.""" + db_dir = Path(self.db_path).parent + if not db_dir.exists(): + db_dir.mkdir(parents=True, exist_ok=True) + + async def _get_connection(self) -> aiosqlite.Connection: + """Return a persistent connection with WAL mode and FK enforcement enabled.""" + if self._db is None: + async with self._db_lock: + if self._db is None: + self._ensure_db_directory() + self._db = await aiosqlite.connect(self.db_path) + # Enable WAL and foreign keys for reliability and integrity + await self._db.execute("PRAGMA journal_mode=WAL;") + await self._db.execute("PRAGMA foreign_keys = ON;") + await self._db.commit() + return self._db + + async def close(self): + """Close the persistent database connection, if open.""" + if self._db is not None: + await self._db.close() + self._db = None + + async def init_db(self): + """Initialize the database tables.""" + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + CREATE TABLE IF NOT EXISTS token_counts ( + endpoint TEXT, + model TEXT, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + PRIMARY KEY(endpoint, model) + ) + ''') + await db.execute(''' + CREATE TABLE IF NOT EXISTS token_time_series ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + endpoint TEXT, + model TEXT, + input_tokens INTEGER, + output_tokens INTEGER, + total_tokens INTEGER, + timestamp INTEGER, + FOREIGN KEY(endpoint, model) REFERENCES token_counts(endpoint, model) + ) + ''') + await db.commit() + + async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): + """Update token counts for a specific endpoint and model.""" + total_tokens = input_tokens + output_tokens + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (endpoint, model) DO UPDATE SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + total_tokens = total_tokens + ? + ''', (endpoint, model, input_tokens, output_tokens, total_tokens, input_tokens, output_tokens, total_tokens)) + await db.commit() + + async def add_time_series_entry(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): + """Add a time series entry with approximate timestamp.""" + total_tokens = input_tokens + output_tokens + # Use current minute/hour as approximate timestamp in UTC + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute).timestamp()) + + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp) + VALUES (?, ?, ?, ?, ?, ?) + ''', (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp)) + await db.commit() + + async def update_batched_counts(self, counts: dict): + """Update multiple token counts in a single transaction.""" + if not counts: + return + db = await self._get_connection() + async with self._operation_lock: + try: + await db.execute('BEGIN') + for endpoint, models in counts.items(): + for model, (input_tokens, output_tokens) in models.items(): + total_tokens = input_tokens + output_tokens + await db.execute(''' + INSERT INTO token_counts (endpoint, model, input_tokens, output_tokens, total_tokens) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (endpoint, model) DO UPDATE SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + total_tokens = total_tokens + ? + ''', (endpoint, model, input_tokens, output_tokens, total_tokens, + input_tokens, output_tokens, total_tokens)) + await db.commit() + except Exception: + # Rollback on error to maintain consistency + try: + await db.execute('ROLLBACK') + except Exception: + pass + raise + + async def add_batched_time_series(self, entries: list): + """Add multiple time series entries in a single transaction.""" + db = await self._get_connection() + async with self._operation_lock: + try: + await db.execute('BEGIN') + for entry in entries: + await db.execute(''' + INSERT INTO token_time_series (endpoint, model, input_tokens, output_tokens, total_tokens, timestamp) + VALUES (?, ?, ?, ?, ?, ?) + ''', (entry['endpoint'], entry['model'], entry['input_tokens'], + entry['output_tokens'], entry['total_tokens'], entry['timestamp'])) + await db.commit() + except Exception: + try: + await db.execute('ROLLBACK') + except Exception: + pass + raise + + async def load_token_counts(self): + """Load all token counts from database.""" + db = await self._get_connection() + async with self._operation_lock: + async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts') as cursor: + async for row in cursor: + yield { + 'endpoint': row[0], + 'model': row[1], + 'input_tokens': row[2], + 'output_tokens': row[3], + 'total_tokens': row[4] + } + + async def get_latest_time_series(self, limit: int = 100): + """Get the latest time series entries.""" + db = await self._get_connection() + async with self._operation_lock: + async with db.execute(''' + SELECT endpoint, model, input_tokens, output_tokens, total_tokens, timestamp + FROM token_time_series + ORDER BY timestamp DESC + LIMIT ? + ''', (limit,)) as cursor: + async for row in cursor: + yield { + 'endpoint': row[0], + 'model': row[1], + 'input_tokens': row[2], + 'output_tokens': row[3], + 'total_tokens': row[4], + 'timestamp': row[5] + } + + async def get_token_counts_for_model(self, model): + """Get token counts for a specific model, aggregated across all endpoints.""" + db = await self._get_connection() + async with self._operation_lock: + async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor: + total_input = 0 + total_output = 0 + total_tokens = 0 + async for row in cursor: + total_input += row[2] + total_output += row[3] + total_tokens += row[4] + + if total_input > 0 or total_output > 0: + return { + 'endpoint': 'aggregated', + 'model': model, + 'input_tokens': total_input, + 'output_tokens': total_output, + 'total_tokens': total_tokens + } + return None + + async def aggregate_time_series_older_than(self, days: int, trim_old: bool = False) -> int: + """ + Aggregate time_series entries older than 'days' days into daily aggregates by + endpoint, model and UTC date (YYYY-MM-DD). The results are stored in + token_time_series_daily with a UNIQUE constraint on (endpoint, model, date). + + Returns the number of aggregated groups (distinct (endpoint, model, date) tuples) + that were created/updated. + """ + if not isinstance(days, int) or days <= 0: + days = 30 + + cutoff_ts = int(datetime.now(tz=timezone.utc).timestamp()) - (days * 86400) + + db = await self._get_connection() + aggregated_count = 0 + + async with self._operation_lock: + # Ensure daily table exists + await db.execute(''' + CREATE TABLE IF NOT EXISTS token_time_series_daily ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + endpoint TEXT, + model TEXT, + date TEXT, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + UNIQUE(endpoint, model, date) + ) + ''') + await db.commit() + + cursor = await db.execute(''' + SELECT endpoint, model, date(timestamp, 'unixepoch') as day, + SUM(input_tokens) as in_sum, + SUM(output_tokens) as out_sum, + SUM(total_tokens) as tot_sum + FROM token_time_series + WHERE timestamp < ? + GROUP BY endpoint, model, day + ''', (cutoff_ts,)) + rows = await cursor.fetchall() + + for row in rows: + endpoint, model, day, in_sum, out_sum, tot_sum = row + await db.execute(''' + INSERT INTO token_time_series_daily (endpoint, model, date, input_tokens, output_tokens, total_tokens) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT (endpoint, model, date) DO UPDATE SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + total_tokens = total_tokens + ? + ''', (endpoint, model, day, int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0), + int(in_sum or 0), int(out_sum or 0), int(tot_sum or 0))) + aggregated_count += 1 + + # Trim old entries if requested + if trim_old: + await db.execute('DELETE FROM token_time_series WHERE timestamp < ?', (cutoff_ts,)) + + await db.commit() + + return aggregated_count diff --git a/entrypoint.sh b/entrypoint.sh old mode 100644 new mode 100755 index cee2f17..6682851 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,6 +1,13 @@ #!/usr/bin/env sh set -e +# Create database directory if it doesn't exist +mkdir -p /app/data +chown -R www-data:www-data /app/data + +# Set database path environment variable +export NOMYO_ROUTER_DB_PATH="/app/data/token_counts.db" + CONFIG_PATH_ARG="" SHOW_HELP=0 diff --git a/requirements.txt b/requirements.txt index e296839..a0c7ab6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,4 @@ typing-inspection==0.4.1 typing_extensions==4.14.1 uvicorn==0.38.0 yarl==1.20.1 +aiosqlite diff --git a/router.py b/router.py index dff8aa9..3b60770 100644 --- a/router.py +++ b/router.py @@ -2,11 +2,12 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.4 +version: 0.5 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io +from datetime import datetime, timezone from pathlib import Path from typing import Dict, Set, List, Optional from urllib.parse import urlparse @@ -45,6 +46,18 @@ app_state = { "connector": None, } token_worker_task: asyncio.Task | None = None +flush_task: asyncio.Task | None = None + +# ------------------------------------------------------------------ +# Token Count Buffer (for write-behind pattern) +# ------------------------------------------------------------------ +# Structure: {endpoint: {model: (input_tokens, output_tokens)}} +token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0))) +# Time series buffer with timestamp +time_series_buffer: list[dict[str, int | str]] = [] + +# Configuration for periodic flushing +FLUSH_INTERVAL = 10 # seconds # ------------------------------------------------------------- # 1. Configuration loader @@ -61,6 +74,9 @@ class Config(BaseSettings): api_keys: Dict[str, str] = Field(default_factory=dict) + # Database configuration + db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db")) + class Config: # Load from `config.yaml` first, then from env variables env_prefix = "NOMYO_ROUTER_" @@ -101,6 +117,8 @@ def _config_path_from_env() -> Path: return Path(candidate).expanduser() return Path("config.yaml") +from db import TokenDatabase + # Create the global config object – it will be overwritten on startup config = Config.from_yaml(_config_path_from_env()) @@ -130,6 +148,9 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict( usage_lock = asyncio.Lock() # protects access to usage_counts token_usage_lock = asyncio.Lock() +# Database instance +db: "TokenDatabase" = None + # ------------------------------------------------------------- # 4. Helperfunctions # ------------------------------------------------------------- @@ -181,7 +202,7 @@ def _format_connection_issue(url: str, error: Exception) -> str: ) return f"Error while contacting {url}: {error}" - + def is_ext_openai_endpoint(endpoint: str) -> bool: if "/v1" not in endpoint: return False @@ -199,9 +220,66 @@ def is_ext_openai_endpoint(endpoint: str) -> bool: async def token_worker() -> None: while True: endpoint, model, prompt, comp = await token_queue.get() + # Accumulate counts in memory buffer + token_buffer[endpoint][model] = ( + token_buffer[endpoint].get(model, (0, 0))[0] + prompt, + token_buffer[endpoint].get(model, (0, 0))[1] + comp + ) + + # Add to time series buffer with timestamp (UTC) + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) + time_series_buffer.append({ + 'endpoint': endpoint, + 'model': model, + 'input_tokens': prompt, + 'output_tokens': comp, + 'total_tokens': prompt + comp, + 'timestamp': timestamp + }) + + # Update in-memory counts for immediate reporting async with token_usage_lock: token_usage_counts[endpoint][model] += (prompt + comp) - await publish_snapshot() + await publish_snapshot() + +async def flush_buffer() -> None: + """Periodically flush accumulated token counts to the database.""" + while True: + await asyncio.sleep(FLUSH_INTERVAL) + + # Flush token counts + if token_buffer: + await db.update_batched_counts(token_buffer) + token_buffer.clear() + + # Flush time series entries + if time_series_buffer: + await db.add_batched_time_series(time_series_buffer) + time_series_buffer.clear() + +async def flush_remaining_buffers() -> None: + """ + Flush any in-memory buffers to the database on shutdown. + This is designed to be safely invoked during shutdown and should not raise. + """ + try: + flushed_entries = 0 + if token_buffer: + await db.update_batched_counts(token_buffer) + flushed_entries += sum(len(v) for v in token_buffer.values()) + token_buffer.clear() + if time_series_buffer: + await db.add_batched_time_series(time_series_buffer) + flushed_entries += len(time_series_buffer) + time_series_buffer.clear() + if flushed_entries: + print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.") + else: + print("[shutdown] No in-memory entries to flush on shutdown.") + except Exception as e: + # Do not raise during shutdown – log and continue teardown + print(f"[shutdown] Error flushing remaining buffers: {e}") class fetch: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: @@ -366,7 +444,7 @@ async def decrement_usage(endpoint: str, model: str) -> None: def iso8601_ns(): ns = time.time_ns() sec, ns_rem = divmod(ns, 1_000_000_000) - dt = datetime.datetime.fromtimestamp(sec, tz=datetime.timezone.utc) + dt = datetime.fromtimestamp(sec, tz=timezone.utc) return ( f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T" f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}." @@ -628,7 +706,6 @@ async def choose_endpoint(model: str) -> str: f"None of the configured endpoints ({', '.join(config.endpoints)}) " f"advertise the model '{model}'." ) - # 3️⃣ Among the candidates, find those that have the model *loaded* # (concurrently, but only for the filtered list) load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] @@ -743,7 +820,8 @@ async def proxy(request: Request): chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -757,7 +835,8 @@ async def proxy(request: Request): response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -859,7 +938,8 @@ async def chat_proxy(request: Request): # `chunk` can be a dict or a pydantic model – dump to JSON safely prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -873,7 +953,8 @@ async def chat_proxy(request: Request): response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1086,7 +1167,7 @@ async def show_proxy(request: Request, model: Optional[str] = None): if not model: payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") - + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -1105,6 +1186,97 @@ async def show_proxy(request: Request, model: Optional[str] = None): # 4. Return ShowResponse return show +# ------------------------------------------------------------- +@app.get("/api/token_counts") +async def token_counts_proxy(): + breakdown = [] + total = 0 + async for entry in db.load_token_counts(): + total += entry['total_tokens'] + breakdown.append({ + "endpoint": entry["endpoint"], + "model": entry["model"], + "input_tokens": entry["input_tokens"], + "output_tokens": entry["output_tokens"], + "total_tokens": entry["total_tokens"], + }) + return {"total_tokens": total, "breakdown": breakdown} + +@app.post("/api/aggregate_time_series_days") +async def aggregate_time_series_days_proxy(request: Request): + """ + Aggregate time_series entries older than days into daily aggregates by endpoint/model/date. + """ + try: + body_bytes = await request.body() + if not body_bytes: + days = 30 + trim_old = False + else: + payload = orjson.loads(body_bytes.decode("utf-8")) + days = int(payload.get("days", 30)) + trim_old = bool(payload.get("trim_old", False)) + except Exception: + days = 30 + trim_old = False + aggregated = await db.aggregate_time_series_older_than(days, trim_old=trim_old) + return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated} + +# 12. API route – Stats +# ------------------------------------------------------------- +@app.post("/api/stats") +async def stats_proxy(request: Request, model: Optional[str] = None): + """ + Return token usage statistics for a specific model. + """ + try: + body_bytes = await request.body() + + if not model: + payload = orjson.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Get token counts from database + token_data = await db.get_token_counts_for_model(model) + + if not token_data: + raise HTTPException( + status_code=404, detail="No token data found for this model" + ) + + # Get time series data for the last 30 days (43200 minutes = 30 days) + # Assuming entries are grouped by minute, 30 days = 43200 entries max + time_series = [] + endpoint_totals = defaultdict(int) # Track tokens per endpoint + + async for entry in db.get_latest_time_series(limit=50000): + if entry['model'] == model: + time_series.append({ + 'endpoint': entry['endpoint'], + 'timestamp': entry['timestamp'], + 'input_tokens': entry['input_tokens'], + 'output_tokens': entry['output_tokens'], + 'total_tokens': entry['total_tokens'] + }) + # Accumulate total tokens per endpoint + endpoint_totals[entry['endpoint']] += entry['total_tokens'] + + return { + 'model': model, + 'input_tokens': token_data['input_tokens'], + 'output_tokens': token_data['output_tokens'], + 'total_tokens': token_data['total_tokens'], + 'time_series': time_series, + 'endpoint_distribution': dict(endpoint_totals) + } + # ------------------------------------------------------------- # 12. API route – Copy # ------------------------------------------------------------- @@ -1482,7 +1654,7 @@ async def openai_chat_completions_proxy(request: Request): optional_params = { "tools": tools, "response_format": response_format, - "stream_options": stream_options, + "stream_options": stream_options or {"include_usage": True }, "max_completion_tokens": max_completion_tokens, "max_tokens": max_tokens, "temperature": temperature, @@ -1524,13 +1696,23 @@ async def openai_chat_completions_proxy(request: Request): if hasattr(chunk, "model_dump_json") else orjson.dumps(chunk) ) - if chunk.choices[0].delta.content is not None: - yield f"data: {data}\n\n".encode("utf-8") + if chunk.choices: + if chunk.choices[0].delta.content is not None: + yield f"data: {data}\n\n".encode("utf-8") + if chunk.usage is not None: + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + if prompt_tok != 0 or comp_tok != 0: + if not is_ext_openai_endpoint(endpoint): + if not ":" in model: + local_model = model+":latest" + await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) yield b"data: [DONE]\n\n" else: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") @@ -1591,7 +1773,7 @@ async def openai_completions_proxy(request: Request): "seed": seed, "stop": stop, "stream": stream, - "stream_options": stream_options, + "stream_options": stream_options or {"include_usage": True }, "temperature": temperature, "top_p": top_p, "max_tokens": max_tokens, @@ -1619,7 +1801,7 @@ async def openai_completions_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint]) # 3. Async generator that streams completions data and decrements the counter - async def stream_ocompletions_response(): + async def stream_ocompletions_response(model=model): try: # The chat method returns a generator of dicts (or GenerateResponse) async_gen = await oclient.completions.create(**params) @@ -1630,13 +1812,24 @@ async def openai_completions_proxy(request: Request): if hasattr(chunk, "model_dump_json") else orjson.dumps(chunk) ) - yield f"data: {data}\n\n".encode("utf-8") + if chunk.choices: + if chunk.choices[0].finish_reason == None: + yield f"data: {data}\n\n".encode("utf-8") + if chunk.usage is not None: + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + if prompt_tok != 0 or comp_tok != 0: + if not is_ext_openai_endpoint(endpoint): + if not ":" in model: + local_model = model+":latest" + await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) # Final DONE event yield b"data: [DONE]\n\n" else: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") @@ -1772,7 +1965,7 @@ async def usage_stream(request: Request): # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: - global config + global config, db # Load YAML config (or use defaults if not present) config_path = _config_path_from_env() config = Config.from_yaml(config_path) @@ -1787,7 +1980,21 @@ async def startup_event() -> None: f"No configuration file found at {config_path}. " "Falling back to default settings." ) - + + # Initialize database + db = TokenDatabase(config.db_path) + await db.init_db() + + # Load existing token counts from database + async for count_entry in db.load_token_counts(): + endpoint = count_entry['endpoint'] + model = count_entry['model'] + input_tokens = count_entry['input_tokens'] + output_tokens = count_entry['output_tokens'] + total_tokens = count_entry['total_tokens'] + + token_usage_counts[endpoint][model] = total_tokens + ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context) timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15) @@ -1796,10 +2003,14 @@ async def startup_event() -> None: app_state["connector"] = connector app_state["session"] = session token_worker_task = asyncio.create_task(token_worker()) + flush_task = asyncio.create_task(flush_buffer()) @app.on_event("shutdown") async def shutdown_event() -> None: await close_all_sse_queues() + await flush_remaining_buffers() await app_state["session"].close() if token_worker_task is not None: - token_worker_task.cancel() \ No newline at end of file + token_worker_task.cancel() + if flush_task is not None: + flush_task.cancel() diff --git a/static/index.html b/static/index.html index 6f11ccb..0d22721 100644 --- a/static/index.html +++ b/static/index.html @@ -3,6 +3,7 @@ NOMYO Router Dashboard + -

Router Dashboard

+
+

Router Dashboard

+ +