adding fetch class and ollama client completions on openai endpoints

This commit is contained in:
Alpha Nerd 2025-09-13 16:57:09 +02:00
parent 0a7fd8ca52
commit 9ea852f154

271
router.py
View file

@ -119,111 +119,112 @@ async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
text = await resp.text()
raise HTTPException(status_code=resp.status, detail=text)
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
class fetch:
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
if _is_fresh(cached_at, 300):
return models
else:
# stale entry drop it
del _models_cache[endpoint]
if endpoint in _error_cache:
if _is_fresh(_error_cache[endpoint], 1):
# Still within the short error TTL pretend nothing is available
return set()
else:
# Error expired remove it
del _error_cache[endpoint]
if "/v1" in endpoint:
endpoint_url = f"{endpoint}/models"
key = "data"
else:
endpoint_url = f"{endpoint}/api/tags"
key = "models"
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
if models:
_models_cache[endpoint] = (models, time.time())
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
if _is_fresh(cached_at, 300):
return models
else:
# Empty list treat as “no models”, but still cache for 300s
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
print(f"[fetch_available_models] {endpoint} error: {e}")
_error_cache[endpoint] = time.time()
return set()
# stale entry drop it
del _models_cache[endpoint]
if endpoint in _error_cache:
if _is_fresh(_error_cache[endpoint], 1):
# Still within the short error TTL pretend nothing is available
return set()
else:
# Error expired remove it
del _error_cache[endpoint]
if "/v1" in endpoint:
endpoint_url = f"{endpoint}/models"
key = "data"
else:
endpoint_url = f"{endpoint}/api/tags"
key = "models"
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
if models:
_models_cache[endpoint] = (models, time.time())
return models
else:
# Empty list treat as “no models”, but still cache for 300s
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
print(f"[fetch.available_models] {endpoint} error: {e}")
_error_cache[endpoint] = time.time()
return set()
async def fetch_loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
"""
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp)
data = await resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception:
# If anything goes wrong we simply assume the endpoint has no models
return set()
async def loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
"""
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp)
data = await resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception:
# If anything goes wrong we simply assume the endpoint has no models
return set()
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
"""
client: aiohttp.ClientSession = app_state["session"]
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
try:
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
detail = data.get(detail, [])
return detail
except Exception as e:
# If anything goes wrong we cannot reply details
print(e)
return []
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
"""
client: aiohttp.ClientSession = app_state["session"]
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
try:
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
detail = data.get(detail, [])
return detail
except Exception as e:
# If anything goes wrong we cannot reply details
print(e)
return []
def ep2base(ep):
if "/v1" in ep:
base_url = ep
else:
base_url = ep+"/v1"
return base_url
def ep2base(ep):
if "/v1" in ep:
base_url = ep
else:
base_url = ep+"/v1"
return base_url
def dedupe_on_keys(dicts, key_fields):
"""
@ -288,6 +289,19 @@ class rechunk:
else:
rechunk["message"] = {"role": chunk.choices[0].message.role, "content": chunk.choices[0].message.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None}
return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float):
rechunk = { "model": chunk.model,
"created_at": iso8601_ns(),
"load_duration": None,
"done_reason": chunk.choices[0].finish_reason,
"total_duration": None,
"eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None),
"thinking": None,
"context": None,
"response": chunk.choices[0].text
}
return rechunk
# ------------------------------------------------------------------
# SSE Helpser
@ -350,8 +364,8 @@ async def choose_endpoint(model: str) -> str:
6 If no endpoint advertises the model at all, raise an error.
"""
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if "/v1" not in ep]
tag_tasks += [fetch.available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
@ -369,7 +383,7 @@ async def choose_endpoint(model: str) -> str:
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints]
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
async with usage_lock:
@ -409,7 +423,6 @@ async def proxy(request: Request):
"""
Proxy a generate request to Ollama and stream the response back to the client.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
@ -439,29 +452,50 @@ async def proxy(request: Request):
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Decide which endpoint to use
endpoint = await choose_endpoint(model)
# Increment usage counter for this endpointmodel pair
await increment_usage(endpoint, model)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
params = {
"prompt": prompt,
"model": model,
}
# 3. Create Ollama client instance
client = ollama.AsyncClient(host=endpoint)
optional_params = {
"stream": stream,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
try:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
if is_openai_endpoint:
start_ts = time.perf_counter()
async_gen = await oclient.completions.create(**params)
else:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
if stream == True:
async for chunk in async_gen:
if is_openai_endpoint:
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = json.dumps(chunk)
yield json_line.encode("utf-8") + b"\n"
else:
if is_openai_endpoint:
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
response = json.dumps(response)
else:
response = async_gen.model_dump_json()
json_line = (
async_gen.model_dump_json()
response
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
)
@ -513,7 +547,8 @@ async def chat_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
if "/v1" in endpoint:
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
params = {
"messages": messages,
"model": model,
@ -530,7 +565,6 @@ async def chat_proxy(request: Request):
client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams chat data and decrements the counter
is_openai_endpoint = "/v1" in endpoint
async def stream_chat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
@ -560,7 +594,6 @@ async def chat_proxy(request: Request):
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
)
print(json_line)
yield json_line.encode("utf-8") + b"\n"
finally:
@ -941,7 +974,7 @@ async def version_proxy(request: Request):
"""
# 1. Query all endpoints for version
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions = await asyncio.gather(*tasks)
def version_key(v):
@ -964,8 +997,8 @@ async def tags_proxy(request: Request):
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -988,7 +1021,7 @@ async def ps_proxy(request: Request):
"""
# 1. Query all endpoints for running models
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
loaded_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -1300,8 +1333,8 @@ async def openai_models_proxy(request: Request):
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'data': []}
@ -1351,9 +1384,7 @@ async def health_proxy(request: Request):
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
# Run all health checks in parallel
tasks = [
fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints
]
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
results = await asyncio.gather(*tasks, return_exceptions=True)