diff --git a/requirements.txt b/requirements.txt index 9e89c90..da6fe43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ PyYAML==6.0.3 sniffio==1.3.1 starlette==0.49.1 truststore==0.10.4 +tiktoken==0.12.0 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.1 diff --git a/router.py b/router.py index 7c77cbf..c1f4d54 100644 --- a/router.py +++ b/router.py @@ -78,6 +78,55 @@ def _mask_secrets(text: str) -> str: text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text) return text +# ------------------------------------------------------------------ +# Context-window sliding-window helpers +# ------------------------------------------------------------------ +try: + import tiktoken as _tiktoken + _tiktoken_enc = _tiktoken.get_encoding("cl100k_base") +except Exception: + _tiktoken_enc = None + +def _count_message_tokens(messages: list) -> int: + """Approximate token count for a message list. + + Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers). + Falls back to char/4 heuristic if tiktoken is unavailable. + Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming. + """ + if _tiktoken_enc is None: + return sum(len(str(m.get("content", ""))) for m in messages) // 4 + + total = 2 # priming tokens + for msg in messages: + total += 4 # per-message role/separator overhead + content = msg.get("content", "") + if isinstance(content, str): + total += len(_tiktoken_enc.encode(content)) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + total += len(_tiktoken_enc.encode(part.get("text", ""))) + return total + +def _trim_messages_for_context(messages: list, n_ctx: int, safety_margin: int = 256) -> list: + """Sliding-window trim — mirrors what llama.cpp context-shift used to do. + + Keeps all system messages and the most recent non-system messages that fit + within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped + first (FIFO). The last message is always preserved. + """ + target = n_ctx - safety_margin + system_msgs = [m for m in messages if m.get("role") == "system"] + non_system = [m for m in messages if m.get("role") != "system"] + + while len(non_system) > 1: + if _count_message_tokens(system_msgs + non_system) <= target: + break + non_system.pop(0) # drop oldest non-system message + + return system_msgs + non_system + # ------------------------------------------------------------------ # Globals # ------------------------------------------------------------------ @@ -985,6 +1034,18 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo start_ts = time.perf_counter() try: response = await oclient.chat.completions.create(**params) + except openai.BadRequestError as e: + if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e): + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0) + if not n_ctx_limit: + raise + trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit) + print(f"[_make_chat_request] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying") + response = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + else: + raise except openai.InternalServerError as e: if "image input is not supported" in str(e): print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages") @@ -1857,6 +1918,18 @@ async def chat_proxy(request: Request): start_ts = time.perf_counter() try: async_gen = await oclient.chat.completions.create(**params) + except openai.BadRequestError as e: + if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e): + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0) + if not n_ctx_limit: + raise + trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit) + print(f"[chat_proxy] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying") + async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + else: + raise except openai.InternalServerError as e: if "image input is not supported" in str(e): print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") @@ -2918,6 +2991,18 @@ async def openai_chat_completions_proxy(request: Request): print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools") params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} async_gen = await oclient.chat.completions.create(**params_without_tools) + elif "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e): + # Backend context limit hit — apply sliding-window trim (context-shift at message level) + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0) + msgs = send_params.get("messages", []) + if not n_ctx_limit: + raise + trimmed_messages = _trim_messages_for_context(msgs, n_ctx_limit) + dropped = len(msgs) - len(trimmed_messages) + print(f"[openai_chat_completions_proxy] Context window exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {dropped} oldest message(s) and retrying") + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) else: raise except openai.InternalServerError as e: