Add files via upload

final touches
This commit is contained in:
Alpha Nerd 2025-09-05 12:11:31 +02:00 committed by GitHub
parent 40ef8ec0c2
commit 9fc0593d3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

315
router.py
View file

@ -2,14 +2,16 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
version: 0.1
version: 0.2.2
license: AGPL
"""
# -------------------------------------------------------------
import json, time, asyncio, yaml, httpx, ollama, openai
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
from httpx_aiohttp import AiohttpTransport
from pathlib import Path
from typing import Dict, Set, List
from typing import Dict, Set, List, Optional
from fastapi import FastAPI, Request, HTTPException
from fastapi_sse import sse_handler
from fastapi.staticfiles import StaticFiles
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
from pydantic import Field
@ -19,12 +21,18 @@ from collections import defaultdict
# ------------------------------------------------------------------
# Inmemory caches
# ------------------------------------------------------------------
# Successful results are cached for 300s
# Successful results are cached for 300s
_models_cache: dict[str, tuple[Set[str], float]] = {}
# Transient errors are cached for 30s the key stays until the
# Transient errors are cached for 1s the key stays until the
# timeout expires, after which the endpoint will be queried again.
_error_cache: dict[str, float] = {}
# ------------------------------------------------------------------
# SSE Queues
# ------------------------------------------------------------------
_subscribers: Set[asyncio.Queue] = set()
_subscribers_lock = asyncio.Lock()
# -------------------------------------------------------------
# 1. Configuration loader
# -------------------------------------------------------------
@ -38,18 +46,35 @@ class Config(BaseSettings):
# Max concurrent connections per endpointmodel pair, see OLLAMA_NUM_PARALLEL
max_concurrent_connections: int = 1
api_keys: Dict[str, str] = Field(default_factory=dict)
class Config:
# Load from `config.yaml` first, then from env variables
env_prefix = "OLLAMA_PROXY_"
env_prefix = "NOMYO_ROUTER_"
yaml_file = Path("config.yaml") # relative to cwd
@classmethod
def _expand_env_refs(cls, obj):
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
if isinstance(obj, dict):
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
if isinstance(obj, list):
return [cls._expand_env_refs(v) for v in obj]
if isinstance(obj, str):
# Only expand if it is exactly ${VAR}
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
if m:
return os.getenv(m.group(1), "")
return obj
@classmethod
def from_yaml(cls, path: Path) -> "Config":
"""Load the YAML file and create the Config instance."""
if path.exists():
with path.open("r", encoding="utf-8") as fp:
data = yaml.safe_load(fp) or {}
return cls(**data)
cleaned = cls._expand_env_refs(data)
return cls(**cleaned)
return cls()
# Create the global config object it will be overwritten on startup
@ -59,6 +84,7 @@ config = Config()
# 2. FastAPI application
# -------------------------------------------------------------
app = FastAPI()
sse_handler.app = app
# -------------------------------------------------------------
# 3. Global state: perendpoint permodel active connection counters
@ -79,15 +105,15 @@ def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
"""
return httpx.AsyncClient(
base_url=endpoint,
timeout=httpx.Timeout(5.0, read=5.0, write=5.0, connect=5.0),
limits=httpx.Limits(
max_keepalive_connections=64,
max_connections=64
)
timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0),
#limits=httpx.Limits(
# max_keepalive_connections=64,
# max_connections=64
#),
transport=AiohttpTransport()
)
#@cached(cache=Cache.MEMORY, ttl=300)
async def fetch_available_models(endpoint: str) -> Set[str]:
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
@ -97,6 +123,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
if _is_fresh(cached_at, 300):
@ -113,10 +143,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
# Error expired remove it
del _error_cache[endpoint]
client = get_httpx_client(endpoint)
try:
client = get_httpx_client(endpoint)
if "/v1" in endpoint:
resp = await client.get(f"/models")
resp = await client.get(f"/models", headers=headers)
else:
resp = await client.get(f"/api/tags")
resp.raise_for_status()
@ -124,15 +154,15 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
# Expected format:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
if "/v1" in endpoint:
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
else:
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
if models:
_models_cache[endpoint] = (models, time.time())
return models
else:
# Empty list treat as “no models”, but still cache for 300s
# Empty list treat as “no models”, but still cache for 300s
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
@ -140,6 +170,8 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
print(f"[fetch_available_models] {endpoint} error: {e}")
_error_cache[endpoint] = time.time()
return set()
finally:
await client.aclose()
async def fetch_loaded_models(endpoint: str) -> Set[str]:
@ -148,8 +180,8 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
"""
client = get_httpx_client(endpoint)
try:
client = get_httpx_client(endpoint)
resp = await client.get(f"/api/ps")
resp.raise_for_status()
data = resp.json()
@ -160,15 +192,21 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
except Exception:
# If anything goes wrong we simply assume the endpoint has no models
return set()
finally:
await client.aclose()
async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]:
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
"""
client = get_httpx_client(endpoint)
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
try:
resp = await client.get(f"{route}")
client = get_httpx_client(endpoint)
resp = await client.get(f"{route}", headers=headers)
resp.raise_for_status()
data = resp.json()
detail = data.get(detail, [])
@ -176,7 +214,9 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List
except Exception as e:
# If anything goes wrong we cannot reply details
print(e)
return {detail: []}
return []
finally:
await client.aclose()
def ep2base(ep):
if "/v1" in ep:
@ -202,6 +242,7 @@ def dedupe_on_keys(dicts, key_fields):
async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock:
usage_counts[endpoint][model] += 1
await publish_snapshot()
async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock:
@ -212,8 +253,43 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# Optionally, clean up zero entries
if usage_counts[endpoint].get(model, 0) == 0:
usage_counts[endpoint].pop(model, None)
if not usage_counts[endpoint]:
usage_counts.pop(endpoint, None)
#if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
# ------------------------------------------------------------------
# SSE Helpser
# ------------------------------------------------------------------
async def publish_snapshot():
snapshot = json.dumps({"usage_counts": usage_counts})
async with _subscribers_lock:
for q in _subscribers:
# If the queue is full, drop the message to avoid backpressure.
if q.full():
continue
await q.put(snapshot)
# ------------------------------------------------------------------
# Subscriber helpers
# ------------------------------------------------------------------
async def subscribe() -> asyncio.Queue:
"""
Returns a new Queue that will receive every snapshot.
"""
q: asyncio.Queue = asyncio.Queue(maxsize=10)
async with _subscribers_lock:
_subscribers.add(q)
return q
async def unsubscribe(q: asyncio.Queue):
async with _subscribers_lock:
_subscribers.discard(q)
# ------------------------------------------------------------------
# Convenience wrapper returns the current snapshot (for the proxy)
# ------------------------------------------------------------------
async def get_usage_counts() -> Dict:
return dict(usage_counts) # shallow copy
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
@ -237,7 +313,8 @@ async def choose_endpoint(model: str) -> str:
6 If no endpoint advertises the model at all, raise an error.
"""
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints]
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
@ -595,16 +672,17 @@ async def create_proxy(request: Request):
# 11. API route Show
# -------------------------------------------------------------
@app.post("/api/show")
async def show_proxy(request: Request):
async def show_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a model show request to Ollama and reply with ShowResponse.
"""
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
@ -615,7 +693,7 @@ async def show_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
#await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple show request
@ -628,7 +706,7 @@ async def show_proxy(request: Request):
# 12. API route Copy
# -------------------------------------------------------------
@app.post("/api/copy")
async def copy_proxy(request: Request):
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
"""
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
@ -636,10 +714,14 @@ async def copy_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
src = payload.get("source")
dst = payload.get("destination")
if not source and not destination:
payload = json.loads(body_bytes.decode("utf-8"))
src = payload.get("source")
dst = payload.get("destination")
else:
src = source
dst = destination
if not src:
raise HTTPException(
@ -655,26 +737,20 @@ async def copy_proxy(request: Request):
# 3. Iterate over all endpoints to copy the model on each endpoint
status_list = []
for endpoint in config.endpoints:
client = ollama.AsyncClient(host=endpoint)
# 4. Proxy a simple copy request
copy = await client.copy(source=src, destination=dst)
status_list.append(copy.status)
if "/v1" not in endpoint:
client = ollama.AsyncClient(host=endpoint)
# 4. Proxy a simple copy request
copy = await client.copy(source=src, destination=dst)
status_list.append(copy.status)
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
if 404 in status_list:
return Response(
status_code=404
)
else:
return Response(
status_code=200
)
return Response(status_code=404 if 404 in status_list else 200)
# -------------------------------------------------------------
# 13. API route Delete
# -------------------------------------------------------------
@app.delete("/api/delete")
async def delete_proxy(request: Request):
async def delete_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
@ -682,9 +758,10 @@ async def delete_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
@ -696,36 +773,33 @@ async def delete_proxy(request: Request):
# 2. Iterate over all endpoints to delete the model on each endpoint
status_list = []
for endpoint in config.endpoints:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple copy request
copy = await client.delete(model=model)
status_list.append(copy.status)
if "/v1" not in endpoint:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple copy request
copy = await client.delete(model=model)
status_list.append(copy.status)
# 4. Retrun 200 0K, if a single enpoint fails, respond with 404
if 404 in status_list:
return Response(
status_code=404
)
else:
return Response(
status_code=200
)
return Response(status_code=404 if 404 in status_list else 200)
# -------------------------------------------------------------
# 14. API route Pull
# -------------------------------------------------------------
@app.post("/api/pull")
async def pull_proxy(request: Request):
async def pull_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a pull request to all Ollama endpoint and report status back.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
if not model:
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
else:
insecure = None
if not model:
raise HTTPException(
@ -737,10 +811,11 @@ async def pull_proxy(request: Request):
# 2. Iterate over all endpoints to pull the model
status_list = []
for endpoint in config.endpoints:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple pull request
pull = await client.pull(model=model, insecure=insecure, stream=False)
status_list.append(pull)
if "/v1" not in endpoint:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple pull request
pull = await client.pull(model=model, insecure=insecure, stream=False)
status_list.append(pull)
combined_status = []
for status in status_list:
@ -802,9 +877,9 @@ async def version_proxy(request: Request):
"""
# 1. Query all endpoints for version
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions = await asyncio.gather(*tasks)
def version_key(v):
return tuple(map(int, v.split('.')))
@ -823,9 +898,10 @@ async def tags_proxy(request: Request):
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -834,7 +910,7 @@ async def tags_proxy(request: Request):
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['digest','name'])},
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
status_code=200,
)
@ -884,9 +960,10 @@ async def config_proxy(request: Request):
"""
async def check_endpoint(url: str):
try:
async with httpx.AsyncClient(timeout=1) as client:
async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client:
if "/v1" in url:
r = await client.get(f"{url}/models")
headers = {"Authorization": "Bearer " + config.api_keys[url]}
r = await client.get(f"{url}/models", headers=headers)
else:
r = await client.get(f"{url}/api/version")
r.raise_for_status()
@ -897,6 +974,8 @@ async def config_proxy(request: Request):
return {"url": url, "status": "ok", "version": data.get("version")}
except Exception as exc:
return {"url": url, "status": "error", "detail": str(exc)}
finally:
await client.aclose()
results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
return {"endpoints": results}
@ -918,6 +997,7 @@ async def openai_embedding_proxy(request: Request):
model = payload.get("model")
input = payload.get("input")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
@ -932,10 +1012,14 @@ async def openai_embedding_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
if "/v1" in endpoint:
api_key = config.api_keys[endpoint]
else:
api_key = "ollama"
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input = [input], model=model)
async_gen = await oclient.embeddings.create(input=[input], model=model)
await decrement_usage(endpoint, model)
@ -968,23 +1052,14 @@ async def openai_chat_completions_proxy(request: Request):
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
params = {
"messages": messages,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens
}
if tools is not None:
@ -993,6 +1068,20 @@ async def openai_chat_completions_proxy(request: Request):
params["response_format"] = response_format
if stream_options is not None:
params["stream_options"] = stream_options
if max_completion_tokens is not None:
params["max_completion_tokens"] = max_completion_tokens
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if top_p is not None:
params["top_p"] = top_p
if seed is not None:
params["seed"] = seed
if presence_penalty is not None:
params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
params["frequency_penalty"] = frequency_penalty
if not model:
raise HTTPException(
@ -1009,7 +1098,7 @@ async def openai_chat_completions_proxy(request: Request):
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
# 3. Async generator that streams completions data and decrements the counter
async def stream_ochat_response():
@ -1069,11 +1158,8 @@ async def openai_completions_proxy(request: Request):
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
params = {
"prompt": prompt,
@ -1107,7 +1193,7 @@ async def openai_completions_proxy(request: Request):
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
# 3. Async generator that streams completions data and decrements the counter
async def stream_ocompletions_response():
@ -1148,16 +1234,21 @@ async def openai_completions_proxy(request: Request):
@app.get("/v1/models")
async def openai_models_proxy(request: Request):
"""
Proxy a models request to Ollama endpoints and reply with a unique list of all models.
Proxy an OpenAI API models request to Ollama endpoints and reply with a unique list of all models.
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'data': []}
for modellist in all_models:
for model in modellist:
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
model['id'] = model['name']
else:
model['name'] = model['id']
models['data'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
@ -1178,8 +1269,8 @@ async def redirect_favicon():
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""
Render the landing page that lists the configured endpoints
and the models available / running.
Render the dynamic NOMYO Router dashboard listing the configured endpoints
and the models details, availability & task status.
"""
return HTMLResponse(content=open("static/index.html", "r").read(), status_code=200)
@ -1225,7 +1316,33 @@ async def health_proxy(request: Request):
return JSONResponse(content=response_payload, status_code=http_status)
# -------------------------------------------------------------
# 27. FastAPI startup event load configuration
# 27. SSE route for usage broadcasts
# -------------------------------------------------------------
@app.get("/api/usage-stream")
async def usage_stream(request: Request):
"""
ServerSentEvents that emits a JSON payload every time the
global `usage_counts` dictionary changes.
"""
async def event_generator():
# The queue that receives *every* new snapshot
queue = await subscribe()
try:
while True:
# If the client disconnects, cancel the loop
if await request.is_disconnected():
break
data = await queue.get()
# Send the data as a single SSE message
yield f"data: {data}\n\n"
finally:
# Cleanup: unsubscribe from the broadcast channel
await unsubscribe(queue)
return StreamingResponse(event_generator(), media_type="text/event-stream")
# -------------------------------------------------------------
# 28. FastAPI startup event load configuration
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None: