fix: model name normalization for context_cash preemptive context-shifting for smaller context-windows with previous failure
This commit is contained in:
parent
be60a348e1
commit
e416542bf8
1 changed files with 69 additions and 1 deletions
70
router.py
70
router.py
|
|
@ -142,6 +142,11 @@ def _trim_messages_for_context(
|
||||||
break
|
break
|
||||||
non_system.pop(0) # drop oldest non-system message
|
non_system.pop(0) # drop oldest non-system message
|
||||||
|
|
||||||
|
# Ensure the first non-system message is a user message (chat templates require it).
|
||||||
|
# Drop any leading assistant/tool messages that were left after trimming.
|
||||||
|
while non_system and non_system[0].get("role") != "user":
|
||||||
|
non_system.pop(0)
|
||||||
|
|
||||||
return system_msgs + non_system
|
return system_msgs + non_system
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -165,6 +170,14 @@ def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
||||||
tiktoken_to_shed = int(to_shed * 1.2)
|
tiktoken_to_shed = int(to_shed * 1.2)
|
||||||
return max(1, cur_tiktoken - tiktoken_to_shed)
|
return max(1, cur_tiktoken - tiktoken_to_shed)
|
||||||
|
|
||||||
|
# Per-(endpoint, model) n_ctx cache.
|
||||||
|
# Populated from two sources:
|
||||||
|
# 1. 400 exceed_context_size_error body → n_ctx field
|
||||||
|
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
||||||
|
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
||||||
|
# so large-context models (200k+ for coding) are never touched.
|
||||||
|
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
||||||
|
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Globals
|
# Globals
|
||||||
|
|
@ -1970,6 +1983,18 @@ async def chat_proxy(request: Request):
|
||||||
async_gen = None
|
async_gen = None
|
||||||
if use_openai:
|
if use_openai:
|
||||||
start_ts = time.perf_counter()
|
start_ts = time.perf_counter()
|
||||||
|
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||||
|
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||||
|
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||||
|
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
|
||||||
|
_pre_est = _count_message_tokens(params.get("messages", []))
|
||||||
|
if _pre_est > _pre_target:
|
||||||
|
_pre_msgs = params.get("messages", [])
|
||||||
|
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||||||
|
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||||
|
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||||
|
params = {**params, "messages": _pre_trimmed}
|
||||||
try:
|
try:
|
||||||
async_gen = await oclient.chat.completions.create(**params)
|
async_gen = await oclient.chat.completions.create(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1983,6 +2008,8 @@ async def chat_proxy(request: Request):
|
||||||
if not n_ctx_limit:
|
if not n_ctx_limit:
|
||||||
await decrement_usage(endpoint, tracking_model)
|
await decrement_usage(endpoint, tracking_model)
|
||||||
raise
|
raise
|
||||||
|
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||||
msgs_to_trim = params.get("messages", [])
|
msgs_to_trim = params.get("messages", [])
|
||||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||||
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||||
|
|
@ -2048,6 +2075,20 @@ async def chat_proxy(request: Request):
|
||||||
# Accumulate and store cache on done chunk — before yield so it always runs
|
# Accumulate and store cache on done chunk — before yield so it always runs
|
||||||
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
||||||
# already converted to Ollama format by rechunk before this point.
|
# already converted to Ollama format by rechunk before this point.
|
||||||
|
if getattr(chunk, "done", False):
|
||||||
|
# Detect context exhaustion mid-generation for small-ctx models
|
||||||
|
_dr = getattr(chunk, "done_reason", None)
|
||||||
|
# Only cache when no max_tokens limit was set — otherwise
|
||||||
|
# finish_reason=length might just mean max_tokens was hit,
|
||||||
|
# not that the context window was exhausted.
|
||||||
|
_req_max_tok = params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
|
||||||
|
if _dr == "length" and not _req_max_tok:
|
||||||
|
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
|
||||||
|
_ct = getattr(chunk, "eval_count", 0) or 0
|
||||||
|
_inferred_nctx = _pt + _ct
|
||||||
|
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||||
|
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||||
if _cache is not None and not _is_moe and _cache_enabled:
|
if _cache is not None and not _is_moe and _cache_enabled:
|
||||||
if chunk.message and getattr(chunk.message, "content", None):
|
if chunk.message and getattr(chunk.message, "content", None):
|
||||||
content_parts.append(chunk.message.content)
|
content_parts.append(chunk.message.content)
|
||||||
|
|
@ -2792,9 +2833,13 @@ async def ps_details_proxy(request: Request):
|
||||||
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
||||||
)
|
)
|
||||||
|
|
||||||
for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results):
|
for (ep, raw_id), model_dict, (n_ctx, is_sleeping) in zip(props_requests, llama_models_pending, props_results):
|
||||||
if n_ctx is not None:
|
if n_ctx is not None:
|
||||||
model_dict["context_length"] = n_ctx
|
model_dict["context_length"] = n_ctx
|
||||||
|
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
normalized = _normalize_llama_model_name(raw_id)
|
||||||
|
_endpoint_nctx[(ep, normalized)] = n_ctx
|
||||||
|
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
|
||||||
if not is_sleeping:
|
if not is_sleeping:
|
||||||
models.append(model_dict)
|
models.append(model_dict)
|
||||||
|
|
||||||
|
|
@ -3063,6 +3108,18 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
if not is_ext_openai_endpoint(endpoint):
|
if not is_ext_openai_endpoint(endpoint):
|
||||||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||||||
send_params = {**params, "messages": resolved_msgs}
|
send_params = {**params, "messages": resolved_msgs}
|
||||||
|
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||||
|
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||||
|
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||||
|
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||||||
|
_pre_est = _count_message_tokens(send_params.get("messages", []))
|
||||||
|
if _pre_est > _pre_target:
|
||||||
|
_pre_msgs = send_params.get("messages", [])
|
||||||
|
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||||||
|
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||||
|
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||||
|
send_params = {**send_params, "messages": _pre_trimmed}
|
||||||
try:
|
try:
|
||||||
async_gen = await oclient.chat.completions.create(**send_params)
|
async_gen = await oclient.chat.completions.create(**send_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -3088,6 +3145,8 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
if not n_ctx_limit:
|
if not n_ctx_limit:
|
||||||
await decrement_usage(endpoint, tracking_model)
|
await decrement_usage(endpoint, tracking_model)
|
||||||
raise
|
raise
|
||||||
|
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||||
|
|
||||||
msgs_to_trim = send_params.get("messages", [])
|
msgs_to_trim = send_params.get("messages", [])
|
||||||
try:
|
try:
|
||||||
|
|
@ -3168,6 +3227,15 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
prompt_tok, comp_tok = llama_usage
|
prompt_tok, comp_tok = llama_usage
|
||||||
if prompt_tok != 0 or comp_tok != 0:
|
if prompt_tok != 0 or comp_tok != 0:
|
||||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||||
|
# Detect context exhaustion mid-generation for small-ctx models.
|
||||||
|
# Guard: skip if max_tokens was set in the request — finish_reason=length
|
||||||
|
# could just mean the caller's token budget was exhausted, not the context window.
|
||||||
|
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
|
||||||
|
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
|
||||||
|
_inferred_nctx = (prompt_tok + comp_tok) or 0
|
||||||
|
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||||
|
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||||
|
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||||
# Cache assembled streaming response — before [DONE] so it always runs
|
# Cache assembled streaming response — before [DONE] so it always runs
|
||||||
if _cache is not None and _cache_enabled and content_parts:
|
if _cache is not None and _cache_enabled and content_parts:
|
||||||
assembled = orjson.dumps({
|
assembled = orjson.dumps({
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue