diff --git a/requirements.txt b/requirements.txt index e39b50c..4d43ce4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiocache==0.12.3 annotated-types==0.7.0 anyio==4.10.0 certifi==2025.8.3 diff --git a/router.py b/router.py index d356885..9173976 100644 --- a/router.py +++ b/router.py @@ -15,6 +15,7 @@ from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLR from pydantic import Field from pydantic_settings import BaseSettings from collections import defaultdict +from aiocache import cached, Cache # ------------------------------------------------------------- # 1. Configuration loader @@ -60,6 +61,21 @@ usage_lock = asyncio.Lock() # protects access to usage_counts # ------------------------------------------------------------- # 4. Helperfunctions # ------------------------------------------------------------- +def get_httpx_client(endpoint: str) -> httpx.AsyncClient: + """ + Use persistent connections to request endpoint info for reliable results + in high load situations or saturated endpoints. + """ + return httpx.AsyncClient( + base_url=endpoint, + timeout=httpx.Timeout(5.0, read=5.0, write=5.0, connect=5.0), + limits=httpx.Limits( + max_keepalive_connections=64, + max_connections=64 + ) + ) + +@cached(cache=Cache.MEMORY, ttl=300) async def fetch_available_models(endpoint: str) -> Set[str]: """ Query /api/tags and return a set of all model names that the @@ -70,41 +86,42 @@ async def fetch_available_models(endpoint: str) -> Set[str]: If the request fails (e.g. timeout, 5xx, or malformed response), an empty set is returned. """ + client = get_httpx_client(endpoint) try: - async with httpx.AsyncClient(timeout=2.5) as client: - if "/v1" in endpoint: - resp = await client.get(f"{endpoint}/models") - else: - resp = await client.get(f"{endpoint}/api/tags") - resp.raise_for_status() - data = resp.json() - # Expected format: - # {"models": [{"name": "model1"}, {"name": "model2"}]} - if "/v1" in endpoint: - models = {m.get("id") for m in data.get("data", []) if m.get("name")} - else: - models = {m.get("name") for m in data.get("models", []) if m.get("name")} - return models + if "/v1" in endpoint: + resp = await client.get(f"/models") + else: + resp = await client.get(f"/api/tags") + resp.raise_for_status() + data = resp.json() + # Expected format: + # {"models": [{"name": "model1"}, {"name": "model2"}]} + if "/v1" in endpoint: + models = {m.get("id") for m in data.get("data", []) if m.get("name")} + else: + models = {m.get("name") for m in data.get("models", []) if m.get("name")} + return models except Exception as e: # Treat any error as if the endpoint offers no models print(e) return set() + async def fetch_loaded_models(endpoint: str) -> Set[str]: """ Query /api/ps and return a set of model names that are currently loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty set is returned. """ + client = get_httpx_client(endpoint) try: - async with httpx.AsyncClient(timeout=1.0) as client: - resp = await client.get(f"{endpoint}/api/ps") - resp.raise_for_status() - data = resp.json() - # The response format is: - # {"models": [{"name": "model1"}, {"name": "model2"}]} - models = {m.get("name") for m in data.get("models", []) if m.get("name")} - return models + resp = await client.get(f"/api/ps") + resp.raise_for_status() + data = resp.json() + # The response format is: + # {"models": [{"name": "model1"}, {"name": "model2"}]} + models = {m.get("name") for m in data.get("models", []) if m.get("name")} + return models except Exception: # If anything goes wrong we simply assume the endpoint has no models return set() @@ -193,7 +210,7 @@ async def choose_endpoint(model: str) -> str: ep for ep, models in zip(config.endpoints, advertised_sets) if model in models ] - + # 6️⃣ if not candidate_endpoints: raise RuntimeError( @@ -231,7 +248,7 @@ async def choose_endpoint(model: str) -> str: ep = min(endpoints_with_free_slot, key=current_usage) return ep - # 5️⃣ All candidate endpoints are saturated – pick any (will queue) + # 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue) ep = min(candidate_endpoints, key=current_usage) return ep