Compare commits

...
Sign in to create a new pull request.

5 commits

4 changed files with 182 additions and 41 deletions

3
.gitignore vendored
View file

@ -67,3 +67,6 @@ config.yaml
*.db* *.db*
*settings.json *settings.json
# Test suite (local only, not committed yet)
test/

View file

@ -1,6 +1,6 @@
# NOMYO Router # NOMYO Router
is a transparent proxy for [Ollama](https://github.com/ollama/ollama) with model deployment aware routing. is a transparent proxy for inference engines, i.e. [Ollama](https://github.com/ollama/ollama), [llama.cpp](https://github.com/ggml-org/llama.cpp/), [vllm](https://github.com/vllm-project/vllm) or any OpenAI V1 compatible endpoint with model deployment aware routing.
[![Click for video](https://github.com/user-attachments/assets/ddacdf88-e3f3-41dd-8be6-f165b22d9879)](https://eu1.nomyo.ai/assets/dash.mp4) [![Click for video](https://github.com/user-attachments/assets/ddacdf88-e3f3-41dd-8be6-f165b22d9879)](https://eu1.nomyo.ai/assets/dash.mp4)

View file

@ -9,8 +9,23 @@ llama_server_endpoints:
- http://192.168.0.50:8889/v1 - http://192.168.0.50:8889/v1
# Maximum concurrent connections *per endpointmodel pair* (equals to OLLAMA_NUM_PARALLEL) # Maximum concurrent connections *per endpointmodel pair* (equals to OLLAMA_NUM_PARALLEL)
# This is the global default; individual endpoints can override it via endpoint_config below.
max_concurrent_connections: 2 max_concurrent_connections: 2
# Per-endpoint overrides (optional). Any field not listed falls back to the global default.
# endpoint_config:
# "http://192.168.0.50:11434":
# max_concurrent_connections: 3
# "http://192.168.0.51:11434":
# max_concurrent_connections: 1
# Priority / WRR routing (optional, default: false).
# When true, requests are routed by utilization ratio (usage/max_concurrent_connections)
# and the config order of endpoints acts as the tiebreaker — the first endpoint listed
# is preferred when two endpoints are equally loaded.
# When false (default), equally-idle endpoints are chosen at random.
# priority_routing: true
# Optional router-level API key that gates router/API/web UI access (leave empty to disable) # Optional router-level API key that gates router/API/web UI access (leave empty to disable)
nomyo-router-api-key: "" nomyo-router-api-key: ""

167
router.py
View file

@ -6,7 +6,7 @@ version: 0.7
license: AGPL license: AGPL
""" """
# ------------------------------------------------------------- # -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx
try: try:
import truststore; truststore.inject_into_ssl() import truststore; truststore.inject_into_ssl()
except ImportError: except ImportError:
@ -185,6 +185,8 @@ _CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or b
app_state = { app_state = {
"session": None, "session": None,
"connector": None, "connector": None,
"socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints
"httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints
} }
token_worker_task: asyncio.Task | None = None token_worker_task: asyncio.Task | None = None
flush_task: asyncio.Task | None = None flush_task: asyncio.Task | None = None
@ -216,6 +218,10 @@ class Config(BaseSettings):
llama_server_endpoints: List[str] = Field(default_factory=list) llama_server_endpoints: List[str] = Field(default_factory=list)
# Max concurrent connections per endpointmodel pair, see OLLAMA_NUM_PARALLEL # Max concurrent connections per endpointmodel pair, see OLLAMA_NUM_PARALLEL
max_concurrent_connections: int = 1 max_concurrent_connections: int = 1
# Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}}
endpoint_config: Dict[str, Dict] = Field(default_factory=dict)
# When True, config order = priority; routes by utilization ratio + config index (WRR)
priority_routing: bool = Field(default=False)
api_keys: Dict[str, str] = Field(default_factory=dict) api_keys: Dict[str, str] = Field(default_factory=dict)
# Optional router-level API key used to gate access to this service and dashboard # Optional router-level API key used to gate access to this service and dashboard
@ -494,6 +500,65 @@ def _extract_llama_quant(name: str) -> str:
return name.rsplit(":", 1)[1] return name.rsplit(":", 1)[1]
return "" return ""
def _is_unix_socket_endpoint(endpoint: str) -> bool:
"""Return True if endpoint uses Unix socket (.sock hostname convention).
Detects URLs like http://192.168.0.52.sock/v1 where the host ends with
.sock, indicating the connection should use a Unix domain socket at
/tmp/<host> instead of TCP.
"""
try:
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
return host.endswith(".sock")
except IndexError:
return False
def _get_socket_path(endpoint: str) -> str:
"""Derive Unix socket file path from a .sock endpoint URL.
http://192.168.0.52.sock/v1 -> /run/user/<uid>/192.168.0.52.sock
"""
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
return f"/run/user/{os.getuid()}/{host}"
def get_session(endpoint: str) -> aiohttp.ClientSession:
"""Return the appropriate aiohttp session for the given endpoint.
Unix socket endpoints (.sock) get their own UnixConnector session.
All other endpoints share the main TCP session.
"""
if _is_unix_socket_endpoint(endpoint):
sess = app_state["socket_sessions"].get(endpoint)
if sess is not None:
return sess
return app_state["session"]
def _make_openai_client(
endpoint: str,
default_headers: dict | None = None,
api_key: str = "no-key",
) -> openai.AsyncOpenAI:
"""Return an AsyncOpenAI client configured for the given endpoint.
For Unix socket endpoints, injects a pre-created httpx UDS transport
so the OpenAI SDK connects via the socket instead of TCP.
"""
base_url = ep2base(endpoint)
kwargs: dict = {"api_key": api_key}
if default_headers is not None:
kwargs["default_headers"] = default_headers
if _is_unix_socket_endpoint(endpoint):
http_client = app_state["httpx_clients"].get(endpoint)
if http_client is not None:
kwargs["http_client"] = http_client
base_url = "http://localhost/v1"
return openai.AsyncOpenAI(base_url=base_url, **kwargs)
def _is_llama_model_loaded(item: dict) -> bool: def _is_llama_model_loaded(item: dict) -> bool:
"""Return True if a llama-server /v1/models item has status 'loaded'. """Return True if a llama-server /v1/models item has status 'loaded'.
Handles both dict format ({"value": "loaded"}) and plain string ("loaded"). Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
@ -733,7 +798,7 @@ class fetch:
endpoint_url = f"{endpoint}/api/tags" endpoint_url = f"{endpoint}/api/tags"
key = "models" key = "models"
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = get_session(endpoint)
try: try:
async with client.get(endpoint_url, headers=headers) as resp: async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp) await _ensure_success(resp)
@ -854,7 +919,7 @@ class fetch:
For Ollama endpoints: queries /api/ps and returns model names For Ollama endpoints: queries /api/ps and returns model names
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
""" """
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = get_session(endpoint)
# Check if this is a llama-server endpoint # Check if this is a llama-server endpoint
if endpoint in config.llama_server_endpoints: if endpoint in config.llama_server_endpoints:
@ -997,7 +1062,7 @@ class fetch:
if _is_fresh(_available_error_cache[endpoint], 300): if _is_fresh(_available_error_cache[endpoint], 300):
return [] return []
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = get_session(endpoint)
headers = None headers = None
if api_key is not None: if api_key is not None:
headers = {"Authorization": "Bearer " + api_key} headers = {"Authorization": "Bearer " + api_key}
@ -1089,7 +1154,7 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None "response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
} }
params.update({k: v for k, v in optional_params.items() if v is not None}) params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
@ -1636,6 +1701,12 @@ async def get_usage_counts() -> Dict:
# ------------------------------------------------------------- # -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit) # 5. Endpoint selection logic (respecting the configurable limit)
# ------------------------------------------------------------- # -------------------------------------------------------------
def get_max_connections(ep: str) -> int:
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
return config.endpoint_config.get(ep, {}).get(
"max_concurrent_connections", config.max_concurrent_connections
)
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
""" """
Determine which endpoint to use for the given model while respecting Determine which endpoint to use for the given model while respecting
@ -1706,13 +1777,26 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
def tracking_usage(ep: str) -> int: def tracking_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0) return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
def utilization_ratio(ep: str) -> float:
return tracking_usage(ep) / get_max_connections(ep)
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
# 3⃣ Endpoints that have the model loaded *and* a free slot # 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [ loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets) ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < config.max_concurrent_connections if model in models and tracking_usage(ep) < get_max_connections(ep)
] ]
if loaded_and_free: if loaded_and_free:
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
else:
# Sort ascending for load balancing — all endpoints here already have the # Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for. # model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage) loaded_and_free.sort(key=tracking_usage)
@ -1726,10 +1810,15 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
# 4⃣ Endpoints among the candidates that simply have a free slot # 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [ endpoints_with_free_slot = [
ep for ep in candidate_endpoints ep for ep in candidate_endpoints
if tracking_usage(ep) < config.max_concurrent_connections if tracking_usage(ep) < get_max_connections(ep)
] ]
if endpoints_with_free_slot: if endpoints_with_free_slot:
if config.priority_routing:
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Sort by total endpoint load (ascending) to prefer idle endpoints. # Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort( endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values()) key=lambda ep: sum(usage_counts.get(ep, {}).values())
@ -1740,6 +1829,12 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
selected = endpoints_with_free_slot[0] selected = endpoints_with_free_slot[0]
else: else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue) # 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
selected = min(candidate_endpoints, key=tracking_usage) selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model) tracking_model = get_tracking_model(selected, model)
@ -1822,7 +1917,7 @@ async def proxy(request: Request):
"suffix": suffix, "suffix": suffix,
} }
params.update({k: v for k, v in optional_params.items() if v is not None}) params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
@ -2000,7 +2095,7 @@ async def chat_proxy(request: Request):
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
} }
params.update({k: v for k, v in optional_params.items() if v is not None}) params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
# For OpenAI endpoints: make the API call in handler scope # For OpenAI endpoints: make the API call in handler scope
@ -2220,7 +2315,7 @@ async def embedding_proxy(request: Request):
if ":latest" in model: if ":latest" in model:
model = model.split(":latest") model = model.split(":latest")
model = model[0] model = model[0]
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams embedding data and decrements the counter # 3. Async generator that streams embedding data and decrements the counter
@ -2285,7 +2380,7 @@ async def embed_proxy(request: Request):
if ":latest" in model: if ":latest" in model:
model = model.split(":latest") model = model.split(":latest")
model = model[0] model = model[0]
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams embed data and decrements the counter # 3. Async generator that streams embed data and decrements the counter
@ -2834,7 +2929,7 @@ async def ps_details_proxy(request: Request):
# Fetch /props for each llama-server model to get context length (n_ctx) # Fetch /props for each llama-server model to get context length (n_ctx)
# and unload sleeping models automatically # and unload sleeping models automatically
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]: async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = get_session(endpoint)
base_url = endpoint.rstrip("/").removesuffix("/v1") base_url = endpoint.rstrip("/").removesuffix("/v1")
props_url = f"{base_url}/props?model={model_id}" props_url = f"{base_url}/props?model={model_id}"
headers = None headers = None
@ -2907,7 +3002,7 @@ async def config_proxy(request: Request):
which endpoints are being proxied. which endpoints are being proxied.
""" """
async def check_endpoint(url: str): async def check_endpoint(url: str):
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = get_session(url)
headers = None headers = None
if "/v1" in url: if "/v1" in url:
headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")} headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
@ -2990,9 +3085,7 @@ async def openai_embedding_proxy(request: Request):
api_key = config.api_keys.get(endpoint, "no-key") api_key = config.api_keys.get(endpoint, "no-key")
else: else:
api_key = "ollama" api_key = "ollama"
base_url = ep2base(endpoint) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
try: try:
async_gen = await oclient.embeddings.create(input=doc, model=model) async_gen = await oclient.embeddings.create(input=doc, model=model)
@ -3105,8 +3198,7 @@ async def openai_chat_completions_proxy(request: Request):
# 2. Endpoint logic # 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably # 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list: async def _normalize_images_in_messages(msgs: list) -> list:
"""Fetch remote image URLs and convert them to base64 data URLs so """Fetch remote image URLs and convert them to base64 data URLs so
@ -3416,8 +3508,7 @@ async def openai_completions_proxy(request: Request):
# 2. Endpoint logic # 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter # 3. Async generator that streams completions data and decrements the counter
# Make the API call in handler scope (try/except inside async generators is unreliable) # Make the API call in handler scope (try/except inside async generators is unreliable)
@ -3672,7 +3763,7 @@ async def rerank_proxy(request: Request):
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
} }
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = get_session(endpoint)
try: try:
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
response_bytes = await resp.read() response_bytes = await resp.read()
@ -3845,7 +3936,9 @@ async def startup_event() -> None:
f"Loaded configuration from {config_path}:\n" f"Loaded configuration from {config_path}:\n"
f" endpoints={config.endpoints},\n" f" endpoints={config.endpoints},\n"
f" llama_server_endpoints={config.llama_server_endpoints},\n" f" llama_server_endpoints={config.llama_server_endpoints},\n"
f" max_concurrent_connections={config.max_concurrent_connections}" f" max_concurrent_connections={config.max_concurrent_connections},\n"
f" endpoint_config={config.endpoint_config},\n"
f" priority_routing={config.priority_routing}"
) )
else: else:
print( print(
@ -3874,6 +3967,19 @@ async def startup_event() -> None:
app_state["connector"] = connector app_state["connector"] = connector
app_state["session"] = session app_state["session"] = session
# Create per-endpoint Unix socket sessions for .sock endpoints
for ep in config.llama_server_endpoints:
if _is_unix_socket_endpoint(ep):
sock_path = _get_socket_path(ep)
sock_connector = aiohttp.UnixConnector(path=sock_path)
sock_timeout = aiohttp.ClientTimeout(total=300, connect=5, sock_read=300)
sock_session = aiohttp.ClientSession(connector=sock_connector, timeout=sock_timeout)
app_state["socket_sessions"][ep] = sock_session
transport = httpx.AsyncHTTPTransport(uds=sock_path)
app_state["httpx_clients"][ep] = httpx.AsyncClient(transport=transport, timeout=300.0)
print(f"[startup] Unix socket session: {ep} -> {sock_path}")
token_worker_task = asyncio.create_task(token_worker()) token_worker_task = asyncio.create_task(token_worker())
flush_task = asyncio.create_task(flush_buffer()) flush_task = asyncio.create_task(flush_buffer())
await init_llm_cache(config) await init_llm_cache(config)
@ -3883,6 +3989,23 @@ async def shutdown_event() -> None:
await close_all_sse_queues() await close_all_sse_queues()
await flush_remaining_buffers() await flush_remaining_buffers()
await app_state["session"].close() await app_state["session"].close()
# Close Unix socket sessions
for ep, sess in list(app_state.get("socket_sessions", {}).items()):
try:
await sess.close()
print(f"[shutdown] Closed Unix socket session: {ep}")
except Exception as e:
print(f"[shutdown] Error closing Unix socket session {ep}: {e}")
# Close httpx Unix socket clients
for ep, client in list(app_state.get("httpx_clients", {}).items()):
try:
await client.aclose()
print(f"[shutdown] Closed httpx client: {ep}")
except Exception as e:
print(f"[shutdown] Error closing httpx client {ep}: {e}")
if token_worker_task is not None: if token_worker_task is not None:
token_worker_task.cancel() token_worker_task.cancel()
if flush_task is not None: if flush_task is not None: