diff --git a/router.py b/router.py index 8e5c2ed..1ef096a 100644 --- a/router.py +++ b/router.py @@ -53,6 +53,9 @@ _loaded_error_cache_lock = asyncio.Lock() _inflight_available_models: dict[str, asyncio.Task] = {} _inflight_loaded_models: dict[str, asyncio.Task] = {} _inflight_lock = asyncio.Lock() +_bg_refresh_available: dict[str, asyncio.Task] = {} +_bg_refresh_loaded: dict[str, asyncio.Task] = {} +_bg_refresh_lock = asyncio.Lock() # ------------------------------------------------------------------ # Queues @@ -414,6 +417,30 @@ def is_openai_compatible(endpoint: str) -> bool: """ return "/v1" in endpoint or endpoint in config.llama_server_endpoints +def get_tracking_model(endpoint: str, model: str) -> str: + """ + Normalize model name for tracking purposes so it matches the PS table key. + + - For llama-server endpoints: strips HF prefix and quantization suffix + - For Ollama endpoints: appends ":latest" if no version suffix is present + - For external OpenAI endpoints: returns as-is (not shown in PS) + + This ensures consistent model naming across all routes for usage tracking. + """ + # External OpenAI endpoints are not shown in PS, keep as-is + if is_ext_openai_endpoint(endpoint): + return model + + # llama-server endpoints use normalized names in PS + if endpoint in config.llama_server_endpoints: + return _normalize_llama_model_name(model) + + # Ollama endpoints: append ":latest" if no version suffix + if ":" not in model: + return model + ":latest" + + return model + async def token_worker() -> None: try: while True: @@ -605,12 +632,23 @@ class fetch: """ Background task to refresh available models cache without blocking the caller. Used for stale-while-revalidate pattern. + Deduplicates: only one background refresh runs per endpoint at a time. """ + async with _bg_refresh_lock: + if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done(): + return # A refresh is already running for this endpoint + task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) + _bg_refresh_available[endpoint] = task + try: - await fetch._fetch_available_models_internal(endpoint, api_key) + await task except Exception as e: # Silently fail - cache will remain stale but functional print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}") + finally: + async with _bg_refresh_lock: + if _bg_refresh_available.get(endpoint) is task: + _bg_refresh_available.pop(endpoint, None) async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: """ @@ -634,7 +672,7 @@ class fetch: if endpoint in _models_cache: models, cached_at = _models_cache[endpoint] - # FRESH: < 300s old - return immediately + # FRESH: <= 300s old - return immediately if _is_fresh(cached_at, 300): return models @@ -649,7 +687,7 @@ class fetch: # Check error cache with lock protection async with _available_error_cache_lock: if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 30): + if _is_fresh(_available_error_cache[endpoint], 300): # Still within the short error TTL – pretend nothing is available return set() # Error expired – remove it @@ -735,12 +773,23 @@ class fetch: """ Background task to refresh loaded models cache without blocking the caller. Used for stale-while-revalidate pattern. + Deduplicates: only one background refresh runs per endpoint at a time. """ + async with _bg_refresh_lock: + if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done(): + return # A refresh is already running for this endpoint + task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) + _bg_refresh_loaded[endpoint] = task + try: - await fetch._fetch_loaded_models_internal(endpoint) + await task except Exception as e: # Silently fail - cache will remain stale but functional print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}") + finally: + async with _bg_refresh_lock: + if _bg_refresh_loaded.get(endpoint) is task: + _bg_refresh_loaded.pop(endpoint, None) async def loaded_models(endpoint: str) -> Set[str]: """ @@ -760,7 +809,7 @@ class fetch: models, cached_at = _loaded_models_cache[endpoint] # FRESH: < 10s old - return immediately - if _is_fresh(cached_at, 30): + if _is_fresh(cached_at, 10): return models # STALE: 10-60s old - return stale data and refresh in background @@ -775,7 +824,7 @@ class fetch: # Check error cache with lock protection async with _loaded_error_cache_lock: if endpoint in _loaded_error_cache: - if _is_fresh(_loaded_error_cache[endpoint], 30): + if _is_fresh(_loaded_error_cache[endpoint], 300): return set() # Error expired - remove it del _loaded_error_cache[endpoint] @@ -813,7 +862,7 @@ class fetch: if not skip_error_cache: async with _available_error_cache_lock: if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 30): + if _is_fresh(_available_error_cache[endpoint], 300): return [] client: aiohttp.ClientSession = app_state["session"] @@ -910,7 +959,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) + await increment_usage(endpoint, tracking_model) try: if use_openai: @@ -923,11 +974,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No async for chunk in response: chunks.append(chunk) _accumulate_openai_tc_delta(chunk, tc_acc) + prompt_tok = 0 + comp_tok = 0 if chunk.usage is not None: prompt_tok = chunk.usage.prompt_tokens or 0 comp_tok = chunk.usage.completion_tokens or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) # Convert to Ollama format if chunks: response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts) @@ -935,10 +992,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No if tc_acc and response.message: response.message.tool_calls = _build_ollama_tool_calls(tc_acc) else: - prompt_tok = response.usage.prompt_tokens or 0 - comp_tok = response.usage.completion_tokens or 0 + prompt_tok = 0 + comp_tok = 0 + if response.usage is not None: + prompt_tok = response.usage.prompt_tokens or 0 + comp_tok = response.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(response) + if llama_usage: + prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) response = rechunk.openai_chat_completion2ollama(response, stream, start_ts) else: response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive) @@ -950,18 +1014,18 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if chunks: response = chunks[-1] else: prompt_tok = response.prompt_eval_count or 0 comp_tok = response.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) return response finally: - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) def get_last_user_content(messages): """ @@ -1317,6 +1381,34 @@ class rechunk: eval_duration=None, embeddings=[chunk.data[0].embedding]) return rechunk + + def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None: + """Extract (prompt_tokens, completion_tokens) from llama-server's timings object. + + llama-server returns a ``timings`` dict instead of the standard OpenAI + ``usage`` field:: + + "timings": { + "cache_n": 236, // prompt tokens reused from cache + "prompt_n": 1, // prompt tokens processed + "predicted_n": 35 // predicted (completion) tokens + } + + prompt_tokens = prompt_n + cache_n + completion_tokens = predicted_n + + Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no + timings are found. + """ + timings = getattr(obj, "timings", None) + if timings is None: + return None + if isinstance(timings, dict): + prompt_n = timings.get("prompt_n", 0) or 0 + cache_n = timings.get("cache_n", 0) or 0 + predicted_n = timings.get("predicted_n", 0) or 0 + return (prompt_n + cache_n, predicted_n) + return None # ------------------------------------------------------------------ # SSE Helpser @@ -1433,27 +1525,26 @@ async def choose_endpoint(model: str) -> str: # Protect all reads of usage_counts with the lock async with usage_lock: - # Helper: get current usage count for (endpoint, model) - def current_usage(ep: str) -> int: - return usage_counts.get(ep, {}).get(model, 0) + # Helper: current usage for (endpoint, model) using the same normalized key + # that increment_usage/decrement_usage store — raw model names differ from + # tracking names for llama-server (HF prefix / quant suffix stripped). + def tracking_usage(ep: str) -> int: + return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0) # 3️⃣ Endpoints that have the model loaded *and* a free slot loaded_and_free = [ ep for ep, models in zip(candidate_endpoints, loaded_sets) - if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections + if model in models and tracking_usage(ep) < config.max_concurrent_connections ] if loaded_and_free: - # Sort by per-model usage in DESCENDING order to ensure model affinity - # Endpoints with higher usage (already handling this model) should be preferred - # until they reach max_concurrent_connections - loaded_and_free.sort( - key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order - ) + # Sort ascending for load balancing — all endpoints here already have the + # model loaded, so there is no model-switching cost to optimise for. + loaded_and_free.sort(key=tracking_usage) - # If all endpoints have zero usage for this model, randomize to distribute - # different models across different endpoints for better resource utilization - if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free): + # When all candidates are equally idle, randomise to avoid always picking + # the first entry in a stable sort. + if all(tracking_usage(ep) == 0 for ep in loaded_and_free): return random.choice(loaded_and_free) return loaded_and_free[0] @@ -1461,30 +1552,22 @@ async def choose_endpoint(model: str) -> str: # 4️⃣ Endpoints among the candidates that simply have a free slot endpoints_with_free_slot = [ ep for ep in candidate_endpoints - if usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections + if tracking_usage(ep) < config.max_concurrent_connections ] if endpoints_with_free_slot: - # Sort by per-model usage (descending) first to ensure model affinity - # Even if the model isn't showing as "loaded" in /api/ps yet (e.g., during initial loading), - # we want to send subsequent requests to the endpoint that already has connections for this model - # Then by total endpoint usage (ascending) to balance idle endpoints + # Sort by total endpoint load (ascending) to prefer idle endpoints. endpoints_with_free_slot.sort( - key=lambda ep: ( - #-usage_counts.get(ep, {}).get(model, 0), # Primary: per-model usage (descending - prefer endpoints with connections) - sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints) - ) + key=lambda ep: sum(usage_counts.get(ep, {}).values()) ) - # If all endpoints have zero usage for this specific model, randomize to distribute - # different models across different endpoints for better resource utilization - if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot): + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): return random.choice(endpoints_with_free_slot) return endpoints_with_free_slot[0] - # 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue) - ep = min(candidate_endpoints, key=current_usage) + # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) + ep = min(candidate_endpoints, key=tracking_usage) return ep # ------------------------------------------------------------- @@ -1528,6 +1611,8 @@ async def proxy(request: Request): endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1552,7 +1637,7 @@ async def proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): @@ -1569,7 +1654,7 @@ async def proxy(request: Request): prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -1584,7 +1669,7 @@ async def proxy(request: Request): prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1594,7 +1679,7 @@ async def proxy(request: Request): finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return StreamingResponse( @@ -1649,6 +1734,8 @@ async def chat_proxy(request: Request): opt = False endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1679,7 +1766,7 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: @@ -1706,7 +1793,7 @@ async def chat_proxy(request: Request): prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -1721,7 +1808,7 @@ async def chat_proxy(request: Request): prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1731,7 +1818,7 @@ async def chat_proxy(request: Request): finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator media_type = "application/x-ndjson" if stream else "application/json" @@ -1773,6 +1860,8 @@ async def embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1780,7 +1869,7 @@ async def embedding_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: @@ -1797,7 +1886,7 @@ async def embedding_proxy(request: Request): yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return StreamingResponse( @@ -1839,6 +1928,8 @@ async def embed_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1846,7 +1937,7 @@ async def embed_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: @@ -1863,7 +1954,7 @@ async def embed_proxy(request: Request): yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( @@ -2226,7 +2317,13 @@ async def version_proxy(request: Request): """ # 1. Query all endpoints for version tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] - all_versions = await asyncio.gather(*tasks) + all_versions_raw = await asyncio.gather(*tasks) + + # Filter out non-string values (e.g., empty lists from failed/timeout responses) + all_versions = [v for v in all_versions_raw if isinstance(v, str) and v] + + if not all_versions: + raise HTTPException(status_code=503, detail="No valid version response from any endpoint") def version_key(v): return tuple(map(int, v.split('.'))) @@ -2364,6 +2461,10 @@ async def ps_details_proxy(request: Request): # Add llama-server models with endpoint info and full status metadata (if any) if llama_loaded: + # Collect (endpoint, raw_id) pairs to fetch /props in parallel + props_requests: list[tuple[str, str]] = [] + llama_models_pending: list[dict] = [] + for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded): # Filter for loaded models only loaded_models = [item for item in modellist if _is_llama_model_loaded(item)] @@ -2388,7 +2489,53 @@ async def ps_details_proxy(request: Request): if isinstance(status_info, dict): model_with_endpoint["llama_status_args"] = status_info.get("args") model_with_endpoint["llama_status_preset"] = status_info.get("preset") - models.append(model_with_endpoint) + llama_models_pending.append(model_with_endpoint) + props_requests.append((endpoint, raw_id)) + + # Fetch /props for each llama-server model to get context length (n_ctx) + # and unload sleeping models automatically + async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool]: + client: aiohttp.ClientSession = app_state["session"] + base_url = endpoint.rstrip("/").removesuffix("/v1") + props_url = f"{base_url}/props?model={model_id}" + headers = None + api_key = config.api_keys.get(endpoint) + if api_key: + headers = {"Authorization": f"Bearer {api_key}"} + try: + async with client.get(props_url, headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + dgs = data.get("default_generation_settings", {}) + n_ctx = dgs.get("n_ctx") + is_sleeping = data.get("is_sleeping", False) + + if is_sleeping: + unload_url = f"{base_url}/models/unload" + try: + async with client.post( + unload_url, + json={"model": model_id}, + headers=headers, + ) as unload_resp: + print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}") + except Exception as ue: + print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") + + return n_ctx, is_sleeping + except Exception as e: + print(f"[ps_details] Failed to fetch props from {props_url}: {e}") + return None, False + + props_results = await asyncio.gather( + *[_fetch_llama_props(ep, mid) for ep, mid in props_requests] + ) + + for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results): + if n_ctx is not None: + model_dict["context_length"] = n_ctx + if not is_sleeping: + models.append(model_dict) return JSONResponse(content={"models": models}, status_code=200) @@ -2479,7 +2626,9 @@ async def openai_embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) + await increment_usage(endpoint, tracking_model) if is_openai_compatible(endpoint): api_key = config.api_keys.get(endpoint, "no-key") else: @@ -2490,8 +2639,8 @@ async def openai_embedding_proxy(request: Request): # 3. Async generator that streams embedding data and decrements the counter async_gen = await oclient.embeddings.create(input=doc, model=model) - - await decrement_usage(endpoint, model) + + await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return async_gen @@ -2568,7 +2717,9 @@ async def openai_chat_completions_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) + await increment_usage(endpoint, tracking_model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter @@ -2593,23 +2744,42 @@ async def openai_chat_completions_proxy(request: Request): else orjson.dumps(chunk) ) if chunk.choices: - if chunk.choices[0].delta.content is not None: + delta = chunk.choices[0].delta + has_content = delta.content is not None + has_reasoning = ( + getattr(delta, "reasoning_content", None) is not None + or getattr(delta, "reasoning", None) is not None + ) + has_tool_calls = getattr(delta, "tool_calls", None) is not None + if has_content or has_reasoning or has_tool_calls: yield f"data: {data}\n\n".encode("utf-8") + elif chunk.usage is not None: + # Forward the usage-only final chunk (e.g. from llama-server) + yield f"data: {data}\n\n".encode("utf-8") + prompt_tok = 0 + comp_tok = 0 if chunk.usage is not None: prompt_tok = chunk.usage.prompt_tokens or 0 comp_tok = chunk.usage.completion_tokens or 0 - if prompt_tok != 0 or comp_tok != 0: - local_model = model - if not is_ext_openai_endpoint(endpoint): - if not ":" in model: - local_model = model if ":" in model else model + ":latest" - await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) yield b"data: [DONE]\n\n" else: - prompt_tok = async_gen.usage.prompt_tokens or 0 - comp_tok = async_gen.usage.completion_tokens or 0 + prompt_tok = 0 + comp_tok = 0 + if async_gen.usage is not None: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) + if llama_usage: + prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") @@ -2619,12 +2789,12 @@ async def openai_chat_completions_proxy(request: Request): finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_ochat_response(), - media_type="application/json", + media_type="text/event-stream" if stream else "application/json", ) # ------------------------------------------------------------- @@ -2693,7 +2863,9 @@ async def openai_completions_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) + await increment_usage(endpoint, tracking_model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) @@ -2710,24 +2882,42 @@ async def openai_completions_proxy(request: Request): else orjson.dumps(chunk) ) if chunk.choices: - if chunk.choices[0].finish_reason == None: + choice = chunk.choices[0] + has_text = getattr(choice, "text", None) is not None + has_reasoning = ( + getattr(choice, "reasoning_content", None) is not None + or getattr(choice, "reasoning", None) is not None + ) + if has_text or has_reasoning or choice.finish_reason is not None: yield f"data: {data}\n\n".encode("utf-8") + elif chunk.usage is not None: + # Forward the usage-only final chunk (e.g. from llama-server) + yield f"data: {data}\n\n".encode("utf-8") + prompt_tok = 0 + comp_tok = 0 if chunk.usage is not None: - prompt_tok = chunk.usage.prompt_tokens or 0 - comp_tok = chunk.usage.completion_tokens or 0 - if prompt_tok != 0 or comp_tok != 0: - local_model = model - if not is_ext_openai_endpoint(endpoint): - if not ":" in model: - local_model = model if ":" in model else model + ":latest" - await token_queue.put((endpoint, local_model, prompt_tok, comp_tok)) + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) # Final DONE event yield b"data: [DONE]\n\n" else: - prompt_tok = async_gen.usage.prompt_tokens or 0 - comp_tok = async_gen.usage.completion_tokens or 0 + prompt_tok = 0 + comp_tok = 0 + if async_gen.usage is not None: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) + if llama_usage: + prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") @@ -2737,12 +2927,12 @@ async def openai_completions_proxy(request: Request): finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_ocompletions_response(), - media_type="application/json", + media_type="text/event-stream" if stream else "application/json", ) # ------------------------------------------------------------- diff --git a/static/index.html b/static/index.html index 3d4d364..fe14ef5 100644 --- a/static/index.html +++ b/static/index.html @@ -928,7 +928,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { const uniqueEndpoints = Array.from(new Set(endpoints)); const endpointsData = encodeURIComponent(JSON.stringify(uniqueEndpoints)); return `