Add files via upload

centralizing remote endpoint secrets management for unified endpoints
This commit is contained in:
Alpha Nerd 2025-09-03 18:01:39 +02:00 committed by GitHub
parent e7fd79c461
commit 2ead1112e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,11 +2,11 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
version: 0.2.1
version: 0.2.2
license: AGPL
"""
# -------------------------------------------------------------
import json, time, asyncio, yaml, httpx, ollama, openai
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
from pathlib import Path
from typing import Dict, Set, List, Optional
from fastapi import FastAPI, Request, HTTPException
@ -38,18 +38,35 @@ class Config(BaseSettings):
# Max concurrent connections per endpointmodel pair, see OLLAMA_NUM_PARALLEL
max_concurrent_connections: int = 1
api_keys: Dict[str, str] = Field(default_factory=dict)
class Config:
# Load from `config.yaml` first, then from env variables
env_prefix = "OLLAMA_PROXY_"
env_prefix = "NOMYO_ROUTER_"
yaml_file = Path("config.yaml") # relative to cwd
@classmethod
def _expand_env_refs(cls, obj):
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
if isinstance(obj, dict):
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
if isinstance(obj, list):
return [cls._expand_env_refs(v) for v in obj]
if isinstance(obj, str):
# Only expand if it is exactly ${VAR}
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
if m:
return os.getenv(m.group(1), "")
return obj
@classmethod
def from_yaml(cls, path: Path) -> "Config":
"""Load the YAML file and create the Config instance."""
if path.exists():
with path.open("r", encoding="utf-8") as fp:
data = yaml.safe_load(fp) or {}
return cls(**data)
cleaned = cls._expand_env_refs(data)
return cls(**cleaned)
return cls()
# Create the global config object it will be overwritten on startup
@ -224,7 +241,7 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
# -------------------------------------------------------------
async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str:
async def choose_endpoint(model: str) -> str:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
@ -243,7 +260,8 @@ async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str:
6 If no endpoint advertises the model at all, raise an error.
"""
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
tag_tasks = [fetch_available_models(ep, api_key) for ep in config.endpoints]
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
@ -808,9 +826,8 @@ async def version_proxy(request: Request):
"""
# 1. Query all endpoints for version
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions = await asyncio.gather(*tasks)
all_versions = [v for v in all_versions if v != "N/A"]
def version_key(v):
return tuple(map(int, v.split('.')))
@ -830,9 +847,10 @@ async def tags_proxy(request: Request):
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] #needs api_key TODO:add central mgmt
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -841,7 +859,7 @@ async def tags_proxy(request: Request):
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['digest','name'])},
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
status_code=200,
)
@ -893,7 +911,8 @@ async def config_proxy(request: Request):
try:
async with httpx.AsyncClient(timeout=1) as client:
if "/v1" in url:
r = await client.get(f"{url}/models")
headers = {"Authorization": "Bearer " + config.api_keys[url]}
r = await client.get(f"{url}/models", headers=headers)
else:
r = await client.get(f"{url}/api/version")
r.raise_for_status()
@ -925,9 +944,6 @@ async def openai_embedding_proxy(request: Request):
model = payload.get("model")
input = payload.get("input")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
if not model:
raise HTTPException(
@ -941,12 +957,16 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model, api_key)
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
if "/v1" in endpoint:
api_key = config.api_keys[endpoint]
else:
api_key = "ollama"
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input = [input], model=model)
async_gen = await oclient.embeddings.create(input=[input], model=model)
await decrement_usage(endpoint, model)
@ -981,10 +1001,6 @@ async def openai_chat_completions_proxy(request: Request):
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
params = {
"messages": messages,
@ -1025,10 +1041,10 @@ async def openai_chat_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model, api_key)
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
# 3. Async generator that streams completions data and decrements the counter
async def stream_ochat_response():
@ -1088,11 +1104,8 @@ async def openai_completions_proxy(request: Request):
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
params = {
"prompt": prompt,
@ -1123,10 +1136,10 @@ async def openai_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model, api_key)
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
# 3. Async generator that streams completions data and decrements the counter
async def stream_ocompletions_response():
@ -1170,13 +1183,9 @@ async def openai_models_proxy(request: Request):
Proxy a models request to Ollama endpoints and reply with a unique list of all models.
"""
headers = request.headers
api_key = headers.get("Authorization")
api_key = api_key.split()[1]
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", api_key) for ep in config.endpoints if "/v1" in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'data': []}