diff --git a/router.py b/router.py index 169f190..98f1fcb 100644 --- a/router.py +++ b/router.py @@ -59,6 +59,28 @@ usage_lock = asyncio.Lock() # protects access to usage_counts # ------------------------------------------------------------- # 4. Helperfunctions # ------------------------------------------------------------- +async def fetch_available_models(endpoint: str) -> Set[str]: + """ + Query /api/tags and return a set of all model names that the + endpoint *advertises* (i.e. is capable of serving). This endpoint lists + every model that is installed on the Ollama instance, regardless of + whether the model is currently loaded into memory. + + If the request fails (e.g. timeout, 5xx, or malformed response), an empty + set is returned. + """ + try: + async with httpx.AsyncClient(timeout=1.0) as client: + resp = await client.get(f"{endpoint}/api/tags") + resp.raise_for_status() + data = resp.json() + # Expected format: + # {"models": [{"name": "model1"}, {"name": "model2"}]} + return {m.get("name") for m in data.get("models", []) if m.get("name")} + except Exception: + # Treat any error as if the endpoint offers no models + return set() + async def fetch_loaded_models(endpoint: str) -> Set[str]: """ Query /api/ps and return a set of model names that are currently @@ -108,50 +130,6 @@ def dedupe_on_keys(dicts, key_fields): out.append(d) return out -# ------------------------------------------------------------- -# 5. Endpoint selection logic (respecting the configurable limit) -# ------------------------------------------------------------- -async def choose_endpoint(model: str) -> str: - """ - Determine which endpoint to use for the given model while respecting the - `max_concurrent_connections` per endpoint‑model pair. - - The selection algorithm is as follows: - - 1. Find all endpoints that have the model *already loaded* and that still - have a free slot (< max_concurrent_connections). - 2. If none exist, find any endpoint that has a free slot regardless of - whether the model is loaded – the endpoint will load the model on demand. - 3. If all endpoints are at capacity for this model, pick any endpoint - arbitrarily – the request will queue on that endpoint. - """ - # Gather loaded models for all endpoints concurrently - tasks = [fetch_loaded_models(ep) for ep in config.endpoints] - loaded_sets = await asyncio.gather(*tasks) - - async with usage_lock: - # 1️⃣ Endpoints that have the model loaded *and* a free slot - loaded_and_free = [ - ep for ep, models in zip(config.endpoints, loaded_sets) - if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections - ] - - if loaded_and_free: - # Prefer an endpoint that already hosts the model and has capacity - return random.choice(loaded_and_free) - - # 2️⃣ Endpoints that simply have a free slot (model may or may not be loaded) - endpoints_with_free_slot = [ - ep for ep in config.endpoints - if usage_counts[ep].get(model, 0) < config.max_concurrent_connections - ] - - if endpoints_with_free_slot: - return random.choice(endpoints_with_free_slot) - - # 3️⃣ All endpoints are at capacity – pick any (will queue on that endpoint according to OLLAMA_MAX_QUEUE) - return random.choice(config.endpoints) - async def increment_usage(endpoint: str, model: str) -> None: async with usage_lock: usage_counts[endpoint][model] += 1 @@ -168,6 +146,71 @@ async def decrement_usage(endpoint: str, model: str) -> None: if not usage_counts[endpoint]: usage_counts.pop(endpoint, None) +# ------------------------------------------------------------- +# 5. Endpoint selection logic (respecting the configurable limit) +# ------------------------------------------------------------- +async def choose_endpoint(model: str) -> str: + """ + Determine which endpoint to use for the given model while respecting + the `max_concurrent_connections` per endpoint‑model pair **and** + ensuring that the chosen endpoint actually *advertises* the model. + + The selection algorithm: + + 1️⃣ Query every endpoint for its advertised models (`/api/tags`). + 2️⃣ Build a list of endpoints that contain the requested model. + 3️⃣ For those endpoints, find those that have the model loaded + (`/api/ps`) *and* still have a free slot. + 4️⃣ If none are both loaded and free, fall back to any endpoint + from the filtered list that simply has a free slot. + 5️⃣ If all are saturated, pick any endpoint from the filtered list + (the request will queue on that endpoint). + 6️⃣ If no endpoint advertises the model at all, raise an error. + """ + # 1️⃣ Gather advertised‑model sets for all endpoints concurrently + tag_tasks = [fetch_available_models(ep) for ep in config.endpoints] + advertised_sets = await asyncio.gather(*tag_tasks) + + # 2️⃣ Filter endpoints that advertise the requested model + candidate_endpoints = [ + ep for ep, models in zip(config.endpoints, advertised_sets) + if model in models + ] + + # 6️⃣ + if not candidate_endpoints: + raise RuntimeError( + f"None of the configured endpoints ({', '.join(config.endpoints)}) " + f"advertise the model '{model}'." + ) + + # 3️⃣ Among the candidates, find those that have the model *loaded* + # (concurrently, but only for the filtered list) + load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints] + loaded_sets = await asyncio.gather(*load_tasks) + + async with usage_lock: + # 3️⃣ Endpoints that have the model loaded *and* a free slot + loaded_and_free = [ + ep for ep, models in zip(candidate_endpoints, loaded_sets) + if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections + ] + + if loaded_and_free: + return random.choice(loaded_and_free) + + # 4️⃣ Endpoints among the candidates that simply have a free slot + endpoints_with_free_slot = [ + ep for ep in candidate_endpoints + if usage_counts[ep].get(model, 0) < config.max_concurrent_connections + ] + + if endpoints_with_free_slot: + return random.choice(endpoints_with_free_slot) + + # 5️⃣ All candidate endpoints are saturated – pick any (will queue) + return random.choice(candidate_endpoints) + # ------------------------------------------------------------- # 6. API route – Generate # -------------------------------------------------------------