Add files via upload
centralizing remote endpoint secrets management for unified endpoints
This commit is contained in:
parent
e7fd79c461
commit
2ead1112e7
1 changed files with 42 additions and 33 deletions
75
router.py
75
router.py
|
|
@ -2,11 +2,11 @@
|
|||
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
|
||||
author: alpha-nerd-nomyo
|
||||
author_url: https://github.com/nomyo-ai
|
||||
version: 0.2.1
|
||||
version: 0.2.2
|
||||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
|
|
@ -38,18 +38,35 @@ class Config(BaseSettings):
|
|||
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
||||
max_concurrent_connections: int = 1
|
||||
|
||||
api_keys: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
# Load from `config.yaml` first, then from env variables
|
||||
env_prefix = "OLLAMA_PROXY_"
|
||||
env_prefix = "NOMYO_ROUTER_"
|
||||
yaml_file = Path("config.yaml") # relative to cwd
|
||||
|
||||
@classmethod
|
||||
def _expand_env_refs(cls, obj):
|
||||
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [cls._expand_env_refs(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
# Only expand if it is exactly ${VAR}
|
||||
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
|
||||
if m:
|
||||
return os.getenv(m.group(1), "")
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "Config":
|
||||
"""Load the YAML file and create the Config instance."""
|
||||
if path.exists():
|
||||
with path.open("r", encoding="utf-8") as fp:
|
||||
data = yaml.safe_load(fp) or {}
|
||||
return cls(**data)
|
||||
cleaned = cls._expand_env_refs(data)
|
||||
return cls(**cleaned)
|
||||
return cls()
|
||||
|
||||
# Create the global config object – it will be overwritten on startup
|
||||
|
|
@ -224,7 +241,7 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str:
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
|
|
@ -243,7 +260,8 @@ async def choose_endpoint(model: str, api_key: Optional[str] = None) -> str:
|
|||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
tag_tasks = [fetch_available_models(ep, api_key) for ep in config.endpoints]
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
|
||||
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
|
|
@ -808,9 +826,8 @@ async def version_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for version
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
all_versions = await asyncio.gather(*tasks)
|
||||
all_versions = [v for v in all_versions if v != "N/A"]
|
||||
|
||||
def version_key(v):
|
||||
return tuple(map(int, v.split('.')))
|
||||
|
|
@ -830,9 +847,10 @@ async def tags_proxy(request: Request):
|
|||
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
||||
|
||||
"""
|
||||
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] #needs api_key TODO:add central mgmt
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -841,7 +859,7 @@ async def tags_proxy(request: Request):
|
|||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
return JSONResponse(
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest','name'])},
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
|
@ -893,7 +911,8 @@ async def config_proxy(request: Request):
|
|||
try:
|
||||
async with httpx.AsyncClient(timeout=1) as client:
|
||||
if "/v1" in url:
|
||||
r = await client.get(f"{url}/models")
|
||||
headers = {"Authorization": "Bearer " + config.api_keys[url]}
|
||||
r = await client.get(f"{url}/models", headers=headers)
|
||||
else:
|
||||
r = await client.get(f"{url}/api/version")
|
||||
r.raise_for_status()
|
||||
|
|
@ -925,9 +944,6 @@ async def openai_embedding_proxy(request: Request):
|
|||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -941,12 +957,16 @@ async def openai_embedding_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model, api_key)
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
if "/v1" in endpoint:
|
||||
api_key = config.api_keys[endpoint]
|
||||
else:
|
||||
api_key = "ollama"
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async_gen = await oclient.embeddings.create(input = [input], model=model)
|
||||
async_gen = await oclient.embeddings.create(input=[input], model=model)
|
||||
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
|
|
@ -981,10 +1001,6 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
tools = payload.get("tools")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
params = {
|
||||
"messages": messages,
|
||||
|
|
@ -1025,10 +1041,10 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model, api_key)
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ochat_response():
|
||||
|
|
@ -1088,11 +1104,8 @@ async def openai_completions_proxy(request: Request):
|
|||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
suffix = payload.get("suffix")
|
||||
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
|
|
@ -1123,10 +1136,10 @@ async def openai_completions_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model, api_key)
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=config.api_keys[endpoint])
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ocompletions_response():
|
||||
|
|
@ -1170,13 +1183,9 @@ async def openai_models_proxy(request: Request):
|
|||
Proxy a models request to Ollama endpoints and reply with a unique list of all models.
|
||||
|
||||
"""
|
||||
headers = request.headers
|
||||
api_key = headers.get("Authorization")
|
||||
api_key = api_key.split()[1]
|
||||
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", api_key) for ep in config.endpoints if "/v1" in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'data': []}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue