Additions

- Frontend
- Internal Monitoring Endpoints
- External OpenAI compatible backends
This commit is contained in:
Alpha Nerd 2025-08-30 00:12:56 +02:00 committed by GitHub
parent 1403c08a81
commit 9e0b53bba3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

173
router.py
View file

@ -10,7 +10,8 @@ import json, random, asyncio, yaml, httpx, ollama, openai
from pathlib import Path
from typing import Dict, Set, List
from fastapi import FastAPI, Request, HTTPException
from starlette.responses import StreamingResponse, JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse
from pydantic import Field
from pydantic_settings import BaseSettings
from collections import defaultdict
@ -71,12 +72,19 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
"""
try:
async with httpx.AsyncClient(timeout=1.0) as client:
resp = await client.get(f"{endpoint}/api/tags")
if "/v1" in endpoint:
resp = await client.get(f"{endpoint}/models")
else:
resp = await client.get(f"{endpoint}/api/tags")
resp.raise_for_status()
data = resp.json()
# Expected format:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
return {m.get("name") for m in data.get("models", []) if m.get("name")}
if "/v1" in endpoint:
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
else:
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception:
# Treat any error as if the endpoint offers no models
return set()
@ -116,6 +124,13 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List
# If anything goes wrong we cannot reply versions
return {detail: "N/A"}
def ep2base(ep):
if "/v1" in ep:
base_url = ep
else:
base_url = ep+"/v1"
return base_url
def dedupe_on_keys(dicts, key_fields):
"""
Helper function to deduplicate endpoint details based on given dict keys.
@ -412,7 +427,7 @@ async def embedding_proxy(request: Request):
)
# -------------------------------------------------------------
# 8. API route Embed
# 9. API route Embed
# -------------------------------------------------------------
@app.post("/api/embed")
async def embed_proxy(request: Request):
@ -468,7 +483,7 @@ async def embed_proxy(request: Request):
)
# -------------------------------------------------------------
# 9. API route Create
# 10. API route Create
# -------------------------------------------------------------
@app.post("/api/create")
async def create_proxy(request: Request):
@ -516,7 +531,7 @@ async def create_proxy(request: Request):
return dict(final_status)
# -------------------------------------------------------------
# 10. API route Show
# 11. API route Show
# -------------------------------------------------------------
@app.post("/api/show")
async def show_proxy(request: Request):
@ -549,7 +564,7 @@ async def show_proxy(request: Request):
return show
# -------------------------------------------------------------
# 11. API route Copy
# 12. API route Copy
# -------------------------------------------------------------
@app.post("/api/copy")
async def copy_proxy(request: Request):
@ -595,7 +610,7 @@ async def copy_proxy(request: Request):
)
# -------------------------------------------------------------
# 12. API route Delete
# 13. API route Delete
# -------------------------------------------------------------
@app.delete("/api/delete")
async def delete_proxy(request: Request):
@ -636,7 +651,7 @@ async def delete_proxy(request: Request):
)
# -------------------------------------------------------------
# 13. API route Pull
# 14. API route Pull
# -------------------------------------------------------------
@app.post("/api/pull")
async def pull_proxy(request: Request):
@ -676,7 +691,7 @@ async def pull_proxy(request: Request):
return dict(final_status)
# -------------------------------------------------------------
# 14. API route Push
# 15. API route Push
# -------------------------------------------------------------
@app.post("/api/push")
async def push_proxy(request: Request):
@ -717,7 +732,7 @@ async def push_proxy(request: Request):
# -------------------------------------------------------------
# 15. API route Version
# 16. API route Version
# -------------------------------------------------------------
@app.get("/api/version")
async def version_proxy(request: Request):
@ -739,7 +754,7 @@ async def version_proxy(request: Request):
)
# -------------------------------------------------------------
# 16. API route tags
# 17. API route tags
# -------------------------------------------------------------
@app.get("/api/tags")
async def tags_proxy(request: Request):
@ -748,7 +763,8 @@ async def tags_proxy(request: Request):
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints]
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -762,7 +778,7 @@ async def tags_proxy(request: Request):
)
# -------------------------------------------------------------
# 17. API route ps
# 18. API route ps
# -------------------------------------------------------------
@app.get("/api/ps")
async def ps_proxy(request: Request):
@ -771,7 +787,7 @@ async def ps_proxy(request: Request):
"""
# 1. Query all endpoints for running models
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints]
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
loaded_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -785,7 +801,30 @@ async def ps_proxy(request: Request):
)
# -------------------------------------------------------------
# 18. API route OpenAI compatible Embedding
# 19. Proxy usage route for monitoring
# -------------------------------------------------------------
@app.get("/api/usage")
async def usage_proxy(request: Request):
"""
Return a snapshot of the usage counter for each endpoint.
Useful for debugging / monitoring.
"""
return {"usage_counts": usage_counts}
# -------------------------------------------------------------
# 20. Proxy config route for monitoring and frontent usage
# -------------------------------------------------------------
@app.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
Ollama endpoints. The frontend uses this to display
which endpoints are being proxied.
"""
return {"endpoints": config.endpoints}
# -------------------------------------------------------------
# 21. API route OpenAI compatible Embedding
# -------------------------------------------------------------
@app.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
@ -826,7 +865,7 @@ async def openai_embedding_proxy(request: Request):
return async_gen
# -------------------------------------------------------------
# 19. API route OpenAI compatible Chat Completions
# 22. API route OpenAI compatible Chat Completions
# -------------------------------------------------------------
@app.post("/v1/chat/completions")
async def openai_chat_completions_proxy(request: Request):
@ -851,8 +890,31 @@ async def openai_chat_completions_proxy(request: Request):
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
tools =payload.get("tools")
tools = payload.get("tools")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
params = {
"messages": messages,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens
}
if tools is not None:
params["tools"] = tools
if response_format is not None:
params["response_format"] = response_format
if stream_options is not None:
params["stream_options"] = stream_options
if not model:
raise HTTPException(
@ -868,13 +930,14 @@ async def openai_chat_completions_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
# 3. Async generator that streams completions data and decrements the counter
async def stream_ochat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools)
async_gen = await oclient.chat.completions.create(**params)
if stream == True:
async for chunk in async_gen:
data = (
@ -904,7 +967,7 @@ async def openai_chat_completions_proxy(request: Request):
)
# -------------------------------------------------------------
# 20. API route OpenAI compatible Completions
# 23. API route OpenAI compatible Completions
# -------------------------------------------------------------
@app.post("/v1/completions")
async def openai_completions_proxy(request: Request):
@ -928,8 +991,28 @@ async def openai_completions_proxy(request: Request):
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
suffix =payload.get("suffix")
suffix = payload.get("suffix")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
params = {
"prompt": prompt,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"suffix": suffix
}
if stream_options is not None:
params["stream_options"] = stream_options
if not model:
raise HTTPException(
@ -945,13 +1028,14 @@ async def openai_completions_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
# 3. Async generator that streams completions data and decrements the counter
async def stream_ocompletions_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix)
async_gen = await oclient.completions.create(**params)
if stream == True:
async for chunk in async_gen:
data = (
@ -981,17 +1065,44 @@ async def openai_completions_proxy(request: Request):
)
# -------------------------------------------------------------
# 21. OpenAI API compatible endpoints #ToDo
# 24. OpenAI API compatible models endpoint
# -------------------------------------------------------------
@app.post("/v1/models")
async def not_implemented_yet(request: Request):
@app.get("/v1/models")
async def openai_models_proxy(request: Request):
"""
Proxy a models request to Ollama endpoints and reply with a unique list of all models.
return Response(
status_code=501
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'data': []}
for modellist in all_models:
models['data'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"data": dedupe_on_keys(models['data'], ['name'])},
status_code=200,
)
# -------------------------------------------------------------
# 22. FastAPI startup event load configuration
# 25. Serve the static frontend
# -------------------------------------------------------------
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""
Render the landing page that lists the configured endpoints
and the models available / running.
"""
return HTMLResponse(content=open("static/index.html", "r").read(), status_code=200)
# -------------------------------------------------------------
# 26. FastAPI startup event load configuration
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None: