diff --git a/router.py b/router.py index a990cb7..829b16a 100644 --- a/router.py +++ b/router.py @@ -41,6 +41,13 @@ _models_cache_lock = asyncio.Lock() _loaded_models_cache_lock = asyncio.Lock() _error_cache_lock = asyncio.Lock() +# ------------------------------------------------------------------ +# In-flight request tracking (prevents cache stampede) +# ------------------------------------------------------------------ +_inflight_available_models: dict[str, asyncio.Task] = {} +_inflight_loaded_models: dict[str, asyncio.Task] = {} +_inflight_lock = asyncio.Lock() + # ------------------------------------------------------------------ # Queues # ------------------------------------------------------------------ @@ -375,6 +382,44 @@ async def flush_remaining_buffers() -> None: print(f"[shutdown] Error flushing remaining buffers: {e}") class fetch: + async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]: + """ + Internal function that performs the actual HTTP request to fetch available models. + This is called by available_models() after checking caches and in-flight requests. + """ + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} + + if "/v1" in endpoint: + endpoint_url = f"{endpoint}/models" + key = "data" + else: + endpoint_url = f"{endpoint}/api/tags" + key = "models" + + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.get(endpoint_url, headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + + items = data.get(key, []) + models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} + + # Update cache with lock protection + async with _models_cache_lock: + _models_cache[endpoint] = (models, time.time()) + return models + except Exception as e: + # Treat any error as if the endpoint offers no models + message = _format_connection_issue(endpoint_url, e) + print(f"[fetch.available_models] {message}") + # Update error cache with lock protection + async with _error_cache_lock: + _error_cache[endpoint] = time.time() + return set() + async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ Query /api/tags and return a set of all model names that the @@ -382,13 +427,12 @@ class fetch: every model that is installed on the Ollama instance, regardless of whether the model is currently loaded into memory. + Uses request coalescing to prevent cache stampede: if multiple requests + arrive when cache is expired, only one actual HTTP request is made. + If the request fails (e.g. timeout, 5xx, or malformed response), an empty set is returned. """ - headers = None - if api_key is not None: - headers = {"Authorization": "Bearer " + api_key} - # Check models cache with lock protection async with _models_cache_lock: if endpoint in _models_cache: @@ -407,65 +451,32 @@ class fetch: # Error expired – remove it del _error_cache[endpoint] - if "/v1" in endpoint: - endpoint_url = f"{endpoint}/models" - key = "data" - else: - endpoint_url = f"{endpoint}/api/tags" - key = "models" - client: aiohttp.ClientSession = app_state["session"] + # Request coalescing: check if another request is already fetching this endpoint + async with _inflight_lock: + if endpoint in _inflight_available_models: + # Another request is already fetching - wait for it + task = _inflight_available_models[endpoint] + else: + # Create new fetch task + task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) + _inflight_available_models[endpoint] = task + try: - async with client.get(endpoint_url, headers=headers) as resp: - await _ensure_success(resp) - data = await resp.json() - - items = data.get(key, []) - models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} - - # Update cache with lock protection - async with _models_cache_lock: - if models: - _models_cache[endpoint] = (models, time.time()) - else: - # Empty list – treat as "no models", but still cache for 300s - _models_cache[endpoint] = (models, time.time()) - return models - except Exception as e: - # Treat any error as if the endpoint offers no models - message = _format_connection_issue(endpoint_url, e) - print(f"[fetch.available_models] {message}") - # Update error cache with lock protection - async with _error_cache_lock: - _error_cache[endpoint] = time.time() - return set() + # Wait for the fetch to complete (either ours or another request's) + result = await task + return result + finally: + # Clean up in-flight tracking (only if we created it) + async with _inflight_lock: + if _inflight_available_models.get(endpoint) == task: + _inflight_available_models.pop(endpoint, None) - async def loaded_models(endpoint: str) -> Set[str]: + async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]: """ - Query /api/ps and return a set of model names that are currently - loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty - set is returned. + Internal function that performs the actual HTTP request to fetch loaded models. + This is called by loaded_models() after checking caches and in-flight requests. """ - if is_ext_openai_endpoint(endpoint): - return set() - - # Check loaded models cache with lock protection - async with _loaded_models_cache_lock: - if endpoint in _loaded_models_cache: - models, cached_at = _loaded_models_cache[endpoint] - if _is_fresh(cached_at, 30): - return models - # Stale entry - remove it - del _loaded_models_cache[endpoint] - - # Check error cache with lock protection - async with _error_cache_lock: - if endpoint in _error_cache: - if _is_fresh(_error_cache[endpoint], 10): - return set() - # Error expired - remove it - del _error_cache[endpoint] - client: aiohttp.ClientSession = app_state["session"] try: async with client.get(f"{endpoint}/api/ps") as resp: @@ -485,6 +496,75 @@ class fetch: print(f"[fetch.loaded_models] {message}") return set() + async def _refresh_loaded_models(endpoint: str) -> None: + """ + Background task to refresh loaded models cache without blocking the caller. + Used for stale-while-revalidate pattern. + """ + try: + await fetch._fetch_loaded_models_internal(endpoint) + except Exception as e: + # Silently fail - cache will remain stale but functional + print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}") + + async def loaded_models(endpoint: str) -> Set[str]: + """ + Query /api/ps and return a set of model names that are currently + loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + set is returned. + + Uses request coalescing to prevent cache stampede and stale-while-revalidate + to serve requests immediately even when cache is stale (refreshing in background). + """ + if is_ext_openai_endpoint(endpoint): + return set() + + # Check loaded models cache with lock protection + async with _loaded_models_cache_lock: + if endpoint in _loaded_models_cache: + models, cached_at = _loaded_models_cache[endpoint] + + # FRESH: < 10s old - return immediately + if _is_fresh(cached_at, 10): + return models + + # STALE: 10-60s old - return stale data and refresh in background + if _is_fresh(cached_at, 60): + # Kick off background refresh (fire-and-forget) + asyncio.create_task(fetch._refresh_loaded_models(endpoint)) + return models # Return stale data immediately + + # EXPIRED: > 60s old - too stale, must refresh synchronously + del _loaded_models_cache[endpoint] + + # Check error cache with lock protection + async with _error_cache_lock: + if endpoint in _error_cache: + if _is_fresh(_error_cache[endpoint], 10): + return set() + # Error expired - remove it + del _error_cache[endpoint] + + # Request coalescing: check if another request is already fetching this endpoint + async with _inflight_lock: + if endpoint in _inflight_loaded_models: + # Another request is already fetching - wait for it + task = _inflight_loaded_models[endpoint] + else: + # Create new fetch task + task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) + _inflight_loaded_models[endpoint] = task + + try: + # Wait for the fetch to complete (either ours or another request's) + result = await task + return result + finally: + # Clean up in-flight tracking (only if we created it) + async with _inflight_lock: + if _inflight_loaded_models.get(endpoint) == task: + _inflight_loaded_models.pop(endpoint, None) + async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: """ Query / to fetch and return a List of dicts with details @@ -868,12 +948,12 @@ class rechunk: # SSE Helpser # ------------------------------------------------------------------ async def publish_snapshot(): - # Take a consistent snapshot while holding the lock - async with usage_lock: - snapshot = orjson.dumps({ - "usage_counts": dict(usage_counts), # Create a copy - "token_usage_counts": dict(token_usage_counts) - }, option=orjson.OPT_SORT_KEYS).decode("utf-8") + # NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller + # Create a snapshot without acquiring the lock (caller must hold it) + snapshot = orjson.dumps({ + "usage_counts": dict(usage_counts), # Create a copy + "token_usage_counts": dict(token_usage_counts) + }, option=orjson.OPT_SORT_KEYS).decode("utf-8") # Distribute the snapshot (no lock needed here since we have a copy) async with _subscribers_lock: @@ -991,6 +1071,12 @@ async def choose_endpoint(model: str) -> str: loaded_and_free.sort( key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order ) + + # If all endpoints have zero usage for this model, randomize to distribute + # different models across different endpoints for better resource utilization + if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free): + return random.choice(loaded_and_free) + return loaded_and_free[0] # 4️⃣ Endpoints among the candidates that simply have a free slot @@ -1010,6 +1096,12 @@ async def choose_endpoint(model: str) -> str: sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints) ) ) + + # If all endpoints have zero usage for this specific model, randomize to distribute + # different models across different endpoints for better resource utilization + if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot): + return random.choice(endpoints_with_free_slot) + return endpoints_with_free_slot[0] # 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue) @@ -1974,7 +2066,16 @@ async def openai_chat_completions_proxy(request: Request): async def stream_ochat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await oclient.chat.completions.create(**params) + try: + async_gen = await oclient.chat.completions.create(**params) + except openai.BadRequestError as e: + # If tools are not supported by the model, retry without tools + if "does not support tools" in str(e): + print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools") + params_without_tools = {k: v for k, v in params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + else: + raise if stream == True: async for chunk in async_gen: data = (