diff --git a/router.py b/router.py index 8e5c2ed..5131aa5 100644 --- a/router.py +++ b/router.py @@ -53,6 +53,9 @@ _loaded_error_cache_lock = asyncio.Lock() _inflight_available_models: dict[str, asyncio.Task] = {} _inflight_loaded_models: dict[str, asyncio.Task] = {} _inflight_lock = asyncio.Lock() +_bg_refresh_available: dict[str, asyncio.Task] = {} +_bg_refresh_loaded: dict[str, asyncio.Task] = {} +_bg_refresh_lock = asyncio.Lock() # ------------------------------------------------------------------ # Queues @@ -605,12 +608,23 @@ class fetch: """ Background task to refresh available models cache without blocking the caller. Used for stale-while-revalidate pattern. + Deduplicates: only one background refresh runs per endpoint at a time. """ + async with _bg_refresh_lock: + if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done(): + return # A refresh is already running for this endpoint + task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) + _bg_refresh_available[endpoint] = task + try: - await fetch._fetch_available_models_internal(endpoint, api_key) + await task except Exception as e: # Silently fail - cache will remain stale but functional print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}") + finally: + async with _bg_refresh_lock: + if _bg_refresh_available.get(endpoint) is task: + _bg_refresh_available.pop(endpoint, None) async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ @@ -634,7 +648,7 @@ class fetch: if endpoint in _models_cache: models, cached_at = _models_cache[endpoint] - # FRESH: < 300s old - return immediately + # FRESH: <= 300s old - return immediately if _is_fresh(cached_at, 300): return models @@ -649,7 +663,7 @@ class fetch: # Check error cache with lock protection async with _available_error_cache_lock: if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 30): + if _is_fresh(_available_error_cache[endpoint], 300): # Still within the short error TTL – pretend nothing is available return set() # Error expired – remove it @@ -735,12 +749,23 @@ class fetch: """ Background task to refresh loaded models cache without blocking the caller. Used for stale-while-revalidate pattern. + Deduplicates: only one background refresh runs per endpoint at a time. """ + async with _bg_refresh_lock: + if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done(): + return # A refresh is already running for this endpoint + task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) + _bg_refresh_loaded[endpoint] = task + try: - await fetch._fetch_loaded_models_internal(endpoint) + await task except Exception as e: # Silently fail - cache will remain stale but functional print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}") + finally: + async with _bg_refresh_lock: + if _bg_refresh_loaded.get(endpoint) is task: + _bg_refresh_loaded.pop(endpoint, None) async def loaded_models(endpoint: str) -> Set[str]: """ @@ -760,7 +785,7 @@ class fetch: models, cached_at = _loaded_models_cache[endpoint] # FRESH: < 10s old - return immediately - if _is_fresh(cached_at, 30): + if _is_fresh(cached_at, 300): return models # STALE: 10-60s old - return stale data and refresh in background @@ -775,7 +800,7 @@ class fetch: # Check error cache with lock protection async with _loaded_error_cache_lock: if endpoint in _loaded_error_cache: - if _is_fresh(_loaded_error_cache[endpoint], 30): + if _is_fresh(_loaded_error_cache[endpoint], 300): return set() # Error expired - remove it del _loaded_error_cache[endpoint] @@ -813,7 +838,7 @@ class fetch: if not skip_error_cache: async with _available_error_cache_lock: if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 30): + if _is_fresh(_available_error_cache[endpoint], 300): return [] client: aiohttp.ClientSession = app_state["session"] @@ -923,11 +948,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No async for chunk in response: chunks.append(chunk) _accumulate_openai_tc_delta(chunk, tc_acc) + prompt_tok = 0 + comp_tok = 0 if chunk.usage is not None: prompt_tok = chunk.usage.prompt_tokens or 0 comp_tok = chunk.usage.completion_tokens or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) # Convert to Ollama format if chunks: response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts) @@ -935,8 +966,15 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No if tc_acc and response.message: response.message.tool_calls = _build_ollama_tool_calls(tc_acc) else: - prompt_tok = response.usage.prompt_tokens or 0 - comp_tok = response.usage.completion_tokens or 0 + prompt_tok = 0 + comp_tok = 0 + if response.usage is not None: + prompt_tok = response.usage.prompt_tokens or 0 + comp_tok = response.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(response) + if llama_usage: + prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, model, prompt_tok, comp_tok)) response = rechunk.openai_chat_completion2ollama(response, stream, start_ts) @@ -1317,6 +1355,34 @@ class rechunk: eval_duration=None, embeddings=[chunk.data[0].embedding]) return rechunk + + def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None: + """Extract (prompt_tokens, completion_tokens) from llama-server's timings object. + + llama-server returns a ``timings`` dict instead of the standard OpenAI + ``usage`` field:: + + "timings": { + "cache_n": 236, // prompt tokens reused from cache + "prompt_n": 1, // prompt tokens processed + "predicted_n": 35 // predicted (completion) tokens + } + + prompt_tokens = prompt_n + cache_n + completion_tokens = predicted_n + + Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no + timings are found. + """ + timings = getattr(obj, "timings", None) + if timings is None: + return None + if isinstance(timings, dict): + prompt_n = timings.get("prompt_n", 0) or 0 + cache_n = timings.get("cache_n", 0) or 0 + predicted_n = timings.get("predicted_n", 0) or 0 + return (prompt_n + cache_n, predicted_n) + return None # ------------------------------------------------------------------ # SSE Helpser @@ -2595,19 +2661,35 @@ async def openai_chat_completions_proxy(request: Request): if chunk.choices: if chunk.choices[0].delta.content is not None: yield f"data: {data}\n\n".encode("utf-8") + elif chunk.usage is not None: + # Forward the usage-only final chunk (e.g. from llama-server) + yield f"data: {data}\n\n".encode("utf-8") + prompt_tok = 0 + comp_tok = 0 if chunk.usage is not None: prompt_tok = chunk.usage.prompt_tokens or 0 comp_tok = chunk.usage.completion_tokens or 0 - if prompt_tok != 0 or comp_tok != 0: - local_model = model - if not is_ext_openai_endpoint(endpoint): - if not ":" in model: - local_model = model if ":" in model else model + ":latest" - await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + local_model = model + if not is_ext_openai_endpoint(endpoint): + if not ":" in model: + local_model = model if ":" in model else model + ":latest" + await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) yield b"data: [DONE]\n\n" else: - prompt_tok = async_gen.usage.prompt_tokens or 0 - comp_tok = async_gen.usage.completion_tokens or 0 + prompt_tok = 0 + comp_tok = 0 + if async_gen.usage is not None: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) + if llama_usage: + prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( @@ -2624,7 +2706,7 @@ async def openai_chat_completions_proxy(request: Request): # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_ochat_response(), - media_type="application/json", + media_type="text/event-stream" if stream else "application/json", ) # ------------------------------------------------------------- @@ -2712,20 +2794,36 @@ async def openai_completions_proxy(request: Request): if chunk.choices: if chunk.choices[0].finish_reason == None: yield f"data: {data}\n\n".encode("utf-8") + elif chunk.usage is not None: + # Forward the usage-only final chunk (e.g. from llama-server) + yield f"data: {data}\n\n".encode("utf-8") + prompt_tok = 0 + comp_tok = 0 if chunk.usage is not None: - prompt_tok = chunk.usage.prompt_tokens or 0 - comp_tok = chunk.usage.completion_tokens or 0 - if prompt_tok != 0 or comp_tok != 0: - local_model = model - if not is_ext_openai_endpoint(endpoint): - if not ":" in model: - local_model = model if ":" in model else model + ":latest" - await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + local_model = model + if not is_ext_openai_endpoint(endpoint): + if not ":" in model: + local_model = model if ":" in model else model + ":latest" + await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) # Final DONE event yield b"data: [DONE]\n\n" else: - prompt_tok = async_gen.usage.prompt_tokens or 0 - comp_tok = async_gen.usage.completion_tokens or 0 + prompt_tok = 0 + comp_tok = 0 + if async_gen.usage is not None: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) + if llama_usage: + prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( @@ -2742,7 +2840,7 @@ async def openai_completions_proxy(request: Request): # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_ocompletions_response(), - media_type="application/json", + media_type="text/event-stream" if stream else "application/json", ) # -------------------------------------------------------------