diff --git a/Dockerfile b/Dockerfile index 75ee26f..0caf66c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,13 +13,14 @@ ARG SEMANTIC_CACHE=false ENV HF_HOME=/app/data/hf_cache # Install SQLite -RUN apt-get update && apt-get install -y --no-install-recommends sqlite3 \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends sqlite3 git \ && rm -rf /var/lib/apt/lists/* + WORKDIR /app COPY requirements.txt . -RUN pip install --no-cache-dir --upgrade pip \ - && pip install --no-cache-dir -r requirements.txt +RUN pip install --root-user-action=ignore --no-cache-dir --upgrade pip \ + && pip install --root-user-action=ignore --no-cache-dir -r requirements.txt # Semantic cache deps — only installed when SEMANTIC_CACHE=true # CPU-only torch must be installed before sentence-transformers to avoid diff --git a/requirements.txt b/requirements.txt index 9e89c90..da6fe43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ PyYAML==6.0.3 sniffio==1.3.1 starlette==0.49.1 truststore==0.10.4 +tiktoken==0.12.0 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.1 diff --git a/router.py b/router.py index 43ac488..a7f6a75 100644 --- a/router.py +++ b/router.py @@ -78,6 +78,107 @@ def _mask_secrets(text: str) -> str: text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text) return text +# ------------------------------------------------------------------ +# Context-window sliding-window helpers +# ------------------------------------------------------------------ +try: + import tiktoken as _tiktoken + _tiktoken_enc = _tiktoken.get_encoding("cl100k_base") +except Exception: + _tiktoken_enc = None + +def _count_message_tokens(messages: list) -> int: + """Approximate token count for a message list. + + Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers). + Falls back to char/4 heuristic if tiktoken is unavailable. + Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming. + """ + if _tiktoken_enc is None: + return sum(len(str(m.get("content", ""))) for m in messages) // 4 + + total = 2 # priming tokens + for msg in messages: + total += 4 # per-message role/separator overhead + content = msg.get("content", "") + if isinstance(content, str): + total += len(_tiktoken_enc.encode(content)) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + total += len(_tiktoken_enc.encode(part.get("text", ""))) + return total + +def _trim_messages_for_context( + messages: list, + n_ctx: int, + safety_margin: int = None, + target_tokens: int = None, +) -> list: + """Sliding-window trim — mirrors what llama.cpp context-shift used to do. + + Keeps all system messages and the most recent non-system messages that fit + within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped + first (FIFO). The last message is always preserved. + + safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated + response, including RAG tool results and tool call JSON synthesis. + + target_tokens: if provided, overrides the (n_ctx - safety_margin) target. + Pass a calibrated value when actual n_prompt_tokens is known from the error + body so that tiktoken underestimation vs the backend tokenizer is corrected. + """ + if target_tokens is not None: + target = target_tokens + else: + if safety_margin is None: + safety_margin = n_ctx // 4 + target = n_ctx - safety_margin + system_msgs = [m for m in messages if m.get("role") == "system"] + non_system = [m for m in messages if m.get("role") != "system"] + + while len(non_system) > 1: + if _count_message_tokens(system_msgs + non_system) <= target: + break + non_system.pop(0) # drop oldest non-system message + + # Ensure the first non-system message is a user message (chat templates require it). + # Drop any leading assistant/tool messages that were left after trimming. + while non_system and non_system[0].get("role") != "user": + non_system.pop(0) + + return system_msgs + non_system + + + +def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int: + """Return a tiktoken-scale trim target based on how much backend tokens must be shed. + + actual_tokens includes messages + tool schemas + overhead as counted by the backend. + _count_message_tokens only counts message text, so we cannot derive an accurate + per-token scale from the ratio. Instead we compute the *delta* we need to remove + in backend space, then convert just that delta to tiktoken scale (×1.2 buffer). + + Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend + tokens → shed 6846 tiktoken tokens from messages. + """ + cur_tiktoken = _count_message_tokens(msgs) + headroom = n_ctx // 4 # reserve for generated output + max_prompt = n_ctx - headroom # desired max backend tokens in prompt + to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop + # Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%) + tiktoken_to_shed = int(to_shed * 1.2) + return max(1, cur_tiktoken - tiktoken_to_shed) + +# Per-(endpoint, model) n_ctx cache. +# Populated from two sources: +# 1. 400 exceed_context_size_error body → n_ctx field +# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens +# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT, +# so large-context models (200k+ for coding) are never touched. +_endpoint_nctx: dict[tuple[str, str], int] = {} +_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this + # ------------------------------------------------------------------ # Globals # ------------------------------------------------------------------ @@ -707,10 +808,15 @@ class fetch: # Check error cache with lock protection async with _available_error_cache_lock: if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 300): - # Still within the short error TTL – pretend nothing is available + err_age = time.time() - _available_error_cache[endpoint] + if err_age < 30: + # Very fresh error (<30s) – endpoint likely still down, bail fast return set() - # Error expired – remove it + elif err_age < 300: + # Stale error (30-300s) – endpoint may have recovered, probe in background + asyncio.create_task(fetch._refresh_available_models(endpoint, api_key)) + return set() + # Error expired (>300s) – remove and fall through to fresh fetch del _available_error_cache[endpoint] # Request coalescing: check if another request is already fetching this endpoint @@ -983,7 +1089,44 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo try: if use_openai: start_ts = time.perf_counter() - response = await oclient.chat.completions.create(**params) + try: + response = await oclient.chat.completions.create(**params) + except Exception as e: + _e_str = str(e) + print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}") + if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + if not n_ctx_limit: + _m = re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + if not n_ctx_limit: + raise + msgs_to_trim = params.get("messages", []) + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") + try: + response = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + except Exception as e2: + if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2): + print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools") + params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} + response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) + else: + raise + elif "image input is not supported" in _e_str: + print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages") + params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} + response = await oclient.chat.completions.create(**params) + else: + raise if stream: # For streaming, we need to collect all chunks chunks = [] @@ -1212,6 +1355,22 @@ def transform_images_to_data_urls(message_list): return message_list +def _strip_images_from_messages(messages: list) -> list: + """Remove image_url parts from message content, keeping only text.""" + result = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + text_only = [p for p in content if p.get("type") != "image_url"] + if len(text_only) == 1 and text_only[0].get("type") == "text": + content = text_only[0]["text"] + else: + content = text_only + result.append({**msg, "content": content}) + else: + result.append(msg) + return result + def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None: """Accumulate tool_call deltas from a single OpenAI streaming chunk. @@ -1825,23 +1984,93 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) + # For OpenAI endpoints: make the API call in handler scope + # (try/except inside async generators is unreliable with Starlette's streaming) + start_ts = None + async_gen = None + if use_openai: + start_ts = time.perf_counter() + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) + _pre_est = _count_message_tokens(params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + params = {**params, "messages": _pre_trimmed} + try: + async_gen = await oclient.chat.completions.create(**params) + except Exception as e: + _e_str = str(e) + print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}") + if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + if not n_ctx_limit: + _m = re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit + msgs_to_trim = params.get("messages", []) + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") + try: + async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools") + params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} + async_gen = await oclient.chat.completions.create(**params) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) if use_openai: - start_ts = time.perf_counter() - async_gen = await oclient.chat.completions.create(**params) + _async_gen = async_gen # established in handler scope above else: if opt == True: # Use the dedicated MOE helper function - async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) + _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) else: - async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) + _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) if stream == True: tc_acc = {} # accumulate OpenAI tool-call deltas across chunks content_parts: list[str] = [] - async for chunk in async_gen: + async for chunk in _async_gen: if use_openai: _accumulate_openai_tc_delta(chunk, tc_acc) chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) @@ -1860,6 +2089,20 @@ async def chat_proxy(request: Request): # Accumulate and store cache on done chunk — before yield so it always runs # Works for both Ollama-native and OpenAI-compatible backends; chunks are # already converted to Ollama format by rechunk before this point. + if getattr(chunk, "done", False): + # Detect context exhaustion mid-generation for small-ctx models + _dr = getattr(chunk, "done_reason", None) + # Only cache when no max_tokens limit was set — otherwise + # finish_reason=length might just mean max_tokens was hit, + # not that the context window was exhausted. + _req_max_tok = params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") + if _dr == "length" and not _req_max_tok: + _pt = getattr(chunk, "prompt_eval_count", 0) or 0 + _ct = getattr(chunk, "eval_count", 0) or 0 + _inferred_nctx = _pt + _ct + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) if _cache is not None and not _is_moe and _cache_enabled: if chunk.message and getattr(chunk.message, "content", None): content_parts.append(chunk.message.content) @@ -1884,18 +2127,18 @@ async def chat_proxy(request: Request): yield json_line.encode("utf-8") + b"\n" else: if use_openai: - response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts) + response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) response = response.model_dump_json() else: - response = async_gen.model_dump_json() - prompt_tok = async_gen.prompt_eval_count or 0 - comp_tok = async_gen.eval_count or 0 + response = _async_gen.model_dump_json() + prompt_tok = _async_gen.prompt_eval_count or 0 + comp_tok = _async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response - if hasattr(async_gen, "model_dump_json") - else orjson.dumps(async_gen) + if hasattr(_async_gen, "model_dump_json") + else orjson.dumps(_async_gen) ) cache_bytes = json_line.encode("utf-8") + b"\n" yield cache_bytes @@ -2567,7 +2810,7 @@ async def ps_details_proxy(request: Request): # Fetch /props for each llama-server model to get context length (n_ctx) # and unload sleeping models automatically - async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool]: + async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]: client: aiohttp.ClientSession = app_state["session"] base_url = endpoint.rstrip("/").removesuffix("/v1") props_url = f"{base_url}/props?model={model_id}" @@ -2582,6 +2825,8 @@ async def ps_details_proxy(request: Request): dgs = data.get("default_generation_settings", {}) n_ctx = dgs.get("n_ctx") is_sleeping = data.get("is_sleeping", False) + # Embedding models have no sampling params in default_generation_settings + is_generation = "temperature" in dgs if is_sleeping: unload_url = f"{base_url}/models/unload" @@ -2595,18 +2840,22 @@ async def ps_details_proxy(request: Request): except Exception as ue: print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") - return n_ctx, is_sleeping + return n_ctx, is_sleeping, is_generation except Exception as e: print(f"[ps_details] Failed to fetch props from {props_url}: {e}") - return None, False + return None, False, False props_results = await asyncio.gather( *[_fetch_llama_props(ep, mid) for ep, mid in props_requests] ) - for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results): + for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results): if n_ctx is not None: model_dict["context_length"] = n_ctx + if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + normalized = _normalize_llama_model_name(raw_id) + _endpoint_nctx[(ep, normalized)] = n_ctx + print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True) if not is_sleeping: models.append(model_dict) @@ -2686,6 +2935,21 @@ async def openai_embedding_proxy(request: Request): model = payload.get("model") doc = payload.get("input") + # Normalize multimodal input: extract only text parts for embedding models + if isinstance(doc, list): + normalized = [] + for item in doc: + if isinstance(item, dict): + # Multimodal content part - extract text only, skip images + if item.get("type") == "text": + normalized.append(item.get("text", "")) + # Skip image_url and other non-text types + else: + normalized.append(item) + doc = normalized if len(normalized) != 1 else normalized[0] + elif isinstance(doc, dict) and doc.get("type") == "text": + doc = doc.get("text", "") + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -2819,7 +3083,7 @@ async def openai_chat_completions_proxy(request: Request): endpoint, tracking_model = await choose_endpoint(model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - # 3. Async generator that streams completions data and decrements the counter + # 3. Helpers and API call — done in handler scope so try/except works reliably async def _normalize_images_in_messages(msgs: list) -> list: """Fetch remote image URLs and convert them to base64 data URLs so Ollama/llama-server can handle them without making outbound HTTP requests.""" @@ -2854,25 +3118,104 @@ async def openai_chat_completions_proxy(request: Request): resolved.append({**msg, "content": new_content}) return resolved + # Make the API call in handler scope — try/except inside async generators is unreliable + # with Starlette's streaming machinery, so we resolve errors here before the generator starts. + send_params = params + if not is_ext_openai_endpoint(endpoint): + resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) + send_params = {**params, "messages": resolved_msgs} + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) + _pre_est = _count_message_tokens(send_params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = send_params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + send_params = {**send_params, "messages": _pre_trimmed} + try: + async_gen = await oclient.chat.completions.create(**send_params) + except Exception as e: + _e_str = str(e) + _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str + print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) + if "does not support tools" in _e_str: + # Model doesn't support tools — retry without them + print(f"[ochat] retry: no tools", flush=True) + try: + params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_ctx_err: + # Backend context limit hit — apply sliding-window trim (context-shift at message level) + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) + if not n_ctx_limit: + import re as _re + _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit + + msgs_to_trim = send_params.get("messages", []) + try: + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + except Exception as _helper_exc: + print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) + await decrement_usage(endpoint, tracking_model) + raise + dropped = len(msgs_to_trim) - len(trimmed_messages) + print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-1 ok", flush=True) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + # Still too large — tool definitions likely consuming too many tokens, strip them too + print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) + params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-2 ok", flush=True) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + # Model doesn't support images — strip and retry + print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + + # 4. Async generator — only streams the already-established async_gen async def stream_ochat_response(): try: - # The chat method returns a generator of dicts (or GenerateResponse) - try: - # For non-external endpoints (Ollama, llama-server), resolve remote - # image URLs to base64 data URLs so the server can handle them locally. - send_params = params - if not is_ext_openai_endpoint(endpoint): - resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) - send_params = {**params, "messages": resolved_msgs} - async_gen = await oclient.chat.completions.create(**send_params) - except openai.BadRequestError as e: - # If tools are not supported by the model, retry without tools - if "does not support tools" in str(e): - print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools") - params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} - async_gen = await oclient.chat.completions.create(**params_without_tools) - else: - raise if stream == True: content_parts: list[str] = [] usage_snapshot: dict = {} @@ -2909,6 +3252,15 @@ async def openai_chat_completions_proxy(request: Request): prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + # Detect context exhaustion mid-generation for small-ctx models. + # Guard: skip if max_tokens was set in the request — finish_reason=length + # could just mean the caller's token budget was exhausted, not the context window. + _req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens") + if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok: + _inferred_nctx = (prompt_tok + comp_tok) or 0 + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) # Cache assembled streaming response — before [DONE] so it always runs if _cache is not None and _cache_enabled and content_parts: assembled = orjson.dumps({ @@ -3044,10 +3396,15 @@ async def openai_completions_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter + # Make the API call in handler scope (try/except inside async generators is unreliable) + try: + async_gen = await oclient.completions.create(**params) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + async def stream_ocompletions_response(model=model): try: - # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await oclient.completions.create(**params) if stream == True: text_parts: list[str] = [] usage_snapshot: dict = {} diff --git a/static/index.html b/static/index.html index fe14ef5..cac7f8d 100644 --- a/static/index.html +++ b/static/index.html @@ -316,13 +316,32 @@ display: flex; align-items: center; /* vertically center the button with the headline */ gap: 1rem; - } + } + .logo-chart-row { + display: flex; + align-items: stretch; + gap: 1rem; + margin-bottom: 1rem; + } + #header-tps-container { + flex: 1; + background: white; + border-radius: 6px; + padding: 0.25rem 0.75rem; + height: 100px; + position: relative; + } - +
+ +
+ +
+

Router Dashboard

@@ -419,6 +438,11 @@ let statsChart = null; let rawTimeSeries = null; let totalTokensChart = null; + let headerTpsChart = null; + const TPS_HISTORY_SIZE = 60; + const tpsHistory = []; + let latestPerModelTokens = {}; + const modelFirstSeen = {}; let usageSource = null; const API_KEY_STORAGE_KEY = "nomyo-router-api-key"; @@ -928,7 +952,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { const uniqueEndpoints = Array.from(new Set(endpoints)); const endpointsData = encodeURIComponent(JSON.stringify(uniqueEndpoints)); return ` - ${modelName} stats + ${modelName} stats ${renderInstanceList(endpoints)} ${params} ${quant} @@ -1009,6 +1033,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { try { const payload = JSON.parse(e.data); // SSE sends plain text renderChart(payload); + updateTpsChart(payload); const usage = payload.usage_counts || {}; const tokens = payload.token_usage_counts || {}; @@ -1035,6 +1060,84 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { window.addEventListener("beforeunload", () => source.close()); } + /* ---------- Header TPS Chart ---------- */ + function initHeaderChart() { + const canvas = document.getElementById('header-tps-canvas'); + if (!canvas) return; + headerTpsChart = new Chart(canvas.getContext('2d'), { + type: 'line', + data: { labels: [], datasets: [] }, + options: { + responsive: true, + maintainAspectRatio: false, + animation: false, + scales: { + x: { display: false }, + y: { display: true, min: 0, ticks: { font: { size: 10 } } } + }, + plugins: { + legend: { display: false }, + tooltip: { + callbacks: { + title: (items) => items[0]?.dataset?.label || '', + label: (item) => `${item.parsed.y.toFixed(1)} tok/s` + } + } + }, + elements: { point: { radius: 0 } } + } + }); + } + + function updateTpsChart(payload) { + const tokens = payload.token_usage_counts || {}; + const perModelTokens = {}; + psRows.forEach((_, model) => { + let total = 0; + for (const ep in tokens) total += tokens[ep]?.[model] || 0; + // Normalise against the first-seen cumulative total so history + // entries start at 0 and the || 0 fallback never causes a spike. + if (!(model in modelFirstSeen)) modelFirstSeen[model] = total; + perModelTokens[model] = total - modelFirstSeen[model]; + }); + latestPerModelTokens = perModelTokens; + } + + function tickTpsChart() { + if (!headerTpsChart) return; + tpsHistory.push({ time: Date.now(), perModelTokens: { ...latestPerModelTokens } }); + if (tpsHistory.length > TPS_HISTORY_SIZE) tpsHistory.shift(); + if (tpsHistory.length < 2) return; + + // Only chart models present in the latest snapshot — never accumulate + // stale names from old history entries. + const allModels = Object.keys(tpsHistory[tpsHistory.length - 1].perModelTokens); + + const labels = tpsHistory.map(h => new Date(h.time).toLocaleTimeString()); + const datasets = Array.from(allModels).map(model => { + const data = tpsHistory.map((h, i) => { + if (i === 0) return 0; + const prev = tpsHistory[i - 1]; + const dt = (h.time - prev.time) / 1000; + const dTokens = (h.perModelTokens[model] || 0) - (prev.perModelTokens[model] || 0); + return dt > 0 ? Math.max(0, dTokens / dt) : 0; + }); + return { + label: model, + data, + borderColor: getColor(model), + backgroundColor: 'transparent', + borderWidth: 2, + tension: 0.3, + pointRadius: 0 + }; + }); + + headerTpsChart.data.labels = labels; + headerTpsChart.data.datasets = datasets; + headerTpsChart.update('none'); + } + /* ---------- Init ---------- */ window.addEventListener("load", () => { updateApiKeyIndicator(); @@ -1068,6 +1171,8 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { loadTags(); loadPS(); loadUsage(); + initHeaderChart(); + setInterval(tickTpsChart, 1000); setInterval(loadPS, 60_000); setInterval(loadEndpoints, 300_000);