feat: deduplicate background refresh tasks and extend cache TTL
Adds lock-protected dictionaries to track running background refresh tasks, preventing duplicate executions per endpoint. Increases cache freshness thresholds from 30s to 300s to reduce blocking behavior. fix: /v1 endpoints use correct media_types and usage information with proper logging
This commit is contained in:
parent
c9ff384bb2
commit
0bad604b02
1 changed files with 129 additions and 31 deletions
160
router.py
160
router.py
|
|
@ -53,6 +53,9 @@ _loaded_error_cache_lock = asyncio.Lock()
|
|||
_inflight_available_models: dict[str, asyncio.Task] = {}
|
||||
_inflight_loaded_models: dict[str, asyncio.Task] = {}
|
||||
_inflight_lock = asyncio.Lock()
|
||||
_bg_refresh_available: dict[str, asyncio.Task] = {}
|
||||
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
|
||||
_bg_refresh_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Queues
|
||||
|
|
@ -605,12 +608,23 @@ class fetch:
|
|||
"""
|
||||
Background task to refresh available models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_bg_refresh_available[endpoint] = task
|
||||
|
||||
try:
|
||||
await fetch._fetch_available_models_internal(endpoint, api_key)
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_available.get(endpoint) is task:
|
||||
_bg_refresh_available.pop(endpoint, None)
|
||||
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -634,7 +648,7 @@ class fetch:
|
|||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
|
||||
# FRESH: < 300s old - return immediately
|
||||
# FRESH: <= 300s old - return immediately
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
|
||||
|
|
@ -649,7 +663,7 @@ class fetch:
|
|||
# Check error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 30):
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
# Still within the short error TTL – pretend nothing is available
|
||||
return set()
|
||||
# Error expired – remove it
|
||||
|
|
@ -735,12 +749,23 @@ class fetch:
|
|||
"""
|
||||
Background task to refresh loaded models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_bg_refresh_loaded[endpoint] = task
|
||||
|
||||
try:
|
||||
await fetch._fetch_loaded_models_internal(endpoint)
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_loaded.get(endpoint) is task:
|
||||
_bg_refresh_loaded.pop(endpoint, None)
|
||||
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -760,7 +785,7 @@ class fetch:
|
|||
models, cached_at = _loaded_models_cache[endpoint]
|
||||
|
||||
# FRESH: < 10s old - return immediately
|
||||
if _is_fresh(cached_at, 30):
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
|
||||
# STALE: 10-60s old - return stale data and refresh in background
|
||||
|
|
@ -775,7 +800,7 @@ class fetch:
|
|||
# Check error cache with lock protection
|
||||
async with _loaded_error_cache_lock:
|
||||
if endpoint in _loaded_error_cache:
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 30):
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
||||
return set()
|
||||
# Error expired - remove it
|
||||
del _loaded_error_cache[endpoint]
|
||||
|
|
@ -813,7 +838,7 @@ class fetch:
|
|||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 30):
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
return []
|
||||
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
|
|
@ -923,11 +948,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
# Convert to Ollama format
|
||||
if chunks:
|
||||
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
|
||||
|
|
@ -935,8 +966,15 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
if tc_acc and response.message:
|
||||
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
else:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if response.usage is not None:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(response)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
|
||||
|
|
@ -1317,6 +1355,34 @@ class rechunk:
|
|||
eval_duration=None,
|
||||
embeddings=[chunk.data[0].embedding])
|
||||
return rechunk
|
||||
|
||||
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
|
||||
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
|
||||
|
||||
llama-server returns a ``timings`` dict instead of the standard OpenAI
|
||||
``usage`` field::
|
||||
|
||||
"timings": {
|
||||
"cache_n": 236, // prompt tokens reused from cache
|
||||
"prompt_n": 1, // prompt tokens processed
|
||||
"predicted_n": 35 // predicted (completion) tokens
|
||||
}
|
||||
|
||||
prompt_tokens = prompt_n + cache_n
|
||||
completion_tokens = predicted_n
|
||||
|
||||
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
|
||||
timings are found.
|
||||
"""
|
||||
timings = getattr(obj, "timings", None)
|
||||
if timings is None:
|
||||
return None
|
||||
if isinstance(timings, dict):
|
||||
prompt_n = timings.get("prompt_n", 0) or 0
|
||||
cache_n = timings.get("cache_n", 0) or 0
|
||||
predicted_n = timings.get("predicted_n", 0) or 0
|
||||
return (prompt_n + cache_n, predicted_n)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
|
|
@ -2595,19 +2661,35 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
if chunk.choices:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
local_model = model
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
local_model = model
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
|
|
@ -2624,7 +2706,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ochat_response(),
|
||||
media_type="application/json",
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -2712,20 +2794,36 @@ async def openai_completions_proxy(request: Request):
|
|||
if chunk.choices:
|
||||
if chunk.choices[0].finish_reason == None:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
local_model = model
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
local_model = model
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
|
|
@ -2742,7 +2840,7 @@ async def openai_completions_proxy(request: Request):
|
|||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ocompletions_response(),
|
||||
media_type="application/json",
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue