nomyo-router/tokens.py

178 lines
7.1 KiB
Python
Raw Normal View History

"""Token-count write-behind pipeline.
``token_worker`` drains ``token_queue`` into the in-memory buffer (and into
``token_usage_counts`` for immediate SSE reporting). ``flush_buffer``
periodically persists the buffer to SQLite via ``TokenDatabase``.
``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left.
The lock order is ``buffer_lock`` then ``token_usage_lock`` see
choose_endpoint for why we never combine them with usage_lock.
"""
import asyncio
from datetime import datetime, timezone
from state import (
token_queue,
token_buffer,
time_series_buffer,
buffer_lock,
token_usage_counts,
token_usage_lock,
FLUSH_INTERVAL,
)
from sse import _capture_snapshot, _distribute_snapshot
from db import get_db
async def token_worker() -> None:
try:
while True:
endpoint, model, prompt, comp = await token_queue.get()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
# Accumulate counts in memory buffer (protected by lock)
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
# Add to time series buffer with timestamp (UTC)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
# Update in-memory counts for immediate reporting
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[token_worker] Task cancelled, processing remaining queue items...")
# Process any remaining items in the queue before exiting
while not token_queue.empty():
try:
endpoint, model, prompt, comp = token_queue.get_nowait()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.QueueEmpty:
break
print("[token_worker] Task cancelled, remaining items processed.")
raise
async def flush_buffer() -> None:
"""Periodically flush accumulated token counts to the database."""
try:
while True:
await asyncio.sleep(FLUSH_INTERVAL)
# Flush token counts and time series (protected by lock)
async with buffer_lock:
if token_buffer:
# Copy buffer before releasing lock for DB operation
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock to avoid blocking
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
# Flush any remaining data before exiting
try:
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
except Exception as e:
print(f"[flush_buffer] Error during shutdown flush: {e}")
raise
async def flush_remaining_buffers() -> None:
"""
Flush any in-memory buffers to the database on shutdown.
This is designed to be safely invoked during shutdown and should not raise.
"""
try:
flushed_entries = 0
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
flushed_entries += sum(len(v) for v in token_buffer.values())
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
flushed_entries += len(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
if flushed_entries:
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
else:
print("[shutdown] No in-memory entries to flush on shutdown.")
except Exception as e:
# Do not raise during shutdown log and continue teardown
print(f"[shutdown] Error flushing remaining buffers: {e}")