diff --git a/requirements.txt b/requirements.txt index 8ef45ff..4ffd391 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ click==8.2.1 distro==1.9.0 exceptiongroup==1.3.0 fastapi==0.116.1 +fastapi-sse==1.1.1 frozenlist==1.7.0 h11==0.16.0 httpcore==1.0.9 diff --git a/router.py b/router.py index b73e4e6..9b83fb4 100644 --- a/router.py +++ b/router.py @@ -11,6 +11,7 @@ from httpx_aiohttp import AiohttpTransport from pathlib import Path from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException +from fastapi_sse import sse_handler from fastapi.staticfiles import StaticFiles from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse from pydantic import Field @@ -26,6 +27,12 @@ _models_cache: dict[str, tuple[Set[str], float]] = {} # timeout expires, after which the endpoint will be queried again. _error_cache: dict[str, float] = {} +# ------------------------------------------------------------------ +# SSE Queues +# ------------------------------------------------------------------ +_subscribers: Set[asyncio.Queue] = set() +_subscribers_lock = asyncio.Lock() + # ------------------------------------------------------------- # 1. Configuration loader # ------------------------------------------------------------- @@ -77,6 +84,7 @@ config = Config() # 2. FastAPI application # ------------------------------------------------------------- app = FastAPI() +sse_handler.app = app # ------------------------------------------------------------- # 3. Global state: per‑endpoint per‑model active connection counters @@ -234,6 +242,7 @@ def dedupe_on_keys(dicts, key_fields): async def increment_usage(endpoint: str, model: str) -> None: async with usage_lock: usage_counts[endpoint][model] += 1 + await publish_snapshot() async def decrement_usage(endpoint: str, model: str) -> None: async with usage_lock: @@ -246,6 +255,41 @@ async def decrement_usage(endpoint: str, model: str) -> None: usage_counts[endpoint].pop(model, None) #if not usage_counts[endpoint]: # usage_counts.pop(endpoint, None) + await publish_snapshot() + +# ------------------------------------------------------------------ +# SSE Helpser +# ------------------------------------------------------------------ +async def publish_snapshot(): + snapshot = json.dumps({"usage_counts": usage_counts}) + async with _subscribers_lock: + for q in _subscribers: + # If the queue is full, drop the message to avoid back‑pressure. + if q.full(): + continue + await q.put(snapshot) + +# ------------------------------------------------------------------ +# Subscriber helpers +# ------------------------------------------------------------------ +async def subscribe() -> asyncio.Queue: + """ + Returns a new Queue that will receive every snapshot. + """ + q: asyncio.Queue = asyncio.Queue(maxsize=10) + async with _subscribers_lock: + _subscribers.add(q) + return q + +async def unsubscribe(q: asyncio.Queue): + async with _subscribers_lock: + _subscribers.discard(q) + +# ------------------------------------------------------------------ +# Convenience wrapper – returns the current snapshot (for the proxy) +# ------------------------------------------------------------------ +async def get_usage_counts() -> Dict: + return dict(usage_counts) # shallow copy # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) @@ -1272,7 +1316,33 @@ async def health_proxy(request: Request): return JSONResponse(content=response_payload, status_code=http_status) # ------------------------------------------------------------- -# 27. FastAPI startup event – load configuration +# 27. SSE route for usage broadcasts +# ------------------------------------------------------------- +@app.get("/api/usage-stream") +async def usage_stream(request: Request): + """ + Server‑Sent‑Events that emits a JSON payload every time the + global `usage_counts` dictionary changes. + """ + async def event_generator(): + # The queue that receives *every* new snapshot + queue = await subscribe() + try: + while True: + # If the client disconnects, cancel the loop + if await request.is_disconnected(): + break + data = await queue.get() + # Send the data as a single SSE message + yield f"data: {data}\n\n" + finally: + # Clean‑up: unsubscribe from the broadcast channel + await unsubscribe(queue) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + +# ------------------------------------------------------------- +# 28. FastAPI startup event – load configuration # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: diff --git a/static/index.html b/static/index.html index a2945b4..c4e522a 100644 --- a/static/index.html +++ b/static/index.html @@ -253,7 +253,7 @@ async function loadPS(){ /* ---------- Usage Chart (stacked‑percentage) ---------- */ function getColor(seed){ const h = Math.abs(hashString(seed) % 360); - return `hsl(${h}, 70%, 50%)`; + return `hsl(${h}, 80%, 40%)`; } function hashString(str){ let hash = 0; @@ -263,28 +263,57 @@ function hashString(str){ } return Math.abs(hash); } -async function loadUsage(){ - try{ - const data = await fetchJSON('/api/usage'); +async function loadUsage() { + // Create the EventSource once and keep it around + const source = new EventSource('/api/usage-stream'); + + // ----------------------------------------------------------------- + // Helper that receives the payload and renders the chart + // ----------------------------------------------------------------- + const renderChart = (data) => { const chart = document.getElementById('usage-chart'); const usage = data.usage_counts || {}; let html = ''; - for (const [endpoint, models] of Object.entries(usage)){ - const total = Object.values(models).reduce((a,b)=>a+b,0); - html += `
${endpoint}
`; - for (const [model, count] of Object.entries(models)){ - const pct = total ? (count/total)*100 : 0; + for (const [endpoint, models] of Object.entries(usage)) { + const total = Object.values(models).reduce((a, b) => a + b, 0); + + html += `
+
${endpoint}
+
`; + + for (const [model, count] of Object.entries(models)) { + const pct = total ? (count / total) * 100 : 0; const width = pct.toFixed(2); const color = getColor(model); - html += `
${model} (${count})
`; + html += `
+ ${model} (${count}) +
`; } + html += `
`; } chart.innerHTML = html; - }catch(e){ - console.error('Failed to load usage counts', e); - } + }; + + // ----------------------------------------------------------------- + // Event handlers + // ----------------------------------------------------------------- + source.onmessage = (e) => { + try { + const payload = JSON.parse(e.data); // SSE sends plain text + renderChart(payload); + } catch (err) { + console.error('Failed to parse SSE payload', err); + } + }; + + source.onerror = (err) => { + console.error('SSE connection error. Retrying...', err); + // EventSource will automatically try to reconnect. + }; + window.addEventListener('beforeunload', () => source.close()); } /* ---------- Init ---------- */ @@ -294,7 +323,7 @@ window.addEventListener('load', ()=>{ loadPS(); loadUsage(); setInterval(loadPS, 60_000); - setInterval(loadUsage, 1_000); + setInterval(loadEndpoints, 300_000); });