From c88ba1e5a4d4f1cbf6a827d253bcd4642e21a1a8 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Tue, 19 May 2026 11:18:06 +0200 Subject: [PATCH] refac: modularize global states III --- router.py | 121 +++++++++++++++++------------------------------------- state.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 84 deletions(-) create mode 100644 state.py diff --git a/router.py b/router.py index 825bafc..ecbad68 100644 --- a/router.py +++ b/router.py @@ -28,51 +28,6 @@ from pydantic_settings import BaseSettings from collections import defaultdict from PIL import Image -# ------------------------------------------------------------------ -# In‑memory caches -# ------------------------------------------------------------------ -# Successful results are cached for 300s -_models_cache: dict[str, tuple[Set[str], float]] = {} -_loaded_models_cache: dict[str, tuple[Set[str], float]] = {} -# Transient errors are cached separately per concern so that a failure -# in one path does not poison the other. -_available_error_cache: dict[str, float] = {} -_loaded_error_cache: dict[str, float] = {} -# Per-(endpoint, model) completion-path failures. A llama-server in router -# mode can keep returning /v1/models 200 OK after its delegated worker for -# a specific model dies — the probe-level caches above will not catch this. -# We record signals observed during actual completion attempts so -# choose_endpoint can avoid the affected (endpoint, model) pair without -# poisoning unrelated models on the same backend. -_completion_error_cache: dict[tuple[str, str], float] = {} -_COMPLETION_ERROR_TTL = 300 - -# ------------------------------------------------------------------ -# Cache locks -# ------------------------------------------------------------------ -_models_cache_lock = asyncio.Lock() -_loaded_models_cache_lock = asyncio.Lock() -_available_error_cache_lock = asyncio.Lock() -_loaded_error_cache_lock = asyncio.Lock() -_completion_error_cache_lock = asyncio.Lock() - -# ------------------------------------------------------------------ -# In-flight request tracking (prevents cache stampede) -# ------------------------------------------------------------------ -_inflight_available_models: dict[str, asyncio.Task] = {} -_inflight_loaded_models: dict[str, asyncio.Task] = {} -_inflight_lock = asyncio.Lock() -_bg_refresh_available: dict[str, asyncio.Task] = {} -_bg_refresh_loaded: dict[str, asyncio.Task] = {} -_bg_refresh_lock = asyncio.Lock() - -# ------------------------------------------------------------------ -# Queues -# ------------------------------------------------------------------ -_subscribers: Set[asyncio.Queue] = set() -_subscribers_lock = asyncio.Lock() -token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue() - from security import _mask_secrets from context_window import ( _count_message_tokens, @@ -81,32 +36,38 @@ from context_window import ( _endpoint_nctx, _CTX_TRIM_SMALL_LIMIT, ) +from state import ( + _models_cache, + _loaded_models_cache, + _available_error_cache, + _loaded_error_cache, + _completion_error_cache, + _COMPLETION_ERROR_TTL, + _models_cache_lock, + _loaded_models_cache_lock, + _available_error_cache_lock, + _loaded_error_cache_lock, + _completion_error_cache_lock, + _inflight_available_models, + _inflight_loaded_models, + _inflight_lock, + _bg_refresh_available, + _bg_refresh_loaded, + _bg_refresh_lock, + _subscribers, + _subscribers_lock, + token_queue, + app_state, + token_buffer, + time_series_buffer, + buffer_lock, + FLUSH_INTERVAL, +) -# ------------------------------------------------------------------ -# Globals -# ------------------------------------------------------------------ -app_state = { - "session": None, - "connector": None, - "socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints - "httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints -} +# Rebound on startup — must stay in router.py module namespace. token_worker_task: asyncio.Task | None = None flush_task: asyncio.Task | None = None -# ------------------------------------------------------------------ -# Token Count Buffer (for write-behind pattern) -# ------------------------------------------------------------------ -# Structure: {endpoint: {model: (input_tokens, output_tokens)}} -token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0))) -# Time series buffer with timestamp -time_series_buffer: list[dict[str, int | str]] = [] -# Lock to protect buffer access from race conditions -buffer_lock = asyncio.Lock() - -# Configuration for periodic flushing -FLUSH_INTERVAL = 10 # seconds - from config import Config, _config_path_from_env from ollama._types import TokenLogprob, Logprob @@ -228,23 +189,15 @@ async def _openai_api_status_error_handler(request: Request, exc: openai.APIStat return JSONResponse(status_code=exc.status_code, content=body) -# ------------------------------------------------------------- -# 3. Global state: per‑endpoint per‑model active connection counters -# ------------------------------------------------------------- -usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) -token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) -usage_lock = asyncio.Lock() # protects access to usage_counts -token_usage_lock = asyncio.Lock() - -# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic). -# Keeps the same conversation pinned to the endpoint that already has its -# KV-cache prefix warm. Model is stored so the dashboard can aggregate live -# entries per (endpoint, model) without recomputing fingerprints. -# Never held together with usage_lock. -_affinity_map: Dict[str, tuple[str, str, float]] = {} -_affinity_lock = asyncio.Lock() -_AFFINITY_MAX_ENTRIES = 10000 - +from state import ( + usage_counts, + token_usage_counts, + usage_lock, + token_usage_lock, + _affinity_map, + _affinity_lock, + _AFFINITY_MAX_ENTRIES, +) from fingerprint import _conversation_fingerprint diff --git a/state.py b/state.py new file mode 100644 index 0000000..9f2b3cd --- /dev/null +++ b/state.py @@ -0,0 +1,100 @@ +"""Shared mutable router state. + +All process-wide caches, locks, in-flight task maps, queues, counters and +buffers used by the router live here. These names are only ever *mutated* +(dict/set updates, lock acquisitions, queue put/get) — never rebound — so +importing them via ``from state import …`` is safe from every module. + +Rebound singletons (``config``, ``db``, ``token_worker_task``, +``flush_task``) intentionally stay in router.py so their reassignment on +startup is visible to all callers. +""" +import asyncio +from collections import defaultdict +from typing import Dict, Set + + +# ------------------------------------------------------------------ +# In‑memory caches +# ------------------------------------------------------------------ +# Successful results are cached for 300s +_models_cache: dict[str, tuple[Set[str], float]] = {} +_loaded_models_cache: dict[str, tuple[Set[str], float]] = {} +# Transient errors are cached separately per concern so that a failure +# in one path does not poison the other. +_available_error_cache: dict[str, float] = {} +_loaded_error_cache: dict[str, float] = {} +# Per-(endpoint, model) completion-path failures. A llama-server in router +# mode can keep returning /v1/models 200 OK after its delegated worker for +# a specific model dies — the probe-level caches above will not catch this. +# We record signals observed during actual completion attempts so +# choose_endpoint can avoid the affected (endpoint, model) pair without +# poisoning unrelated models on the same backend. +_completion_error_cache: dict[tuple[str, str], float] = {} +_COMPLETION_ERROR_TTL = 300 + +# ------------------------------------------------------------------ +# Cache locks +# ------------------------------------------------------------------ +_models_cache_lock = asyncio.Lock() +_loaded_models_cache_lock = asyncio.Lock() +_available_error_cache_lock = asyncio.Lock() +_loaded_error_cache_lock = asyncio.Lock() +_completion_error_cache_lock = asyncio.Lock() + +# ------------------------------------------------------------------ +# In-flight request tracking (prevents cache stampede) +# ------------------------------------------------------------------ +_inflight_available_models: dict[str, asyncio.Task] = {} +_inflight_loaded_models: dict[str, asyncio.Task] = {} +_inflight_lock = asyncio.Lock() +_bg_refresh_available: dict[str, asyncio.Task] = {} +_bg_refresh_loaded: dict[str, asyncio.Task] = {} +_bg_refresh_lock = asyncio.Lock() + +# ------------------------------------------------------------------ +# Queues +# ------------------------------------------------------------------ +_subscribers: Set[asyncio.Queue] = set() +_subscribers_lock = asyncio.Lock() +token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue() + +# ------------------------------------------------------------------ +# HTTP client / connector cache +# ------------------------------------------------------------------ +app_state = { + "session": None, + "connector": None, + "socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints + "httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints +} + +# ------------------------------------------------------------------ +# Token Count Buffer (for write-behind pattern) +# ------------------------------------------------------------------ +# Structure: {endpoint: {model: (input_tokens, output_tokens)}} +token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0))) +# Time series buffer with timestamp +time_series_buffer: list[dict[str, int | str]] = [] +# Lock to protect buffer access from race conditions +buffer_lock = asyncio.Lock() + +# Configuration for periodic flushing +FLUSH_INTERVAL = 10 # seconds + +# ------------------------------------------------------------------ +# Per‑endpoint per‑model active connection counters +# ------------------------------------------------------------------ +usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) +token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) +usage_lock = asyncio.Lock() # protects access to usage_counts +token_usage_lock = asyncio.Lock() + +# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic). +# Keeps the same conversation pinned to the endpoint that already has its +# KV-cache prefix warm. Model is stored so the dashboard can aggregate live +# entries per (endpoint, model) without recomputing fingerprints. +# Never held together with usage_lock. +_affinity_map: Dict[str, tuple[str, str, float]] = {} +_affinity_lock = asyncio.Lock() +_AFFINITY_MAX_ENTRIES = 10000