Add files via upload

adding persistent connections for endpoints
adding cache to available models routine
This commit is contained in:
Alpha Nerd 2025-09-01 11:07:07 +02:00 committed by GitHub
parent 9f19350f55
commit 0c6387f5af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 24 deletions

View file

@ -1,3 +1,4 @@
aiocache==0.12.3
annotated-types==0.7.0
anyio==4.10.0
certifi==2025.8.3

View file

@ -15,6 +15,7 @@ from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLR
from pydantic import Field
from pydantic_settings import BaseSettings
from collections import defaultdict
from aiocache import cached, Cache
# -------------------------------------------------------------
# 1. Configuration loader
@ -60,6 +61,21 @@ usage_lock = asyncio.Lock() # protects access to usage_counts
# -------------------------------------------------------------
# 4. Helperfunctions
# -------------------------------------------------------------
def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
"""
Use persistent connections to request endpoint info for reliable results
in high load situations or saturated endpoints.
"""
return httpx.AsyncClient(
base_url=endpoint,
timeout=httpx.Timeout(5.0, read=5.0, write=5.0, connect=5.0),
limits=httpx.Limits(
max_keepalive_connections=64,
max_connections=64
)
)
@cached(cache=Cache.MEMORY, ttl=300)
async def fetch_available_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
@ -70,41 +86,42 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
client = get_httpx_client(endpoint)
try:
async with httpx.AsyncClient(timeout=2.5) as client:
if "/v1" in endpoint:
resp = await client.get(f"{endpoint}/models")
else:
resp = await client.get(f"{endpoint}/api/tags")
resp.raise_for_status()
data = resp.json()
# Expected format:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
if "/v1" in endpoint:
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
else:
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
if "/v1" in endpoint:
resp = await client.get(f"/models")
else:
resp = await client.get(f"/api/tags")
resp.raise_for_status()
data = resp.json()
# Expected format:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
if "/v1" in endpoint:
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
else:
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
print(e)
return set()
async def fetch_loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
"""
client = get_httpx_client(endpoint)
try:
async with httpx.AsyncClient(timeout=1.0) as client:
resp = await client.get(f"{endpoint}/api/ps")
resp.raise_for_status()
data = resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
resp = await client.get(f"/api/ps")
resp.raise_for_status()
data = resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception:
# If anything goes wrong we simply assume the endpoint has no models
return set()
@ -193,7 +210,7 @@ async def choose_endpoint(model: str) -> str:
ep for ep, models in zip(config.endpoints, advertised_sets)
if model in models
]
# 6
if not candidate_endpoints:
raise RuntimeError(
@ -231,7 +248,7 @@ async def choose_endpoint(model: str) -> str:
ep = min(endpoints_with_free_slot, key=current_usage)
return ep
# 5⃣ All candidate endpoints are saturated pick any (will queue)
# 5⃣ All candidate endpoints are saturated pick one with lowest usages count (will queue)
ep = min(candidate_endpoints, key=current_usage)
return ep