"""Token-count write-behind pipeline. ``token_worker`` drains ``token_queue`` into the in-memory buffer (and into ``token_usage_counts`` for immediate SSE reporting). ``flush_buffer`` periodically persists the buffer to SQLite via ``TokenDatabase``. ``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left. The lock order is ``buffer_lock`` then ``token_usage_lock`` — see choose_endpoint for why we never combine them with usage_lock. """ import asyncio from datetime import datetime, timezone from state import ( token_queue, token_buffer, time_series_buffer, buffer_lock, token_usage_counts, token_usage_lock, FLUSH_INTERVAL, ) from sse import _capture_snapshot, _distribute_snapshot from db import get_db async def token_worker() -> None: try: while True: endpoint, model, prompt, comp = await token_queue.get() # Calculate timestamp once before acquiring lock now = datetime.now(tz=timezone.utc) timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) # Accumulate counts in memory buffer (protected by lock) async with buffer_lock: token_buffer[endpoint][model] = ( token_buffer[endpoint].get(model, (0, 0))[0] + prompt, token_buffer[endpoint].get(model, (0, 0))[1] + comp ) # Add to time series buffer with timestamp (UTC) time_series_buffer.append({ 'endpoint': endpoint, 'model': model, 'input_tokens': prompt, 'output_tokens': comp, 'total_tokens': prompt + comp, 'timestamp': timestamp }) # Update in-memory counts for immediate reporting async with token_usage_lock: token_usage_counts[endpoint][model] += (prompt + comp) snapshot = _capture_snapshot() await _distribute_snapshot(snapshot) except asyncio.CancelledError: # Gracefully handle task cancellation during shutdown print("[token_worker] Task cancelled, processing remaining queue items...") # Process any remaining items in the queue before exiting while not token_queue.empty(): try: endpoint, model, prompt, comp = token_queue.get_nowait() # Calculate timestamp once before acquiring lock now = datetime.now(tz=timezone.utc) timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) async with buffer_lock: token_buffer[endpoint][model] = ( token_buffer[endpoint].get(model, (0, 0))[0] + prompt, token_buffer[endpoint].get(model, (0, 0))[1] + comp ) time_series_buffer.append({ 'endpoint': endpoint, 'model': model, 'input_tokens': prompt, 'output_tokens': comp, 'total_tokens': prompt + comp, 'timestamp': timestamp }) async with token_usage_lock: token_usage_counts[endpoint][model] += (prompt + comp) snapshot = _capture_snapshot() await _distribute_snapshot(snapshot) except asyncio.QueueEmpty: break print("[token_worker] Task cancelled, remaining items processed.") raise async def flush_buffer() -> None: """Periodically flush accumulated token counts to the database.""" try: while True: await asyncio.sleep(FLUSH_INTERVAL) # Flush token counts and time series (protected by lock) async with buffer_lock: if token_buffer: # Copy buffer before releasing lock for DB operation buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} token_buffer.clear() else: buffer_copy = None if time_series_buffer: ts_copy = list(time_series_buffer) time_series_buffer.clear() else: ts_copy = None # Perform DB operations outside the lock to avoid blocking db = get_db() if buffer_copy: await db.update_batched_counts(buffer_copy) if ts_copy: await db.add_batched_time_series(ts_copy) except asyncio.CancelledError: # Gracefully handle task cancellation during shutdown print("[flush_buffer] Task cancelled, flushing remaining buffers...") # Flush any remaining data before exiting try: async with buffer_lock: if token_buffer: buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} token_buffer.clear() else: buffer_copy = None if time_series_buffer: ts_copy = list(time_series_buffer) time_series_buffer.clear() else: ts_copy = None db = get_db() if buffer_copy: await db.update_batched_counts(buffer_copy) if ts_copy: await db.add_batched_time_series(ts_copy) print("[flush_buffer] Task cancelled, remaining buffers flushed.") except Exception as e: print(f"[flush_buffer] Error during shutdown flush: {e}") raise async def flush_remaining_buffers() -> None: """ Flush any in-memory buffers to the database on shutdown. This is designed to be safely invoked during shutdown and should not raise. """ try: flushed_entries = 0 async with buffer_lock: if token_buffer: buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} flushed_entries += sum(len(v) for v in token_buffer.values()) token_buffer.clear() else: buffer_copy = None if time_series_buffer: ts_copy = list(time_series_buffer) flushed_entries += len(time_series_buffer) time_series_buffer.clear() else: ts_copy = None # Perform DB operations outside the lock db = get_db() if buffer_copy: await db.update_batched_counts(buffer_copy) if ts_copy: await db.add_batched_time_series(ts_copy) if flushed_entries: print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.") else: print("[shutdown] No in-memory entries to flush on shutdown.") except Exception as e: # Do not raise during shutdown – log and continue teardown print(f"[shutdown] Error flushing remaining buffers: {e}")