diff --git a/router.py b/router.py index 5796972..99ea34b 100644 --- a/router.py +++ b/router.py @@ -160,14 +160,18 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]: # If anything goes wrong we simply assume the endpoint has no models return set() -async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]: +async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: str = None) -> List[dict]: """ Query / to fetch and return a List of dicts with details for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. """ + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} + else: + headers = None client = get_httpx_client(endpoint) try: - resp = await client.get(f"{route}") + resp = await client.get(f"{route}", headers=headers) resp.raise_for_status() data = resp.json() detail = data.get(detail, []) @@ -175,7 +179,7 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List except Exception as e: # If anything goes wrong we cannot reply details print(e) - return {detail: []} + return "N/A" def ep2base(ep): if "/v1" in ep: @@ -803,7 +807,8 @@ async def version_proxy(request: Request): # 1. Query all endpoints for version tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] all_versions = await asyncio.gather(*tasks) - + all_versions = [v for v in all_versions if v != "N/A"] + def version_key(v): return tuple(map(int, v.split('.'))) @@ -824,7 +829,7 @@ async def tags_proxy(request: Request): """ # 1. Query all endpoints for models tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] #needs api_key TODO:add central mgmt all_models = await asyncio.gather(*tasks) models = {'models': []} @@ -1154,9 +1159,13 @@ async def openai_models_proxy(request: Request): Proxy a models request to Ollama endpoints and reply with a unique list of all models. """ + headers = request.headers + api_key = headers.get("Authorization") + api_key = api_key.split()[1] + # 1. Query all endpoints for models tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data", api_key) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'data': []}