feat(router): implement in-flight request tracking to prevent cache stampede in high concurrency scenarios
Added in-flight request tracking mechanism to prevent cache stampede when multiple concurrent requests arrive for the same endpoint. This introduces new dictionaries to track ongoing requests and a lock to coordinate access. The available_models method was refactored to use an internal helper function and includes request coalescing logic to ensure only one HTTP request is made per endpoint when cache entries expire. The loaded_models method was also updated to use the new caching and coalescing pattern.
This commit is contained in:
parent
bfdae1e4a6
commit
4ca1a5667e
1 changed files with 166 additions and 65 deletions
231
router.py
231
router.py
|
|
@ -41,6 +41,13 @@ _models_cache_lock = asyncio.Lock()
|
||||||
_loaded_models_cache_lock = asyncio.Lock()
|
_loaded_models_cache_lock = asyncio.Lock()
|
||||||
_error_cache_lock = asyncio.Lock()
|
_error_cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# In-flight request tracking (prevents cache stampede)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
_inflight_available_models: dict[str, asyncio.Task] = {}
|
||||||
|
_inflight_loaded_models: dict[str, asyncio.Task] = {}
|
||||||
|
_inflight_lock = asyncio.Lock()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Queues
|
# Queues
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -375,6 +382,44 @@ async def flush_remaining_buffers() -> None:
|
||||||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
||||||
|
|
||||||
class fetch:
|
class fetch:
|
||||||
|
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Internal function that performs the actual HTTP request to fetch available models.
|
||||||
|
This is called by available_models() after checking caches and in-flight requests.
|
||||||
|
"""
|
||||||
|
headers = None
|
||||||
|
if api_key is not None:
|
||||||
|
headers = {"Authorization": "Bearer " + api_key}
|
||||||
|
|
||||||
|
if "/v1" in endpoint:
|
||||||
|
endpoint_url = f"{endpoint}/models"
|
||||||
|
key = "data"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint}/api/tags"
|
||||||
|
key = "models"
|
||||||
|
|
||||||
|
client: aiohttp.ClientSession = app_state["session"]
|
||||||
|
try:
|
||||||
|
async with client.get(endpoint_url, headers=headers) as resp:
|
||||||
|
await _ensure_success(resp)
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
items = data.get(key, [])
|
||||||
|
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||||
|
|
||||||
|
# Update cache with lock protection
|
||||||
|
async with _models_cache_lock:
|
||||||
|
_models_cache[endpoint] = (models, time.time())
|
||||||
|
return models
|
||||||
|
except Exception as e:
|
||||||
|
# Treat any error as if the endpoint offers no models
|
||||||
|
message = _format_connection_issue(endpoint_url, e)
|
||||||
|
print(f"[fetch.available_models] {message}")
|
||||||
|
# Update error cache with lock protection
|
||||||
|
async with _error_cache_lock:
|
||||||
|
_error_cache[endpoint] = time.time()
|
||||||
|
return set()
|
||||||
|
|
||||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
Query <endpoint>/api/tags and return a set of all model names that the
|
Query <endpoint>/api/tags and return a set of all model names that the
|
||||||
|
|
@ -382,13 +427,12 @@ class fetch:
|
||||||
every model that is installed on the Ollama instance, regardless of
|
every model that is installed on the Ollama instance, regardless of
|
||||||
whether the model is currently loaded into memory.
|
whether the model is currently loaded into memory.
|
||||||
|
|
||||||
|
Uses request coalescing to prevent cache stampede: if multiple requests
|
||||||
|
arrive when cache is expired, only one actual HTTP request is made.
|
||||||
|
|
||||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||||
set is returned.
|
set is returned.
|
||||||
"""
|
"""
|
||||||
headers = None
|
|
||||||
if api_key is not None:
|
|
||||||
headers = {"Authorization": "Bearer " + api_key}
|
|
||||||
|
|
||||||
# Check models cache with lock protection
|
# Check models cache with lock protection
|
||||||
async with _models_cache_lock:
|
async with _models_cache_lock:
|
||||||
if endpoint in _models_cache:
|
if endpoint in _models_cache:
|
||||||
|
|
@ -407,65 +451,32 @@ class fetch:
|
||||||
# Error expired – remove it
|
# Error expired – remove it
|
||||||
del _error_cache[endpoint]
|
del _error_cache[endpoint]
|
||||||
|
|
||||||
if "/v1" in endpoint:
|
# Request coalescing: check if another request is already fetching this endpoint
|
||||||
endpoint_url = f"{endpoint}/models"
|
async with _inflight_lock:
|
||||||
key = "data"
|
if endpoint in _inflight_available_models:
|
||||||
else:
|
# Another request is already fetching - wait for it
|
||||||
endpoint_url = f"{endpoint}/api/tags"
|
task = _inflight_available_models[endpoint]
|
||||||
key = "models"
|
else:
|
||||||
client: aiohttp.ClientSession = app_state["session"]
|
# Create new fetch task
|
||||||
|
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||||
|
_inflight_available_models[endpoint] = task
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with client.get(endpoint_url, headers=headers) as resp:
|
# Wait for the fetch to complete (either ours or another request's)
|
||||||
await _ensure_success(resp)
|
result = await task
|
||||||
data = await resp.json()
|
return result
|
||||||
|
finally:
|
||||||
items = data.get(key, [])
|
# Clean up in-flight tracking (only if we created it)
|
||||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
async with _inflight_lock:
|
||||||
|
if _inflight_available_models.get(endpoint) == task:
|
||||||
# Update cache with lock protection
|
_inflight_available_models.pop(endpoint, None)
|
||||||
async with _models_cache_lock:
|
|
||||||
if models:
|
|
||||||
_models_cache[endpoint] = (models, time.time())
|
|
||||||
else:
|
|
||||||
# Empty list – treat as "no models", but still cache for 300s
|
|
||||||
_models_cache[endpoint] = (models, time.time())
|
|
||||||
return models
|
|
||||||
except Exception as e:
|
|
||||||
# Treat any error as if the endpoint offers no models
|
|
||||||
message = _format_connection_issue(endpoint_url, e)
|
|
||||||
print(f"[fetch.available_models] {message}")
|
|
||||||
# Update error cache with lock protection
|
|
||||||
async with _error_cache_lock:
|
|
||||||
_error_cache[endpoint] = time.time()
|
|
||||||
return set()
|
|
||||||
|
|
||||||
|
|
||||||
async def loaded_models(endpoint: str) -> Set[str]:
|
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
Internal function that performs the actual HTTP request to fetch loaded models.
|
||||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
This is called by loaded_models() after checking caches and in-flight requests.
|
||||||
set is returned.
|
|
||||||
"""
|
"""
|
||||||
if is_ext_openai_endpoint(endpoint):
|
|
||||||
return set()
|
|
||||||
|
|
||||||
# Check loaded models cache with lock protection
|
|
||||||
async with _loaded_models_cache_lock:
|
|
||||||
if endpoint in _loaded_models_cache:
|
|
||||||
models, cached_at = _loaded_models_cache[endpoint]
|
|
||||||
if _is_fresh(cached_at, 30):
|
|
||||||
return models
|
|
||||||
# Stale entry - remove it
|
|
||||||
del _loaded_models_cache[endpoint]
|
|
||||||
|
|
||||||
# Check error cache with lock protection
|
|
||||||
async with _error_cache_lock:
|
|
||||||
if endpoint in _error_cache:
|
|
||||||
if _is_fresh(_error_cache[endpoint], 10):
|
|
||||||
return set()
|
|
||||||
# Error expired - remove it
|
|
||||||
del _error_cache[endpoint]
|
|
||||||
|
|
||||||
client: aiohttp.ClientSession = app_state["session"]
|
client: aiohttp.ClientSession = app_state["session"]
|
||||||
try:
|
try:
|
||||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||||
|
|
@ -485,6 +496,75 @@ class fetch:
|
||||||
print(f"[fetch.loaded_models] {message}")
|
print(f"[fetch.loaded_models] {message}")
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
async def _refresh_loaded_models(endpoint: str) -> None:
|
||||||
|
"""
|
||||||
|
Background task to refresh loaded models cache without blocking the caller.
|
||||||
|
Used for stale-while-revalidate pattern.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await fetch._fetch_loaded_models_internal(endpoint)
|
||||||
|
except Exception as e:
|
||||||
|
# Silently fail - cache will remain stale but functional
|
||||||
|
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||||||
|
|
||||||
|
async def loaded_models(endpoint: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||||
|
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||||
|
set is returned.
|
||||||
|
|
||||||
|
Uses request coalescing to prevent cache stampede and stale-while-revalidate
|
||||||
|
to serve requests immediately even when cache is stale (refreshing in background).
|
||||||
|
"""
|
||||||
|
if is_ext_openai_endpoint(endpoint):
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# Check loaded models cache with lock protection
|
||||||
|
async with _loaded_models_cache_lock:
|
||||||
|
if endpoint in _loaded_models_cache:
|
||||||
|
models, cached_at = _loaded_models_cache[endpoint]
|
||||||
|
|
||||||
|
# FRESH: < 10s old - return immediately
|
||||||
|
if _is_fresh(cached_at, 10):
|
||||||
|
return models
|
||||||
|
|
||||||
|
# STALE: 10-60s old - return stale data and refresh in background
|
||||||
|
if _is_fresh(cached_at, 60):
|
||||||
|
# Kick off background refresh (fire-and-forget)
|
||||||
|
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
|
||||||
|
return models # Return stale data immediately
|
||||||
|
|
||||||
|
# EXPIRED: > 60s old - too stale, must refresh synchronously
|
||||||
|
del _loaded_models_cache[endpoint]
|
||||||
|
|
||||||
|
# Check error cache with lock protection
|
||||||
|
async with _error_cache_lock:
|
||||||
|
if endpoint in _error_cache:
|
||||||
|
if _is_fresh(_error_cache[endpoint], 10):
|
||||||
|
return set()
|
||||||
|
# Error expired - remove it
|
||||||
|
del _error_cache[endpoint]
|
||||||
|
|
||||||
|
# Request coalescing: check if another request is already fetching this endpoint
|
||||||
|
async with _inflight_lock:
|
||||||
|
if endpoint in _inflight_loaded_models:
|
||||||
|
# Another request is already fetching - wait for it
|
||||||
|
task = _inflight_loaded_models[endpoint]
|
||||||
|
else:
|
||||||
|
# Create new fetch task
|
||||||
|
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||||
|
_inflight_loaded_models[endpoint] = task
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for the fetch to complete (either ours or another request's)
|
||||||
|
result = await task
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
# Clean up in-flight tracking (only if we created it)
|
||||||
|
async with _inflight_lock:
|
||||||
|
if _inflight_loaded_models.get(endpoint) == task:
|
||||||
|
_inflight_loaded_models.pop(endpoint, None)
|
||||||
|
|
||||||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||||
|
|
@ -868,12 +948,12 @@ class rechunk:
|
||||||
# SSE Helpser
|
# SSE Helpser
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
async def publish_snapshot():
|
async def publish_snapshot():
|
||||||
# Take a consistent snapshot while holding the lock
|
# NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller
|
||||||
async with usage_lock:
|
# Create a snapshot without acquiring the lock (caller must hold it)
|
||||||
snapshot = orjson.dumps({
|
snapshot = orjson.dumps({
|
||||||
"usage_counts": dict(usage_counts), # Create a copy
|
"usage_counts": dict(usage_counts), # Create a copy
|
||||||
"token_usage_counts": dict(token_usage_counts)
|
"token_usage_counts": dict(token_usage_counts)
|
||||||
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||||
|
|
||||||
# Distribute the snapshot (no lock needed here since we have a copy)
|
# Distribute the snapshot (no lock needed here since we have a copy)
|
||||||
async with _subscribers_lock:
|
async with _subscribers_lock:
|
||||||
|
|
@ -991,6 +1071,12 @@ async def choose_endpoint(model: str) -> str:
|
||||||
loaded_and_free.sort(
|
loaded_and_free.sort(
|
||||||
key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order
|
key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If all endpoints have zero usage for this model, randomize to distribute
|
||||||
|
# different models across different endpoints for better resource utilization
|
||||||
|
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free):
|
||||||
|
return random.choice(loaded_and_free)
|
||||||
|
|
||||||
return loaded_and_free[0]
|
return loaded_and_free[0]
|
||||||
|
|
||||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||||
|
|
@ -1010,6 +1096,12 @@ async def choose_endpoint(model: str) -> str:
|
||||||
sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints)
|
sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If all endpoints have zero usage for this specific model, randomize to distribute
|
||||||
|
# different models across different endpoints for better resource utilization
|
||||||
|
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot):
|
||||||
|
return random.choice(endpoints_with_free_slot)
|
||||||
|
|
||||||
return endpoints_with_free_slot[0]
|
return endpoints_with_free_slot[0]
|
||||||
|
|
||||||
# 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue)
|
# 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue)
|
||||||
|
|
@ -1974,7 +2066,16 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
async def stream_ochat_response():
|
async def stream_ochat_response():
|
||||||
try:
|
try:
|
||||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||||
async_gen = await oclient.chat.completions.create(**params)
|
try:
|
||||||
|
async_gen = await oclient.chat.completions.create(**params)
|
||||||
|
except openai.BadRequestError as e:
|
||||||
|
# If tools are not supported by the model, retry without tools
|
||||||
|
if "does not support tools" in str(e):
|
||||||
|
print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
|
||||||
|
params_without_tools = {k: v for k, v in params.items() if k != "tools"}
|
||||||
|
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
if stream == True:
|
if stream == True:
|
||||||
async for chunk in async_gen:
|
async for chunk in async_gen:
|
||||||
data = (
|
data = (
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue