From 2ead1112e74b1da425f976535eb03854bb2bcf84 Mon Sep 17 00:00:00 2001 From: alpha-nerd-nomyo Date: Wed, 3 Sep 2025 18:01:39 +0200 Subject: [PATCH] Add files via upload centralizing remote endpoint secrets management for unified endpoints --- router.py | 75 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/router.py b/router.py index e835deb..594496e 100644 --- a/router.py +++ b/router.py @@ -2,11 +2,11 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.2.1 +version: 0.2.2 license: AGPL """ # ------------------------------------------------------------- -import json, time, asyncio, yaml, httpx, ollama, openai +import json, time, asyncio, yaml, httpx, ollama, openai, os, re from pathlib import Path from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException @@ -38,18 +38,35 @@ class Config(BaseSettings): # Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL max_concurrent_connections: int = 1 + api_keys: Dict[str, str] = Field(default_factory=dict) + class Config: # Load from `config.yaml` first, then from env variables - env_prefix = "OLLAMA_PROXY_" + env_prefix = "NOMYO_ROUTER_" yaml_file = Path("config.yaml") # relative to cwd + @classmethod + def _expand_env_refs(cls, obj): + """Recursively replace `${VAR}` with os.getenv('VAR').""" + if isinstance(obj, dict): + return {k: cls._expand_env_refs(v) for k, v in obj.items()} + if isinstance(obj, list): + return [cls._expand_env_refs(v) for v in obj] + if isinstance(obj, str): + # Only expand if it is exactly ${VAR} + m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj) + if m: + return os.getenv(m.group(1), "") + return obj + @classmethod def from_yaml(cls, path: Path) -> "Config": """Load the YAML file and create the Config instance.""" if path.exists(): with path.open("r", encoding="utf-8") as fp: data = yaml.safe_load(fp) or {} - return cls(**data) + cleaned = cls._expand_env_refs(data) + return cls(**cleaned) return cls() # Create the global config object – it will be overwritten on startup @@ -224,7 +241,7 @@ async def decrement_usage(endpoint: str, model: str) -> None: # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) # ------------------------------------------------------------- -async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str: +async def choose_endpoint(model: str) -> str: """ Determine which endpoint to use for the given model while respecting the `max_concurrent_connections` per endpoint‑model pair **and** @@ -243,7 +260,8 @@ async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str: 6️⃣ If no endpoint advertises the model at all, raise an error. """ # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - tag_tasks = [fetch_available_models(ep, api_key) for ep in config.endpoints] + tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep] + tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] advertised_sets = await asyncio.gather(*tag_tasks) # 2️⃣ Filter endpoints that advertise the requested model @@ -808,9 +826,8 @@ async def version_proxy(request: Request): """ # 1. Query all endpoints for version - tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] + tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] all_versions = await asyncio.gather(*tasks) - all_versions = [v for v in all_versions if v != "N/A"] def version_key(v): return tuple(map(int, v.split('.'))) @@ -830,9 +847,10 @@ async def tags_proxy(request: Request): Proxy a tags request to Ollama endpoints and reply with a unique list of all models. """ + # 1. Query all endpoints for models tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] #needs api_key TODO:add central mgmt + tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'models': []} @@ -841,7 +859,7 @@ async def tags_proxy(request: Request): # 2. Return a JSONResponse with a deduplicated list of unique models for inference return JSONResponse( - content={"models": dedupe_on_keys(models['models'], ['digest','name'])}, + content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])}, status_code=200, ) @@ -893,7 +911,8 @@ async def config_proxy(request: Request): try: async with httpx.AsyncClient(timeout=1) as client: if "/v1" in url: - r = await client.get(f"{url}/models") + headers = {"Authorization": "Bearer " + config.api_keys[url]} + r = await client.get(f"{url}/models", headers=headers) else: r = await client.get(f"{url}/api/version") r.raise_for_status() @@ -925,9 +944,6 @@ async def openai_embedding_proxy(request: Request): model = payload.get("model") input = payload.get("input") - headers = request.headers - api_key = headers.get("Authorization") - api_key = api_key.split()[1] if not model: raise HTTPException( @@ -941,12 +957,16 @@ async def openai_embedding_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model, api_key) + endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) + if "/v1" in endpoint: + api_key = config.api_keys[endpoint] + else: + api_key = "ollama" oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key) # 3. Async generator that streams embedding data and decrements the counter - async_gen = await oclient.embeddings.create(input = [input], model=model) + async_gen = await oclient.embeddings.create(input=[input], model=model) await decrement_usage(endpoint, model) @@ -981,10 +1001,6 @@ async def openai_chat_completions_proxy(request: Request): max_tokens = payload.get("max_tokens") max_completion_tokens = payload.get("max_completion_tokens") tools = payload.get("tools") - - headers = request.headers - api_key = headers.get("Authorization") - api_key = api_key.split()[1] params = { "messages": messages, @@ -1025,10 +1041,10 @@ async def openai_chat_completions_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model, api_key) + endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) base_url = ep2base(endpoint) - oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint]) # 3. Async generator that streams completions data and decrements the counter async def stream_ochat_response(): @@ -1088,11 +1104,8 @@ async def openai_completions_proxy(request: Request): temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") + max_completion_tokens = payload.get("max_completion_tokens") suffix = payload.get("suffix") - - headers = request.headers - api_key = headers.get("Authorization") - api_key = api_key.split()[1] params = { "prompt": prompt, @@ -1123,10 +1136,10 @@ async def openai_completions_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model, api_key) + endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) base_url = ep2base(endpoint) - oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint]) # 3. Async generator that streams completions data and decrements the counter async def stream_ocompletions_response(): @@ -1170,13 +1183,9 @@ async def openai_models_proxy(request: Request): Proxy a models request to Ollama endpoints and reply with a unique list of all models. """ - headers = request.headers - api_key = headers.get("Authorization") - api_key = api_key.split()[1] - # 1. Query all endpoints for models tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data", api_key) for ep in config.endpoints if "/v1" in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'data': []}