enhance routing logic
add a pre-routing model check: allows for different configs on the ollama backend servers
This commit is contained in:
parent
516ec8b102
commit
1403c08a81
1 changed files with 87 additions and 44 deletions
131
router.py
131
router.py
|
|
@ -59,6 +59,28 @@ usage_lock = asyncio.Lock() # protects access to usage_counts
|
|||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
# -------------------------------------------------------------
|
||||
async def fetch_available_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1.0) as client:
|
||||
resp = await client.get(f"{endpoint}/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
return {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
except Exception:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
return set()
|
||||
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
|
|
@ -108,50 +130,6 @@ def dedupe_on_keys(dicts, key_fields):
|
|||
out.append(d)
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting the
|
||||
`max_concurrent_connections` per endpoint‑model pair.
|
||||
|
||||
The selection algorithm is as follows:
|
||||
|
||||
1. Find all endpoints that have the model *already loaded* and that still
|
||||
have a free slot (< max_concurrent_connections).
|
||||
2. If none exist, find any endpoint that has a free slot regardless of
|
||||
whether the model is loaded – the endpoint will load the model on demand.
|
||||
3. If all endpoints are at capacity for this model, pick any endpoint
|
||||
arbitrarily – the request will queue on that endpoint.
|
||||
"""
|
||||
# Gather loaded models for all endpoints concurrently
|
||||
tasks = [fetch_loaded_models(ep) for ep in config.endpoints]
|
||||
loaded_sets = await asyncio.gather(*tasks)
|
||||
|
||||
async with usage_lock:
|
||||
# 1️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(config.endpoints, loaded_sets)
|
||||
if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
# Prefer an endpoint that already hosts the model and has capacity
|
||||
return random.choice(loaded_and_free)
|
||||
|
||||
# 2️⃣ Endpoints that simply have a free slot (model may or may not be loaded)
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in config.endpoints
|
||||
if usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
# 3️⃣ All endpoints are at capacity – pick any (will queue on that endpoint according to OLLAMA_MAX_QUEUE)
|
||||
return random.choice(config.endpoints)
|
||||
|
||||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
usage_counts[endpoint][model] += 1
|
||||
|
|
@ -168,6 +146,71 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
if not usage_counts[endpoint]:
|
||||
usage_counts.pop(endpoint, None)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
ensuring that the chosen endpoint actually *advertises* the model.
|
||||
|
||||
The selection algorithm:
|
||||
|
||||
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
|
||||
2️⃣ Build a list of endpoints that contain the requested model.
|
||||
3️⃣ For those endpoints, find those that have the model loaded
|
||||
(`/api/ps`) *and* still have a free slot.
|
||||
4️⃣ If none are both loaded and free, fall back to any endpoint
|
||||
from the filtered list that simply has a free slot.
|
||||
5️⃣ If all are saturated, pick any endpoint from the filtered list
|
||||
(the request will queue on that endpoint).
|
||||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(config.endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
|
||||
# 6️⃣
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
|
||||
f"advertise the model '{model}'."
|
||||
)
|
||||
|
||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||
# (concurrently, but only for the filtered list)
|
||||
load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
async with usage_lock:
|
||||
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
return random.choice(loaded_and_free)
|
||||
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
# 5️⃣ All candidate endpoints are saturated – pick any (will queue)
|
||||
return random.choice(candidate_endpoints)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 6. API route – Generate
|
||||
# -------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue