diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 2b1c07cda..4ddb47330 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -15,6 +15,7 @@ from pathlib import Path import yaml from langchain_litellm import ChatLiteLLM +from litellm import get_model_info from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -62,6 +63,22 @@ PROVIDER_MAP = { } +def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: + """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" + try: + info = get_model_info(model_string) + max_input_tokens = info.get("max_input_tokens") + if isinstance(max_input_tokens, int) and max_input_tokens > 0: + llm.profile = { + "max_input_tokens": max_input_tokens, + "max_input_tokens_upper": max_input_tokens, + "token_count_model": model_string, + "token_count_models": [model_string], + } + except Exception: + return + + @dataclass class AgentConfig: """ @@ -366,7 +383,9 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("litellm_params"): litellm_kwargs.update(llm_config["litellm_params"]) - return ChatLiteLLM(**litellm_kwargs) + llm = ChatLiteLLM(**litellm_kwargs) + _attach_model_profile(llm, model_string) + return llm def create_chat_litellm_from_agent_config( @@ -419,4 +438,6 @@ def create_chat_litellm_from_agent_config( if agent_config.litellm_params: litellm_kwargs.update(agent_config.litellm_params) - return ChatLiteLLM(**litellm_kwargs) + llm = ChatLiteLLM(**litellm_kwargs) + _attach_model_profile(llm, model_string) + return llm diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index e8c0d2d47..7a0b6e55b 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -11,6 +11,7 @@ The router is initialized from global LLM configs and provides both synchronous ChatLiteLLM-like interface and async methods. """ +import copy import logging import re import time @@ -331,11 +332,16 @@ _router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {} def _get_cached_context_profile(router: Router) -> dict | None: - """Compute and cache the min context profile across all router deployments. + """Compute and cache context profile across all router deployments. Called once on first ChatLiteLLMRouter creation; subsequent calls return the cached value. This avoids calling litellm.get_model_info() for every deployment on every request. + + Caches both ``max_input_tokens`` (minimum across deployments, used by + SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across + deployments, used for context-trimming so we can target the largest + available model — the router's fallback logic handles smaller ones). """ global _cached_context_profile, _cached_context_profile_computed if _cached_context_profile_computed: @@ -344,20 +350,50 @@ def _get_cached_context_profile(router: Router) -> dict | None: from litellm import get_model_info min_ctx: int | None = None + max_ctx: int | None = None + token_count_model: str | None = None + ctx_pairs: list[tuple[int, str]] = [] for deployment in router.model_list: params = deployment.get("litellm_params", {}) base_model = params.get("base_model") or params.get("model", "") try: info = get_model_info(base_model) ctx = info.get("max_input_tokens") - if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx): - min_ctx = ctx + if isinstance(ctx, int) and ctx > 0: + if min_ctx is None or ctx < min_ctx: + min_ctx = ctx + if max_ctx is None or ctx > max_ctx: + max_ctx = ctx + if token_count_model is None: + token_count_model = base_model + ctx_pairs.append((ctx, base_model)) except Exception: continue if min_ctx is not None: - logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx) - _cached_context_profile = {"max_input_tokens": min_ctx} + token_count_models: list[str] = [] + if token_count_model: + token_count_models.append(token_count_model) + if ctx_pairs: + ctx_pairs.sort(key=lambda x: x[0]) + smallest_model = ctx_pairs[0][1] + largest_model = ctx_pairs[-1][1] + if smallest_model not in token_count_models: + token_count_models.append(smallest_model) + if largest_model not in token_count_models: + token_count_models.append(largest_model) + logger.info( + "ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s", + min_ctx, + max_ctx, + token_count_models, + ) + _cached_context_profile = { + "max_input_tokens": min_ctx, + "max_input_tokens_upper": max_ctx, + "token_count_model": token_count_model, + "token_count_models": token_count_models, + } else: _cached_context_profile = None @@ -425,6 +461,248 @@ class ChatLiteLLMRouter(BaseChatModel): logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}") raise + # ----------------------------------------------------------------- + # Context-aware trimming helpers + # ----------------------------------------------------------------- + + def _get_token_count_model_names(self) -> list[str]: + """Return concrete model names usable by ``litellm.token_counter``. + + The router uses ``"auto"`` as the model group name but tokenizers need + concrete model identifiers. We keep multiple candidates and take the + most conservative count across them. + """ + names: list[str] = [] + profile = getattr(self, "profile", None) + if isinstance(profile, dict): + tcms = profile.get("token_count_models") + if isinstance(tcms, list): + for name in tcms: + if isinstance(name, str) and name and name not in names: + names.append(name) + tcm = profile.get("token_count_model") + if isinstance(tcm, str) and tcm and tcm not in names: + names.append(tcm) + + if self._router and self._router.model_list: + for dep in self._router.model_list: + params = dep.get("litellm_params", {}) + base = params.get("base_model") or params.get("model", "") + if base and base not in names: + names.append(base) + if len(names) >= 3: + break + if not names: + return ["gpt-4o"] + return names + + def _count_tokens(self, messages: list[dict]) -> int | None: + """Return conservative token count across candidate deployment models.""" + from litellm import token_counter as _tc + + models = self._get_token_count_model_names() + counts: list[int] = [] + for model_name in models: + try: + counts.append(_tc(messages=messages, model=model_name)) + except Exception: + continue + return max(counts) if counts else None + + def _get_max_input_tokens(self) -> int: + """Return the max input tokens to use for context trimming. + + Prefers the *largest* context window across all deployments so we + maximise usable context (the router's ``context_window_fallbacks`` + handle routing to the right model). Falls back to the minimum + profile value or a conservative default. + """ + profile = getattr(self, "profile", None) + if isinstance(profile, dict): + upper = profile.get("max_input_tokens_upper") + if isinstance(upper, int) and upper > 0: + return upper + lower = profile.get("max_input_tokens") + if isinstance(lower, int) and lower > 0: + return lower + return 128_000 + + def _trim_messages_to_fit_context( + self, + messages: list[dict], + output_reserve_fraction: float = 0.10, + ) -> list[dict]: + """Trim message content via binary search to fit the model's context window. + + When the total token count exceeds the model's ``max_input_tokens``, + this method identifies the largest messages (typically tool responses + containing search results) and uses binary search on each to find the + maximum content length that keeps the total within budget. + + Cutting prefers ```` XML boundaries so complete documents + are preserved when possible. + + This is model-aware: it reads the context limit from + ``litellm.get_model_info`` (cached in ``self.profile``) and counts + tokens with ``litellm.token_counter``. + """ + max_input = self._get_max_input_tokens() + output_reserve = min(int(max_input * output_reserve_fraction), 16_384) + budget = max_input - output_reserve + + total_tokens = self._count_tokens(messages) + if total_tokens is None: + return messages + + if total_tokens <= budget: + return messages + + perf = get_perf_logger() + perf.warning( + "[llm_router] context overflow detected: %d tokens > %d budget " + "(max_input=%d, reserve=%d). Trimming messages.", + total_tokens, + budget, + max_input, + output_reserve, + ) + + trimmed = copy.deepcopy(messages) + + # Per-message token counts for trimmable candidates. + # Skip system messages to preserve agent instructions. + msg_token_map: dict[int, int] = {} + candidate_priority: dict[int, int] = {} + for i, msg in enumerate(trimmed): + if msg.get("role") == "system": + continue + role = msg.get("role") + content = msg.get("content", "") + if not isinstance(content, str) or len(content) < 500: + continue + # Prefer trimming tool/assistant outputs first. + # User messages are only trimmed if they clearly contain injected + # document context blobs. + is_doc_blob = "" in content or "" in content + if role in ("tool", "assistant"): + candidate_priority[i] = 0 + elif role == "user" and is_doc_blob: + candidate_priority[i] = 1 + else: + continue + token_count = self._count_tokens([msg]) + if token_count is not None: + msg_token_map[i] = token_count + + if not msg_token_map: + perf.warning("[llm_router] no trimmable messages found, returning as-is") + return trimmed + + # Trim largest messages first + candidates = sorted( + msg_token_map.items(), + key=lambda x: (candidate_priority.get(x[0], 9), -x[1]), + ) + running_total = total_tokens + + trim_suffix = ( + "\n\n" + ) + + for idx, orig_msg_tokens in candidates: + if running_total <= budget: + break + + content = trimmed[idx]["content"] + orig_len = len(content) + + # Binary search: find maximum content[:mid] that keeps total ≤ budget. + lo, hi = 200, orig_len - 1 + best = 200 + + while lo <= hi: + mid = (lo + hi) // 2 + trimmed[idx]["content"] = content[:mid] + trim_suffix + new_msg_tokens = self._count_tokens([trimmed[idx]]) + if new_msg_tokens is None: + hi = mid - 1 + continue + + projected_total = running_total - orig_msg_tokens + new_msg_tokens + if projected_total <= budget: + best = mid + lo = mid + 1 + else: + hi = mid - 1 + + # Prefer cutting at a boundary for cleaner output + last_doc_end = content[:best].rfind("") + if last_doc_end > min(200, best // 4): + best = last_doc_end + len("") + + trimmed[idx]["content"] = content[:best] + trim_suffix + + try: + new_msg_tokens = self._count_tokens([trimmed[idx]]) + if new_msg_tokens is None: + continue + running_total = running_total - orig_msg_tokens + new_msg_tokens + except Exception: + pass + + # Hard guarantee: if still over budget, replace remaining large + # non-system messages with compact placeholders until we fit. + if running_total > budget: + fallback_indices: list[int] = [] + for i, msg in enumerate(trimmed): + if msg.get("role") == "system": + continue + content = msg.get("content") + if isinstance(content, str) and len(content) > 0: + fallback_indices.append(i) + + for idx in fallback_indices: + if running_total <= budget: + break + role = trimmed[idx].get("role", "message") + placeholder = ( + f"[content omitted to fit model context window; role={role}]" + ) + old_tokens = self._count_tokens([trimmed[idx]]) or 0 + trimmed[idx]["content"] = placeholder + new_tokens = self._count_tokens([trimmed[idx]]) or 0 + running_total = running_total - old_tokens + new_tokens + + if running_total > budget: + perf.error( + "[llm_router] unable to fit context even after aggressive trimming: " + "tokens=%d budget=%d", + running_total, + budget, + ) + # Final safety net: clear oldest non-system contents. + for idx in fallback_indices: + if running_total <= budget: + break + old_tokens = self._count_tokens([trimmed[idx]]) or 0 + trimmed[idx]["content"] = "" + new_tokens = self._count_tokens([trimmed[idx]]) or 0 + running_total = running_total - old_tokens + new_tokens + + perf.info( + "[llm_router] messages trimmed: %d → %d tokens (budget=%d, max_input=%d)", + total_tokens, + running_total, + budget, + max_input, + ) + + return trimmed + + # ----------------------------------------------------------------- + @property def _llm_type(self) -> str: return "litellm-router" @@ -499,6 +777,7 @@ class ChatLiteLLMRouter(BaseChatModel): # Convert LangChain messages to OpenAI format formatted_messages = self._convert_messages(messages) + formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} @@ -564,6 +843,7 @@ class ChatLiteLLMRouter(BaseChatModel): # Convert LangChain messages to OpenAI format formatted_messages = self._convert_messages(messages) + formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} @@ -624,6 +904,7 @@ class ChatLiteLLMRouter(BaseChatModel): raise ValueError("Router not initialized") formatted_messages = self._convert_messages(messages) + formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} @@ -673,6 +954,7 @@ class ChatLiteLLMRouter(BaseChatModel): msg_count = len(messages) formatted_messages = self._convert_messages(messages) + formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs}