Additions
- Frontend - Internal Monitoring Endpoints - External OpenAI compatible backends
This commit is contained in:
parent
1403c08a81
commit
9e0b53bba3
1 changed files with 142 additions and 31 deletions
173
router.py
173
router.py
|
|
@ -10,7 +10,8 @@ import json, random, asyncio, yaml, httpx, ollama, openai
|
|||
from pathlib import Path
|
||||
from typing import Dict, Set, List
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from starlette.responses import StreamingResponse, JSONResponse, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from collections import defaultdict
|
||||
|
|
@ -71,12 +72,19 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1.0) as client:
|
||||
resp = await client.get(f"{endpoint}/api/tags")
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"{endpoint}/models")
|
||||
else:
|
||||
resp = await client.get(f"{endpoint}/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
return {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
|
||||
else:
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
return set()
|
||||
|
|
@ -116,6 +124,13 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List
|
|||
# If anything goes wrong we cannot reply versions
|
||||
return {detail: "N/A"}
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
base_url = ep
|
||||
else:
|
||||
base_url = ep+"/v1"
|
||||
return base_url
|
||||
|
||||
def dedupe_on_keys(dicts, key_fields):
|
||||
"""
|
||||
Helper function to deduplicate endpoint details based on given dict keys.
|
||||
|
|
@ -412,7 +427,7 @@ async def embedding_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 8. API route – Embed
|
||||
# 9. API route – Embed
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/embed")
|
||||
async def embed_proxy(request: Request):
|
||||
|
|
@ -468,7 +483,7 @@ async def embed_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 9. API route – Create
|
||||
# 10. API route – Create
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/create")
|
||||
async def create_proxy(request: Request):
|
||||
|
|
@ -516,7 +531,7 @@ async def create_proxy(request: Request):
|
|||
return dict(final_status)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 10. API route – Show
|
||||
# 11. API route – Show
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/show")
|
||||
async def show_proxy(request: Request):
|
||||
|
|
@ -549,7 +564,7 @@ async def show_proxy(request: Request):
|
|||
return show
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 11. API route – Copy
|
||||
# 12. API route – Copy
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/copy")
|
||||
async def copy_proxy(request: Request):
|
||||
|
|
@ -595,7 +610,7 @@ async def copy_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 12. API route – Delete
|
||||
# 13. API route – Delete
|
||||
# -------------------------------------------------------------
|
||||
@app.delete("/api/delete")
|
||||
async def delete_proxy(request: Request):
|
||||
|
|
@ -636,7 +651,7 @@ async def delete_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 13. API route – Pull
|
||||
# 14. API route – Pull
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/pull")
|
||||
async def pull_proxy(request: Request):
|
||||
|
|
@ -676,7 +691,7 @@ async def pull_proxy(request: Request):
|
|||
return dict(final_status)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 14. API route – Push
|
||||
# 15. API route – Push
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/push")
|
||||
async def push_proxy(request: Request):
|
||||
|
|
@ -717,7 +732,7 @@ async def push_proxy(request: Request):
|
|||
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 15. API route – Version
|
||||
# 16. API route – Version
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/version")
|
||||
async def version_proxy(request: Request):
|
||||
|
|
@ -739,7 +754,7 @@ async def version_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 16. API route – tags
|
||||
# 17. API route – tags
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/tags")
|
||||
async def tags_proxy(request: Request):
|
||||
|
|
@ -748,7 +763,8 @@ async def tags_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints]
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -762,7 +778,7 @@ async def tags_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 17. API route – ps
|
||||
# 18. API route – ps
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/ps")
|
||||
async def ps_proxy(request: Request):
|
||||
|
|
@ -771,7 +787,7 @@ async def ps_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for running models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints]
|
||||
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
loaded_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -785,7 +801,30 @@ async def ps_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 18. API route – OpenAI compatible Embedding
|
||||
# 19. Proxy usage route – for monitoring
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/usage")
|
||||
async def usage_proxy(request: Request):
|
||||
"""
|
||||
Return a snapshot of the usage counter for each endpoint.
|
||||
Useful for debugging / monitoring.
|
||||
"""
|
||||
return {"usage_counts": usage_counts}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 20. Proxy config route – for monitoring and frontent usage
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/config")
|
||||
async def config_proxy(request: Request):
|
||||
"""
|
||||
Return a simple JSON object that contains the configured
|
||||
Ollama endpoints. The front‑end uses this to display
|
||||
which endpoints are being proxied.
|
||||
"""
|
||||
return {"endpoints": config.endpoints}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 21. API route – OpenAI compatible Embedding
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/embeddings")
|
||||
async def openai_embedding_proxy(request: Request):
|
||||
|
|
@ -826,7 +865,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
return async_gen
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 19. API route – OpenAI compatible Chat Completions
|
||||
# 22. API route – OpenAI compatible Chat Completions
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions_proxy(request: Request):
|
||||
|
|
@ -851,8 +890,31 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
tools =payload.get("tools")
|
||||
tools = payload.get("tools")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
if tools is not None:
|
||||
params["tools"] = tools
|
||||
if response_format is not None:
|
||||
params["response_format"] = response_format
|
||||
if stream_options is not None:
|
||||
params["stream_options"] = stream_options
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -868,13 +930,14 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ochat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools)
|
||||
async_gen = await oclient.chat.completions.create(**params)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
|
|
@ -904,7 +967,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 20. API route – OpenAI compatible Completions
|
||||
# 23. API route – OpenAI compatible Completions
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions_proxy(request: Request):
|
||||
|
|
@ -928,8 +991,28 @@ async def openai_completions_proxy(request: Request):
|
|||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
suffix =payload.get("suffix")
|
||||
suffix = payload.get("suffix")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"suffix": suffix
|
||||
}
|
||||
|
||||
if stream_options is not None:
|
||||
params["stream_options"] = stream_options
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -945,13 +1028,14 @@ async def openai_completions_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ocompletions_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix)
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
|
|
@ -981,17 +1065,44 @@ async def openai_completions_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 21. OpenAI API compatible endpoints #ToDo
|
||||
# 24. OpenAI API compatible models endpoint
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/models")
|
||||
async def not_implemented_yet(request: Request):
|
||||
@app.get("/v1/models")
|
||||
async def openai_models_proxy(request: Request):
|
||||
"""
|
||||
Proxy a models request to Ollama endpoints and reply with a unique list of all models.
|
||||
|
||||
return Response(
|
||||
status_code=501
|
||||
"""
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'data': []}
|
||||
for modellist in all_models:
|
||||
models['data'] += modellist
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
return JSONResponse(
|
||||
content={"data": dedupe_on_keys(models['data'], ['name'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 22. FastAPI startup event – load configuration
|
||||
# 25. Serve the static front‑end
|
||||
# -------------------------------------------------------------
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""
|
||||
Render the landing page that lists the configured endpoints
|
||||
and the models available / running.
|
||||
"""
|
||||
return HTMLResponse(content=open("static/index.html", "r").read(), status_code=200)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 26. FastAPI startup event – load configuration
|
||||
# -------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue