From e416542bf8b293fd8d5e7b31846c2ee4bb107d64 Mon Sep 17 00:00:00 2001 From: alpha-nerd-nomyo Date: Thu, 12 Mar 2026 16:08:01 +0100 Subject: [PATCH] fix: model name normalization for context_cash preemptive context-shifting for smaller context-windows with previous failure --- router.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/router.py b/router.py index f0e13ab..d11c2ad 100644 --- a/router.py +++ b/router.py @@ -142,6 +142,11 @@ def _trim_messages_for_context( break non_system.pop(0) # drop oldest non-system message + # Ensure the first non-system message is a user message (chat templates require it). + # Drop any leading assistant/tool messages that were left after trimming. + while non_system and non_system[0].get("role") != "user": + non_system.pop(0) + return system_msgs + non_system @@ -165,6 +170,14 @@ def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int: tiktoken_to_shed = int(to_shed * 1.2) return max(1, cur_tiktoken - tiktoken_to_shed) +# Per-(endpoint, model) n_ctx cache. +# Populated from two sources: +# 1. 400 exceed_context_size_error body → n_ctx field +# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens +# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT, +# so large-context models (200k+ for coding) are never touched. +_endpoint_nctx: dict[tuple[str, str], int] = {} +_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this # ------------------------------------------------------------------ # Globals @@ -1970,6 +1983,18 @@ async def chat_proxy(request: Request): async_gen = None if use_openai: start_ts = time.perf_counter() + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) + _pre_est = _count_message_tokens(params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + params = {**params, "messages": _pre_trimmed} try: async_gen = await oclient.chat.completions.create(**params) except Exception as e: @@ -1983,6 +2008,8 @@ async def chat_proxy(request: Request): if not n_ctx_limit: await decrement_usage(endpoint, tracking_model) raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit msgs_to_trim = params.get("messages", []) cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) @@ -2048,6 +2075,20 @@ async def chat_proxy(request: Request): # Accumulate and store cache on done chunk — before yield so it always runs # Works for both Ollama-native and OpenAI-compatible backends; chunks are # already converted to Ollama format by rechunk before this point. + if getattr(chunk, "done", False): + # Detect context exhaustion mid-generation for small-ctx models + _dr = getattr(chunk, "done_reason", None) + # Only cache when no max_tokens limit was set — otherwise + # finish_reason=length might just mean max_tokens was hit, + # not that the context window was exhausted. + _req_max_tok = params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") + if _dr == "length" and not _req_max_tok: + _pt = getattr(chunk, "prompt_eval_count", 0) or 0 + _ct = getattr(chunk, "eval_count", 0) or 0 + _inferred_nctx = _pt + _ct + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) if _cache is not None and not _is_moe and _cache_enabled: if chunk.message and getattr(chunk.message, "content", None): content_parts.append(chunk.message.content) @@ -2792,9 +2833,13 @@ async def ps_details_proxy(request: Request): *[_fetch_llama_props(ep, mid) for ep, mid in props_requests] ) - for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results): + for (ep, raw_id), model_dict, (n_ctx, is_sleeping) in zip(props_requests, llama_models_pending, props_results): if n_ctx is not None: model_dict["context_length"] = n_ctx + if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + normalized = _normalize_llama_model_name(raw_id) + _endpoint_nctx[(ep, normalized)] = n_ctx + print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True) if not is_sleeping: models.append(model_dict) @@ -3063,6 +3108,18 @@ async def openai_chat_completions_proxy(request: Request): if not is_ext_openai_endpoint(endpoint): resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) send_params = {**params, "messages": resolved_msgs} + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) + _pre_est = _count_message_tokens(send_params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = send_params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + send_params = {**send_params, "messages": _pre_trimmed} try: async_gen = await oclient.chat.completions.create(**send_params) except Exception as e: @@ -3088,6 +3145,8 @@ async def openai_chat_completions_proxy(request: Request): if not n_ctx_limit: await decrement_usage(endpoint, tracking_model) raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit msgs_to_trim = send_params.get("messages", []) try: @@ -3168,6 +3227,15 @@ async def openai_chat_completions_proxy(request: Request): prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + # Detect context exhaustion mid-generation for small-ctx models. + # Guard: skip if max_tokens was set in the request — finish_reason=length + # could just mean the caller's token budget was exhausted, not the context window. + _req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens") + if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok: + _inferred_nctx = (prompt_tok + comp_tok) or 0 + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) # Cache assembled streaming response — before [DONE] so it always runs if _cache is not None and _cache_enabled and content_parts: assembled = orjson.dumps({