diff --git a/router.py b/router.py index 99ea34b..e835deb 100644 --- a/router.py +++ b/router.py @@ -8,7 +8,7 @@ license: AGPL # ------------------------------------------------------------- import json, time, asyncio, yaml, httpx, ollama, openai from pathlib import Path -from typing import Dict, Set, List +from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException from fastapi.staticfiles import StaticFiles from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse @@ -86,7 +86,7 @@ def get_httpx_client(endpoint: str) -> httpx.AsyncClient: ) ) -async def fetch_available_models(endpoint: str) -> Set[str]: +async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ Query /api/tags and return a set of all model names that the endpoint *advertises* (i.e. is capable of serving). This endpoint lists @@ -96,6 +96,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]: If the request fails (e.g. timeout, 5xx, or malformed response), an empty set is returned. """ + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} + if endpoint in _models_cache: models, cached_at = _models_cache[endpoint] if _is_fresh(cached_at, 300): @@ -115,7 +119,7 @@ async def fetch_available_models(endpoint: str) -> Set[str]: client = get_httpx_client(endpoint) try: if "/v1" in endpoint: - resp = await client.get(f"/models") + resp = await client.get(f"/models", headers=headers) else: resp = await client.get(f"/api/tags") resp.raise_for_status() @@ -123,7 +127,7 @@ async def fetch_available_models(endpoint: str) -> Set[str]: # Expected format: # {"models": [{"name": "model1"}, {"name": "model2"}]} if "/v1" in endpoint: - models = {m.get("id") for m in data.get("data", []) if m.get("name")} + models = {m.get("id") for m in data.get("data", []) if m.get("id")} else: models = {m.get("name") for m in data.get("models", []) if m.get("name")} @@ -160,15 +164,14 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]: # If anything goes wrong we simply assume the endpoint has no models return set() -async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: str = None) -> List[dict]: +async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: """ Query / to fetch and return a List of dicts with details for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. """ + headers = None if api_key is not None: headers = {"Authorization": "Bearer " + api_key} - else: - headers = None client = get_httpx_client(endpoint) try: resp = await client.get(f"{route}", headers=headers) @@ -179,7 +182,7 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key except Exception as e: # If anything goes wrong we cannot reply details print(e) - return "N/A" + return [] def ep2base(ep): if "/v1" in ep: @@ -221,7 +224,7 @@ async def decrement_usage(endpoint: str, model: str) -> None: # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) # ------------------------------------------------------------- -async def choose_endpoint(model: str) -> str: +async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str: """ Determine which endpoint to use for the given model while respecting the `max_concurrent_connections` per endpoint‑model pair **and** @@ -240,7 +243,7 @@ async def choose_endpoint(model: str) -> str: 6️⃣ If no endpoint advertises the model at all, raise an error. """ # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - tag_tasks = [fetch_available_models(ep) for ep in config.endpoints] + tag_tasks = [fetch_available_models(ep, api_key) for ep in config.endpoints] advertised_sets = await asyncio.gather(*tag_tasks) # 2️⃣ Filter endpoints that advertise the requested model @@ -938,7 +941,7 @@ async def openai_embedding_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint = await choose_endpoint(model, api_key) await increment_usage(endpoint, model) oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key) @@ -976,6 +979,7 @@ async def openai_chat_completions_proxy(request: Request): temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") + max_completion_tokens = payload.get("max_completion_tokens") tools = payload.get("tools") headers = request.headers @@ -985,14 +989,9 @@ async def openai_chat_completions_proxy(request: Request): params = { "messages": messages, "model": model, - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, "seed": seed, "stop": stop, "stream": stream, - "temperature": temperature, - "top_p": top_p, - "max_tokens": max_tokens } if tools is not None: @@ -1001,6 +1000,18 @@ async def openai_chat_completions_proxy(request: Request): params["response_format"] = response_format if stream_options is not None: params["stream_options"] = stream_options + if max_completion_tokens is not None: + params["max_completion_tokens"] = max_completion_tokens + if max_tokens is not None: + params["max_tokens"] = max_tokens + if temperature is not None: + params["temperature"] = temperature + if top_p is not None: + params["top_p"] = top_p + if presence_penalty is not None: + params["presence_penalty"] = presence_penalty + if frequency_penalty is not None: + params["frequency_penalty"] = frequency_penalty if not model: raise HTTPException( @@ -1014,7 +1025,7 @@ async def openai_chat_completions_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint = await choose_endpoint(model, api_key) await increment_usage(endpoint, model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) @@ -1112,7 +1123,7 @@ async def openai_completions_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint = await choose_endpoint(model, api_key) await increment_usage(endpoint, model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) @@ -1171,8 +1182,10 @@ async def openai_models_proxy(request: Request): models = {'data': []} for modellist in all_models: for model in modellist: - if not id in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name + if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name model['id'] = model['name'] + else: + model['name'] = model['id'] models['data'] += modellist # 2. Return a JSONResponse with a deduplicated list of unique models for inference