fix: usage locks now release before the subscriber queue awaits

This commit is contained in:
Alpha Nerd 2026-04-07 15:30:52 +02:00
parent 2c87472483
commit e7cd8d4d68
Signed by: alpha-nerd
SSH key fingerprint: SHA256:QkkAgVoYi9TQ0UKPkiKSfnerZy2h4qhi3SVPXJmBN+M
2 changed files with 20 additions and 14 deletions

View file

@ -1,5 +1,5 @@
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.13.3 aiohttp==3.13.4
aiosignal==1.4.0 aiosignal==1.4.0
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.10.0 anyio==4.10.0

View file

@ -590,7 +590,8 @@ async def token_worker() -> None:
# Update in-memory counts for immediate reporting # Update in-memory counts for immediate reporting
async with token_usage_lock: async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp) token_usage_counts[endpoint][model] += (prompt + comp)
await publish_snapshot() snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.CancelledError: except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown # Gracefully handle task cancellation during shutdown
print("[token_worker] Task cancelled, processing remaining queue items...") print("[token_worker] Task cancelled, processing remaining queue items...")
@ -617,7 +618,8 @@ async def token_worker() -> None:
}) })
async with token_usage_lock: async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp) token_usage_counts[endpoint][model] += (prompt + comp)
await publish_snapshot() snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.QueueEmpty: except asyncio.QueueEmpty:
break break
print("[token_worker] Task cancelled, remaining items processed.") print("[token_worker] Task cancelled, remaining items processed.")
@ -1033,7 +1035,8 @@ def dedupe_on_keys(dicts, key_fields):
async def increment_usage(endpoint: str, model: str) -> None: async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock: async with usage_lock:
usage_counts[endpoint][model] += 1 usage_counts[endpoint][model] += 1
await publish_snapshot() snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
async def decrement_usage(endpoint: str, model: str) -> None: async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock: async with usage_lock:
@ -1046,7 +1049,8 @@ async def decrement_usage(endpoint: str, model: str) -> None:
usage_counts[endpoint].pop(model, None) usage_counts[endpoint].pop(model, None)
#if not usage_counts[endpoint]: #if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None) # usage_counts.pop(endpoint, None)
await publish_snapshot() snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
""" """
@ -1580,18 +1584,17 @@ class rechunk:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# SSE Helpser # SSE Helpser
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def publish_snapshot(): def _capture_snapshot() -> str:
# NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller """Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
# Create a snapshot without acquiring the lock (caller must hold it) return orjson.dumps({
snapshot = orjson.dumps({ "usage_counts": dict(usage_counts),
"usage_counts": dict(usage_counts), # Create a copy
"token_usage_counts": dict(token_usage_counts) "token_usage_counts": dict(token_usage_counts)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8") }, option=orjson.OPT_SORT_KEYS).decode("utf-8")
# Distribute the snapshot (no lock needed here since we have a copy) async def _distribute_snapshot(snapshot: str) -> None:
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
async with _subscribers_lock: async with _subscribers_lock:
for q in _subscribers: for q in _subscribers:
# If the queue is full, drop the message to avoid backpressure.
if q.full(): if q.full():
try: try:
await q.get() await q.get()
@ -1736,10 +1739,13 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
selected = min(candidate_endpoints, key=tracking_usage) selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model) tracking_model = get_tracking_model(selected, model)
snapshot = None
if reserve: if reserve:
usage_counts[selected][tracking_model] += 1 usage_counts[selected][tracking_model] += 1
await publish_snapshot() snapshot = _capture_snapshot()
return selected, tracking_model if snapshot is not None:
await _distribute_snapshot(snapshot)
return selected, tracking_model
# ------------------------------------------------------------- # -------------------------------------------------------------
# 6. API route Generate # 6. API route Generate