Add files via upload
adding persistent connections for endpoints adding cache to available models routine
This commit is contained in:
parent
9f19350f55
commit
0c6387f5af
2 changed files with 42 additions and 24 deletions
|
|
@ -1,3 +1,4 @@
|
|||
aiocache==0.12.3
|
||||
annotated-types==0.7.0
|
||||
anyio==4.10.0
|
||||
certifi==2025.8.3
|
||||
|
|
|
|||
65
router.py
65
router.py
|
|
@ -15,6 +15,7 @@ from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLR
|
|||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from collections import defaultdict
|
||||
from aiocache import cached, Cache
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 1. Configuration loader
|
||||
|
|
@ -60,6 +61,21 @@ usage_lock = asyncio.Lock() # protects access to usage_counts
|
|||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
# -------------------------------------------------------------
|
||||
def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
|
||||
"""
|
||||
Use persistent connections to request endpoint info for reliable results
|
||||
in high load situations or saturated endpoints.
|
||||
"""
|
||||
return httpx.AsyncClient(
|
||||
base_url=endpoint,
|
||||
timeout=httpx.Timeout(5.0, read=5.0, write=5.0, connect=5.0),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=64,
|
||||
max_connections=64
|
||||
)
|
||||
)
|
||||
|
||||
@cached(cache=Cache.MEMORY, ttl=300)
|
||||
async def fetch_available_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
|
|
@ -70,41 +86,42 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
client = get_httpx_client(endpoint)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=2.5) as client:
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"{endpoint}/models")
|
||||
else:
|
||||
resp = await client.get(f"{endpoint}/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
|
||||
else:
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"/models")
|
||||
else:
|
||||
resp = await client.get(f"/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
|
||||
else:
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
print(e)
|
||||
return set()
|
||||
|
||||
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
client = get_httpx_client(endpoint)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1.0) as client:
|
||||
resp = await client.get(f"{endpoint}/api/ps")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
resp = await client.get(f"/api/ps")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
|
|
@ -193,7 +210,7 @@ async def choose_endpoint(model: str) -> str:
|
|||
ep for ep, models in zip(config.endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
|
||||
|
||||
# 6️⃣
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
|
|
@ -231,7 +248,7 @@ async def choose_endpoint(model: str) -> str:
|
|||
ep = min(endpoints_with_free_slot, key=current_usage)
|
||||
return ep
|
||||
|
||||
# 5️⃣ All candidate endpoints are saturated – pick any (will queue)
|
||||
# 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue)
|
||||
ep = min(candidate_endpoints, key=current_usage)
|
||||
return ep
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue