diff --git a/router.py b/router.py index c1f4d54..f0e13ab 100644 --- a/router.py +++ b/router.py @@ -109,14 +109,31 @@ def _count_message_tokens(messages: list) -> int: total += len(_tiktoken_enc.encode(part.get("text", ""))) return total -def _trim_messages_for_context(messages: list, n_ctx: int, safety_margin: int = 256) -> list: +def _trim_messages_for_context( + messages: list, + n_ctx: int, + safety_margin: int = None, + target_tokens: int = None, +) -> list: """Sliding-window trim — mirrors what llama.cpp context-shift used to do. Keeps all system messages and the most recent non-system messages that fit within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped first (FIFO). The last message is always preserved. + + safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated + response, including RAG tool results and tool call JSON synthesis. + + target_tokens: if provided, overrides the (n_ctx - safety_margin) target. + Pass a calibrated value when actual n_prompt_tokens is known from the error + body so that tiktoken underestimation vs the backend tokenizer is corrected. """ - target = n_ctx - safety_margin + if target_tokens is not None: + target = target_tokens + else: + if safety_margin is None: + safety_margin = n_ctx // 4 + target = n_ctx - safety_margin system_msgs = [m for m in messages if m.get("role") == "system"] non_system = [m for m in messages if m.get("role") != "system"] @@ -127,6 +144,28 @@ def _trim_messages_for_context(messages: list, n_ctx: int, safety_margin: int = return system_msgs + non_system + + +def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int: + """Return a tiktoken-scale trim target based on how much backend tokens must be shed. + + actual_tokens includes messages + tool schemas + overhead as counted by the backend. + _count_message_tokens only counts message text, so we cannot derive an accurate + per-token scale from the ratio. Instead we compute the *delta* we need to remove + in backend space, then convert just that delta to tiktoken scale (×1.2 buffer). + + Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend + tokens → shed 6846 tiktoken tokens from messages. + """ + cur_tiktoken = _count_message_tokens(msgs) + headroom = n_ctx // 4 # reserve for generated output + max_prompt = n_ctx - headroom # desired max backend tokens in prompt + to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop + # Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%) + tiktoken_to_shed = int(to_shed * 1.2) + return max(1, cur_tiktoken - tiktoken_to_shed) + + # ------------------------------------------------------------------ # Globals # ------------------------------------------------------------------ @@ -756,10 +795,15 @@ class fetch: # Check error cache with lock protection async with _available_error_cache_lock: if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 300): - # Still within the short error TTL – pretend nothing is available + err_age = time.time() - _available_error_cache[endpoint] + if err_age < 30: + # Very fresh error (<30s) – endpoint likely still down, bail fast return set() - # Error expired – remove it + elif err_age < 300: + # Stale error (30-300s) – endpoint may have recovered, probe in background + asyncio.create_task(fetch._refresh_available_models(endpoint, api_key)) + return set() + # Error expired (>300s) – remove and fall through to fresh fetch del _available_error_cache[endpoint] # Request coalescing: check if another request is already fetching this endpoint @@ -1034,20 +1078,30 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo start_ts = time.perf_counter() try: response = await oclient.chat.completions.create(**params) - except openai.BadRequestError as e: - if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e): + except Exception as e: + _e_str = str(e) + print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}") + if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: err_body = getattr(e, "body", {}) or {} err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0) + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) if not n_ctx_limit: raise - trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit) - print(f"[_make_chat_request] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying") - response = await oclient.chat.completions.create(**{**params, "messages": trimmed}) - else: - raise - except openai.InternalServerError as e: - if "image input is not supported" in str(e): + msgs_to_trim = params.get("messages", []) + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") + try: + response = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + except Exception as e2: + if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2): + print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools") + params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} + response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) + else: + raise + elif "image input is not supported" in _e_str: print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages") params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} response = await oclient.chat.completions.create(**params) @@ -1910,43 +1964,72 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) + # For OpenAI endpoints: make the API call in handler scope + # (try/except inside async generators is unreliable with Starlette's streaming) + start_ts = None + async_gen = None + if use_openai: + start_ts = time.perf_counter() + try: + async_gen = await oclient.chat.completions.create(**params) + except Exception as e: + _e_str = str(e) + print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}") + if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + msgs_to_trim = params.get("messages", []) + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") + try: + async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools") + params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} + async_gen = await oclient.chat.completions.create(**params) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) if use_openai: - start_ts = time.perf_counter() - try: - async_gen = await oclient.chat.completions.create(**params) - except openai.BadRequestError as e: - if "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e): - err_body = getattr(e, "body", {}) or {} - err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0) - if not n_ctx_limit: - raise - trimmed = _trim_messages_for_context(params.get("messages", []), n_ctx_limit) - print(f"[chat_proxy] Context exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {len(params.get('messages', [])) - len(trimmed)} oldest message(s) and retrying") - async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) - else: - raise - except openai.InternalServerError as e: - if "image input is not supported" in str(e): - print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") - params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} - async_gen = await oclient.chat.completions.create(**params) - else: - raise + _async_gen = async_gen # established in handler scope above else: if opt == True: # Use the dedicated MOE helper function - async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) + _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) else: - async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) + _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) if stream == True: tc_acc = {} # accumulate OpenAI tool-call deltas across chunks content_parts: list[str] = [] - async for chunk in async_gen: + async for chunk in _async_gen: if use_openai: _accumulate_openai_tc_delta(chunk, tc_acc) chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) @@ -1989,18 +2072,18 @@ async def chat_proxy(request: Request): yield json_line.encode("utf-8") + b"\n" else: if use_openai: - response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts) + response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) response = response.model_dump_json() else: - response = async_gen.model_dump_json() - prompt_tok = async_gen.prompt_eval_count or 0 - comp_tok = async_gen.eval_count or 0 + response = _async_gen.model_dump_json() + prompt_tok = _async_gen.prompt_eval_count or 0 + comp_tok = _async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response - if hasattr(async_gen, "model_dump_json") - else orjson.dumps(async_gen) + if hasattr(_async_gen, "model_dump_json") + else orjson.dumps(_async_gen) ) cache_bytes = json_line.encode("utf-8") + b"\n" yield cache_bytes @@ -2939,7 +3022,7 @@ async def openai_chat_completions_proxy(request: Request): endpoint, tracking_model = await choose_endpoint(model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - # 3. Async generator that streams completions data and decrements the counter + # 3. Helpers and API call — done in handler scope so try/except works reliably async def _normalize_images_in_messages(msgs: list) -> list: """Fetch remote image URLs and convert them to base64 data URLs so Ollama/llama-server can handle them without making outbound HTTP requests.""" @@ -2974,44 +3057,81 @@ async def openai_chat_completions_proxy(request: Request): resolved.append({**msg, "content": new_content}) return resolved + # Make the API call in handler scope — try/except inside async generators is unreliable + # with Starlette's streaming machinery, so we resolve errors here before the generator starts. + send_params = params + if not is_ext_openai_endpoint(endpoint): + resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) + send_params = {**params, "messages": resolved_msgs} + try: + async_gen = await oclient.chat.completions.create(**send_params) + except Exception as e: + _e_str = str(e) + _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str + print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) + if "does not support tools" in _e_str: + # Model doesn't support tools — retry without them + print(f"[ochat] retry: no tools", flush=True) + try: + params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_ctx_err: + # Backend context limit hit — apply sliding-window trim (context-shift at message level) + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + + msgs_to_trim = send_params.get("messages", []) + try: + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + except Exception as _helper_exc: + print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) + await decrement_usage(endpoint, tracking_model) + raise + dropped = len(msgs_to_trim) - len(trimmed_messages) + print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-1 ok", flush=True) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + # Still too large — tool definitions likely consuming too many tokens, strip them too + print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) + params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-2 ok", flush=True) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + # Model doesn't support images — strip and retry + print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + + # 4. Async generator — only streams the already-established async_gen async def stream_ochat_response(): try: - # The chat method returns a generator of dicts (or GenerateResponse) - try: - # For non-external endpoints (Ollama, llama-server), resolve remote - # image URLs to base64 data URLs so the server can handle them locally. - send_params = params - if not is_ext_openai_endpoint(endpoint): - resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) - send_params = {**params, "messages": resolved_msgs} - async_gen = await oclient.chat.completions.create(**send_params) - except openai.BadRequestError as e: - # If tools are not supported by the model, retry without tools - if "does not support tools" in str(e): - print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools") - params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} - async_gen = await oclient.chat.completions.create(**params_without_tools) - elif "exceed_context_size_error" in str(e) or "exceeds the available context size" in str(e): - # Backend context limit hit — apply sliding-window trim (context-shift at message level) - err_body = getattr(e, "body", {}) or {} - err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) or err_detail.get("n_prompt_tokens", 0) - msgs = send_params.get("messages", []) - if not n_ctx_limit: - raise - trimmed_messages = _trim_messages_for_context(msgs, n_ctx_limit) - dropped = len(msgs) - len(trimmed_messages) - print(f"[openai_chat_completions_proxy] Context window exceeded ({err_detail.get('n_prompt_tokens')}/{n_ctx_limit} tokens), dropped {dropped} oldest message(s) and retrying") - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) - else: - raise - except openai.InternalServerError as e: - # If the model doesn't support image input, strip images and retry - if "image input is not supported" in str(e): - print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) - else: - raise if stream == True: content_parts: list[str] = [] usage_snapshot: dict = {} @@ -3183,10 +3303,15 @@ async def openai_completions_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter + # Make the API call in handler scope (try/except inside async generators is unreliable) + try: + async_gen = await oclient.completions.create(**params) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + async def stream_ocompletions_response(model=model): try: - # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await oclient.completions.create(**params) if stream == True: text_parts: list[str] = [] usage_snapshot: dict = {}