diff --git a/router.py b/router.py index 0a3c3d6..51cf51d 100644 --- a/router.py +++ b/router.py @@ -2,14 +2,16 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.1 +version: 0.2.2 license: AGPL """ # ------------------------------------------------------------- -import json, time, asyncio, yaml, httpx, ollama, openai +import json, time, asyncio, yaml, httpx, ollama, openai, os, re +from httpx_aiohttp import AiohttpTransport from pathlib import Path -from typing import Dict, Set, List +from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException +from fastapi_sse import sse_handler from fastapi.staticfiles import StaticFiles from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse from pydantic import Field @@ -19,12 +21,18 @@ from collections import defaultdict # ------------------------------------------------------------------ # In‑memory caches # ------------------------------------------------------------------ -# Successful results are cached for 300 s +# Successful results are cached for 300s _models_cache: dict[str, tuple[Set[str], float]] = {} -# Transient errors are cached for 30 s – the key stays until the +# Transient errors are cached for 1s – the key stays until the # timeout expires, after which the endpoint will be queried again. _error_cache: dict[str, float] = {} +# ------------------------------------------------------------------ +# SSE Queues +# ------------------------------------------------------------------ +_subscribers: Set[asyncio.Queue] = set() +_subscribers_lock = asyncio.Lock() + # ------------------------------------------------------------- # 1. Configuration loader # ------------------------------------------------------------- @@ -38,18 +46,35 @@ class Config(BaseSettings): # Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL max_concurrent_connections: int = 1 + api_keys: Dict[str, str] = Field(default_factory=dict) + class Config: # Load from `config.yaml` first, then from env variables - env_prefix = "OLLAMA_PROXY_" + env_prefix = "NOMYO_ROUTER_" yaml_file = Path("config.yaml") # relative to cwd + @classmethod + def _expand_env_refs(cls, obj): + """Recursively replace `${VAR}` with os.getenv('VAR').""" + if isinstance(obj, dict): + return {k: cls._expand_env_refs(v) for k, v in obj.items()} + if isinstance(obj, list): + return [cls._expand_env_refs(v) for v in obj] + if isinstance(obj, str): + # Only expand if it is exactly ${VAR} + m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj) + if m: + return os.getenv(m.group(1), "") + return obj + @classmethod def from_yaml(cls, path: Path) -> "Config": """Load the YAML file and create the Config instance.""" if path.exists(): with path.open("r", encoding="utf-8") as fp: data = yaml.safe_load(fp) or {} - return cls(**data) + cleaned = cls._expand_env_refs(data) + return cls(**cleaned) return cls() # Create the global config object – it will be overwritten on startup @@ -59,6 +84,7 @@ config = Config() # 2. FastAPI application # ------------------------------------------------------------- app = FastAPI() +sse_handler.app = app # ------------------------------------------------------------- # 3. Global state: per‑endpoint per‑model active connection counters @@ -79,15 +105,15 @@ def get_httpx_client(endpoint: str) -> httpx.AsyncClient: """ return httpx.AsyncClient( base_url=endpoint, - timeout=httpx.Timeout(5.0, read=5.0, write=5.0, connect=5.0), - limits=httpx.Limits( - max_keepalive_connections=64, - max_connections=64 - ) + timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0), + #limits=httpx.Limits( + # max_keepalive_connections=64, + # max_connections=64 + #), + transport=AiohttpTransport() ) -#@cached(cache=Cache.MEMORY, ttl=300) -async def fetch_available_models(endpoint: str) -> Set[str]: +async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ Query /api/tags and return a set of all model names that the endpoint *advertises* (i.e. is capable of serving). This endpoint lists @@ -97,6 +123,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]: If the request fails (e.g. timeout, 5xx, or malformed response), an empty set is returned. """ + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} + if endpoint in _models_cache: models, cached_at = _models_cache[endpoint] if _is_fresh(cached_at, 300): @@ -113,10 +143,10 @@ async def fetch_available_models(endpoint: str) -> Set[str]: # Error expired – remove it del _error_cache[endpoint] - client = get_httpx_client(endpoint) try: + client = get_httpx_client(endpoint) if "/v1" in endpoint: - resp = await client.get(f"/models") + resp = await client.get(f"/models", headers=headers) else: resp = await client.get(f"/api/tags") resp.raise_for_status() @@ -124,15 +154,15 @@ async def fetch_available_models(endpoint: str) -> Set[str]: # Expected format: # {"models": [{"name": "model1"}, {"name": "model2"}]} if "/v1" in endpoint: - models = {m.get("id") for m in data.get("data", []) if m.get("name")} + models = {m.get("id") for m in data.get("data", []) if m.get("id")} else: models = {m.get("name") for m in data.get("models", []) if m.get("name")} - + if models: _models_cache[endpoint] = (models, time.time()) return models else: - # Empty list – treat as “no models”, but still cache for 300 s + # Empty list – treat as “no models”, but still cache for 300s _models_cache[endpoint] = (models, time.time()) return models except Exception as e: @@ -140,6 +170,8 @@ async def fetch_available_models(endpoint: str) -> Set[str]: print(f"[fetch_available_models] {endpoint} error: {e}") _error_cache[endpoint] = time.time() return set() + finally: + await client.aclose() async def fetch_loaded_models(endpoint: str) -> Set[str]: @@ -148,8 +180,8 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]: loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty set is returned. """ - client = get_httpx_client(endpoint) try: + client = get_httpx_client(endpoint) resp = await client.get(f"/api/ps") resp.raise_for_status() data = resp.json() @@ -160,15 +192,21 @@ async def fetch_loaded_models(endpoint: str) -> Set[str]: except Exception: # If anything goes wrong we simply assume the endpoint has no models return set() + finally: + await client.aclose() -async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]: +async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: """ Query / to fetch and return a List of dicts with details for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. """ - client = get_httpx_client(endpoint) + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} + try: - resp = await client.get(f"{route}") + client = get_httpx_client(endpoint) + resp = await client.get(f"{route}", headers=headers) resp.raise_for_status() data = resp.json() detail = data.get(detail, []) @@ -176,7 +214,9 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List except Exception as e: # If anything goes wrong we cannot reply details print(e) - return {detail: []} + return [] + finally: + await client.aclose() def ep2base(ep): if "/v1" in ep: @@ -202,6 +242,7 @@ def dedupe_on_keys(dicts, key_fields): async def increment_usage(endpoint: str, model: str) -> None: async with usage_lock: usage_counts[endpoint][model] += 1 + await publish_snapshot() async def decrement_usage(endpoint: str, model: str) -> None: async with usage_lock: @@ -212,8 +253,43 @@ async def decrement_usage(endpoint: str, model: str) -> None: # Optionally, clean up zero entries if usage_counts[endpoint].get(model, 0) == 0: usage_counts[endpoint].pop(model, None) - if not usage_counts[endpoint]: - usage_counts.pop(endpoint, None) + #if not usage_counts[endpoint]: + # usage_counts.pop(endpoint, None) + await publish_snapshot() + +# ------------------------------------------------------------------ +# SSE Helpser +# ------------------------------------------------------------------ +async def publish_snapshot(): + snapshot = json.dumps({"usage_counts": usage_counts}) + async with _subscribers_lock: + for q in _subscribers: + # If the queue is full, drop the message to avoid back‑pressure. + if q.full(): + continue + await q.put(snapshot) + +# ------------------------------------------------------------------ +# Subscriber helpers +# ------------------------------------------------------------------ +async def subscribe() -> asyncio.Queue: + """ + Returns a new Queue that will receive every snapshot. + """ + q: asyncio.Queue = asyncio.Queue(maxsize=10) + async with _subscribers_lock: + _subscribers.add(q) + return q + +async def unsubscribe(q: asyncio.Queue): + async with _subscribers_lock: + _subscribers.discard(q) + +# ------------------------------------------------------------------ +# Convenience wrapper – returns the current snapshot (for the proxy) +# ------------------------------------------------------------------ +async def get_usage_counts() -> Dict: + return dict(usage_counts) # shallow copy # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) @@ -237,7 +313,8 @@ async def choose_endpoint(model: str) -> str: 6️⃣ If no endpoint advertises the model at all, raise an error. """ # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - tag_tasks = [fetch_available_models(ep) for ep in config.endpoints] + tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep] + tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] advertised_sets = await asyncio.gather(*tag_tasks) # 2️⃣ Filter endpoints that advertise the requested model @@ -595,16 +672,17 @@ async def create_proxy(request: Request): # 11. API route – Show # ------------------------------------------------------------- @app.post("/api/show") -async def show_proxy(request: Request): +async def show_proxy(request: Request, model: Optional[str] = None): """ Proxy a model show request to Ollama and reply with ShowResponse. """ try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) - model = payload.get("model") + if not model: + payload = json.loads(body_bytes.decode("utf-8")) + model = payload.get("model") if not model: raise HTTPException( @@ -615,7 +693,7 @@ async def show_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) + #await increment_usage(endpoint, model) client = ollama.AsyncClient(host=endpoint) # 3. Proxy a simple show request @@ -628,7 +706,7 @@ async def show_proxy(request: Request): # 12. API route – Copy # ------------------------------------------------------------- @app.post("/api/copy") -async def copy_proxy(request: Request): +async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None): """ Proxy a model copy request to each Ollama endpoint and reply with Status Code. @@ -636,10 +714,14 @@ async def copy_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) - src = payload.get("source") - dst = payload.get("destination") + if not source and not destination: + payload = json.loads(body_bytes.decode("utf-8")) + src = payload.get("source") + dst = payload.get("destination") + else: + src = source + dst = destination if not src: raise HTTPException( @@ -655,26 +737,20 @@ async def copy_proxy(request: Request): # 3. Iterate over all endpoints to copy the model on each endpoint status_list = [] for endpoint in config.endpoints: - client = ollama.AsyncClient(host=endpoint) - # 4. Proxy a simple copy request - copy = await client.copy(source=src, destination=dst) - status_list.append(copy.status) + if "/v1" not in endpoint: + client = ollama.AsyncClient(host=endpoint) + # 4. Proxy a simple copy request + copy = await client.copy(source=src, destination=dst) + status_list.append(copy.status) # 4. Return with 200 OK if all went well, 404 if a single endpoint failed - if 404 in status_list: - return Response( - status_code=404 - ) - else: - return Response( - status_code=200 - ) + return Response(status_code=404 if 404 in status_list else 200) # ------------------------------------------------------------- # 13. API route – Delete # ------------------------------------------------------------- @app.delete("/api/delete") -async def delete_proxy(request: Request): +async def delete_proxy(request: Request, model: Optional[str] = None): """ Proxy a model delete request to each Ollama endpoint and reply with Status Code. @@ -682,9 +758,10 @@ async def delete_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) - model = payload.get("model") + if not model: + payload = json.loads(body_bytes.decode("utf-8")) + model = payload.get("model") if not model: raise HTTPException( @@ -696,36 +773,33 @@ async def delete_proxy(request: Request): # 2. Iterate over all endpoints to delete the model on each endpoint status_list = [] for endpoint in config.endpoints: - client = ollama.AsyncClient(host=endpoint) - # 3. Proxy a simple copy request - copy = await client.delete(model=model) - status_list.append(copy.status) + if "/v1" not in endpoint: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple copy request + copy = await client.delete(model=model) + status_list.append(copy.status) # 4. Retrun 200 0K, if a single enpoint fails, respond with 404 - if 404 in status_list: - return Response( - status_code=404 - ) - else: - return Response( - status_code=200 - ) + return Response(status_code=404 if 404 in status_list else 200) # ------------------------------------------------------------- # 14. API route – Pull # ------------------------------------------------------------- @app.post("/api/pull") -async def pull_proxy(request: Request): +async def pull_proxy(request: Request, model: Optional[str] = None): """ Proxy a pull request to all Ollama endpoint and report status back. """ # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) - model = payload.get("model") - insecure = payload.get("insecure") + if not model: + payload = json.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + insecure = payload.get("insecure") + else: + insecure = None if not model: raise HTTPException( @@ -737,10 +811,11 @@ async def pull_proxy(request: Request): # 2. Iterate over all endpoints to pull the model status_list = [] for endpoint in config.endpoints: - client = ollama.AsyncClient(host=endpoint) - # 3. Proxy a simple pull request - pull = await client.pull(model=model, insecure=insecure, stream=False) - status_list.append(pull) + if "/v1" not in endpoint: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple pull request + pull = await client.pull(model=model, insecure=insecure, stream=False) + status_list.append(pull) combined_status = [] for status in status_list: @@ -802,9 +877,9 @@ async def version_proxy(request: Request): """ # 1. Query all endpoints for version - tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] + tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] all_versions = await asyncio.gather(*tasks) - + def version_key(v): return tuple(map(int, v.split('.'))) @@ -823,9 +898,10 @@ async def tags_proxy(request: Request): Proxy a tags request to Ollama endpoints and reply with a unique list of all models. """ + # 1. Query all endpoints for models tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'models': []} @@ -834,7 +910,7 @@ async def tags_proxy(request: Request): # 2. Return a JSONResponse with a deduplicated list of unique models for inference return JSONResponse( - content={"models": dedupe_on_keys(models['models'], ['digest','name'])}, + content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])}, status_code=200, ) @@ -884,9 +960,10 @@ async def config_proxy(request: Request): """ async def check_endpoint(url: str): try: - async with httpx.AsyncClient(timeout=1) as client: + async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client: if "/v1" in url: - r = await client.get(f"{url}/models") + headers = {"Authorization": "Bearer " + config.api_keys[url]} + r = await client.get(f"{url}/models", headers=headers) else: r = await client.get(f"{url}/api/version") r.raise_for_status() @@ -897,6 +974,8 @@ async def config_proxy(request: Request): return {"url": url, "status": "ok", "version": data.get("version")} except Exception as exc: return {"url": url, "status": "error", "detail": str(exc)} + finally: + await client.aclose() results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints]) return {"endpoints": results} @@ -918,6 +997,7 @@ async def openai_embedding_proxy(request: Request): model = payload.get("model") input = payload.get("input") + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -932,10 +1012,14 @@ async def openai_embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama") + if "/v1" in endpoint: + api_key = config.api_keys[endpoint] + else: + api_key = "ollama" + oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key) # 3. Async generator that streams embedding data and decrements the counter - async_gen = await oclient.embeddings.create(input = [input], model=model) + async_gen = await oclient.embeddings.create(input=[input], model=model) await decrement_usage(endpoint, model) @@ -968,23 +1052,14 @@ async def openai_chat_completions_proxy(request: Request): temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") + max_completion_tokens = payload.get("max_completion_tokens") tools = payload.get("tools") - - headers = request.headers - api_key = headers.get("Authorization") - api_key = api_key.split()[1] params = { "messages": messages, "model": model, - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "seed": seed, "stop": stop, "stream": stream, - "temperature": temperature, - "top_p": top_p, - "max_tokens": max_tokens } if tools is not None: @@ -993,6 +1068,20 @@ async def openai_chat_completions_proxy(request: Request): params["response_format"] = response_format if stream_options is not None: params["stream_options"] = stream_options + if max_completion_tokens is not None: + params["max_completion_tokens"] = max_completion_tokens + if max_tokens is not None: + params["max_tokens"] = max_tokens + if temperature is not None: + params["temperature"] = temperature + if top_p is not None: + params["top_p"] = top_p + if seed is not None: + params["seed"] = seed + if presence_penalty is not None: + params["presence_penalty"] = presence_penalty + if frequency_penalty is not None: + params["frequency_penalty"] = frequency_penalty if not model: raise HTTPException( @@ -1009,7 +1098,7 @@ async def openai_chat_completions_proxy(request: Request): endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) base_url = ep2base(endpoint) - oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint]) # 3. Async generator that streams completions data and decrements the counter async def stream_ochat_response(): @@ -1069,11 +1158,8 @@ async def openai_completions_proxy(request: Request): temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") + max_completion_tokens = payload.get("max_completion_tokens") suffix = payload.get("suffix") - - headers = request.headers - api_key = headers.get("Authorization") - api_key = api_key.split()[1] params = { "prompt": prompt, @@ -1107,7 +1193,7 @@ async def openai_completions_proxy(request: Request): endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) base_url = ep2base(endpoint) - oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint]) # 3. Async generator that streams completions data and decrements the counter async def stream_ocompletions_response(): @@ -1148,16 +1234,21 @@ async def openai_completions_proxy(request: Request): @app.get("/v1/models") async def openai_models_proxy(request: Request): """ - Proxy a models request to Ollama endpoints and reply with a unique list of all models. + Proxy an OpenAI API models request to Ollama endpoints and reply with a unique list of all models. """ # 1. Query all endpoints for models tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'data': []} for modellist in all_models: + for model in modellist: + if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name + model['id'] = model['name'] + else: + model['name'] = model['id'] models['data'] += modellist # 2. Return a JSONResponse with a deduplicated list of unique models for inference @@ -1178,8 +1269,8 @@ async def redirect_favicon(): @app.get("/", response_class=HTMLResponse) async def index(request: Request): """ - Render the landing page that lists the configured endpoints - and the models available / running. + Render the dynamic NOMYO Router dashboard listing the configured endpoints + and the models details, availability & task status. """ return HTMLResponse(content=open("static/index.html", "r").read(), status_code=200) @@ -1225,7 +1316,33 @@ async def health_proxy(request: Request): return JSONResponse(content=response_payload, status_code=http_status) # ------------------------------------------------------------- -# 27. FastAPI startup event – load configuration +# 27. SSE route for usage broadcasts +# ------------------------------------------------------------- +@app.get("/api/usage-stream") +async def usage_stream(request: Request): + """ + Server‑Sent‑Events that emits a JSON payload every time the + global `usage_counts` dictionary changes. + """ + async def event_generator(): + # The queue that receives *every* new snapshot + queue = await subscribe() + try: + while True: + # If the client disconnects, cancel the loop + if await request.is_disconnected(): + break + data = await queue.get() + # Send the data as a single SSE message + yield f"data: {data}\n\n" + finally: + # Clean‑up: unsubscribe from the broadcast channel + await unsubscribe(queue) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + +# ------------------------------------------------------------- +# 28. FastAPI startup event – load configuration # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: