diff --git a/db.py b/db.py index 0816c17..9a96ca2 100644 --- a/db.py +++ b/db.py @@ -132,3 +132,17 @@ class TokenDatabase: 'total_tokens': row[4], 'timestamp': row[5] } + + async def get_token_counts_for_model(self, model): + """Get token counts for a specific model.""" + async with aiosqlite.connect(self.db_path) as db: + async with db.execute('SELECT endpoint, model, input_tokens, output_tokens, total_tokens FROM token_counts WHERE model = ?', (model,)) as cursor: + async for row in cursor: + return { + 'endpoint': row[0], + 'model': row[1], + 'input_tokens': row[2], + 'output_tokens': row[3], + 'total_tokens': row[4] + } + return None diff --git a/router.py b/router.py index 1b0af1c..94831f9 100644 --- a/router.py +++ b/router.py @@ -797,7 +797,8 @@ async def proxy(request: Request): chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -811,7 +812,8 @@ async def proxy(request: Request): response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -913,7 +915,8 @@ async def chat_proxy(request: Request): # `chunk` can be a dict or a pydantic model – dump to JSON safely prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -927,7 +930,8 @@ async def chat_proxy(request: Request): response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1140,7 +1144,7 @@ async def show_proxy(request: Request, model: Optional[str] = None): if not model: payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") - + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -1159,6 +1163,55 @@ async def show_proxy(request: Request, model: Optional[str] = None): # 4. Return ShowResponse return show +# ------------------------------------------------------------- +# 12. API route – Stats +# ------------------------------------------------------------- +@app.post("/api/stats") +async def stats_proxy(request: Request, model: Optional[str] = None): + """ + Return token usage statistics for a specific model. + """ + try: + body_bytes = await request.body() + + if not model: + payload = orjson.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Get token counts from database + token_data = await db.get_token_counts_for_model(model) + + if not token_data: + raise HTTPException( + status_code=404, detail="No token data found for this model" + ) + + # Get time series data + time_series = [] + async for entry in db.get_latest_time_series(limit=10): + if entry['model'] == model: + time_series.append({ + 'timestamp': entry['timestamp'], + 'input_tokens': entry['input_tokens'], + 'output_tokens': entry['output_tokens'], + 'total_tokens': entry['total_tokens'] + }) + + return { + 'model': model, + 'input_tokens': token_data['input_tokens'], + 'output_tokens': token_data['output_tokens'], + 'total_tokens': token_data['total_tokens'], + 'time_series': time_series + } + # ------------------------------------------------------------- # 12. API route – Copy # ------------------------------------------------------------- @@ -1584,7 +1637,8 @@ async def openai_chat_completions_proxy(request: Request): else: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") @@ -1690,7 +1744,8 @@ async def openai_completions_proxy(request: Request): else: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") diff --git a/static/index.html b/static/index.html index b4f5466..0ecbf7d 100644 --- a/static/index.html +++ b/static/index.html @@ -447,7 +447,7 @@ ? `${digest.slice(0, 12)}...${digest.slice(-12)}` : digest; return ` - ${m.name} + ${m.name} stats ${m.details.parameter_size} ${m.details.quantization_level} ${m.context_length} @@ -636,6 +636,56 @@ modal.style.display = "none"; } }); + + /* stats logic */ + document.body.addEventListener("click", async (e) => { + if (!e.target.matches(".stats-link")) return; + e.preventDefault(); + const model = e.target.dataset.model; + try { + const resp = await fetch( + `/api/stats?model=${encodeURIComponent(model)}`, + { method: "POST" }, + ); + if (!resp.ok) + throw new Error(`Status ${resp.status}`); + const data = await resp.json(); + const content = document.getElementById("stats-content"); + content.innerHTML = ` +

Token Usage

+

Input tokens: ${data.input_tokens}

+

Output tokens: ${data.output_tokens}

+

Total tokens: ${data.total_tokens}

+

Usage Over Time

+
+ ${data.time_series.length > 0 ? + data.time_series.map(ts => ` +
+ ${new Date(ts.timestamp * 1000).toLocaleString()} +

Input: ${ts.input_tokens}, Output: ${ts.output_tokens}, Total: ${ts.total_tokens}

+
+ `).join('') : + '

No time series data available

' + } +
+ `; + document.getElementById("stats-modal").style.display = "flex"; + } catch (err) { + console.error(err); + alert(`Could not load model stats: ${err.message}`); + } + }); + + /* stats modal close */ + const statsModal = document.getElementById("stats-modal"); + statsModal.addEventListener("click", (e) => { + if ( + e.target === statsModal || + e.target.matches(".close-btn") + ) { + statsModal.style.display = "none"; + } + }); }); @@ -646,5 +696,15 @@

             
         
+
+