Add files via upload
final touches
This commit is contained in:
parent
40ef8ec0c2
commit
9fc0593d3a
1 changed files with 216 additions and 99 deletions
315
router.py
315
router.py
|
|
@ -2,14 +2,16 @@
|
|||
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
|
||||
author: alpha-nerd-nomyo
|
||||
author_url: https://github.com/nomyo-ai
|
||||
version: 0.1
|
||||
version: 0.2.2
|
||||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
|
||||
from httpx_aiohttp import AiohttpTransport
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List
|
||||
from typing import Dict, Set, List, Optional
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi_sse import sse_handler
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
|
||||
from pydantic import Field
|
||||
|
|
@ -19,12 +21,18 @@ from collections import defaultdict
|
|||
# ------------------------------------------------------------------
|
||||
# In‑memory caches
|
||||
# ------------------------------------------------------------------
|
||||
# Successful results are cached for 300 s
|
||||
# Successful results are cached for 300s
|
||||
_models_cache: dict[str, tuple[Set[str], float]] = {}
|
||||
# Transient errors are cached for 30 s – the key stays until the
|
||||
# Transient errors are cached for 1s – the key stays until the
|
||||
# timeout expires, after which the endpoint will be queried again.
|
||||
_error_cache: dict[str, float] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Queues
|
||||
# ------------------------------------------------------------------
|
||||
_subscribers: Set[asyncio.Queue] = set()
|
||||
_subscribers_lock = asyncio.Lock()
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 1. Configuration loader
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -38,18 +46,35 @@ class Config(BaseSettings):
|
|||
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
||||
max_concurrent_connections: int = 1
|
||||
|
||||
api_keys: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
# Load from `config.yaml` first, then from env variables
|
||||
env_prefix = "OLLAMA_PROXY_"
|
||||
env_prefix = "NOMYO_ROUTER_"
|
||||
yaml_file = Path("config.yaml") # relative to cwd
|
||||
|
||||
@classmethod
|
||||
def _expand_env_refs(cls, obj):
|
||||
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [cls._expand_env_refs(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
# Only expand if it is exactly ${VAR}
|
||||
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
|
||||
if m:
|
||||
return os.getenv(m.group(1), "")
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "Config":
|
||||
"""Load the YAML file and create the Config instance."""
|
||||
if path.exists():
|
||||
with path.open("r", encoding="utf-8") as fp:
|
||||
data = yaml.safe_load(fp) or {}
|
||||
return cls(**data)
|
||||
cleaned = cls._expand_env_refs(data)
|
||||
return cls(**cleaned)
|
||||
return cls()
|
||||
|
||||
# Create the global config object – it will be overwritten on startup
|
||||
|
|
@ -59,6 +84,7 @@ config = Config()
|
|||
# 2. FastAPI application
|
||||
# -------------------------------------------------------------
|
||||
app = FastAPI()
|
||||
sse_handler.app = app
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 3. Global state: per‑endpoint per‑model active connection counters
|
||||
|
|
@ -79,15 +105,15 @@ def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
|
|||
"""
|
||||
return httpx.AsyncClient(
|
||||
base_url=endpoint,
|
||||
timeout=httpx.Timeout(5.0, read=5.0, write=5.0, connect=5.0),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=64,
|
||||
max_connections=64
|
||||
)
|
||||
timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0),
|
||||
#limits=httpx.Limits(
|
||||
# max_keepalive_connections=64,
|
||||
# max_connections=64
|
||||
#),
|
||||
transport=AiohttpTransport()
|
||||
)
|
||||
|
||||
#@cached(cache=Cache.MEMORY, ttl=300)
|
||||
async def fetch_available_models(endpoint: str) -> Set[str]:
|
||||
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
|
|
@ -97,6 +123,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
if _is_fresh(cached_at, 300):
|
||||
|
|
@ -113,10 +143,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
# Error expired – remove it
|
||||
del _error_cache[endpoint]
|
||||
|
||||
client = get_httpx_client(endpoint)
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"/models")
|
||||
resp = await client.get(f"/models", headers=headers)
|
||||
else:
|
||||
resp = await client.get(f"/api/tags")
|
||||
resp.raise_for_status()
|
||||
|
|
@ -124,15 +154,15 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("name")}
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
|
||||
else:
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300 s
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
|
|
@ -140,6 +170,8 @@ async def fetch_available_models(endpoint: str) -> Set[str]:
|
|||
print(f"[fetch_available_models] {endpoint} error: {e}")
|
||||
_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
|
|
@ -148,8 +180,8 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
|||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
client = get_httpx_client(endpoint)
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
resp = await client.get(f"/api/ps")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
|
@ -160,15 +192,21 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
|||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]:
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
client = get_httpx_client(endpoint)
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
try:
|
||||
resp = await client.get(f"{route}")
|
||||
client = get_httpx_client(endpoint)
|
||||
resp = await client.get(f"{route}", headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
detail = data.get(detail, [])
|
||||
|
|
@ -176,7 +214,9 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List
|
|||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return {detail: []}
|
||||
return []
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
|
|
@ -202,6 +242,7 @@ def dedupe_on_keys(dicts, key_fields):
|
|||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
usage_counts[endpoint][model] += 1
|
||||
await publish_snapshot()
|
||||
|
||||
async def decrement_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
|
|
@ -212,8 +253,43 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# Optionally, clean up zero entries
|
||||
if usage_counts[endpoint].get(model, 0) == 0:
|
||||
usage_counts[endpoint].pop(model, None)
|
||||
if not usage_counts[endpoint]:
|
||||
usage_counts.pop(endpoint, None)
|
||||
#if not usage_counts[endpoint]:
|
||||
# usage_counts.pop(endpoint, None)
|
||||
await publish_snapshot()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
# ------------------------------------------------------------------
|
||||
async def publish_snapshot():
|
||||
snapshot = json.dumps({"usage_counts": usage_counts})
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
# If the queue is full, drop the message to avoid back‑pressure.
|
||||
if q.full():
|
||||
continue
|
||||
await q.put(snapshot)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscriber helpers
|
||||
# ------------------------------------------------------------------
|
||||
async def subscribe() -> asyncio.Queue:
|
||||
"""
|
||||
Returns a new Queue that will receive every snapshot.
|
||||
"""
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=10)
|
||||
async with _subscribers_lock:
|
||||
_subscribers.add(q)
|
||||
return q
|
||||
|
||||
async def unsubscribe(q: asyncio.Queue):
|
||||
async with _subscribers_lock:
|
||||
_subscribers.discard(q)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience wrapper – returns the current snapshot (for the proxy)
|
||||
# ------------------------------------------------------------------
|
||||
async def get_usage_counts() -> Dict:
|
||||
return dict(usage_counts) # shallow copy
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
|
|
@ -237,7 +313,8 @@ async def choose_endpoint(model: str) -> str:
|
|||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints]
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
|
||||
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
|
|
@ -595,16 +672,17 @@ async def create_proxy(request: Request):
|
|||
# 11. API route – Show
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/show")
|
||||
async def show_proxy(request: Request):
|
||||
async def show_proxy(request: Request, model: Optional[str] = None):
|
||||
"""
|
||||
Proxy a model show request to Ollama and reply with ShowResponse.
|
||||
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
if not model:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -615,7 +693,7 @@ async def show_proxy(request: Request):
|
|||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
#await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Proxy a simple show request
|
||||
|
|
@ -628,7 +706,7 @@ async def show_proxy(request: Request):
|
|||
# 12. API route – Copy
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/copy")
|
||||
async def copy_proxy(request: Request):
|
||||
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
|
||||
"""
|
||||
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
|
||||
|
||||
|
|
@ -636,10 +714,14 @@ async def copy_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
src = payload.get("source")
|
||||
dst = payload.get("destination")
|
||||
if not source and not destination:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
src = payload.get("source")
|
||||
dst = payload.get("destination")
|
||||
else:
|
||||
src = source
|
||||
dst = destination
|
||||
|
||||
if not src:
|
||||
raise HTTPException(
|
||||
|
|
@ -655,26 +737,20 @@ async def copy_proxy(request: Request):
|
|||
# 3. Iterate over all endpoints to copy the model on each endpoint
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 4. Proxy a simple copy request
|
||||
copy = await client.copy(source=src, destination=dst)
|
||||
status_list.append(copy.status)
|
||||
if "/v1" not in endpoint:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 4. Proxy a simple copy request
|
||||
copy = await client.copy(source=src, destination=dst)
|
||||
status_list.append(copy.status)
|
||||
|
||||
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
|
||||
if 404 in status_list:
|
||||
return Response(
|
||||
status_code=404
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200
|
||||
)
|
||||
return Response(status_code=404 if 404 in status_list else 200)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 13. API route – Delete
|
||||
# -------------------------------------------------------------
|
||||
@app.delete("/api/delete")
|
||||
async def delete_proxy(request: Request):
|
||||
async def delete_proxy(request: Request, model: Optional[str] = None):
|
||||
"""
|
||||
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
|
||||
|
||||
|
|
@ -682,9 +758,10 @@ async def delete_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
if not model:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -696,36 +773,33 @@ async def delete_proxy(request: Request):
|
|||
# 2. Iterate over all endpoints to delete the model on each endpoint
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple copy request
|
||||
copy = await client.delete(model=model)
|
||||
status_list.append(copy.status)
|
||||
if "/v1" not in endpoint:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple copy request
|
||||
copy = await client.delete(model=model)
|
||||
status_list.append(copy.status)
|
||||
|
||||
# 4. Retrun 200 0K, if a single enpoint fails, respond with 404
|
||||
if 404 in status_list:
|
||||
return Response(
|
||||
status_code=404
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200
|
||||
)
|
||||
return Response(status_code=404 if 404 in status_list else 200)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 14. API route – Pull
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/pull")
|
||||
async def pull_proxy(request: Request):
|
||||
async def pull_proxy(request: Request, model: Optional[str] = None):
|
||||
"""
|
||||
Proxy a pull request to all Ollama endpoint and report status back.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
insecure = payload.get("insecure")
|
||||
if not model:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
insecure = payload.get("insecure")
|
||||
else:
|
||||
insecure = None
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -737,10 +811,11 @@ async def pull_proxy(request: Request):
|
|||
# 2. Iterate over all endpoints to pull the model
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple pull request
|
||||
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
||||
status_list.append(pull)
|
||||
if "/v1" not in endpoint:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple pull request
|
||||
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
||||
status_list.append(pull)
|
||||
|
||||
combined_status = []
|
||||
for status in status_list:
|
||||
|
|
@ -802,9 +877,9 @@ async def version_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for version
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
all_versions = await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def version_key(v):
|
||||
return tuple(map(int, v.split('.')))
|
||||
|
||||
|
|
@ -823,9 +898,10 @@ async def tags_proxy(request: Request):
|
|||
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
||||
|
||||
"""
|
||||
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -834,7 +910,7 @@ async def tags_proxy(request: Request):
|
|||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
return JSONResponse(
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest','name'])},
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
|
@ -884,9 +960,10 @@ async def config_proxy(request: Request):
|
|||
"""
|
||||
async def check_endpoint(url: str):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1) as client:
|
||||
async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client:
|
||||
if "/v1" in url:
|
||||
r = await client.get(f"{url}/models")
|
||||
headers = {"Authorization": "Bearer " + config.api_keys[url]}
|
||||
r = await client.get(f"{url}/models", headers=headers)
|
||||
else:
|
||||
r = await client.get(f"{url}/api/version")
|
||||
r.raise_for_status()
|
||||
|
|
@ -897,6 +974,8 @@ async def config_proxy(request: Request):
|
|||
return {"url": url, "status": "ok", "version": data.get("version")}
|
||||
except Exception as exc:
|
||||
return {"url": url, "status": "error", "detail": str(exc)}
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
|
||||
return {"endpoints": results}
|
||||
|
|
@ -918,6 +997,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
|
|
@ -932,10 +1012,14 @@ async def openai_embedding_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||
if "/v1" in endpoint:
|
||||
api_key = config.api_keys[endpoint]
|
||||
else:
|
||||
api_key = "ollama"
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async_gen = await oclient.embeddings.create(input = [input], model=model)
|
||||
async_gen = await oclient.embeddings.create(input=[input], model=model)
|
||||
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
|
|
@ -968,23 +1052,14 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
tools = payload.get("tools")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
if tools is not None:
|
||||
|
|
@ -993,6 +1068,20 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
params["response_format"] = response_format
|
||||
if stream_options is not None:
|
||||
params["stream_options"] = stream_options
|
||||
if max_completion_tokens is not None:
|
||||
params["max_completion_tokens"] = max_completion_tokens
|
||||
if max_tokens is not None:
|
||||
params["max_tokens"] = max_tokens
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
params["top_p"] = top_p
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
if presence_penalty is not None:
|
||||
params["presence_penalty"] = presence_penalty
|
||||
if frequency_penalty is not None:
|
||||
params["frequency_penalty"] = frequency_penalty
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -1009,7 +1098,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ochat_response():
|
||||
|
|
@ -1069,11 +1158,8 @@ async def openai_completions_proxy(request: Request):
|
|||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
suffix = payload.get("suffix")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
|
|
@ -1107,7 +1193,7 @@ async def openai_completions_proxy(request: Request):
|
|||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ocompletions_response():
|
||||
|
|
@ -1148,16 +1234,21 @@ async def openai_completions_proxy(request: Request):
|
|||
@app.get("/v1/models")
|
||||
async def openai_models_proxy(request: Request):
|
||||
"""
|
||||
Proxy a models request to Ollama endpoints and reply with a unique list of all models.
|
||||
Proxy an OpenAI API models request to Ollama endpoints and reply with a unique list of all models.
|
||||
|
||||
"""
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'data': []}
|
||||
for modellist in all_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
||||
model['id'] = model['name']
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'] += modellist
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
|
|
@ -1178,8 +1269,8 @@ async def redirect_favicon():
|
|||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""
|
||||
Render the landing page that lists the configured endpoints
|
||||
and the models available / running.
|
||||
Render the dynamic NOMYO Router dashboard listing the configured endpoints
|
||||
and the models details, availability & task status.
|
||||
"""
|
||||
return HTMLResponse(content=open("static/index.html", "r").read(), status_code=200)
|
||||
|
||||
|
|
@ -1225,7 +1316,33 @@ async def health_proxy(request: Request):
|
|||
return JSONResponse(content=response_payload, status_code=http_status)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 27. FastAPI startup event – load configuration
|
||||
# 27. SSE route for usage broadcasts
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/usage-stream")
|
||||
async def usage_stream(request: Request):
|
||||
"""
|
||||
Server‑Sent‑Events that emits a JSON payload every time the
|
||||
global `usage_counts` dictionary changes.
|
||||
"""
|
||||
async def event_generator():
|
||||
# The queue that receives *every* new snapshot
|
||||
queue = await subscribe()
|
||||
try:
|
||||
while True:
|
||||
# If the client disconnects, cancel the loop
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
data = await queue.get()
|
||||
# Send the data as a single SSE message
|
||||
yield f"data: {data}\n\n"
|
||||
finally:
|
||||
# Clean‑up: unsubscribe from the broadcast channel
|
||||
await unsubscribe(queue)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 28. FastAPI startup event – load configuration
|
||||
# -------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue