From e7cd8d4d68827d5a0a5427fc21795da62101bfbe Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Tue, 7 Apr 2026 15:30:52 +0200 Subject: [PATCH] fix: usage locks now release before the subscriber queue awaits --- requirements.txt | 2 +- router.py | 32 +++++++++++++++++++------------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index 222dfc8..2db1ba4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiohappyeyeballs==2.6.1 -aiohttp==3.13.3 +aiohttp==3.13.4 aiosignal==1.4.0 annotated-types==0.7.0 anyio==4.10.0 diff --git a/router.py b/router.py index 94cf1c2..c87c5ca 100644 --- a/router.py +++ b/router.py @@ -590,7 +590,8 @@ async def token_worker() -> None: # Update in-memory counts for immediate reporting async with token_usage_lock: token_usage_counts[endpoint][model] += (prompt + comp) - await publish_snapshot() + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) except asyncio.CancelledError: # Gracefully handle task cancellation during shutdown print("[token_worker] Task cancelled, processing remaining queue items...") @@ -617,7 +618,8 @@ async def token_worker() -> None: }) async with token_usage_lock: token_usage_counts[endpoint][model] += (prompt + comp) - await publish_snapshot() + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) except asyncio.QueueEmpty: break print("[token_worker] Task cancelled, remaining items processed.") @@ -1033,7 +1035,8 @@ def dedupe_on_keys(dicts, key_fields): async def increment_usage(endpoint: str, model: str) -> None: async with usage_lock: usage_counts[endpoint][model] += 1 - await publish_snapshot() + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) async def decrement_usage(endpoint: str, model: str) -> None: async with usage_lock: @@ -1046,7 +1049,8 @@ async def decrement_usage(endpoint: str, model: str) -> None: usage_counts[endpoint].pop(model, None) #if not usage_counts[endpoint]: # usage_counts.pop(endpoint, None) - await publish_snapshot() + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: """ @@ -1580,18 +1584,17 @@ class rechunk: # ------------------------------------------------------------------ # SSE Helpser # ------------------------------------------------------------------ -async def publish_snapshot(): - # NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller - # Create a snapshot without acquiring the lock (caller must hold it) - snapshot = orjson.dumps({ - "usage_counts": dict(usage_counts), # Create a copy +def _capture_snapshot() -> str: + """Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock.""" + return orjson.dumps({ + "usage_counts": dict(usage_counts), "token_usage_counts": dict(token_usage_counts) }, option=orjson.OPT_SORT_KEYS).decode("utf-8") - # Distribute the snapshot (no lock needed here since we have a copy) +async def _distribute_snapshot(snapshot: str) -> None: + """Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock.""" async with _subscribers_lock: for q in _subscribers: - # If the queue is full, drop the message to avoid back‑pressure. if q.full(): try: await q.get() @@ -1736,10 +1739,13 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: selected = min(candidate_endpoints, key=tracking_usage) tracking_model = get_tracking_model(selected, model) + snapshot = None if reserve: usage_counts[selected][tracking_model] += 1 - await publish_snapshot() - return selected, tracking_model + snapshot = _capture_snapshot() + if snapshot is not None: + await _distribute_snapshot(snapshot) + return selected, tracking_model # ------------------------------------------------------------- # 6. API route – Generate