From 27dfc07889970f2446e158d70f4e1b0c33023be8 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Tue, 12 May 2026 18:33:47 +0200 Subject: [PATCH] feat: add conversation-endpoint affinity to benefit from hot kv-caches if possible --- config.yaml | 10 +++ router.py | 199 ++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 157 insertions(+), 52 deletions(-) diff --git a/config.yaml b/config.yaml index 76fbbe1..6fc57e6 100644 --- a/config.yaml +++ b/config.yaml @@ -26,6 +26,16 @@ max_concurrent_connections: 2 # When false (default), equally-idle endpoints are chosen at random. # priority_routing: true +# Conversation affinity (optional, default: false). +# Routes follow-up requests back to the endpoint that previously served the +# same conversation so the llama.cpp / Ollama prompt cache (KV cache) stays +# warm — first turn does a cold prefill, follow-ups skip it. Soft preference: +# falls back to the standard algorithm when the affine endpoint no longer has +# the model loaded or has no free slot. Conversations are fingerprinted by +# (model, first system + first user turn). +# conversation_affinity: true +# conversation_affinity_ttl: 300 # seconds; matches Ollama's default keep_alive + # Optional router-level API key that gates router/API/web UI access (leave empty to disable) nomyo-router-api-key: "" diff --git a/router.py b/router.py index 603387e..34e02da 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.7 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx, hashlib try: import truststore; truststore.inject_into_ssl() except ImportError: @@ -223,6 +223,15 @@ class Config(BaseSettings): # When True, config order = priority; routes by utilization ratio + config index (WRR) priority_routing: bool = Field(default=False) + # Conversation affinity: route the same conversation back to the endpoint that + # previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm. + # Soft preference — falls back to the standard algorithm when the affine endpoint + # is saturated or no longer has the model loaded. + conversation_affinity: bool = Field(default=False) + # TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min): + # if the backend has already evicted the model, the KV cache is cold anyway. + conversation_affinity_ttl: int = Field(default=300) + api_keys: Dict[str, str] = Field(default_factory=dict) # Optional router-level API key used to gate access to this service and dashboard router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY") @@ -436,6 +445,45 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict( usage_lock = asyncio.Lock() # protects access to usage_counts token_usage_lock = asyncio.Lock() +# Conversation affinity map: fingerprint -> (endpoint, expires_at_monotonic). +# Keeps the same conversation pinned to the endpoint that already has its +# KV-cache prefix warm. Never held together with usage_lock. +_affinity_map: Dict[str, tuple[str, float]] = {} +_affinity_lock = asyncio.Lock() +_AFFINITY_MAX_ENTRIES = 10000 + + +def _conversation_fingerprint(model: str, messages: Optional[list], + prompt: Optional[str]) -> Optional[str]: + """ + Stable hash over (model, first system + first user turn). That prefix + determines whether the backend's prompt cache is reusable; later turns + don't influence the routing decision because they extend the same prefix. + Returns None when there is no usable prefix. + """ + parts: list[str] = [model or "_"] + if messages: + for m in messages: + role = m.get("role") if isinstance(m, dict) else None + if role not in ("system", "user"): + continue + content = m.get("content") + if isinstance(content, list): # OpenAI multimodal parts + content = "".join( + p.get("text", "") for p in content + if isinstance(p, dict) and p.get("type") == "text" + ) + if not isinstance(content, str): + continue + parts.append(f"{role}:{content}") + if role == "user": + break + elif prompt: + parts.append(f"user:{prompt}") + else: + return None + return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest() + # Database instance db: "TokenDatabase" = None @@ -1738,7 +1786,8 @@ def get_max_connections(ep: str) -> int: "max_concurrent_connections", config.max_concurrent_connections ) -async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: +async def choose_endpoint(model: str, reserve: bool = True, + affinity_key: Optional[str] = None) -> tuple[str, str]: """ Determine which endpoint to use for the given model while respecting the `max_concurrent_connections` per endpoint‑model pair **and** @@ -1748,10 +1797,14 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: 1️⃣ Query every endpoint for its advertised models (`/api/tags`). 2️⃣ Build a list of endpoints that contain the requested model. + 2️⃣.5 If conversation affinity is enabled and the caller passes + ``affinity_key``, prefer the endpoint that previously served the + same conversation — but only when it still has the model loaded + and a free slot. Otherwise fall through to the standard logic. 3️⃣ For those endpoints, find those that have the model loaded (`/api/ps`) *and* still have a free slot. 4️⃣ If none are both loaded and free, fall back to any endpoint - from the filtered list that simply has a free slot and randomly + from the filtered list that simply has a free slot and randomly select one. 5️⃣ If all are saturated, pick any endpoint from the filtered list (the request will queue on that endpoint). @@ -1799,6 +1852,19 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) + # Look up a possible affinity hint *before* taking usage_lock. The two + # locks are never held together to avoid lock-ordering issues. + affine_ep: Optional[str] = None + if config.conversation_affinity and affinity_key: + async with _affinity_lock: + entry = _affinity_map.get(affinity_key) + if entry is not None: + ep, expires_at = entry + if expires_at < time.monotonic(): + _affinity_map.pop(affinity_key, None) + else: + affine_ep = ep + # Protect all reads/writes of usage_counts with the lock so that selection # and reservation are atomic — concurrent callers see each other's pending load. async with usage_lock: @@ -1814,59 +1880,75 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: # Priority map: position in all_endpoints list (lower = higher priority) ep_priority = {ep: i for i, ep in enumerate(all_endpoints)} - # 3️⃣ Endpoints that have the model loaded *and* a free slot - loaded_and_free = [ - ep for ep, models in zip(candidate_endpoints, loaded_sets) - if model in models and tracking_usage(ep) < get_max_connections(ep) - ] + selected: Optional[str] = None - if loaded_and_free: - if config.priority_routing: - # WRR: sort by config order first (stable), then by utilization ratio. - # Stable sort preserves priority for equal-ratio endpoints. - loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) - loaded_and_free.sort(key=utilization_ratio) - selected = loaded_and_free[0] - else: - # Sort ascending for load balancing — all endpoints here already have the - # model loaded, so there is no model-switching cost to optimise for. - loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking - # the first entry in a stable sort. - if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - selected = random.choice(loaded_and_free) - else: - selected = loaded_and_free[0] - else: - # 4️⃣ Endpoints among the candidates that simply have a free slot - endpoints_with_free_slot = [ - ep for ep in candidate_endpoints - if tracking_usage(ep) < get_max_connections(ep) + # 2️⃣.5 Conversation affinity preference — only honour the hint when + # the affine endpoint still advertises the model loaded *and* has a + # free slot. Otherwise fall back to the standard algorithm. + if affine_ep: + ep_loaded = { + ep: set(models) + for ep, models in zip(candidate_endpoints, loaded_sets) + } + if (affine_ep in candidate_endpoints + and model in ep_loaded.get(affine_ep, set()) + and tracking_usage(affine_ep) < get_max_connections(affine_ep)): + selected = affine_ep + + if selected is None: + # 3️⃣ Endpoints that have the model loaded *and* a free slot + loaded_and_free = [ + ep for ep, models in zip(candidate_endpoints, loaded_sets) + if model in models and tracking_usage(ep) < get_max_connections(ep) ] - if endpoints_with_free_slot: + if loaded_and_free: if config.priority_routing: - endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) - endpoints_with_free_slot.sort(key=utilization_ratio) - selected = endpoints_with_free_slot[0] + # WRR: sort by config order first (stable), then by utilization ratio. + # Stable sort preserves priority for equal-ratio endpoints. + loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) + loaded_and_free.sort(key=utilization_ratio) + selected = loaded_and_free[0] else: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - selected = random.choice(endpoints_with_free_slot) + # Sort ascending for load balancing — all endpoints here already have the + # model loaded, so there is no model-switching cost to optimise for. + loaded_and_free.sort(key=tracking_usage) + # When all candidates are equally idle, randomise to avoid always picking + # the first entry in a stable sort. + if all(tracking_usage(ep) == 0 for ep in loaded_and_free): + selected = random.choice(loaded_and_free) else: - selected = endpoints_with_free_slot[0] + selected = loaded_and_free[0] else: - # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) - if config.priority_routing: - selected = min( - candidate_endpoints, - key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), - ) + # 4️⃣ Endpoints among the candidates that simply have a free slot + endpoints_with_free_slot = [ + ep for ep in candidate_endpoints + if tracking_usage(ep) < get_max_connections(ep) + ] + + if endpoints_with_free_slot: + if config.priority_routing: + endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) + endpoints_with_free_slot.sort(key=utilization_ratio) + selected = endpoints_with_free_slot[0] + else: + # Sort by total endpoint load (ascending) to prefer idle endpoints. + endpoints_with_free_slot.sort( + key=lambda ep: sum(usage_counts.get(ep, {}).values()) + ) + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): + selected = random.choice(endpoints_with_free_slot) + else: + selected = endpoints_with_free_slot[0] else: - selected = min(candidate_endpoints, key=tracking_usage) + # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) + if config.priority_routing: + selected = min( + candidate_endpoints, + key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), + ) + else: + selected = min(candidate_endpoints, key=tracking_usage) tracking_model = get_tracking_model(selected, model) snapshot = None @@ -1875,6 +1957,15 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: snapshot = _capture_snapshot() if snapshot is not None: await _distribute_snapshot(snapshot) + # Record / refresh affinity *after* releasing usage_lock. + if reserve and config.conversation_affinity and affinity_key: + expires_at = time.monotonic() + config.conversation_affinity_ttl + async with _affinity_lock: + _affinity_map[affinity_key] = (selected, expires_at) + if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: + now = time.monotonic() + for k in [k for k, v in _affinity_map.items() if v[1] < now]: + _affinity_map.pop(k, None) return selected, tracking_model # ------------------------------------------------------------- @@ -1925,7 +2016,8 @@ async def proxy(request: Request): yield _cached return StreamingResponse(_serve_cached_generate(), media_type="application/json") - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, None, prompt) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: @@ -2095,7 +2187,8 @@ async def chat_proxy(request: Request): opt = True else: opt = False - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: @@ -3228,7 +3321,8 @@ async def openai_chat_completions_proxy(request: Request): return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json") # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Helpers and API call — done in handler scope so try/except works reliably async def _normalize_images_in_messages(msgs: list) -> list: @@ -3538,7 +3632,8 @@ async def openai_completions_proxy(request: Request): return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json") # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, None, prompt) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter