fix: model name normalization for context_cash preemptive context-shifting for smaller context-windows with previous failure

This commit is contained in:
Alpha Nerd 2026-03-12 16:08:01 +01:00
parent be60a348e1
commit e416542bf8

View file

@ -142,6 +142,11 @@ def _trim_messages_for_context(
break
non_system.pop(0) # drop oldest non-system message
# Ensure the first non-system message is a user message (chat templates require it).
# Drop any leading assistant/tool messages that were left after trimming.
while non_system and non_system[0].get("role") != "user":
non_system.pop(0)
return system_msgs + non_system
@ -165,6 +170,14 @@ def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
tiktoken_to_shed = int(to_shed * 1.2)
return max(1, cur_tiktoken - tiktoken_to_shed)
# Per-(endpoint, model) n_ctx cache.
# Populated from two sources:
# 1. 400 exceed_context_size_error body → n_ctx field
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
# so large-context models (200k+ for coding) are never touched.
_endpoint_nctx: dict[tuple[str, str], int] = {}
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
# ------------------------------------------------------------------
# Globals
@ -1970,6 +1983,18 @@ async def chat_proxy(request: Request):
async_gen = None
if use_openai:
start_ts = time.perf_counter()
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
_pre_est = _count_message_tokens(params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
params = {**params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**params)
except Exception as e:
@ -1983,6 +2008,8 @@ async def chat_proxy(request: Request):
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
@ -2048,6 +2075,20 @@ async def chat_proxy(request: Request):
# Accumulate and store cache on done chunk — before yield so it always runs
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
# already converted to Ollama format by rechunk before this point.
if getattr(chunk, "done", False):
# Detect context exhaustion mid-generation for small-ctx models
_dr = getattr(chunk, "done_reason", None)
# Only cache when no max_tokens limit was set — otherwise
# finish_reason=length might just mean max_tokens was hit,
# not that the context window was exhausted.
_req_max_tok = params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
if _dr == "length" and not _req_max_tok:
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
_ct = getattr(chunk, "eval_count", 0) or 0
_inferred_nctx = _pt + _ct
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
if _cache is not None and not _is_moe and _cache_enabled:
if chunk.message and getattr(chunk.message, "content", None):
content_parts.append(chunk.message.content)
@ -2792,9 +2833,13 @@ async def ps_details_proxy(request: Request):
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
)
for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results):
for (ep, raw_id), model_dict, (n_ctx, is_sleeping) in zip(props_requests, llama_models_pending, props_results):
if n_ctx is not None:
model_dict["context_length"] = n_ctx
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
normalized = _normalize_llama_model_name(raw_id)
_endpoint_nctx[(ep, normalized)] = n_ctx
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
if not is_sleeping:
models.append(model_dict)
@ -3063,6 +3108,18 @@ async def openai_chat_completions_proxy(request: Request):
if not is_ext_openai_endpoint(endpoint):
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
send_params = {**params, "messages": resolved_msgs}
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
_pre_est = _count_message_tokens(send_params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = send_params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
send_params = {**send_params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
@ -3088,6 +3145,8 @@ async def openai_chat_completions_proxy(request: Request):
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
@ -3168,6 +3227,15 @@ async def openai_chat_completions_proxy(request: Request):
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Detect context exhaustion mid-generation for small-ctx models.
# Guard: skip if max_tokens was set in the request — finish_reason=length
# could just mean the caller's token budget was exhausted, not the context window.
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
_inferred_nctx = (prompt_tok + comp_tok) or 0
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and content_parts:
assembled = orjson.dumps({