feat: add reactive auto context-shift in openai endpoints to prevent recover from out of context errors
This commit is contained in:
parent
95c643109a
commit
9acc37951a
2 changed files with 86 additions and 0 deletions
85
router.py
85
router.py
|
|
@ -78,6 +78,55 @@ def _mask_secrets(text: str) -> str:
|
|||
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
||||
return text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context-window sliding-window helpers
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
import tiktoken as _tiktoken
|
||||
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
_tiktoken_enc = None
|
||||
|
||||
def _count_message_tokens(messages: list) -> int:
|
||||
"""Approximate token count for a message list.
|
||||
|
||||
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
||||
Falls back to char/4 heuristic if tiktoken is unavailable.
|
||||
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
||||
"""
|
||||
if _tiktoken_enc is None:
|
||||
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
||||
|
||||
total = 2 # priming tokens
|
||||
for msg in messages:
|
||||
total += 4 # per-message role/separator overhead
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += len(_tiktoken_enc.encode(content))
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
||||
return total
|
||||
|
||||
def _trim_messages_for_context(messages: list, n_ctx: int, safety_margin: int = 256) -> list:
|
||||
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
||||
|
||||
Keeps all system messages and the most recent non-system messages that fit
|
||||
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
||||
first (FIFO). The last message is always preserved.
|
||||
"""
|
||||
target = n_ctx - safety_margin
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
while len(non_system) > 1:
|
||||
if _count_message_tokens(system_msgs + non_system) <= target:
|
||||
break
|
||||
non_system.pop(0) # drop oldest non-system message
|
||||
|
||||
return system_msgs + non_system
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Globals
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -985,6 +1034,18 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo
|
|||
start_ts = time.perf_counter()
|
||||
try:
|
||||
response = await oclient.chat.completions.create(**params)
|
||||
except openai.BadRequestError as e:
|
||||
if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e):
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0)
|
||||
if not n_ctx_limit:
|
||||
raise
|
||||
trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit)
|
||||
print(f"[_make_chat_request] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying")
|
||||
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
||||
else:
|
||||
raise
|
||||
except openai.InternalServerError as e:
|
||||
if "image input is not supported" in str(e):
|
||||
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
|
||||
|
|
@ -1857,6 +1918,18 @@ async def chat_proxy(request: Request):
|
|||
start_ts = time.perf_counter()
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**params)
|
||||
except openai.BadRequestError as e:
|
||||
if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e):
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0)
|
||||
if not n_ctx_limit:
|
||||
raise
|
||||
trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit)
|
||||
print(f"[chat_proxy] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying")
|
||||
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
||||
else:
|
||||
raise
|
||||
except openai.InternalServerError as e:
|
||||
if "image input is not supported" in str(e):
|
||||
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
|
|
@ -2918,6 +2991,18 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
elif "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e):
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0)
|
||||
msgs = send_params.get("messages", [])
|
||||
if not n_ctx_limit:
|
||||
raise
|
||||
trimmed_messages = _trim_messages_for_context(msgs, n_ctx_limit)
|
||||
dropped = len(msgs) - len(trimmed_messages)
|
||||
print(f"[openai_chat_completions_proxy] Context window exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {dropped} oldest message(s) and retrying")
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
else:
|
||||
raise
|
||||
except openai.InternalServerError as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue