diff --git a/README.md b/README.md index 37e1686..6c3f402 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ source .venv/router/bin/activate pip3 install -r requirements.txt ``` -on the shell do: +[optional] on the shell do: ``` export OPENAI_KEY=YOUR_SECRET_API_KEY diff --git a/router.py b/router.py index 7cc24ac..2deb639 100644 --- a/router.py +++ b/router.py @@ -123,6 +123,7 @@ default_headers={ # 3. Global state: per‑endpoint per‑model active connection counters # ------------------------------------------------------------- usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) +token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) usage_lock = asyncio.Lock() # protects access to usage_counts # ------------------------------------------------------------- @@ -191,6 +192,13 @@ def is_ext_openai_endpoint(endpoint: str) -> bool: return True # It's an external OpenAI endpoint +def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None: + async def _record(): + async with usage_lock: # reuse the same lock that protects usage_counts + token_usage_counts[endpoint][model] += (prompt + completion) + await publish_snapshot() # immediately broadcast the new totals + asyncio.create_task(_record()) + class fetch: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ @@ -336,15 +344,14 @@ async def decrement_usage(endpoint: str, model: str) -> None: await publish_snapshot() def iso8601_ns(): - ns_since_epoch = time.time_ns() - dt = datetime.datetime.fromtimestamp( - ns_since_epoch / 1_000_000_000, # seconds - tz=datetime.timezone.utc + ns = time.time_ns() + sec, ns_rem = divmod(ns, 1_000_000_000) + dt = datetime.datetime.fromtimestamp(sec, tz=datetime.timezone.utc) + return ( + f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T" + f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}." + f"{ns_rem:09d}Z" ) - iso8601_with_ns = ( - dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z" - ) - return iso8601_with_ns def is_base64(image_string): try: @@ -507,7 +514,9 @@ class rechunk: # ------------------------------------------------------------------ async def publish_snapshot(): async with usage_lock: - snapshot = json.dumps({"usage_counts": usage_counts}, sort_keys=True) + snapshot = json.dumps({"usage_counts": usage_counts, + "token_usage_counts": token_usage_counts, + }, sort_keys=True) async with _subscribers_lock: for q in _subscribers: # If the queue is full, drop the message to avoid back‑pressure. @@ -710,6 +719,9 @@ async def proxy(request: Request): async for chunk in async_gen: if is_openai_endpoint: chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + record_token_usage(endpoint, model, prompt_tok, comp_tok) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -721,6 +733,9 @@ async def proxy(request: Request): response = response.model_dump_json() else: response = async_gen.model_dump_json() + prompt_tok = async_gen.prompt_eval_count or 0 + comp_tok = async_gen.eval_count or 0 + record_token_usage(endpoint, model, prompt_tok, comp_tok) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -791,7 +806,7 @@ async def chat_proxy(request: Request): optional_params = { "tools": tools, "stream": stream, - "stream_options": {"include_usage": True} if stream is not None else None, + "stream_options": {"include_usage": True} if stream else None, "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, @@ -820,6 +835,9 @@ async def chat_proxy(request: Request): if is_openai_endpoint: chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) # `chunk` can be a dict or a pydantic model – dump to JSON safely + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + record_token_usage(endpoint, model, prompt_tok, comp_tok) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -831,6 +849,9 @@ async def chat_proxy(request: Request): response = response.model_dump_json() else: response = async_gen.model_dump_json() + prompt_tok = async_gen.prompt_eval_count or 0 + comp_tok = async_gen.eval_count or 0 + record_token_usage(endpoint, model, prompt_tok, comp_tok) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1315,7 +1336,8 @@ async def usage_proxy(request: Request): Return a snapshot of the usage counter for each endpoint. Useful for debugging / monitoring. """ - return {"usage_counts": usage_counts} + return {"usage_counts": usage_counts, + "token_usage_counts": token_usage_counts} # ------------------------------------------------------------- # 20. Proxy config route – for monitoring and frontent usage @@ -1485,6 +1507,9 @@ async def openai_chat_completions_proxy(request: Request): yield f"data: {data}\n\n".encode("utf-8") yield b"data: [DONE]\n\n" else: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") @@ -1588,6 +1613,9 @@ async def openai_completions_proxy(request: Request): # Final DONE event yield b"data: [DONE]\n\n" else: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") diff --git a/static/index.html b/static/index.html index ef6119b..6f11ccb 100644 --- a/static/index.html +++ b/static/index.html @@ -267,11 +267,12 @@