nomyo-router/routing.py

295 lines
13 KiB
Python
Raw Normal View History

"""Endpoint selection (load-balancing + conversation affinity).
``choose_endpoint`` is the heart of routing it picks an endpoint that
advertises the model, prefers ones with the model already loaded and a free
slot, applies the conversation-affinity hint when available, and honors
config-order priority routing when ``priority_routing`` is set.
``increment_usage`` / ``decrement_usage`` keep the per-(endpoint, model)
counter that drives utilization-based selection; they fan out an SSE
snapshot on every change.
"""
import asyncio
import random
import time
from typing import Optional
from config import get_config
from state import (
usage_counts,
usage_lock,
_loaded_error_cache,
_loaded_error_cache_lock,
_completion_error_cache,
_completion_error_cache_lock,
_COMPLETION_ERROR_TTL,
_affinity_map,
_affinity_lock,
_AFFINITY_MAX_ENTRIES,
)
from sse import _capture_snapshot, _distribute_snapshot
from backends.health import _is_fresh
from backends.normalize import (
is_ext_openai_endpoint,
is_openai_compatible,
get_tracking_model,
)
from backends.probe import fetch
async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock:
usage_counts[endpoint][model] += 1
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock:
# Avoid negative counts
current = usage_counts[endpoint].get(model, 0)
if current > 0:
usage_counts[endpoint][model] = current - 1
# Optionally, clean up zero entries
if usage_counts[endpoint].get(model, 0) == 0:
usage_counts[endpoint].pop(model, None)
#if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
def get_max_connections(ep: str) -> int:
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
cfg = get_config()
return cfg.endpoint_config.get(ep, {}).get(
"max_concurrent_connections", cfg.max_concurrent_connections
)
async def choose_endpoint(model: str, reserve: bool = True,
affinity_key: Optional[str] = None) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
ensuring that the chosen endpoint actually *advertises* the model.
The selection algorithm:
1 Query every endpoint for its advertised models (`/api/tags`).
2 Build a list of endpoints that contain the requested model.
2.5 If conversation affinity is enabled and the caller passes
``affinity_key``, prefer the endpoint that previously served the
same conversation but only when it still has the model loaded
and a free slot. Otherwise fall through to the standard logic.
3 For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
4 If none are both loaded and free, fall back to any endpoint
from the filtered list that simply has a free slot and randomly
select one.
5 If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
6 If no endpoint advertises the model at all, raise an error.
"""
config = get_config()
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
# Include both config.endpoints and config.llama_server_endpoints
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints = config.endpoints + llama_eps_extra
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
# 6
if not candidate_endpoints:
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
model_without_latest = model.split(":latest")[0]
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
]
if not candidate_endpoints:
# Only add :latest suffix if model doesn't already have a version suffix
if ":" not in model:
model = model + ":latest"
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
if not candidate_endpoints:
raise RuntimeError(
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
f"advertise the model '{model}'."
)
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# 3⃣.5 Exclude endpoints whose loaded-model probe has been failing
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
# would appear with an empty loaded set but pass through to the
# free-slot fallback (step 4) — sending completion calls to an
# unhealthy backend. See issue #83.
async with _loaded_error_cache_lock:
unhealthy = {
ep for ep, ts in _loaded_error_cache.items()
if _is_fresh(ts, 300)
}
if unhealthy:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in unhealthy
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# If *every* candidate is unhealthy we still fall through with the
# original list — refusing to route is worse than retrying a
# possibly-recovered backend.
# 3⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
# failed with a backend connection error (e.g. llama-server in router mode
# whose delegated worker for *this* model died). /v1/models keeps reporting
# OK in that case, so the probe-level filter above cannot catch it.
async with _completion_error_cache_lock:
completion_broken = {
ep for (ep, m), ts in _completion_error_cache.items()
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
}
if completion_broken:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in completion_broken
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# Same fallback: if every candidate is broken for this model, fall
# through and let the upstream retry — possibly the operator restarted
# the dead worker.
# Look up a possible affinity hint *before* taking usage_lock. The two
# locks are never held together to avoid lock-ordering issues.
affine_ep: Optional[str] = None
if config.conversation_affinity and affinity_key:
async with _affinity_lock:
entry = _affinity_map.get(affinity_key)
if entry is not None:
ep, _stored_model, expires_at = entry
if expires_at < time.monotonic():
_affinity_map.pop(affinity_key, None)
else:
affine_ep = ep
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
# tracking names for llama-server (HF prefix / quant suffix stripped).
def tracking_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
def utilization_ratio(ep: str) -> float:
return tracking_usage(ep) / get_max_connections(ep)
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
selected: Optional[str] = None
# 2⃣.5 Conversation affinity preference — only honour the hint when
# the affine endpoint still advertises the model loaded *and* has a
# free slot. Otherwise fall back to the standard algorithm.
if affine_ep:
ep_loaded = {
ep: set(models)
for ep, models in zip(candidate_endpoints, loaded_sets)
}
if (affine_ep in candidate_endpoints
and model in ep_loaded.get(affine_ep, set())
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
selected = affine_ep
if selected is None:
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < get_max_connections(ep)
]
if loaded_and_free:
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
else:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
selected = random.choice(loaded_and_free)
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < get_max_connections(ep)
]
if endpoints_with_free_slot:
if config.priority_routing:
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = endpoints_with_free_slot[0]
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model)
snapshot = None
if reserve:
usage_counts[selected][tracking_model] += 1
snapshot = _capture_snapshot()
if snapshot is not None:
await _distribute_snapshot(snapshot)
# Record / refresh affinity *after* releasing usage_lock.
if reserve and config.conversation_affinity and affinity_key:
expires_at = time.monotonic() + config.conversation_affinity_ttl
async with _affinity_lock:
_affinity_map[affinity_key] = (selected, model, expires_at)
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
now = time.monotonic()
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
_affinity_map.pop(k, None)
return selected, tracking_model