diff --git a/README.md b/README.md index abbf181..d7e08c1 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ is a transparent proxy for [Ollama](https://github.com/ollama/ollama) with model deployment aware routing. -Screenshot_NOMYO_Router_0-2-2_Dashboard
+[![Click for video](https://github.com/user-attachments/assets/ddacdf88-e3f3-41dd-8be6-f165b22d9879)](https://eu1.nomyo.ai/assets/dash.mp4) It runs between your frontend application and Ollama backend and is transparent for both, the front- and backend. @@ -78,6 +78,7 @@ docker run -d \ ``` Notes: + - `-e CONFIG_PATH` sets the `NOMYO_ROUTER_CONFIG_PATH` environment variable under the hood; you can export it directly instead if you prefer. - To override the bind address or port, export `UVICORN_HOST` or `UVICORN_PORT`, or pass the corresponding uvicorn flags after `--`, e.g. `nomyo-router --config-path /config/config.yaml -- --port 9000`. - Use `docker logs nomyo-router` to confirm the loaded endpoints and concurrency settings at startup. diff --git a/db.py b/db.py index 24c8480..9f4efd3 100644 --- a/db.py +++ b/db.py @@ -50,6 +50,7 @@ class TokenDatabase: PRIMARY KEY(endpoint, model) ) ''') + await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)') await db.execute(''' CREATE TABLE IF NOT EXISTS token_time_series ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/enhance.py b/enhance.py index 9be1fef..f3200c7 100644 --- a/enhance.py +++ b/enhance.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -class feedback(BaseModel): +class Feedback(BaseModel): query_id: int content: str @@ -25,13 +25,13 @@ def moe(query: str, query_id: int, response: str) -> str: def moe_select_candidate(query: str, candidates: list[str]) -> str: if not candidates: - raise ValueError("No candidates supplied") + raise ValueError("No candidates supplied") candidate_sections = "" - for i, cand in enumerate(candidates[:3], start=0): + for i, cand in enumerate(candidates[:3], start=1): candidate_sections += f""" - {cand.message.content} + {cand} """ @@ -45,5 +45,4 @@ def moe_select_candidate(query: str, candidates: list[str]) -> str: **Do NOT** mention candidate numbers, strengths, weaknesses, or any other commentary. Just give the final answer—nothing else. """ - return select_prompt.strip() - + return select_prompt.strip() \ No newline at end of file diff --git a/entrypoint.sh b/entrypoint.sh index 6682851..e9ccbc3 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -50,7 +50,7 @@ Options: Any arguments that remain after the options above are passed directly to uvicorn. Environment variables: - CONFIG_PATH Alternative way to specify the config path. + CONFIG_PATH_ARG Alternative way to specify the config path. NOMYO_ROUTER_CONFIG_PATH Overrides the config path (same as --config-path). UVICORN_HOST Host interface to bind to (default: 0.0.0.0). UVICORN_PORT Port to listen on (default: 12434). diff --git a/requirements.txt b/requirements.txt index c88fd0d..01e704b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ jiter==0.10.0 multidict==6.6.4 ollama==0.6.0 openai==1.102.0 -orjson==3.11.4 +orjson>=3.11.5 pillow==11.3.0 propcache==0.3.2 pydantic==2.11.7 diff --git a/router.py b/router.py index 8ba1d50..908b8c9 100644 --- a/router.py +++ b/router.py @@ -2,7 +2,7 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.5 +version: 0.6 license: AGPL """ # ------------------------------------------------------------- @@ -34,6 +34,20 @@ _loaded_models_cache: dict[str, tuple[Set[str], float]] = {} # timeout expires, after which the endpoint will be queried again. _error_cache: dict[str, float] = {} +# ------------------------------------------------------------------ +# Cache locks +# ------------------------------------------------------------------ +_models_cache_lock = asyncio.Lock() +_loaded_models_cache_lock = asyncio.Lock() +_error_cache_lock = asyncio.Lock() + +# ------------------------------------------------------------------ +# In-flight request tracking (prevents cache stampede) +# ------------------------------------------------------------------ +_inflight_available_models: dict[str, asyncio.Task] = {} +_inflight_loaded_models: dict[str, asyncio.Task] = {} +_inflight_lock = asyncio.Lock() + # ------------------------------------------------------------------ # Queues # ------------------------------------------------------------------ @@ -339,6 +353,10 @@ async def token_worker() -> None: try: while True: endpoint, model, prompt, comp = await token_queue.get() + # Calculate timestamp once before acquiring lock + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) + # Accumulate counts in memory buffer (protected by lock) async with buffer_lock: token_buffer[endpoint][model] = ( @@ -347,8 +365,6 @@ async def token_worker() -> None: ) # Add to time series buffer with timestamp (UTC) - now = datetime.now(tz=timezone.utc) - timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) time_series_buffer.append({ 'endpoint': endpoint, 'model': model, @@ -369,13 +385,15 @@ async def token_worker() -> None: while not token_queue.empty(): try: endpoint, model, prompt, comp = token_queue.get_nowait() + # Calculate timestamp once before acquiring lock + now = datetime.now(tz=timezone.utc) + timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) + async with buffer_lock: token_buffer[endpoint][model] = ( token_buffer[endpoint].get(model, (0, 0))[0] + prompt, token_buffer[endpoint].get(model, (0, 0))[1] + comp ) - now = datetime.now(tz=timezone.utc) - timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp()) time_series_buffer.append({ 'endpoint': endpoint, 'model': model, @@ -477,42 +495,22 @@ async def flush_remaining_buffers() -> None: print(f"[shutdown] Error flushing remaining buffers: {e}") class fetch: - async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: + async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ - Query /api/tags and return a set of all model names that the - endpoint *advertises* (i.e. is capable of serving). This endpoint lists - every model that is installed on the Ollama instance, regardless of - whether the model is currently loaded into memory. - - If the request fails (e.g. timeout, 5xx, or malformed response), an empty - set is returned. + Internal function that performs the actual HTTP request to fetch available models. + This is called by available_models() after checking caches and in-flight requests. """ headers = None if api_key is not None: headers = {"Authorization": "Bearer " + api_key} - if endpoint in _models_cache: - models, cached_at = _models_cache[endpoint] - if _is_fresh(cached_at, 300): - return models - else: - # stale entry – drop it - del _models_cache[endpoint] - - if endpoint in _error_cache: - if _is_fresh(_error_cache[endpoint], 10): - # Still within the short error TTL – pretend nothing is available - return set() - else: - # Error expired – remove it - del _error_cache[endpoint] - if "/v1" in endpoint: endpoint_url = f"{endpoint}/models" key = "data" else: endpoint_url = f"{endpoint}/api/tags" key = "models" + client: aiohttp.ClientSession = app_state["session"] try: async with client.get(endpoint_url, headers=headers) as resp: @@ -521,43 +519,77 @@ class fetch: items = data.get(key, []) models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} - - if models: + + # Update cache with lock protection + async with _models_cache_lock: _models_cache[endpoint] = (models, time.time()) - return models - else: - # Empty list – treat as “no models”, but still cache for 300s - _models_cache[endpoint] = (models, time.time()) - return models + return models except Exception as e: # Treat any error as if the endpoint offers no models message = _format_connection_issue(endpoint_url, e) print(f"[fetch.available_models] {message}") - _error_cache[endpoint] = time.time() + # Update error cache with lock protection + async with _error_cache_lock: + _error_cache[endpoint] = time.time() return set() - - async def loaded_models(endpoint: str) -> Set[str]: + async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ - Query /api/ps and return a set of model names that are currently - loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + Query /api/tags and return a set of all model names that the + endpoint *advertises* (i.e. is capable of serving). This endpoint lists + every model that is installed on the Ollama instance, regardless of + whether the model is currently loaded into memory. + + Uses request coalescing to prevent cache stampede: if multiple requests + arrive when cache is expired, only one actual HTTP request is made. + + If the request fails (e.g. timeout, 5xx, or malformed response), an empty set is returned. """ - if is_ext_openai_endpoint(endpoint): - return set() - if endpoint in _loaded_models_cache: - models, cached_at = _loaded_models_cache[endpoint] - if _is_fresh(cached_at, 30): - return models - else: - # stale entry – drop it - del _loaded_models_cache[endpoint] + # Check models cache with lock protection + async with _models_cache_lock: + if endpoint in _models_cache: + models, cached_at = _models_cache[endpoint] + if _is_fresh(cached_at, 300): + return models + # Stale entry - remove it + del _models_cache[endpoint] - if endpoint in _error_cache: - if _is_fresh(_error_cache[endpoint], 10): - return set() - else: + # Check error cache with lock protection + async with _error_cache_lock: + if endpoint in _error_cache: + if _is_fresh(_error_cache[endpoint], 10): + # Still within the short error TTL – pretend nothing is available + return set() + # Error expired – remove it del _error_cache[endpoint] + + # Request coalescing: check if another request is already fetching this endpoint + async with _inflight_lock: + if endpoint in _inflight_available_models: + # Another request is already fetching - wait for it + task = _inflight_available_models[endpoint] + else: + # Create new fetch task + task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) + _inflight_available_models[endpoint] = task + + try: + # Wait for the fetch to complete (either ours or another request's) + result = await task + return result + finally: + # Clean up in-flight tracking (only if we created it) + async with _inflight_lock: + if _inflight_available_models.get(endpoint) == task: + _inflight_available_models.pop(endpoint, None) + + + async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]: + """ + Internal function that performs the actual HTTP request to fetch loaded models. + This is called by loaded_models() after checking caches and in-flight requests. + """ client: aiohttp.ClientSession = app_state["session"] try: async with client.get(f"{endpoint}/api/ps") as resp: @@ -566,7 +598,10 @@ class fetch: # The response format is: # {"models": [{"name": "model1"}, {"name": "model2"}]} models = {m.get("name") for m in data.get("models", []) if m.get("name")} - _loaded_models_cache[endpoint] = (models, time.time()) + + # Update cache with lock protection + async with _loaded_models_cache_lock: + _loaded_models_cache[endpoint] = (models, time.time()) return models except Exception as e: # If anything goes wrong we simply assume the endpoint has no models @@ -574,6 +609,75 @@ class fetch: print(f"[fetch.loaded_models] {message}") return set() + async def _refresh_loaded_models(endpoint: str) -> None: + """ + Background task to refresh loaded models cache without blocking the caller. + Used for stale-while-revalidate pattern. + """ + try: + await fetch._fetch_loaded_models_internal(endpoint) + except Exception as e: + # Silently fail - cache will remain stale but functional + print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}") + + async def loaded_models(endpoint: str) -> Set[str]: + """ + Query /api/ps and return a set of model names that are currently + loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + set is returned. + + Uses request coalescing to prevent cache stampede and stale-while-revalidate + to serve requests immediately even when cache is stale (refreshing in background). + """ + if is_ext_openai_endpoint(endpoint): + return set() + + # Check loaded models cache with lock protection + async with _loaded_models_cache_lock: + if endpoint in _loaded_models_cache: + models, cached_at = _loaded_models_cache[endpoint] + + # FRESH: < 10s old - return immediately + if _is_fresh(cached_at, 10): + return models + + # STALE: 10-60s old - return stale data and refresh in background + if _is_fresh(cached_at, 60): + # Kick off background refresh (fire-and-forget) + asyncio.create_task(fetch._refresh_loaded_models(endpoint)) + return models # Return stale data immediately + + # EXPIRED: > 60s old - too stale, must refresh synchronously + del _loaded_models_cache[endpoint] + + # Check error cache with lock protection + async with _error_cache_lock: + if endpoint in _error_cache: + if _is_fresh(_error_cache[endpoint], 10): + return set() + # Error expired - remove it + del _error_cache[endpoint] + + # Request coalescing: check if another request is already fetching this endpoint + async with _inflight_lock: + if endpoint in _inflight_loaded_models: + # Another request is already fetching - wait for it + task = _inflight_loaded_models[endpoint] + else: + # Create new fetch task + task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) + _inflight_loaded_models[endpoint] = task + + try: + # Wait for the fetch to complete (either ours or another request's) + result = await task + return result + finally: + # Clean up in-flight tracking (only if we created it) + async with _inflight_lock: + if _inflight_loaded_models.get(endpoint) == task: + _inflight_loaded_models.pop(endpoint, None) + async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: """ Query / to fetch and return a List of dicts with details @@ -621,7 +725,7 @@ def dedupe_on_keys(dicts, key_fields): async def increment_usage(endpoint: str, model: str) -> None: async with usage_lock: usage_counts[endpoint][model] += 1 - await publish_snapshot() + await publish_snapshot() async def decrement_usage(endpoint: str, model: str) -> None: async with usage_lock: @@ -634,7 +738,7 @@ async def decrement_usage(endpoint: str, model: str) -> None: usage_counts[endpoint].pop(model, None) #if not usage_counts[endpoint]: # usage_counts.pop(endpoint, None) - await publish_snapshot() + await publish_snapshot() async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: """ @@ -957,10 +1061,14 @@ class rechunk: # SSE Helpser # ------------------------------------------------------------------ async def publish_snapshot(): - async with usage_lock: - snapshot = orjson.dumps({"usage_counts": usage_counts, - "token_usage_counts": token_usage_counts, - }, option=orjson.OPT_SORT_KEYS).decode("utf-8") + # NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller + # Create a snapshot without acquiring the lock (caller must hold it) + snapshot = orjson.dumps({ + "usage_counts": dict(usage_counts), # Create a copy + "token_usage_counts": dict(token_usage_counts) + }, option=orjson.OPT_SORT_KEYS).decode("utf-8") + + # Distribute the snapshot (no lock needed here since we have a copy) async with _subscribers_lock: for q in _subscribers: # If the queue is full, drop the message to avoid back‑pressure. @@ -1056,21 +1164,33 @@ async def choose_endpoint(model: str) -> str: # (concurrently, but only for the filtered list) load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) - + + # Protect all reads of usage_counts with the lock async with usage_lock: # Helper: get current usage count for (endpoint, model) def current_usage(ep: str) -> int: return usage_counts.get(ep, {}).get(model, 0) - + # 3️⃣ Endpoints that have the model loaded *and* a free slot loaded_and_free = [ ep for ep, models in zip(candidate_endpoints, loaded_sets) if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections ] - + if loaded_and_free: - ep = min(loaded_and_free, key=current_usage) - return ep + # Sort by per-model usage in DESCENDING order to ensure model affinity + # Endpoints with higher usage (already handling this model) should be preferred + # until they reach max_concurrent_connections + loaded_and_free.sort( + key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order + ) + + # If all endpoints have zero usage for this model, randomize to distribute + # different models across different endpoints for better resource utilization + if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free): + return random.choice(loaded_and_free) + + return loaded_and_free[0] # 4️⃣ Endpoints among the candidates that simply have a free slot endpoints_with_free_slot = [ @@ -1079,8 +1199,22 @@ async def choose_endpoint(model: str) -> str: ] if endpoints_with_free_slot: - #return random.choice(endpoints_with_free_slot) - endpoints_with_free_slot.sort(key=lambda ep: sum(usage_counts.get(ep, {}).values())) + # Sort by per-model usage (descending) first to ensure model affinity + # Even if the model isn't showing as "loaded" in /api/ps yet (e.g., during initial loading), + # we want to send subsequent requests to the endpoint that already has connections for this model + # Then by total endpoint usage (ascending) to balance idle endpoints + endpoints_with_free_slot.sort( + key=lambda ep: ( + -usage_counts.get(ep, {}).get(model, 0), # Primary: per-model usage (descending - prefer endpoints with connections) + sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints) + ) + ) + + # If all endpoints have zero usage for this specific model, randomize to distribute + # different models across different endpoints for better resource utilization + if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot): + return random.choice(endpoints_with_free_slot) + return endpoints_with_free_slot[0] # 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue) @@ -2070,7 +2204,16 @@ async def openai_chat_completions_proxy(request: Request): async def stream_ochat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await oclient.chat.completions.create(**params) + try: + async_gen = await oclient.chat.completions.create(**params) + except openai.BadRequestError as e: + # If tools are not supported by the model, retry without tools + if "does not support tools" in str(e): + print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools") + params_without_tools = {k: v for k, v in params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + else: + raise if stream == True: async for chunk in async_gen: data = (