moving from httpx to aiohttp
This commit is contained in:
parent
b6a9aa82cb
commit
d3e4555c8c
2 changed files with 43 additions and 60 deletions
|
|
@ -14,8 +14,6 @@ fastapi-sse==1.1.1
|
|||
frozenlist==1.7.0
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
httpx-aiohttp==0.1.8
|
||||
idna==3.10
|
||||
jiter==0.10.0
|
||||
multidict==6.6.4
|
||||
|
|
|
|||
101
router.py
101
router.py
|
|
@ -6,8 +6,7 @@ version: 0.2.2
|
|||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
|
||||
from httpx_aiohttp import AiohttpTransport
|
||||
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
|
|
@ -95,23 +94,15 @@ usage_lock = asyncio.Lock() # protects access to usage_counts
|
|||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
# -------------------------------------------------------------
|
||||
aiotimeout = aiohttp.ClientTimeout(total=5)
|
||||
|
||||
def _is_fresh(cached_at: float, ttl: int) -> bool:
|
||||
return (time.time() - cached_at) < ttl
|
||||
|
||||
def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
|
||||
"""
|
||||
Use persistent connections to request endpoint info for reliable results
|
||||
in high load situations or saturated endpoints.
|
||||
"""
|
||||
return httpx.AsyncClient(
|
||||
base_url=endpoint,
|
||||
timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0),
|
||||
#limits=httpx.Limits(
|
||||
# max_keepalive_connections=64,
|
||||
# max_connections=64
|
||||
#),
|
||||
transport=AiohttpTransport()
|
||||
)
|
||||
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
||||
if resp.status >= 400:
|
||||
text = await resp.text()
|
||||
raise HTTPException(status_code=resp.status, detail=text)
|
||||
|
||||
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -143,35 +134,33 @@ async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -
|
|||
# Error expired – remove it
|
||||
del _error_cache[endpoint]
|
||||
|
||||
if "/v1" in endpoint:
|
||||
endpoint_url = f"{endpoint}/models"
|
||||
key = "data"
|
||||
else:
|
||||
endpoint_url = f"{endpoint}/api/tags"
|
||||
key = "models"
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"/models", headers=headers)
|
||||
else:
|
||||
resp = await client.get(f"/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
|
||||
else:
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
async with aiohttp.ClientSession(timeout=aiotimeout) as client:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
print(f"[fetch_available_models] {endpoint} error: {e}")
|
||||
_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
|
|
@ -181,10 +170,10 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
|||
set is returned.
|
||||
"""
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
resp = await client.get(f"/api/ps")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
async with aiohttp.ClientSession(timeout=aiotimeout) as client:
|
||||
async with client.get(f"/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
|
@ -192,8 +181,6 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
|||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
|
|
@ -205,18 +192,16 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key
|
|||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
resp = await client.get(f"{route}", headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
async with aiohttp.ClientSession(timeout=aiotimeout) as client:
|
||||
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return []
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
|
|
@ -960,22 +945,22 @@ async def config_proxy(request: Request):
|
|||
"""
|
||||
async def check_endpoint(url: str):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client:
|
||||
async with aiohttp.ClientSession(timeout=aiotimeout) as client:
|
||||
if "/v1" in url:
|
||||
headers = {"Authorization": "Bearer " + config.api_keys[url]}
|
||||
r = await client.get(f"{url}/models", headers=headers)
|
||||
async with client.get(f"{url}/models", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
else:
|
||||
r = await client.get(f"{url}/api/version")
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
async with client.get(f"{url}/api/version") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
if "/v1" in url:
|
||||
return {"url": url, "status": "ok", "version": "latest"}
|
||||
else:
|
||||
return {"url": url, "status": "ok", "version": data.get("version")}
|
||||
except Exception as exc:
|
||||
return {"url": url, "status": "error", "detail": str(exc)}
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
|
||||
return {"endpoints": results}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue