From 8355bf9a1efae3c73130fbfc89ad963f138f8d38 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Tue, 19 May 2026 12:48:55 +0200 Subject: [PATCH] refac: modularize sse, routing, db and token handling V --- db.py | 11 ++ router.py | 452 ++--------------------------------------------------- routing.py | 294 ++++++++++++++++++++++++++++++++++ sse.py | 62 ++++++++ tokens.py | 177 +++++++++++++++++++++ 5 files changed, 555 insertions(+), 441 deletions(-) create mode 100644 routing.py create mode 100644 sse.py create mode 100644 tokens.py diff --git a/db.py b/db.py index af7b252..3621144 100644 --- a/db.py +++ b/db.py @@ -4,6 +4,17 @@ from pathlib import Path from datetime import datetime, timezone from collections import defaultdict + +def get_db() -> "TokenDatabase": + """Return the live TokenDatabase instance held by router.py. + + Resolved lazily so submodules can access it without import cycles, and + so test patches of ``router.db`` flow through to all callers. + """ + import router # lazy to avoid module-load circular import + return router.db + + class TokenDatabase: def __init__(self, db_path: str = "token_counts.db"): self.db_path = db_path diff --git a/router.py b/router.py index 514590e..7b24bcc 100644 --- a/router.py +++ b/router.py @@ -233,175 +233,13 @@ from backends.normalize import ( get_tracking_model, ) -async def token_worker() -> None: - try: - while True: - endpoint, model, prompt, comp = await token_queue.get() - # Calculate timestamp once before acquiring lock - now = datetime.now(tz=timezone.utc) - timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) - - # Accumulate counts in memory buffer (protected by lock) - async with buffer_lock: - token_buffer[endpoint][model] = ( - token_buffer[endpoint].get(model, (0, 0))[0] + prompt, - token_buffer[endpoint].get(model, (0, 0))[1] + comp - ) - - # Add to time series buffer with timestamp (UTC) - time_series_buffer.append({ - 'endpoint': endpoint, - 'model': model, - 'input_tokens': prompt, - 'output_tokens': comp, - 'total_tokens': prompt + comp, - 'timestamp': timestamp - }) - - # Update in-memory counts for immediate reporting - async with token_usage_lock: - token_usage_counts[endpoint][model] += (prompt + comp) - snapshot = _capture_snapshot() - await _distribute_snapshot(snapshot) - except asyncio.CancelledError: - # Gracefully handle task cancellation during shutdown - print("[token_worker] Task cancelled, processing remaining queue items...") - # Process any remaining items in the queue before exiting - while not token_queue.empty(): - try: - endpoint, model, prompt, comp = token_queue.get_nowait() - # Calculate timestamp once before acquiring lock - now = datetime.now(tz=timezone.utc) - timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) - - async with buffer_lock: - token_buffer[endpoint][model] = ( - token_buffer[endpoint].get(model, (0, 0))[0] + prompt, - token_buffer[endpoint].get(model, (0, 0))[1] + comp - ) - time_series_buffer.append({ - 'endpoint': endpoint, - 'model': model, - 'input_tokens': prompt, - 'output_tokens': comp, - 'total_tokens': prompt + comp, - 'timestamp': timestamp - }) - async with token_usage_lock: - token_usage_counts[endpoint][model] += (prompt + comp) - snapshot = _capture_snapshot() - await _distribute_snapshot(snapshot) - except asyncio.QueueEmpty: - break - print("[token_worker] Task cancelled, remaining items processed.") - raise - -async def flush_buffer() -> None: - """Periodically flush accumulated token counts to the database.""" - try: - while True: - await asyncio.sleep(FLUSH_INTERVAL) - - # Flush token counts and time series (protected by lock) - async with buffer_lock: - if token_buffer: - # Copy buffer before releasing lock for DB operation - buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} - token_buffer.clear() - else: - buffer_copy = None - - if time_series_buffer: - ts_copy = list(time_series_buffer) - time_series_buffer.clear() - else: - ts_copy = None - - # Perform DB operations outside the lock to avoid blocking - if buffer_copy: - await db.update_batched_counts(buffer_copy) - if ts_copy: - await db.add_batched_time_series(ts_copy) - except asyncio.CancelledError: - # Gracefully handle task cancellation during shutdown - print("[flush_buffer] Task cancelled, flushing remaining buffers...") - # Flush any remaining data before exiting - try: - async with buffer_lock: - if token_buffer: - buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} - token_buffer.clear() - else: - buffer_copy = None - if time_series_buffer: - ts_copy = list(time_series_buffer) - time_series_buffer.clear() - else: - ts_copy = None - if buffer_copy: - await db.update_batched_counts(buffer_copy) - if ts_copy: - await db.add_batched_time_series(ts_copy) - print("[flush_buffer] Task cancelled, remaining buffers flushed.") - except Exception as e: - print(f"[flush_buffer] Error during shutdown flush: {e}") - raise - -async def flush_remaining_buffers() -> None: - """ - Flush any in-memory buffers to the database on shutdown. - This is designed to be safely invoked during shutdown and should not raise. - """ - try: - flushed_entries = 0 - async with buffer_lock: - if token_buffer: - buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} - flushed_entries += sum(len(v) for v in token_buffer.values()) - token_buffer.clear() - else: - buffer_copy = None - if time_series_buffer: - ts_copy = list(time_series_buffer) - flushed_entries += len(time_series_buffer) - time_series_buffer.clear() - else: - ts_copy = None - # Perform DB operations outside the lock - if buffer_copy: - await db.update_batched_counts(buffer_copy) - if ts_copy: - await db.add_batched_time_series(ts_copy) - if flushed_entries: - print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.") - else: - print("[shutdown] No in-memory entries to flush on shutdown.") - except Exception as e: - # Do not raise during shutdown – log and continue teardown - print(f"[shutdown] Error flushing remaining buffers: {e}") +from tokens import token_worker, flush_buffer, flush_remaining_buffers from backends.probe import fetch -async def increment_usage(endpoint: str, model: str) -> None: - async with usage_lock: - usage_counts[endpoint][model] += 1 - snapshot = _capture_snapshot() - await _distribute_snapshot(snapshot) +from routing import increment_usage, decrement_usage -async def decrement_usage(endpoint: str, model: str) -> None: - async with usage_lock: - # Avoid negative counts - current = usage_counts[endpoint].get(model, 0) - if current > 0: - usage_counts[endpoint][model] = current - 1 - # Optionally, clean up zero entries - if usage_counts[endpoint].get(model, 0) == 0: - usage_counts[endpoint].pop(model, None) - #if not usage_counts[endpoint]: - # usage_counts.pop(endpoint, None) - snapshot = _capture_snapshot() - await _distribute_snapshot(snapshot) async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: """ @@ -878,287 +716,19 @@ class rechunk: return (prompt_n + cache_n, predicted_n) return None -# ------------------------------------------------------------------ -# SSE Helpser -# ------------------------------------------------------------------ -def _capture_snapshot() -> str: - """Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock.""" - return orjson.dumps({ - "usage_counts": dict(usage_counts), - "token_usage_counts": dict(token_usage_counts) - }, option=orjson.OPT_SORT_KEYS).decode("utf-8") - -async def _distribute_snapshot(snapshot: str) -> None: - """Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock.""" - async with _subscribers_lock: - for q in _subscribers: - if q.full(): - try: - await q.get() - except asyncio.QueueEmpty: - pass - await q.put(snapshot) - -async def close_all_sse_queues(): - for q in list(_subscribers): - # sentinel value that the generator will recognise - await q.put(None) - -# ------------------------------------------------------------------ -# Subscriber helpers -# ------------------------------------------------------------------ -async def subscribe() -> asyncio.Queue: - """ - Returns a new Queue that will receive every snapshot. - """ - q: asyncio.Queue = asyncio.Queue(maxsize=10) - async with _subscribers_lock: - _subscribers.add(q) - return q - -async def unsubscribe(q: asyncio.Queue): - async with _subscribers_lock: - _subscribers.discard(q) - -# ------------------------------------------------------------------ -# Convenience wrapper – returns the current snapshot (for the proxy) -# ------------------------------------------------------------------ -async def get_usage_counts() -> Dict: - return dict(usage_counts) # shallow copy +from sse import ( + _capture_snapshot, + _distribute_snapshot, + close_all_sse_queues, + subscribe, + unsubscribe, + get_usage_counts, +) # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) # ------------------------------------------------------------- -def get_max_connections(ep: str) -> int: - """Per-endpoint max_concurrent_connections, falling back to the global value.""" - return config.endpoint_config.get(ep, {}).get( - "max_concurrent_connections", config.max_concurrent_connections - ) - -async def choose_endpoint(model: str, reserve: bool = True, - affinity_key: Optional[str] = None) -> tuple[str, str]: - """ - Determine which endpoint to use for the given model while respecting - the `max_concurrent_connections` per endpoint‑model pair **and** - ensuring that the chosen endpoint actually *advertises* the model. - - The selection algorithm: - - 1️⃣ Query every endpoint for its advertised models (`/api/tags`). - 2️⃣ Build a list of endpoints that contain the requested model. - 2️⃣.5 If conversation affinity is enabled and the caller passes - ``affinity_key``, prefer the endpoint that previously served the - same conversation — but only when it still has the model loaded - and a free slot. Otherwise fall through to the standard logic. - 3️⃣ For those endpoints, find those that have the model loaded - (`/api/ps`) *and* still have a free slot. - 4️⃣ If none are both loaded and free, fall back to any endpoint - from the filtered list that simply has a free slot and randomly - select one. - 5️⃣ If all are saturated, pick any endpoint from the filtered list - (the request will queue on that endpoint). - 6️⃣ If no endpoint advertises the model at all, raise an error. - """ - # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - # Include both config.endpoints and config.llama_server_endpoints - llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] - all_endpoints = config.endpoints + llama_eps_extra - - tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)] - tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)] - tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra] - advertised_sets = await asyncio.gather(*tag_tasks) - - # 2️⃣ Filter endpoints that advertise the requested model - candidate_endpoints = [ - ep for ep, models in zip(all_endpoints, advertised_sets) - if model in models - ] - - # 6️⃣ - if not candidate_endpoints: - if ":latest" in model: #ollama naming convention not applicable to openai/llama-server - model_without_latest = model.split(":latest")[0] - candidate_endpoints = [ - ep for ep, models in zip(all_endpoints, advertised_sets) - if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints) - ] - if not candidate_endpoints: - # Only add :latest suffix if model doesn't already have a version suffix - if ":" not in model: - model = model + ":latest" - candidate_endpoints = [ - ep for ep, models in zip(all_endpoints, advertised_sets) - if model in models - ] - if not candidate_endpoints: - raise RuntimeError( - f"None of the configured endpoints ({', '.join(all_endpoints)}) " - f"advertise the model '{model}'." - ) - # 3️⃣ Among the candidates, find those that have the model *loaded* - # (concurrently, but only for the filtered list) - load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] - loaded_sets = await asyncio.gather(*load_tasks) - - # 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing - # recently. Without this filter, an endpoint where `/api/ps` returns 5xx - # would appear with an empty loaded set but pass through to the - # free-slot fallback (step 4) — sending completion calls to an - # unhealthy backend. See issue #83. - async with _loaded_error_cache_lock: - unhealthy = { - ep for ep, ts in _loaded_error_cache.items() - if _is_fresh(ts, 300) - } - if unhealthy: - filtered = [ - (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets) - if ep not in unhealthy - ] - if filtered: - candidate_endpoints = [ep for ep, _ in filtered] - loaded_sets = [models for _, models in filtered] - # If *every* candidate is unhealthy we still fall through with the - # original list — refusing to route is worse than retrying a - # possibly-recovered backend. - - # 3️⃣.6 Exclude (endpoint, model) pairs whose completion path has recently - # failed with a backend connection error (e.g. llama-server in router mode - # whose delegated worker for *this* model died). /v1/models keeps reporting - # OK in that case, so the probe-level filter above cannot catch it. - async with _completion_error_cache_lock: - completion_broken = { - ep for (ep, m), ts in _completion_error_cache.items() - if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL) - } - if completion_broken: - filtered = [ - (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets) - if ep not in completion_broken - ] - if filtered: - candidate_endpoints = [ep for ep, _ in filtered] - loaded_sets = [models for _, models in filtered] - # Same fallback: if every candidate is broken for this model, fall - # through and let the upstream retry — possibly the operator restarted - # the dead worker. - - # Look up a possible affinity hint *before* taking usage_lock. The two - # locks are never held together to avoid lock-ordering issues. - affine_ep: Optional[str] = None - if config.conversation_affinity and affinity_key: - async with _affinity_lock: - entry = _affinity_map.get(affinity_key) - if entry is not None: - ep, _stored_model, expires_at = entry - if expires_at < time.monotonic(): - _affinity_map.pop(affinity_key, None) - else: - affine_ep = ep - - # Protect all reads/writes of usage_counts with the lock so that selection - # and reservation are atomic — concurrent callers see each other's pending load. - async with usage_lock: - # Helper: current usage for (endpoint, model) using the same normalized key - # that increment_usage/decrement_usage store — raw model names differ from - # tracking names for llama-server (HF prefix / quant suffix stripped). - def tracking_usage(ep: str) -> int: - return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0) - - def utilization_ratio(ep: str) -> float: - return tracking_usage(ep) / get_max_connections(ep) - - # Priority map: position in all_endpoints list (lower = higher priority) - ep_priority = {ep: i for i, ep in enumerate(all_endpoints)} - - selected: Optional[str] = None - - # 2️⃣.5 Conversation affinity preference — only honour the hint when - # the affine endpoint still advertises the model loaded *and* has a - # free slot. Otherwise fall back to the standard algorithm. - if affine_ep: - ep_loaded = { - ep: set(models) - for ep, models in zip(candidate_endpoints, loaded_sets) - } - if (affine_ep in candidate_endpoints - and model in ep_loaded.get(affine_ep, set()) - and tracking_usage(affine_ep) < get_max_connections(affine_ep)): - selected = affine_ep - - if selected is None: - # 3️⃣ Endpoints that have the model loaded *and* a free slot - loaded_and_free = [ - ep for ep, models in zip(candidate_endpoints, loaded_sets) - if model in models and tracking_usage(ep) < get_max_connections(ep) - ] - - if loaded_and_free: - if config.priority_routing: - # WRR: sort by config order first (stable), then by utilization ratio. - # Stable sort preserves priority for equal-ratio endpoints. - loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) - loaded_and_free.sort(key=utilization_ratio) - selected = loaded_and_free[0] - else: - # Sort ascending for load balancing — all endpoints here already have the - # model loaded, so there is no model-switching cost to optimise for. - loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking - # the first entry in a stable sort. - if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - selected = random.choice(loaded_and_free) - else: - selected = loaded_and_free[0] - else: - # 4️⃣ Endpoints among the candidates that simply have a free slot - endpoints_with_free_slot = [ - ep for ep in candidate_endpoints - if tracking_usage(ep) < get_max_connections(ep) - ] - - if endpoints_with_free_slot: - if config.priority_routing: - endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) - endpoints_with_free_slot.sort(key=utilization_ratio) - selected = endpoints_with_free_slot[0] - else: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - selected = random.choice(endpoints_with_free_slot) - else: - selected = endpoints_with_free_slot[0] - else: - # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) - if config.priority_routing: - selected = min( - candidate_endpoints, - key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), - ) - else: - selected = min(candidate_endpoints, key=tracking_usage) - - tracking_model = get_tracking_model(selected, model) - snapshot = None - if reserve: - usage_counts[selected][tracking_model] += 1 - snapshot = _capture_snapshot() - if snapshot is not None: - await _distribute_snapshot(snapshot) - # Record / refresh affinity *after* releasing usage_lock. - if reserve and config.conversation_affinity and affinity_key: - expires_at = time.monotonic() + config.conversation_affinity_ttl - async with _affinity_lock: - _affinity_map[affinity_key] = (selected, model, expires_at) - if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: - now = time.monotonic() - for k in [k for k, v in _affinity_map.items() if v[2] < now]: - _affinity_map.pop(k, None) - return selected, tracking_model +from routing import get_max_connections, choose_endpoint # ------------------------------------------------------------- # 6. API route – Generate diff --git a/routing.py b/routing.py new file mode 100644 index 0000000..059cf9a --- /dev/null +++ b/routing.py @@ -0,0 +1,294 @@ +"""Endpoint selection (load-balancing + conversation affinity). + +``choose_endpoint`` is the heart of routing — it picks an endpoint that +advertises the model, prefers ones with the model already loaded and a free +slot, applies the conversation-affinity hint when available, and honors +config-order priority routing when ``priority_routing`` is set. + +``increment_usage`` / ``decrement_usage`` keep the per-(endpoint, model) +counter that drives utilization-based selection; they fan out an SSE +snapshot on every change. +""" +import asyncio +import random +import time +from typing import Optional + +from config import get_config +from state import ( + usage_counts, + usage_lock, + _loaded_error_cache, + _loaded_error_cache_lock, + _completion_error_cache, + _completion_error_cache_lock, + _COMPLETION_ERROR_TTL, + _affinity_map, + _affinity_lock, + _AFFINITY_MAX_ENTRIES, +) +from sse import _capture_snapshot, _distribute_snapshot +from backends.health import _is_fresh +from backends.normalize import ( + is_ext_openai_endpoint, + is_openai_compatible, + get_tracking_model, +) +from backends.probe import fetch + + +async def increment_usage(endpoint: str, model: str) -> None: + async with usage_lock: + usage_counts[endpoint][model] += 1 + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) + + +async def decrement_usage(endpoint: str, model: str) -> None: + async with usage_lock: + # Avoid negative counts + current = usage_counts[endpoint].get(model, 0) + if current > 0: + usage_counts[endpoint][model] = current - 1 + # Optionally, clean up zero entries + if usage_counts[endpoint].get(model, 0) == 0: + usage_counts[endpoint].pop(model, None) + #if not usage_counts[endpoint]: + # usage_counts.pop(endpoint, None) + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) + + +def get_max_connections(ep: str) -> int: + """Per-endpoint max_concurrent_connections, falling back to the global value.""" + cfg = get_config() + return cfg.endpoint_config.get(ep, {}).get( + "max_concurrent_connections", cfg.max_concurrent_connections + ) + + +async def choose_endpoint(model: str, reserve: bool = True, + affinity_key: Optional[str] = None) -> tuple[str, str]: + """ + Determine which endpoint to use for the given model while respecting + the `max_concurrent_connections` per endpoint‑model pair **and** + ensuring that the chosen endpoint actually *advertises* the model. + + The selection algorithm: + + 1️⃣ Query every endpoint for its advertised models (`/api/tags`). + 2️⃣ Build a list of endpoints that contain the requested model. + 2️⃣.5 If conversation affinity is enabled and the caller passes + ``affinity_key``, prefer the endpoint that previously served the + same conversation — but only when it still has the model loaded + and a free slot. Otherwise fall through to the standard logic. + 3️⃣ For those endpoints, find those that have the model loaded + (`/api/ps`) *and* still have a free slot. + 4️⃣ If none are both loaded and free, fall back to any endpoint + from the filtered list that simply has a free slot and randomly + select one. + 5️⃣ If all are saturated, pick any endpoint from the filtered list + (the request will queue on that endpoint). + 6️⃣ If no endpoint advertises the model at all, raise an error. + """ + config = get_config() + # 1️⃣ Gather advertised‑model sets for all endpoints concurrently + # Include both config.endpoints and config.llama_server_endpoints + llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + all_endpoints = config.endpoints + llama_eps_extra + + tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)] + tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)] + tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra] + advertised_sets = await asyncio.gather(*tag_tasks) + + # 2️⃣ Filter endpoints that advertise the requested model + candidate_endpoints = [ + ep for ep, models in zip(all_endpoints, advertised_sets) + if model in models + ] + + # 6️⃣ + if not candidate_endpoints: + if ":latest" in model: #ollama naming convention not applicable to openai/llama-server + model_without_latest = model.split(":latest")[0] + candidate_endpoints = [ + ep for ep, models in zip(all_endpoints, advertised_sets) + if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints) + ] + if not candidate_endpoints: + # Only add :latest suffix if model doesn't already have a version suffix + if ":" not in model: + model = model + ":latest" + candidate_endpoints = [ + ep for ep, models in zip(all_endpoints, advertised_sets) + if model in models + ] + if not candidate_endpoints: + raise RuntimeError( + f"None of the configured endpoints ({', '.join(all_endpoints)}) " + f"advertise the model '{model}'." + ) + # 3️⃣ Among the candidates, find those that have the model *loaded* + # (concurrently, but only for the filtered list) + load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] + loaded_sets = await asyncio.gather(*load_tasks) + + # 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing + # recently. Without this filter, an endpoint where `/api/ps` returns 5xx + # would appear with an empty loaded set but pass through to the + # free-slot fallback (step 4) — sending completion calls to an + # unhealthy backend. See issue #83. + async with _loaded_error_cache_lock: + unhealthy = { + ep for ep, ts in _loaded_error_cache.items() + if _is_fresh(ts, 300) + } + if unhealthy: + filtered = [ + (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets) + if ep not in unhealthy + ] + if filtered: + candidate_endpoints = [ep for ep, _ in filtered] + loaded_sets = [models for _, models in filtered] + # If *every* candidate is unhealthy we still fall through with the + # original list — refusing to route is worse than retrying a + # possibly-recovered backend. + + # 3️⃣.6 Exclude (endpoint, model) pairs whose completion path has recently + # failed with a backend connection error (e.g. llama-server in router mode + # whose delegated worker for *this* model died). /v1/models keeps reporting + # OK in that case, so the probe-level filter above cannot catch it. + async with _completion_error_cache_lock: + completion_broken = { + ep for (ep, m), ts in _completion_error_cache.items() + if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL) + } + if completion_broken: + filtered = [ + (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets) + if ep not in completion_broken + ] + if filtered: + candidate_endpoints = [ep for ep, _ in filtered] + loaded_sets = [models for _, models in filtered] + # Same fallback: if every candidate is broken for this model, fall + # through and let the upstream retry — possibly the operator restarted + # the dead worker. + + # Look up a possible affinity hint *before* taking usage_lock. The two + # locks are never held together to avoid lock-ordering issues. + affine_ep: Optional[str] = None + if config.conversation_affinity and affinity_key: + async with _affinity_lock: + entry = _affinity_map.get(affinity_key) + if entry is not None: + ep, _stored_model, expires_at = entry + if expires_at < time.monotonic(): + _affinity_map.pop(affinity_key, None) + else: + affine_ep = ep + + # Protect all reads/writes of usage_counts with the lock so that selection + # and reservation are atomic — concurrent callers see each other's pending load. + async with usage_lock: + # Helper: current usage for (endpoint, model) using the same normalized key + # that increment_usage/decrement_usage store — raw model names differ from + # tracking names for llama-server (HF prefix / quant suffix stripped). + def tracking_usage(ep: str) -> int: + return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0) + + def utilization_ratio(ep: str) -> float: + return tracking_usage(ep) / get_max_connections(ep) + + # Priority map: position in all_endpoints list (lower = higher priority) + ep_priority = {ep: i for i, ep in enumerate(all_endpoints)} + + selected: Optional[str] = None + + # 2️⃣.5 Conversation affinity preference — only honour the hint when + # the affine endpoint still advertises the model loaded *and* has a + # free slot. Otherwise fall back to the standard algorithm. + if affine_ep: + ep_loaded = { + ep: set(models) + for ep, models in zip(candidate_endpoints, loaded_sets) + } + if (affine_ep in candidate_endpoints + and model in ep_loaded.get(affine_ep, set()) + and tracking_usage(affine_ep) < get_max_connections(affine_ep)): + selected = affine_ep + + if selected is None: + # 3️⃣ Endpoints that have the model loaded *and* a free slot + loaded_and_free = [ + ep for ep, models in zip(candidate_endpoints, loaded_sets) + if model in models and tracking_usage(ep) < get_max_connections(ep) + ] + + if loaded_and_free: + if config.priority_routing: + # WRR: sort by config order first (stable), then by utilization ratio. + # Stable sort preserves priority for equal-ratio endpoints. + loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) + loaded_and_free.sort(key=utilization_ratio) + selected = loaded_and_free[0] + else: + # Sort ascending for load balancing — all endpoints here already have the + # model loaded, so there is no model-switching cost to optimise for. + loaded_and_free.sort(key=tracking_usage) + # When all candidates are equally idle, randomise to avoid always picking + # the first entry in a stable sort. + if all(tracking_usage(ep) == 0 for ep in loaded_and_free): + selected = random.choice(loaded_and_free) + else: + selected = loaded_and_free[0] + else: + # 4️⃣ Endpoints among the candidates that simply have a free slot + endpoints_with_free_slot = [ + ep for ep in candidate_endpoints + if tracking_usage(ep) < get_max_connections(ep) + ] + + if endpoints_with_free_slot: + if config.priority_routing: + endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) + endpoints_with_free_slot.sort(key=utilization_ratio) + selected = endpoints_with_free_slot[0] + else: + # Sort by total endpoint load (ascending) to prefer idle endpoints. + endpoints_with_free_slot.sort( + key=lambda ep: sum(usage_counts.get(ep, {}).values()) + ) + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): + selected = random.choice(endpoints_with_free_slot) + else: + selected = endpoints_with_free_slot[0] + else: + # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) + if config.priority_routing: + selected = min( + candidate_endpoints, + key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), + ) + else: + selected = min(candidate_endpoints, key=tracking_usage) + + tracking_model = get_tracking_model(selected, model) + snapshot = None + if reserve: + usage_counts[selected][tracking_model] += 1 + snapshot = _capture_snapshot() + if snapshot is not None: + await _distribute_snapshot(snapshot) + # Record / refresh affinity *after* releasing usage_lock. + if reserve and config.conversation_affinity and affinity_key: + expires_at = time.monotonic() + config.conversation_affinity_ttl + async with _affinity_lock: + _affinity_map[affinity_key] = (selected, model, expires_at) + if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: + now = time.monotonic() + for k in [k for k, v in _affinity_map.items() if v[2] < now]: + _affinity_map.pop(k, None) + return selected, tracking_model diff --git a/sse.py b/sse.py new file mode 100644 index 0000000..61e9c8b --- /dev/null +++ b/sse.py @@ -0,0 +1,62 @@ +"""Server-sent-events plumbing. + +Captures the current ``usage_counts`` / ``token_usage_counts`` snapshot and +fan-outs it to every subscribed asyncio.Queue. Routes that need a live +dashboard feed call ``subscribe`` / ``unsubscribe`` to obtain a queue. +""" +import asyncio +from typing import Dict + +import orjson + +from state import ( + usage_counts, + token_usage_counts, + _subscribers, + _subscribers_lock, +) + + +def _capture_snapshot() -> str: + """Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock.""" + return orjson.dumps({ + "usage_counts": dict(usage_counts), + "token_usage_counts": dict(token_usage_counts) + }, option=orjson.OPT_SORT_KEYS).decode("utf-8") + + +async def _distribute_snapshot(snapshot: str) -> None: + """Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock.""" + async with _subscribers_lock: + for q in _subscribers: + if q.full(): + try: + await q.get() + except asyncio.QueueEmpty: + pass + await q.put(snapshot) + + +async def close_all_sse_queues(): + for q in list(_subscribers): + # sentinel value that the generator will recognise + await q.put(None) + + +async def subscribe() -> asyncio.Queue: + """ + Returns a new Queue that will receive every snapshot. + """ + q: asyncio.Queue = asyncio.Queue(maxsize=10) + async with _subscribers_lock: + _subscribers.add(q) + return q + + +async def unsubscribe(q: asyncio.Queue): + async with _subscribers_lock: + _subscribers.discard(q) + + +async def get_usage_counts() -> Dict: + return dict(usage_counts) # shallow copy diff --git a/tokens.py b/tokens.py new file mode 100644 index 0000000..9667d18 --- /dev/null +++ b/tokens.py @@ -0,0 +1,177 @@ +"""Token-count write-behind pipeline. + +``token_worker`` drains ``token_queue`` into the in-memory buffer (and into +``token_usage_counts`` for immediate SSE reporting). ``flush_buffer`` +periodically persists the buffer to SQLite via ``TokenDatabase``. +``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left. + +The lock order is ``buffer_lock`` then ``token_usage_lock`` — see +choose_endpoint for why we never combine them with usage_lock. +""" +import asyncio +from datetime import datetime, timezone + +from state import ( + token_queue, + token_buffer, + time_series_buffer, + buffer_lock, + token_usage_counts, + token_usage_lock, + FLUSH_INTERVAL, +) +from sse import _capture_snapshot, _distribute_snapshot +from db import get_db + + +async def token_worker() -> None: + try: + while True: + endpoint, model, prompt, comp = await token_queue.get() + # Calculate timestamp once before acquiring lock + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) + + # Accumulate counts in memory buffer (protected by lock) + async with buffer_lock: + token_buffer[endpoint][model] = ( + token_buffer[endpoint].get(model, (0, 0))[0] + prompt, + token_buffer[endpoint].get(model, (0, 0))[1] + comp + ) + + # Add to time series buffer with timestamp (UTC) + time_series_buffer.append({ + 'endpoint': endpoint, + 'model': model, + 'input_tokens': prompt, + 'output_tokens': comp, + 'total_tokens': prompt + comp, + 'timestamp': timestamp + }) + + # Update in-memory counts for immediate reporting + async with token_usage_lock: + token_usage_counts[endpoint][model] += (prompt + comp) + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) + except asyncio.CancelledError: + # Gracefully handle task cancellation during shutdown + print("[token_worker] Task cancelled, processing remaining queue items...") + # Process any remaining items in the queue before exiting + while not token_queue.empty(): + try: + endpoint, model, prompt, comp = token_queue.get_nowait() + # Calculate timestamp once before acquiring lock + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) + + async with buffer_lock: + token_buffer[endpoint][model] = ( + token_buffer[endpoint].get(model, (0, 0))[0] + prompt, + token_buffer[endpoint].get(model, (0, 0))[1] + comp + ) + time_series_buffer.append({ + 'endpoint': endpoint, + 'model': model, + 'input_tokens': prompt, + 'output_tokens': comp, + 'total_tokens': prompt + comp, + 'timestamp': timestamp + }) + async with token_usage_lock: + token_usage_counts[endpoint][model] += (prompt + comp) + snapshot = _capture_snapshot() + await _distribute_snapshot(snapshot) + except asyncio.QueueEmpty: + break + print("[token_worker] Task cancelled, remaining items processed.") + raise + + +async def flush_buffer() -> None: + """Periodically flush accumulated token counts to the database.""" + try: + while True: + await asyncio.sleep(FLUSH_INTERVAL) + + # Flush token counts and time series (protected by lock) + async with buffer_lock: + if token_buffer: + # Copy buffer before releasing lock for DB operation + buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} + token_buffer.clear() + else: + buffer_copy = None + + if time_series_buffer: + ts_copy = list(time_series_buffer) + time_series_buffer.clear() + else: + ts_copy = None + + # Perform DB operations outside the lock to avoid blocking + db = get_db() + if buffer_copy: + await db.update_batched_counts(buffer_copy) + if ts_copy: + await db.add_batched_time_series(ts_copy) + except asyncio.CancelledError: + # Gracefully handle task cancellation during shutdown + print("[flush_buffer] Task cancelled, flushing remaining buffers...") + # Flush any remaining data before exiting + try: + async with buffer_lock: + if token_buffer: + buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} + token_buffer.clear() + else: + buffer_copy = None + if time_series_buffer: + ts_copy = list(time_series_buffer) + time_series_buffer.clear() + else: + ts_copy = None + db = get_db() + if buffer_copy: + await db.update_batched_counts(buffer_copy) + if ts_copy: + await db.add_batched_time_series(ts_copy) + print("[flush_buffer] Task cancelled, remaining buffers flushed.") + except Exception as e: + print(f"[flush_buffer] Error during shutdown flush: {e}") + raise + + +async def flush_remaining_buffers() -> None: + """ + Flush any in-memory buffers to the database on shutdown. + This is designed to be safely invoked during shutdown and should not raise. + """ + try: + flushed_entries = 0 + async with buffer_lock: + if token_buffer: + buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()} + flushed_entries += sum(len(v) for v in token_buffer.values()) + token_buffer.clear() + else: + buffer_copy = None + if time_series_buffer: + ts_copy = list(time_series_buffer) + flushed_entries += len(time_series_buffer) + time_series_buffer.clear() + else: + ts_copy = None + # Perform DB operations outside the lock + db = get_db() + if buffer_copy: + await db.update_batched_counts(buffer_copy) + if ts_copy: + await db.add_batched_time_series(ts_copy) + if flushed_entries: + print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.") + else: + print("[shutdown] No in-memory entries to flush on shutdown.") + except Exception as e: + # Do not raise during shutdown – log and continue teardown + print(f"[shutdown] Error flushing remaining buffers: {e}")