Add files via upload

preparations for /v1 endpoints with auth
This commit is contained in:
Alpha Nerd 2025-09-03 16:34:41 +02:00 committed by GitHub
parent 0a456e6e21
commit d257073cb1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,7 +8,7 @@ license: AGPL
# -------------------------------------------------------------
import json, time, asyncio, yaml, httpx, ollama, openai
from pathlib import Path
from typing import Dict, Set, List
from typing import Dict, Set, List, Optional
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
@ -86,7 +86,7 @@ def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
)
)
async def fetch_available_models(endpoint: str) -> Set[str]:
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
@ -96,6 +96,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
if _is_fresh(cached_at, 300):
@ -115,7 +119,7 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
client = get_httpx_client(endpoint)
try:
if "/v1" in endpoint:
resp = await client.get(f"/models")
resp = await client.get(f"/models", headers=headers)
else:
resp = await client.get(f"/api/tags")
resp.raise_for_status()
@ -123,7 +127,7 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
# Expected format:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
if "/v1" in endpoint:
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
else:
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
@ -160,15 +164,14 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
# If anything goes wrong we simply assume the endpoint has no models
return set()
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: str = None) -> List[dict]:
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
else:
headers = None
client = get_httpx_client(endpoint)
try:
resp = await client.get(f"{route}", headers=headers)
@ -179,7 +182,7 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key
except Exception as e:
# If anything goes wrong we cannot reply details
print(e)
return "N/A"
return []
def ep2base(ep):
if "/v1" in ep:
@ -221,7 +224,7 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
# -------------------------------------------------------------
async def choose_endpoint(model: str) -> str:
async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
@ -240,7 +243,7 @@ async def choose_endpoint(model: str) -> str:
6 If no endpoint advertises the model at all, raise an error.
"""
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints]
tag_tasks = [fetch_available_models(ep, api_key) for ep in config.endpoints]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
@ -938,7 +941,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint = await choose_endpoint(model, api_key)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
@ -976,6 +979,7 @@ async def openai_chat_completions_proxy(request: Request):
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
headers = request.headers
@ -985,14 +989,9 @@ async def openai_chat_completions_proxy(request: Request):
params = {
"messages": messages,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens
}
if tools is not None:
@ -1001,6 +1000,18 @@ async def openai_chat_completions_proxy(request: Request):
params["response_format"] = response_format
if stream_options is not None:
params["stream_options"] = stream_options
if max_completion_tokens is not None:
params["max_completion_tokens"] = max_completion_tokens
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if top_p is not None:
params["top_p"] = top_p
if presence_penalty is not None:
params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
params["frequency_penalty"] = frequency_penalty
if not model:
raise HTTPException(
@ -1014,7 +1025,7 @@ async def openai_chat_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint = await choose_endpoint(model, api_key)
await increment_usage(endpoint, model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
@ -1112,7 +1123,7 @@ async def openai_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint = await choose_endpoint(model, api_key)
await increment_usage(endpoint, model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
@ -1171,8 +1182,10 @@ async def openai_models_proxy(request: Request):
models = {'data': []}
for modellist in all_models:
for model in modellist:
if not id in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
model['id'] = model['name']
else:
model['name'] = model['id']
models['data'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference