62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
"""Server-sent-events plumbing.
|
|
|
|
Captures the current ``usage_counts`` / ``token_usage_counts`` snapshot and
|
|
fan-outs it to every subscribed asyncio.Queue. Routes that need a live
|
|
dashboard feed call ``subscribe`` / ``unsubscribe`` to obtain a queue.
|
|
"""
|
|
import asyncio
|
|
from typing import Dict
|
|
|
|
import orjson
|
|
|
|
from state import (
|
|
usage_counts,
|
|
token_usage_counts,
|
|
_subscribers,
|
|
_subscribers_lock,
|
|
)
|
|
|
|
|
|
def _capture_snapshot() -> str:
|
|
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
|
|
return orjson.dumps({
|
|
"usage_counts": dict(usage_counts),
|
|
"token_usage_counts": dict(token_usage_counts)
|
|
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
|
|
|
|
|
async def _distribute_snapshot(snapshot: str) -> None:
|
|
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
|
|
async with _subscribers_lock:
|
|
for q in _subscribers:
|
|
if q.full():
|
|
try:
|
|
await q.get()
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
await q.put(snapshot)
|
|
|
|
|
|
async def close_all_sse_queues():
|
|
for q in list(_subscribers):
|
|
# sentinel value that the generator will recognise
|
|
await q.put(None)
|
|
|
|
|
|
async def subscribe() -> asyncio.Queue:
|
|
"""
|
|
Returns a new Queue that will receive every snapshot.
|
|
"""
|
|
q: asyncio.Queue = asyncio.Queue(maxsize=10)
|
|
async with _subscribers_lock:
|
|
_subscribers.add(q)
|
|
return q
|
|
|
|
|
|
async def unsubscribe(q: asyncio.Queue):
|
|
async with _subscribers_lock:
|
|
_subscribers.discard(q)
|
|
|
|
|
|
async def get_usage_counts() -> Dict:
|
|
return dict(usage_counts) # shallow copy
|