diff --git a/README.md b/README.md
index abbf181..d7e08c1 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
is a transparent proxy for [Ollama](https://github.com/ollama/ollama) with model deployment aware routing.
-
+[](https://eu1.nomyo.ai/assets/dash.mp4)
It runs between your frontend application and Ollama backend and is transparent for both, the front- and backend.
@@ -78,6 +78,7 @@ docker run -d \
```
Notes:
+
- `-e CONFIG_PATH` sets the `NOMYO_ROUTER_CONFIG_PATH` environment variable under the hood; you can export it directly instead if you prefer.
- To override the bind address or port, export `UVICORN_HOST` or `UVICORN_PORT`, or pass the corresponding uvicorn flags after `--`, e.g. `nomyo-router --config-path /config/config.yaml -- --port 9000`.
- Use `docker logs nomyo-router` to confirm the loaded endpoints and concurrency settings at startup.
diff --git a/db.py b/db.py
index 24c8480..9f4efd3 100644
--- a/db.py
+++ b/db.py
@@ -50,6 +50,7 @@ class TokenDatabase:
PRIMARY KEY(endpoint, model)
)
''')
+ await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
await db.execute('''
CREATE TABLE IF NOT EXISTS token_time_series (
id INTEGER PRIMARY KEY AUTOINCREMENT,
diff --git a/enhance.py b/enhance.py
index 9be1fef..f3200c7 100644
--- a/enhance.py
+++ b/enhance.py
@@ -1,6 +1,6 @@
from pydantic import BaseModel
-class feedback(BaseModel):
+class Feedback(BaseModel):
query_id: int
content: str
@@ -25,13 +25,13 @@ def moe(query: str, query_id: int, response: str) -> str:
def moe_select_candidate(query: str, candidates: list[str]) -> str:
if not candidates:
- raise ValueError("No candidates supplied")
+ raise ValueError("No candidates supplied")
candidate_sections = ""
- for i, cand in enumerate(candidates[:3], start=0):
+ for i, cand in enumerate(candidates[:3], start=1):
candidate_sections += f"""
- {cand.message.content}
+ {cand}
"""
@@ -45,5 +45,4 @@ def moe_select_candidate(query: str, candidates: list[str]) -> str:
**Do NOT** mention candidate numbers, strengths, weaknesses, or any other commentary.
Just give the final answer—nothing else.
"""
- return select_prompt.strip()
-
+ return select_prompt.strip()
\ No newline at end of file
diff --git a/entrypoint.sh b/entrypoint.sh
index 6682851..e9ccbc3 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -50,7 +50,7 @@ Options:
Any arguments that remain after the options above are passed directly to uvicorn.
Environment variables:
- CONFIG_PATH Alternative way to specify the config path.
+ CONFIG_PATH_ARG Alternative way to specify the config path.
NOMYO_ROUTER_CONFIG_PATH Overrides the config path (same as --config-path).
UVICORN_HOST Host interface to bind to (default: 0.0.0.0).
UVICORN_PORT Port to listen on (default: 12434).
diff --git a/requirements.txt b/requirements.txt
index c88fd0d..01e704b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -21,7 +21,7 @@ jiter==0.10.0
multidict==6.6.4
ollama==0.6.0
openai==1.102.0
-orjson==3.11.4
+orjson>=3.11.5
pillow==11.3.0
propcache==0.3.2
pydantic==2.11.7
diff --git a/router.py b/router.py
index 8ba1d50..908b8c9 100644
--- a/router.py
+++ b/router.py
@@ -2,7 +2,7 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
-version: 0.5
+version: 0.6
license: AGPL
"""
# -------------------------------------------------------------
@@ -34,6 +34,20 @@ _loaded_models_cache: dict[str, tuple[Set[str], float]] = {}
# timeout expires, after which the endpoint will be queried again.
_error_cache: dict[str, float] = {}
+# ------------------------------------------------------------------
+# Cache locks
+# ------------------------------------------------------------------
+_models_cache_lock = asyncio.Lock()
+_loaded_models_cache_lock = asyncio.Lock()
+_error_cache_lock = asyncio.Lock()
+
+# ------------------------------------------------------------------
+# In-flight request tracking (prevents cache stampede)
+# ------------------------------------------------------------------
+_inflight_available_models: dict[str, asyncio.Task] = {}
+_inflight_loaded_models: dict[str, asyncio.Task] = {}
+_inflight_lock = asyncio.Lock()
+
# ------------------------------------------------------------------
# Queues
# ------------------------------------------------------------------
@@ -339,6 +353,10 @@ async def token_worker() -> None:
try:
while True:
endpoint, model, prompt, comp = await token_queue.get()
+ # Calculate timestamp once before acquiring lock
+ now = datetime.now(tz=timezone.utc)
+ timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
+
# Accumulate counts in memory buffer (protected by lock)
async with buffer_lock:
token_buffer[endpoint][model] = (
@@ -347,8 +365,6 @@ async def token_worker() -> None:
)
# Add to time series buffer with timestamp (UTC)
- now = datetime.now(tz=timezone.utc)
- timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
@@ -369,13 +385,15 @@ async def token_worker() -> None:
while not token_queue.empty():
try:
endpoint, model, prompt, comp = token_queue.get_nowait()
+ # Calculate timestamp once before acquiring lock
+ now = datetime.now(tz=timezone.utc)
+ timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
+
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
- now = datetime.now(tz=timezone.utc)
- timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
@@ -477,42 +495,22 @@ async def flush_remaining_buffers() -> None:
print(f"[shutdown] Error flushing remaining buffers: {e}")
class fetch:
- async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
+ async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
- Query /api/tags and return a set of all model names that the
- endpoint *advertises* (i.e. is capable of serving). This endpoint lists
- every model that is installed on the Ollama instance, regardless of
- whether the model is currently loaded into memory.
-
- If the request fails (e.g. timeout, 5xx, or malformed response), an empty
- set is returned.
+ Internal function that performs the actual HTTP request to fetch available models.
+ This is called by available_models() after checking caches and in-flight requests.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
- if endpoint in _models_cache:
- models, cached_at = _models_cache[endpoint]
- if _is_fresh(cached_at, 300):
- return models
- else:
- # stale entry – drop it
- del _models_cache[endpoint]
-
- if endpoint in _error_cache:
- if _is_fresh(_error_cache[endpoint], 10):
- # Still within the short error TTL – pretend nothing is available
- return set()
- else:
- # Error expired – remove it
- del _error_cache[endpoint]
-
if "/v1" in endpoint:
endpoint_url = f"{endpoint}/models"
key = "data"
else:
endpoint_url = f"{endpoint}/api/tags"
key = "models"
+
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(endpoint_url, headers=headers) as resp:
@@ -521,43 +519,77 @@ class fetch:
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
-
- if models:
+
+ # Update cache with lock protection
+ async with _models_cache_lock:
_models_cache[endpoint] = (models, time.time())
- return models
- else:
- # Empty list – treat as “no models”, but still cache for 300s
- _models_cache[endpoint] = (models, time.time())
- return models
+ return models
except Exception as e:
# Treat any error as if the endpoint offers no models
message = _format_connection_issue(endpoint_url, e)
print(f"[fetch.available_models] {message}")
- _error_cache[endpoint] = time.time()
+ # Update error cache with lock protection
+ async with _error_cache_lock:
+ _error_cache[endpoint] = time.time()
return set()
-
- async def loaded_models(endpoint: str) -> Set[str]:
+ async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
- Query /api/ps and return a set of model names that are currently
- loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
+ Query /api/tags and return a set of all model names that the
+ endpoint *advertises* (i.e. is capable of serving). This endpoint lists
+ every model that is installed on the Ollama instance, regardless of
+ whether the model is currently loaded into memory.
+
+ Uses request coalescing to prevent cache stampede: if multiple requests
+ arrive when cache is expired, only one actual HTTP request is made.
+
+ If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
- if is_ext_openai_endpoint(endpoint):
- return set()
- if endpoint in _loaded_models_cache:
- models, cached_at = _loaded_models_cache[endpoint]
- if _is_fresh(cached_at, 30):
- return models
- else:
- # stale entry – drop it
- del _loaded_models_cache[endpoint]
+ # Check models cache with lock protection
+ async with _models_cache_lock:
+ if endpoint in _models_cache:
+ models, cached_at = _models_cache[endpoint]
+ if _is_fresh(cached_at, 300):
+ return models
+ # Stale entry - remove it
+ del _models_cache[endpoint]
- if endpoint in _error_cache:
- if _is_fresh(_error_cache[endpoint], 10):
- return set()
- else:
+ # Check error cache with lock protection
+ async with _error_cache_lock:
+ if endpoint in _error_cache:
+ if _is_fresh(_error_cache[endpoint], 10):
+ # Still within the short error TTL – pretend nothing is available
+ return set()
+ # Error expired – remove it
del _error_cache[endpoint]
+
+ # Request coalescing: check if another request is already fetching this endpoint
+ async with _inflight_lock:
+ if endpoint in _inflight_available_models:
+ # Another request is already fetching - wait for it
+ task = _inflight_available_models[endpoint]
+ else:
+ # Create new fetch task
+ task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
+ _inflight_available_models[endpoint] = task
+
+ try:
+ # Wait for the fetch to complete (either ours or another request's)
+ result = await task
+ return result
+ finally:
+ # Clean up in-flight tracking (only if we created it)
+ async with _inflight_lock:
+ if _inflight_available_models.get(endpoint) == task:
+ _inflight_available_models.pop(endpoint, None)
+
+
+ async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
+ """
+ Internal function that performs the actual HTTP request to fetch loaded models.
+ This is called by loaded_models() after checking caches and in-flight requests.
+ """
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(f"{endpoint}/api/ps") as resp:
@@ -566,7 +598,10 @@ class fetch:
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
- _loaded_models_cache[endpoint] = (models, time.time())
+
+ # Update cache with lock protection
+ async with _loaded_models_cache_lock:
+ _loaded_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
@@ -574,6 +609,75 @@ class fetch:
print(f"[fetch.loaded_models] {message}")
return set()
+ async def _refresh_loaded_models(endpoint: str) -> None:
+ """
+ Background task to refresh loaded models cache without blocking the caller.
+ Used for stale-while-revalidate pattern.
+ """
+ try:
+ await fetch._fetch_loaded_models_internal(endpoint)
+ except Exception as e:
+ # Silently fail - cache will remain stale but functional
+ print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
+
+ async def loaded_models(endpoint: str) -> Set[str]:
+ """
+ Query /api/ps and return a set of model names that are currently
+ loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
+ set is returned.
+
+ Uses request coalescing to prevent cache stampede and stale-while-revalidate
+ to serve requests immediately even when cache is stale (refreshing in background).
+ """
+ if is_ext_openai_endpoint(endpoint):
+ return set()
+
+ # Check loaded models cache with lock protection
+ async with _loaded_models_cache_lock:
+ if endpoint in _loaded_models_cache:
+ models, cached_at = _loaded_models_cache[endpoint]
+
+ # FRESH: < 10s old - return immediately
+ if _is_fresh(cached_at, 10):
+ return models
+
+ # STALE: 10-60s old - return stale data and refresh in background
+ if _is_fresh(cached_at, 60):
+ # Kick off background refresh (fire-and-forget)
+ asyncio.create_task(fetch._refresh_loaded_models(endpoint))
+ return models # Return stale data immediately
+
+ # EXPIRED: > 60s old - too stale, must refresh synchronously
+ del _loaded_models_cache[endpoint]
+
+ # Check error cache with lock protection
+ async with _error_cache_lock:
+ if endpoint in _error_cache:
+ if _is_fresh(_error_cache[endpoint], 10):
+ return set()
+ # Error expired - remove it
+ del _error_cache[endpoint]
+
+ # Request coalescing: check if another request is already fetching this endpoint
+ async with _inflight_lock:
+ if endpoint in _inflight_loaded_models:
+ # Another request is already fetching - wait for it
+ task = _inflight_loaded_models[endpoint]
+ else:
+ # Create new fetch task
+ task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
+ _inflight_loaded_models[endpoint] = task
+
+ try:
+ # Wait for the fetch to complete (either ours or another request's)
+ result = await task
+ return result
+ finally:
+ # Clean up in-flight tracking (only if we created it)
+ async with _inflight_lock:
+ if _inflight_loaded_models.get(endpoint) == task:
+ _inflight_loaded_models.pop(endpoint, None)
+
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query / to fetch and return a List of dicts with details
@@ -621,7 +725,7 @@ def dedupe_on_keys(dicts, key_fields):
async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock:
usage_counts[endpoint][model] += 1
- await publish_snapshot()
+ await publish_snapshot()
async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock:
@@ -634,7 +738,7 @@ async def decrement_usage(endpoint: str, model: str) -> None:
usage_counts[endpoint].pop(model, None)
#if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None)
- await publish_snapshot()
+ await publish_snapshot()
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
@@ -957,10 +1061,14 @@ class rechunk:
# SSE Helpser
# ------------------------------------------------------------------
async def publish_snapshot():
- async with usage_lock:
- snapshot = orjson.dumps({"usage_counts": usage_counts,
- "token_usage_counts": token_usage_counts,
- }, option=orjson.OPT_SORT_KEYS).decode("utf-8")
+ # NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller
+ # Create a snapshot without acquiring the lock (caller must hold it)
+ snapshot = orjson.dumps({
+ "usage_counts": dict(usage_counts), # Create a copy
+ "token_usage_counts": dict(token_usage_counts)
+ }, option=orjson.OPT_SORT_KEYS).decode("utf-8")
+
+ # Distribute the snapshot (no lock needed here since we have a copy)
async with _subscribers_lock:
for q in _subscribers:
# If the queue is full, drop the message to avoid back‑pressure.
@@ -1056,21 +1164,33 @@ async def choose_endpoint(model: str) -> str:
# (concurrently, but only for the filtered list)
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
-
+
+ # Protect all reads of usage_counts with the lock
async with usage_lock:
# Helper: get current usage count for (endpoint, model)
def current_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(model, 0)
-
+
# 3️⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
]
-
+
if loaded_and_free:
- ep = min(loaded_and_free, key=current_usage)
- return ep
+ # Sort by per-model usage in DESCENDING order to ensure model affinity
+ # Endpoints with higher usage (already handling this model) should be preferred
+ # until they reach max_concurrent_connections
+ loaded_and_free.sort(
+ key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order
+ )
+
+ # If all endpoints have zero usage for this model, randomize to distribute
+ # different models across different endpoints for better resource utilization
+ if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free):
+ return random.choice(loaded_and_free)
+
+ return loaded_and_free[0]
# 4️⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
@@ -1079,8 +1199,22 @@ async def choose_endpoint(model: str) -> str:
]
if endpoints_with_free_slot:
- #return random.choice(endpoints_with_free_slot)
- endpoints_with_free_slot.sort(key=lambda ep: sum(usage_counts.get(ep, {}).values()))
+ # Sort by per-model usage (descending) first to ensure model affinity
+ # Even if the model isn't showing as "loaded" in /api/ps yet (e.g., during initial loading),
+ # we want to send subsequent requests to the endpoint that already has connections for this model
+ # Then by total endpoint usage (ascending) to balance idle endpoints
+ endpoints_with_free_slot.sort(
+ key=lambda ep: (
+ -usage_counts.get(ep, {}).get(model, 0), # Primary: per-model usage (descending - prefer endpoints with connections)
+ sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints)
+ )
+ )
+
+ # If all endpoints have zero usage for this specific model, randomize to distribute
+ # different models across different endpoints for better resource utilization
+ if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot):
+ return random.choice(endpoints_with_free_slot)
+
return endpoints_with_free_slot[0]
# 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue)
@@ -2070,7 +2204,16 @@ async def openai_chat_completions_proxy(request: Request):
async def stream_ochat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
- async_gen = await oclient.chat.completions.create(**params)
+ try:
+ async_gen = await oclient.chat.completions.create(**params)
+ except openai.BadRequestError as e:
+ # If tools are not supported by the model, retry without tools
+ if "does not support tools" in str(e):
+ print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
+ params_without_tools = {k: v for k, v in params.items() if k != "tools"}
+ async_gen = await oclient.chat.completions.create(**params_without_tools)
+ else:
+ raise
if stream == True:
async for chunk in async_gen:
data = (