Add files via upload
preparations for /v1 endpoints with auth
This commit is contained in:
parent
0a456e6e21
commit
d257073cb1
1 changed files with 32 additions and 19 deletions
51
router.py
51
router.py
|
|
@ -8,7 +8,7 @@ license: AGPL
|
|||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List
|
||||
from typing import Dict, Set, List, Optional
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
|
||||
|
|
@ -86,7 +86,7 @@ def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
|
|||
)
|
||||
)
|
||||
|
||||
async def fetch_available_models(endpoint: str) -> Set[str]:
|
||||
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
|
|
@ -96,6 +96,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
if _is_fresh(cached_at, 300):
|
||||
|
|
@ -115,7 +119,7 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
client = get_httpx_client(endpoint)
|
||||
try:
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"/models")
|
||||
resp = await client.get(f"/models", headers=headers)
|
||||
else:
|
||||
resp = await client.get(f"/api/tags")
|
||||
resp.raise_for_status()
|
||||
|
|
@ -123,7 +127,7 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
|
||||
else:
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
||||
|
|
@ -160,15 +164,14 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
|||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: str = None) -> List[dict]:
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
else:
|
||||
headers = None
|
||||
client = get_httpx_client(endpoint)
|
||||
try:
|
||||
resp = await client.get(f"{route}", headers=headers)
|
||||
|
|
@ -179,7 +182,7 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key
|
|||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return "N/A"
|
||||
return []
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
|
|
@ -221,7 +224,7 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
|
|
@ -240,7 +243,7 @@ async def choose_endpoint(model: str) -> str:
|
|||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints]
|
||||
tag_tasks = [fetch_available_models(ep, api_key) for ep in config.endpoints]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
|
|
@ -938,7 +941,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint = await choose_endpoint(model, api_key)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
|
||||
|
||||
|
|
@ -976,6 +979,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
tools = payload.get("tools")
|
||||
|
||||
headers = request.headers
|
||||
|
|
@ -985,14 +989,9 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
if tools is not None:
|
||||
|
|
@ -1001,6 +1000,18 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
params["response_format"] = response_format
|
||||
if stream_options is not None:
|
||||
params["stream_options"] = stream_options
|
||||
if max_completion_tokens is not None:
|
||||
params["max_completion_tokens"] = max_completion_tokens
|
||||
if max_tokens is not None:
|
||||
params["max_tokens"] = max_tokens
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
params["top_p"] = top_p
|
||||
if presence_penalty is not None:
|
||||
params["presence_penalty"] = presence_penalty
|
||||
if frequency_penalty is not None:
|
||||
params["frequency_penalty"] = frequency_penalty
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -1014,7 +1025,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint = await choose_endpoint(model, api_key)
|
||||
await increment_usage(endpoint, model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
|
@ -1112,7 +1123,7 @@ async def openai_completions_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint = await choose_endpoint(model, api_key)
|
||||
await increment_usage(endpoint, model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
|
@ -1171,8 +1182,10 @@ async def openai_models_proxy(request: Request):
|
|||
models = {'data': []}
|
||||
for modellist in all_models:
|
||||
for model in modellist:
|
||||
if not id in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
||||
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
||||
model['id'] = model['name']
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'] += modellist
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue