feat(router): implement in-flight request tracking to prevent cache stampede in high concurrency scenarios

Added in-flight request tracking mechanism to prevent cache stampede when multiple concurrent requests arrive for the same endpoint. This introduces new dictionaries to track ongoing requests and a lock to coordinate access. The available_models method was refactored to use an internal helper function and includes request coalescing logic to ensure only one HTTP request is made per endpoint when cache entries expire. The loaded_models method was also updated to use the new caching and coalescing pattern.
This commit is contained in:
Alpha Nerd 2026-01-29 18:00:33 +01:00
parent bfdae1e4a6
commit 4ca1a5667e

231
router.py
View file

@ -41,6 +41,13 @@ _models_cache_lock = asyncio.Lock()
_loaded_models_cache_lock = asyncio.Lock()
_error_cache_lock = asyncio.Lock()
# ------------------------------------------------------------------
# In-flight request tracking (prevents cache stampede)
# ------------------------------------------------------------------
_inflight_available_models: dict[str, asyncio.Task] = {}
_inflight_loaded_models: dict[str, asyncio.Task] = {}
_inflight_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Queues
# ------------------------------------------------------------------
@ -375,6 +382,44 @@ async def flush_remaining_buffers() -> None:
print(f"[shutdown] Error flushing remaining buffers: {e}")
class fetch:
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Internal function that performs the actual HTTP request to fetch available models.
This is called by available_models() after checking caches and in-flight requests.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
if "/v1" in endpoint:
endpoint_url = f"{endpoint}/models"
key = "data"
else:
endpoint_url = f"{endpoint}/api/tags"
key = "models"
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
# Update cache with lock protection
async with _models_cache_lock:
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
message = _format_connection_issue(endpoint_url, e)
print(f"[fetch.available_models] {message}")
# Update error cache with lock protection
async with _error_cache_lock:
_error_cache[endpoint] = time.time()
return set()
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
@ -382,13 +427,12 @@ class fetch:
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
Uses request coalescing to prevent cache stampede: if multiple requests
arrive when cache is expired, only one actual HTTP request is made.
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
# Check models cache with lock protection
async with _models_cache_lock:
if endpoint in _models_cache:
@ -407,65 +451,32 @@ class fetch:
# Error expired remove it
del _error_cache[endpoint]
if "/v1" in endpoint:
endpoint_url = f"{endpoint}/models"
key = "data"
else:
endpoint_url = f"{endpoint}/api/tags"
key = "models"
client: aiohttp.ClientSession = app_state["session"]
# Request coalescing: check if another request is already fetching this endpoint
async with _inflight_lock:
if endpoint in _inflight_available_models:
# Another request is already fetching - wait for it
task = _inflight_available_models[endpoint]
else:
# Create new fetch task
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
_inflight_available_models[endpoint] = task
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
# Update cache with lock protection
async with _models_cache_lock:
if models:
_models_cache[endpoint] = (models, time.time())
else:
# Empty list treat as "no models", but still cache for 300s
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
message = _format_connection_issue(endpoint_url, e)
print(f"[fetch.available_models] {message}")
# Update error cache with lock protection
async with _error_cache_lock:
_error_cache[endpoint] = time.time()
return set()
# Wait for the fetch to complete (either ours or another request's)
result = await task
return result
finally:
# Clean up in-flight tracking (only if we created it)
async with _inflight_lock:
if _inflight_available_models.get(endpoint) == task:
_inflight_available_models.pop(endpoint, None)
async def loaded_models(endpoint: str) -> Set[str]:
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
Internal function that performs the actual HTTP request to fetch loaded models.
This is called by loaded_models() after checking caches and in-flight requests.
"""
if is_ext_openai_endpoint(endpoint):
return set()
# Check loaded models cache with lock protection
async with _loaded_models_cache_lock:
if endpoint in _loaded_models_cache:
models, cached_at = _loaded_models_cache[endpoint]
if _is_fresh(cached_at, 30):
return models
# Stale entry - remove it
del _loaded_models_cache[endpoint]
# Check error cache with lock protection
async with _error_cache_lock:
if endpoint in _error_cache:
if _is_fresh(_error_cache[endpoint], 10):
return set()
# Error expired - remove it
del _error_cache[endpoint]
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(f"{endpoint}/api/ps") as resp:
@ -485,6 +496,75 @@ class fetch:
print(f"[fetch.loaded_models] {message}")
return set()
async def _refresh_loaded_models(endpoint: str) -> None:
"""
Background task to refresh loaded models cache without blocking the caller.
Used for stale-while-revalidate pattern.
"""
try:
await fetch._fetch_loaded_models_internal(endpoint)
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
async def loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
Uses request coalescing to prevent cache stampede and stale-while-revalidate
to serve requests immediately even when cache is stale (refreshing in background).
"""
if is_ext_openai_endpoint(endpoint):
return set()
# Check loaded models cache with lock protection
async with _loaded_models_cache_lock:
if endpoint in _loaded_models_cache:
models, cached_at = _loaded_models_cache[endpoint]
# FRESH: < 10s old - return immediately
if _is_fresh(cached_at, 10):
return models
# STALE: 10-60s old - return stale data and refresh in background
if _is_fresh(cached_at, 60):
# Kick off background refresh (fire-and-forget)
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
return models # Return stale data immediately
# EXPIRED: > 60s old - too stale, must refresh synchronously
del _loaded_models_cache[endpoint]
# Check error cache with lock protection
async with _error_cache_lock:
if endpoint in _error_cache:
if _is_fresh(_error_cache[endpoint], 10):
return set()
# Error expired - remove it
del _error_cache[endpoint]
# Request coalescing: check if another request is already fetching this endpoint
async with _inflight_lock:
if endpoint in _inflight_loaded_models:
# Another request is already fetching - wait for it
task = _inflight_loaded_models[endpoint]
else:
# Create new fetch task
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
_inflight_loaded_models[endpoint] = task
try:
# Wait for the fetch to complete (either ours or another request's)
result = await task
return result
finally:
# Clean up in-flight tracking (only if we created it)
async with _inflight_lock:
if _inflight_loaded_models.get(endpoint) == task:
_inflight_loaded_models.pop(endpoint, None)
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
@ -868,12 +948,12 @@ class rechunk:
# SSE Helpser
# ------------------------------------------------------------------
async def publish_snapshot():
# Take a consistent snapshot while holding the lock
async with usage_lock:
snapshot = orjson.dumps({
"usage_counts": dict(usage_counts), # Create a copy
"token_usage_counts": dict(token_usage_counts)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
# NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller
# Create a snapshot without acquiring the lock (caller must hold it)
snapshot = orjson.dumps({
"usage_counts": dict(usage_counts), # Create a copy
"token_usage_counts": dict(token_usage_counts)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
# Distribute the snapshot (no lock needed here since we have a copy)
async with _subscribers_lock:
@ -991,6 +1071,12 @@ async def choose_endpoint(model: str) -> str:
loaded_and_free.sort(
key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order
)
# If all endpoints have zero usage for this model, randomize to distribute
# different models across different endpoints for better resource utilization
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free):
return random.choice(loaded_and_free)
return loaded_and_free[0]
# 4⃣ Endpoints among the candidates that simply have a free slot
@ -1010,6 +1096,12 @@ async def choose_endpoint(model: str) -> str:
sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints)
)
)
# If all endpoints have zero usage for this specific model, randomize to distribute
# different models across different endpoints for better resource utilization
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot):
return random.choice(endpoints_with_free_slot)
return endpoints_with_free_slot[0]
# 5⃣ All candidate endpoints are saturated pick one with lowest usages count (will queue)
@ -1974,7 +2066,16 @@ async def openai_chat_completions_proxy(request: Request):
async def stream_ochat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await oclient.chat.completions.create(**params)
try:
async_gen = await oclient.chat.completions.create(**params)
except openai.BadRequestError as e:
# If tools are not supported by the model, retry without tools
if "does not support tools" in str(e):
print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
params_without_tools = {k: v for k, v in params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
else:
raise
if stream == True:
async for chunk in async_gen:
data = (