feat: add reactive auto context-shift in openai endpoints to prevent recover from out of context errors

This commit is contained in:
Alpha Nerd 2026-03-12 10:15:52 +01:00
parent 95c643109a
commit 9acc37951a
2 changed files with 86 additions and 0 deletions

View file

@ -78,6 +78,55 @@ def _mask_secrets(text: str) -> str:
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
return text
# ------------------------------------------------------------------
# Context-window sliding-window helpers
# ------------------------------------------------------------------
try:
import tiktoken as _tiktoken
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
except Exception:
_tiktoken_enc = None
def _count_message_tokens(messages: list) -> int:
"""Approximate token count for a message list.
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
Falls back to char/4 heuristic if tiktoken is unavailable.
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
"""
if _tiktoken_enc is None:
return sum(len(str(m.get("content", ""))) for m in messages) // 4
total = 2 # priming tokens
for msg in messages:
total += 4 # per-message role/separator overhead
content = msg.get("content", "")
if isinstance(content, str):
total += len(_tiktoken_enc.encode(content))
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
total += len(_tiktoken_enc.encode(part.get("text", "")))
return total
def _trim_messages_for_context(messages: list, n_ctx: int, safety_margin: int = 256) -> list:
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
Keeps all system messages and the most recent non-system messages that fit
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
first (FIFO). The last message is always preserved.
"""
target = n_ctx - safety_margin
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
while len(non_system) > 1:
if _count_message_tokens(system_msgs + non_system) <= target:
break
non_system.pop(0) # drop oldest non-system message
return system_msgs + non_system
# ------------------------------------------------------------------
# Globals
# ------------------------------------------------------------------
@ -985,6 +1034,18 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo
start_ts = time.perf_counter()
try:
response = await oclient.chat.completions.create(**params)
except openai.BadRequestError as e:
if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e):
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
raise
trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit)
print(f"[_make_chat_request] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying")
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
else:
raise
except openai.InternalServerError as e:
if "image input is not supported" in str(e):
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
@ -1857,6 +1918,18 @@ async def chat_proxy(request: Request):
start_ts = time.perf_counter()
try:
async_gen = await oclient.chat.completions.create(**params)
except openai.BadRequestError as e:
if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e):
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
raise
trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit)
print(f"[chat_proxy] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying")
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
else:
raise
except openai.InternalServerError as e:
if "image input is not supported" in str(e):
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
@ -2918,6 +2991,18 @@ async def openai_chat_completions_proxy(request: Request):
print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
elif "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e):
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0)
msgs = send_params.get("messages", [])
if not n_ctx_limit:
raise
trimmed_messages = _trim_messages_for_context(msgs, n_ctx_limit)
dropped = len(msgs) - len(trimmed_messages)
print(f"[openai_chat_completions_proxy] Context window exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {dropped} oldest message(s) and retrying")
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
else:
raise
except openai.InternalServerError as e: