177 lines
7.1 KiB
Python
177 lines
7.1 KiB
Python
"""Token-count write-behind pipeline.
|
||
|
||
``token_worker`` drains ``token_queue`` into the in-memory buffer (and into
|
||
``token_usage_counts`` for immediate SSE reporting). ``flush_buffer``
|
||
periodically persists the buffer to SQLite via ``TokenDatabase``.
|
||
``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left.
|
||
|
||
The lock order is ``buffer_lock`` then ``token_usage_lock`` — see
|
||
choose_endpoint for why we never combine them with usage_lock.
|
||
"""
|
||
import asyncio
|
||
from datetime import datetime, timezone
|
||
|
||
from state import (
|
||
token_queue,
|
||
token_buffer,
|
||
time_series_buffer,
|
||
buffer_lock,
|
||
token_usage_counts,
|
||
token_usage_lock,
|
||
FLUSH_INTERVAL,
|
||
)
|
||
from sse import _capture_snapshot, _distribute_snapshot
|
||
from db import get_db
|
||
|
||
|
||
async def token_worker() -> None:
|
||
try:
|
||
while True:
|
||
endpoint, model, prompt, comp = await token_queue.get()
|
||
# Calculate timestamp once before acquiring lock
|
||
now = datetime.now(tz=timezone.utc)
|
||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||
|
||
# Accumulate counts in memory buffer (protected by lock)
|
||
async with buffer_lock:
|
||
token_buffer[endpoint][model] = (
|
||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||
)
|
||
|
||
# Add to time series buffer with timestamp (UTC)
|
||
time_series_buffer.append({
|
||
'endpoint': endpoint,
|
||
'model': model,
|
||
'input_tokens': prompt,
|
||
'output_tokens': comp,
|
||
'total_tokens': prompt + comp,
|
||
'timestamp': timestamp
|
||
})
|
||
|
||
# Update in-memory counts for immediate reporting
|
||
async with token_usage_lock:
|
||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||
snapshot = _capture_snapshot()
|
||
await _distribute_snapshot(snapshot)
|
||
except asyncio.CancelledError:
|
||
# Gracefully handle task cancellation during shutdown
|
||
print("[token_worker] Task cancelled, processing remaining queue items...")
|
||
# Process any remaining items in the queue before exiting
|
||
while not token_queue.empty():
|
||
try:
|
||
endpoint, model, prompt, comp = token_queue.get_nowait()
|
||
# Calculate timestamp once before acquiring lock
|
||
now = datetime.now(tz=timezone.utc)
|
||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||
|
||
async with buffer_lock:
|
||
token_buffer[endpoint][model] = (
|
||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||
)
|
||
time_series_buffer.append({
|
||
'endpoint': endpoint,
|
||
'model': model,
|
||
'input_tokens': prompt,
|
||
'output_tokens': comp,
|
||
'total_tokens': prompt + comp,
|
||
'timestamp': timestamp
|
||
})
|
||
async with token_usage_lock:
|
||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||
snapshot = _capture_snapshot()
|
||
await _distribute_snapshot(snapshot)
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
print("[token_worker] Task cancelled, remaining items processed.")
|
||
raise
|
||
|
||
|
||
async def flush_buffer() -> None:
|
||
"""Periodically flush accumulated token counts to the database."""
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(FLUSH_INTERVAL)
|
||
|
||
# Flush token counts and time series (protected by lock)
|
||
async with buffer_lock:
|
||
if token_buffer:
|
||
# Copy buffer before releasing lock for DB operation
|
||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||
token_buffer.clear()
|
||
else:
|
||
buffer_copy = None
|
||
|
||
if time_series_buffer:
|
||
ts_copy = list(time_series_buffer)
|
||
time_series_buffer.clear()
|
||
else:
|
||
ts_copy = None
|
||
|
||
# Perform DB operations outside the lock to avoid blocking
|
||
db = get_db()
|
||
if buffer_copy:
|
||
await db.update_batched_counts(buffer_copy)
|
||
if ts_copy:
|
||
await db.add_batched_time_series(ts_copy)
|
||
except asyncio.CancelledError:
|
||
# Gracefully handle task cancellation during shutdown
|
||
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
|
||
# Flush any remaining data before exiting
|
||
try:
|
||
async with buffer_lock:
|
||
if token_buffer:
|
||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||
token_buffer.clear()
|
||
else:
|
||
buffer_copy = None
|
||
if time_series_buffer:
|
||
ts_copy = list(time_series_buffer)
|
||
time_series_buffer.clear()
|
||
else:
|
||
ts_copy = None
|
||
db = get_db()
|
||
if buffer_copy:
|
||
await db.update_batched_counts(buffer_copy)
|
||
if ts_copy:
|
||
await db.add_batched_time_series(ts_copy)
|
||
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
|
||
except Exception as e:
|
||
print(f"[flush_buffer] Error during shutdown flush: {e}")
|
||
raise
|
||
|
||
|
||
async def flush_remaining_buffers() -> None:
|
||
"""
|
||
Flush any in-memory buffers to the database on shutdown.
|
||
This is designed to be safely invoked during shutdown and should not raise.
|
||
"""
|
||
try:
|
||
flushed_entries = 0
|
||
async with buffer_lock:
|
||
if token_buffer:
|
||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||
flushed_entries += sum(len(v) for v in token_buffer.values())
|
||
token_buffer.clear()
|
||
else:
|
||
buffer_copy = None
|
||
if time_series_buffer:
|
||
ts_copy = list(time_series_buffer)
|
||
flushed_entries += len(time_series_buffer)
|
||
time_series_buffer.clear()
|
||
else:
|
||
ts_copy = None
|
||
# Perform DB operations outside the lock
|
||
db = get_db()
|
||
if buffer_copy:
|
||
await db.update_batched_counts(buffer_copy)
|
||
if ts_copy:
|
||
await db.add_batched_time_series(ts_copy)
|
||
if flushed_entries:
|
||
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
|
||
else:
|
||
print("[shutdown] No in-memory entries to flush on shutdown.")
|
||
except Exception as e:
|
||
# Do not raise during shutdown – log and continue teardown
|
||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|