From afa51e97cfbeecd2e90fc996525dd574ef44c942 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 5 Jun 2026 11:15:13 +0200 Subject: [PATCH] refactor(agents): delete dead single-agent-only middleware file_intent (FileIntentMiddleware) and flatten_system (FlattenSystemMessageMiddleware) were only ever instantiated in the single-agent chat_deepagent stack, which was removed in 14bbea085. They have no production consumer in multi_agent_chat. Delete both modules and their unit tests. Also drop the vestigial KnowledgeBaseSearchMiddleware alias (= the live KnowledgePriorityMiddleware); its tests now target the real class so the behavior coverage is preserved. Trim the three barrel/__all__ entries and strip the now-dead class names from comments. --- .../app/agents/shared/context.py | 4 +- .../app/agents/shared/middleware/__init__.py | 10 - .../agents/shared/middleware/compaction.py | 2 +- .../agents/shared/middleware/file_intent.py | 334 ----------------- .../shared/middleware/flatten_system.py | 233 ------------ .../shared/middleware/knowledge_search.py | 5 - .../app/agents/shared/prompt_caching.py | 14 +- .../agents/new_chat/test_flatten_system.py | 344 ------------------ .../middleware/test_file_intent_middleware.py | 214 ----------- .../unit/middleware/test_knowledge_search.py | 20 +- 10 files changed, 19 insertions(+), 1161 deletions(-) delete mode 100644 surfsense_backend/app/agents/shared/middleware/file_intent.py delete mode 100644 surfsense_backend/app/agents/shared/middleware/flatten_system.py delete mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py delete mode 100644 surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py diff --git a/surfsense_backend/app/agents/shared/context.py b/surfsense_backend/app/agents/shared/context.py index 1b3ea3d20..50b761f5b 100644 --- a/surfsense_backend/app/agents/shared/context.py +++ b/surfsense_backend/app/agents/shared/context.py @@ -50,8 +50,8 @@ class SurfSenseContextSchema: (cloud filesystem mode). Surfaced as ``[USER-MENTIONED]`` entries in ```` so the agent prioritises walking those folders with ``ls`` / ``find_documents``. - file_operation_contract: One-shot file operation contract emitted - by ``FileIntentMiddleware`` for the upcoming turn. + file_operation_contract: One-shot file operation contract for the + upcoming turn (reserved; not currently populated). turn_id / request_id: Correlation IDs surfaced by the streaming task; populated for telemetry. diff --git a/surfsense_backend/app/agents/shared/middleware/__init__.py b/surfsense_backend/app/agents/shared/middleware/__init__.py index fb6eacfdb..7aaeb2713 100644 --- a/surfsense_backend/app/agents/shared/middleware/__init__.py +++ b/surfsense_backend/app/agents/shared/middleware/__init__.py @@ -21,18 +21,11 @@ from app.agents.shared.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, ) from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware -from app.agents.shared.middleware.file_intent import ( - FileIntentMiddleware, -) -from app.agents.shared.middleware.flatten_system import ( - FlattenSystemMessageMiddleware, -) from app.agents.shared.middleware.kb_persistence import ( KnowledgeBasePersistenceMiddleware, commit_staged_filesystem_state, ) from app.agents.shared.middleware.knowledge_search import ( - KnowledgeBaseSearchMiddleware, KnowledgePriorityMiddleware, ) from app.agents.shared.middleware.knowledge_tree import ( @@ -56,10 +49,7 @@ __all__ = [ "ClearToolUsesEdit", "DedupHITLToolCallsMiddleware", "DoomLoopMiddleware", - "FileIntentMiddleware", - "FlattenSystemMessageMiddleware", "KnowledgeBasePersistenceMiddleware", - "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", "KnowledgeTreeMiddleware", "MemoryInjectionMiddleware", diff --git a/surfsense_backend/app/agents/shared/middleware/compaction.py b/surfsense_backend/app/agents/shared/middleware/compaction.py index f8d340e5d..6a533be6b 100644 --- a/surfsense_backend/app/agents/shared/middleware/compaction.py +++ b/surfsense_backend/app/agents/shared/middleware/compaction.py @@ -94,7 +94,7 @@ Respond ONLY with the structured summary. Do not include any text before or afte PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = ( "", # KnowledgePriorityMiddleware "", # KnowledgeTreeMiddleware - "", # FileIntentMiddleware + "", # reserved file-operation contract prefix "", # MemoryInjectionMiddleware "", # MemoryInjectionMiddleware "", # MemoryInjectionMiddleware diff --git a/surfsense_backend/app/agents/shared/middleware/file_intent.py b/surfsense_backend/app/agents/shared/middleware/file_intent.py deleted file mode 100644 index 7897e13d6..000000000 --- a/surfsense_backend/app/agents/shared/middleware/file_intent.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Semantic file-intent routing middleware for new chat turns. - -This middleware classifies the latest human turn into a small intent set: -- chat_only -- file_write -- file_read - -For ``file_write`` turns it injects a strict system contract so the model -uses filesystem tools before claiming success, and provides a deterministic -fallback path when no filename is specified by the user. -""" - -from __future__ import annotations - -import json -import logging -import re -from datetime import UTC, datetime -from enum import StrEnum -from typing import Any - -from langchain.agents.middleware import AgentMiddleware, AgentState -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langgraph.runtime import Runtime -from pydantic import BaseModel, Field, ValidationError - -logger = logging.getLogger(__name__) - - -class FileOperationIntent(StrEnum): - CHAT_ONLY = "chat_only" - FILE_WRITE = "file_write" - FILE_READ = "file_read" - - -class FileIntentPlan(BaseModel): - intent: FileOperationIntent = Field( - description="Primary user intent for this turn." - ) - confidence: float = Field( - ge=0.0, - le=1.0, - default=0.5, - description="Model confidence in the selected intent.", - ) - suggested_filename: str | None = Field( - default=None, - description="Optional filename (e.g. notes.md) inferred from user request.", - ) - suggested_directory: str | None = Field( - default=None, - description=( - "Optional directory path (e.g. /reports/q2 or reports/q2) inferred from " - "user request." - ), - ) - suggested_path: str | None = Field( - default=None, - description=( - "Optional full file path (e.g. /reports/q2/summary.md). If present, this " - "takes precedence over suggested_directory + suggested_filename." - ), - ) - - -def _extract_text_from_message(message: BaseMessage) -> str: - content = getattr(message, "content", "") - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, str): - parts.append(item) - elif isinstance(item, dict) and item.get("type") == "text": - parts.append(str(item.get("text", ""))) - return "\n".join(part for part in parts if part) - return str(content) - - -def _extract_json_payload(text: str) -> str: - stripped = text.strip() - fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL) - if fenced: - return fenced.group(1) - start = stripped.find("{") - end = stripped.rfind("}") - if start != -1 and end != -1 and end > start: - return stripped[start : end + 1] - return stripped - - -def _sanitize_filename(value: str) -> str: - name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() - name = re.sub(r"\s+", "-", name) - name = name.strip("._-") - if not name: - name = "note" - if len(name) > 80: - name = name[:80].rstrip("-_.") - return name - - -def _sanitize_path_segment(value: str) -> str: - segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() - segment = re.sub(r"\s+", "_", segment) - segment = segment.strip("._-") - return segment - - -def _normalize_directory(value: str) -> str: - raw = value.strip().replace("\\", "/") - raw = raw.strip("/") - if not raw: - return "" - parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()] - parts = [part for part in parts if part] - return "/".join(parts) - - -def _normalize_file_path(value: str) -> str: - raw = value.strip().replace("\\", "/").strip() - if not raw: - return "" - had_trailing_slash = raw.endswith("/") - raw = raw.strip("/") - if not raw: - return "" - parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()] - parts = [part for part in parts if part] - if not parts: - return "" - if had_trailing_slash: - return f"/{'/'.join(parts)}/" - return f"/{'/'.join(parts)}" - - -def _infer_directory_from_user_text(user_text: str) -> str | None: - patterns = ( - r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b", - r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b", - ) - lowered = user_text.lower() - for pattern in patterns: - match = re.search(pattern, lowered, flags=re.IGNORECASE) - if not match: - continue - candidate = match.group(1).strip() - if candidate in {"the", "a", "an"}: - continue - normalized = _normalize_directory(candidate) - if normalized: - return normalized - return None - - -def _fallback_path( - suggested_filename: str | None, - *, - suggested_directory: str | None = None, - suggested_path: str | None = None, - user_text: str, -) -> str: - inferred_dir = _infer_directory_from_user_text(user_text) - - sanitized_filename = "" - if suggested_filename: - sanitized_filename = _sanitize_filename(suggested_filename) - if sanitized_filename.lower().endswith(".txt"): - sanitized_filename = f"{sanitized_filename[:-4]}.md" - if not sanitized_filename: - sanitized_filename = "notes.md" - elif "." not in sanitized_filename: - sanitized_filename = f"{sanitized_filename}.md" - - normalized_suggested_path = ( - _normalize_file_path(suggested_path) if suggested_path else "" - ) - if normalized_suggested_path: - if normalized_suggested_path.endswith("/"): - return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}" - return normalized_suggested_path - - directory = _normalize_directory(suggested_directory or "") - if not directory and inferred_dir: - directory = inferred_dir - if directory: - return f"/{directory}/{sanitized_filename}" - - return f"/{sanitized_filename}" - - -def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str: - return ( - "Classify the latest user request into a filesystem intent for an AI agent.\n" - "Return JSON only with this exact schema:\n" - '{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n' - "Rules:\n" - "- Use semantic intent, not literal keywords.\n" - "- file_write: user asks to create/save/write/update/edit content as a file.\n" - "- file_read: user asks to open/read/list/search existing files.\n" - "- chat_only: conversational/analysis responses without required file operations.\n" - "- For file_write, choose a concise semantic suggested_filename and match the requested format.\n" - "- If the user mentions a folder/directory, populate suggested_directory.\n" - "- If user specifies an explicit full path, populate suggested_path.\n" - "- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n" - "- Do not use .txt; prefer .md for generic text notes.\n" - "- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n" - "- Never include markdown or explanation.\n\n" - f"Recent conversation:\n{recent_conversation or '(none)'}\n\n" - f"Latest user message:\n{user_text}" - ) - - -def _build_recent_conversation( - messages: list[BaseMessage], *, max_messages: int = 6 -) -> str: - rows: list[str] = [] - filtered: list[tuple[str, BaseMessage]] = [] - for msg in messages: - role: str | None = None - if isinstance(msg, HumanMessage): - role = "user" - elif isinstance(msg, AIMessage): - if getattr(msg, "tool_calls", None): - continue - role = "assistant" - else: - continue - filtered.append((role, msg)) - for role, msg in filtered[-max_messages:]: - text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip() - if text: - rows.append(f"{role}: {text[:280]}") - return "\n".join(rows) - - -class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Classify file intent and inject a strict file-write contract.""" - - tools = () - - def __init__(self, *, llm: BaseChatModel | None = None) -> None: - self.llm = llm - - async def _classify_intent( - self, *, messages: list[BaseMessage], user_text: str - ) -> FileIntentPlan: - if self.llm is None: - return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) - - prompt = _build_classifier_prompt( - recent_conversation=_build_recent_conversation(messages), - user_text=user_text, - ) - try: - response = await self.llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal"]}, - ) - payload = json.loads( - _extract_json_payload(_extract_text_from_message(response)) - ) - plan = FileIntentPlan.model_validate(payload) - return plan - except (json.JSONDecodeError, ValidationError, ValueError) as exc: - logger.warning("File intent classifier returned invalid output: %s", exc) - except Exception as exc: # pragma: no cover - defensive fallback - logger.warning("File intent classifier failed: %s", exc) - - return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0) - - async def abefore_agent( # type: ignore[override] - self, - state: AgentState, - runtime: Runtime[Any], - ) -> dict[str, Any] | None: - del runtime - messages = state.get("messages") or [] - if not messages: - return None - - last_human: HumanMessage | None = None - for msg in reversed(messages): - if isinstance(msg, HumanMessage): - last_human = msg - break - if last_human is None: - return None - - user_text = _extract_text_from_message(last_human).strip() - if not user_text: - return None - - plan = await self._classify_intent(messages=messages, user_text=user_text) - suggested_path = _fallback_path( - plan.suggested_filename, - suggested_directory=plan.suggested_directory, - suggested_path=plan.suggested_path, - user_text=user_text, - ) - contract = { - "intent": plan.intent.value, - "confidence": plan.confidence, - "suggested_path": suggested_path, - "timestamp": datetime.now(UTC).isoformat(), - "turn_id": state.get("turn_id", ""), - } - - if plan.intent != FileOperationIntent.FILE_WRITE: - return {"file_operation_contract": contract} - - contract_msg = SystemMessage( - content=( - "\n" - "This turn intent is file_write.\n" - f"Suggested default path: {suggested_path}\n" - "Rules:\n" - "- You MUST call write_file or edit_file before claiming success.\n" - "- If no path is provided by the user, use the suggested default path.\n" - "- Do not claim a file was created/updated unless tool output confirms it.\n" - "- If the write/edit fails, clearly report failure instead of success.\n" - "- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n" - "- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n" - "" - ) - ) - - # Insert just before the latest human turn so it applies to this request. - new_messages = list(messages) - insert_at = max(len(new_messages) - 1, 0) - new_messages.insert(insert_at, contract_msg) - return {"messages": new_messages, "file_operation_contract": contract} diff --git a/surfsense_backend/app/agents/shared/middleware/flatten_system.py b/surfsense_backend/app/agents/shared/middleware/flatten_system.py deleted file mode 100644 index 4a621d70a..000000000 --- a/surfsense_backend/app/agents/shared/middleware/flatten_system.py +++ /dev/null @@ -1,233 +0,0 @@ -r"""Coalesce multi-block system messages into a single text block. - -Several middlewares in our deepagent stack each call -``append_to_system_message`` on the way down to the model -(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``, -``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the -request reaches the LLM, the system message has 5+ separate text blocks. - -Anthropic enforces a hard cap of **4 ``cache_control`` blocks per -request**, and we configure 2 injection points -(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting -the prepended ``request.system_message``, this middleware is the -defensive partner: it guarantees that "the system block" is *one* -content block, so LiteLLM's ``AnthropicCacheControlHook`` and any -OpenRouter→Anthropic transformer can never multiply our budget into -several breakpoints by spreading ``cache_control`` across multiple -text blocks of a multi-block system content. - -Without flattening we used to see:: - - OpenrouterException - {"error":{"message":"Provider returned error", - "code":400,"metadata":{"raw":"...A maximum of 4 blocks with - cache_control may be provided. Found 5."}}} - -(Same error class documented in -https://github.com/BerriAI/litellm/issues/15696 and -https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix -in PR #15395 covers the litellm transformer but does not protect us -when the OpenRouter SaaS itself does the redistribution.) - -A separate fix in :mod:`app.agents.shared.prompt_caching` (switching -the first injection point from ``role: system`` to ``index: 0``) -neutralises the *primary* cause of the same 400 — multiple -``SystemMessage``\ s injected by ``before_agent`` middlewares -(priority/tree/memory/file-intent/anonymous-doc) accumulating across -turns, each tagged with ``cache_control`` by the ``role: system`` -matcher. This middleware remains useful as defence-in-depth against -the multi-block redistribution path. - -Placement: innermost on the system-message-mutation chain, after every -appender (``todo``/``filesystem``/``skills``/``subagents``) and after -summarization, but before ``noop``/``retry``/``fallback`` so each retry -attempt sees a flattened payload. - -Idempotent: a string-content system message is left untouched. A list -that contains anything other than plain text blocks (e.g. an image) is -also left untouched — those are rare on system messages and we'd lose -the non-text payload by joining. -""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable -from typing import Any - -from langchain.agents.middleware.types import ( - AgentMiddleware, - AgentState, - ContextT, - ModelRequest, - ModelResponse, - ResponseT, -) -from langchain_core.messages import SystemMessage - -logger = logging.getLogger(__name__) - - -def _flatten_text_blocks(content: list[Any]) -> str | None: - """Return joined text if every block is a plain ``{"type": "text"}``. - - Returns ``None`` when the list contains anything that isn't a text - block we can safely concatenate (image, audio, file, non-standard - blocks, dicts with extra non-cache_control fields). The caller - leaves the original content untouched in that case rather than - silently dropping payload. - - ``cache_control`` on individual blocks is intentionally discarded — - the whole point of flattening is to let LiteLLM's - ``cache_control_injection_points`` re-place a single breakpoint on - the resulting one-block system content. - """ - chunks: list[str] = [] - for block in content: - if isinstance(block, str): - chunks.append(block) - continue - if not isinstance(block, dict): - return None - if block.get("type") != "text": - return None - text = block.get("text") - if not isinstance(text, str): - return None - chunks.append(text) - return "\n\n".join(chunks) - - -def _flattened_request( - request: ModelRequest[ContextT], -) -> ModelRequest[ContextT] | None: - """Return a request with system_message flattened, or ``None`` for no-op.""" - sys_msg = request.system_message - if sys_msg is None: - return None - content = sys_msg.content - if not isinstance(content, list) or len(content) <= 1: - return None - - flattened = _flatten_text_blocks(content) - if flattened is None: - return None - - new_sys = SystemMessage( - content=flattened, - additional_kwargs=dict(sys_msg.additional_kwargs), - response_metadata=dict(sys_msg.response_metadata), - ) - if sys_msg.id is not None: - new_sys.id = sys_msg.id - return request.override(system_message=new_sys) - - -def _diagnostic_summary(request: ModelRequest[Any]) -> str: - """One-line dump of cache_control-relevant request shape. - - Temporary diagnostic to prove where the ``Found N`` cache_control - breakpoints are coming from when Anthropic 400s. Removed once the - root cause is confirmed and a fix is in place. - """ - sys_msg = request.system_message - if sys_msg is None: - sys_shape = "none" - elif isinstance(sys_msg.content, str): - sys_shape = f"str(len={len(sys_msg.content)})" - elif isinstance(sys_msg.content, list): - sys_shape = f"list(blocks={len(sys_msg.content)})" - else: - sys_shape = f"other({type(sys_msg.content).__name__})" - - role_hist: list[str] = [] - multi_block_msgs = 0 - msgs_with_cc = 0 - sys_msgs_in_history = 0 - for m in request.messages: - mtype = getattr(m, "type", type(m).__name__) - role_hist.append(mtype) - if isinstance(m, SystemMessage): - sys_msgs_in_history += 1 - c = getattr(m, "content", None) - if isinstance(c, list): - multi_block_msgs += 1 - for blk in c: - if isinstance(blk, dict) and "cache_control" in blk: - msgs_with_cc += 1 - break - if "cache_control" in getattr(m, "additional_kwargs", {}) or {}: - msgs_with_cc += 1 - - tools = request.tools or [] - tools_with_cc = 0 - for t in tools: - if isinstance(t, dict) and ( - "cache_control" in t or "cache_control" in t.get("function", {}) - ): - tools_with_cc += 1 - - return ( - f"sys={sys_shape} msgs={len(request.messages)} " - f"sys_msgs_in_history={sys_msgs_in_history} " - f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} " - f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} " - f"roles={role_hist[-8:]}" - ) - - -class FlattenSystemMessageMiddleware( - AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] -): - """Collapse a multi-text-block system message to a single string. - - Sits innermost on the system-message-mutation chain so it observes - every middleware's contribution. Has no other side effect — the - body of every block is preserved, just joined with ``"\\n\\n"``. - """ - - def __init__(self) -> None: - super().__init__() - self.tools = [] - - def wrap_model_call( # type: ignore[override] - self, - request: ModelRequest[ContextT], - handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], - ) -> Any: - if logger.isEnabledFor(logging.DEBUG): - logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) - flattened = _flattened_request(request) - if flattened is not None: - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[flatten_system] collapsed %d system blocks to one", - len(request.system_message.content), # type: ignore[arg-type, union-attr] - ) - return handler(flattened) - return handler(request) - - async def awrap_model_call( # type: ignore[override] - self, - request: ModelRequest[ContextT], - handler: Callable[ - [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] - ], - ) -> Any: - if logger.isEnabledFor(logging.DEBUG): - logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request)) - flattened = _flattened_request(request) - if flattened is not None: - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[flatten_system] collapsed %d system blocks to one", - len(request.system_message.content), # type: ignore[arg-type, union-attr] - ) - return await handler(flattened) - return await handler(request) - - -__all__ = [ - "FlattenSystemMessageMiddleware", - "_flatten_text_blocks", - "_flattened_request", -] diff --git a/surfsense_backend/app/agents/shared/middleware/knowledge_search.py b/surfsense_backend/app/agents/shared/middleware/knowledge_search.py index b71ed7035..26f06f4a5 100644 --- a/surfsense_backend/app/agents/shared/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/shared/middleware/knowledge_search.py @@ -1049,12 +1049,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] return priority, matched_chunk_ids -# Backwards-compatible alias for any external imports. -KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware - - __all__ = [ - "KnowledgeBaseSearchMiddleware", "KnowledgePriorityMiddleware", "browse_recent_documents", "fetch_mentioned_documents", diff --git a/surfsense_backend/app/agents/shared/prompt_caching.py b/surfsense_backend/app/agents/shared/prompt_caching.py index f8aae45a8..d72ef22bc 100644 --- a/surfsense_backend/app/agents/shared/prompt_caching.py +++ b/surfsense_backend/app/agents/shared/prompt_caching.py @@ -78,14 +78,12 @@ logger = logging.getLogger(__name__) # blocks; we use 2 here, leaving headroom for Phase-2 tool caching. # # IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's -# ``before_agent`` middlewares (priority, tree, memory, file-intent, -# anonymous-doc) insert ``SystemMessage`` instances into -# ``state["messages"]`` that accumulate across turns. With -# ``role: system`` the LiteLLM hook would tag *every* one of them with -# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0`` -# always targets the langchain-prepended ``request.system_message`` -# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text -# block), giving us exactly one stable cache breakpoint. +# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc) +# insert ``SystemMessage`` instances into ``state["messages"]`` that +# accumulate across turns. With ``role: system`` the LiteLLM hook would +# tag *every* one of them with ``cache_control`` and overflow Anthropic's +# 4-block limit. ``index: 0`` always targets the langchain-prepended +# ``request.system_message``, giving us exactly one stable cache breakpoint. _DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( {"location": "message", "index": 0}, {"location": "message", "index": -1}, diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py b/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py deleted file mode 100644 index f38d1ebc2..000000000 --- a/surfsense_backend/tests/unit/agents/new_chat/test_flatten_system.py +++ /dev/null @@ -1,344 +0,0 @@ -"""Tests for ``FlattenSystemMessageMiddleware``. - -The middleware exists to defend against Anthropic's "Found 5 cache_control -blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on -the system message and the OpenRouter→Anthropic adapter redistributes -``cache_control`` across all of them. The flattening collapses every -all-text system content list to a single string before the LLM call. -""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest -from langchain_core.messages import HumanMessage, SystemMessage - -from app.agents.shared.middleware.flatten_system import ( - FlattenSystemMessageMiddleware, - _flatten_text_blocks, - _flattened_request, -) - -pytestmark = pytest.mark.unit - - -# --------------------------------------------------------------------------- -# _flatten_text_blocks — pure helper, the heart of the middleware. -# --------------------------------------------------------------------------- - - -class TestFlattenTextBlocks: - def test_joins_text_blocks_with_double_newline(self) -> None: - blocks = [ - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - ] - assert ( - _flatten_text_blocks(blocks) - == "\n\n\n\n" - ) - - def test_handles_single_text_block(self) -> None: - blocks = [{"type": "text", "text": "only one"}] - assert _flatten_text_blocks(blocks) == "only one" - - def test_handles_empty_list(self) -> None: - assert _flatten_text_blocks([]) == "" - - def test_passes_through_bare_string_blocks(self) -> None: - # LangChain content can mix bare strings and dict blocks. - blocks = ["raw string", {"type": "text", "text": "dict block"}] - assert _flatten_text_blocks(blocks) == "raw string\n\ndict block" - - def test_returns_none_for_image_block(self) -> None: - # System messages with images are rare — but we never want to - # silently lose the image payload by joining as text. - blocks = [ - {"type": "text", "text": "look at this"}, - {"type": "image_url", "image_url": {"url": "data:image/png..."}}, - ] - assert _flatten_text_blocks(blocks) is None - - def test_returns_none_for_non_dict_non_str_block(self) -> None: - blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item] - assert _flatten_text_blocks(blocks) is None - - def test_returns_none_when_text_field_missing(self) -> None: - blocks = [{"type": "text"}] # no ``text`` key - assert _flatten_text_blocks(blocks) is None - - def test_returns_none_when_text_is_not_string(self) -> None: - blocks = [{"type": "text", "text": ["nested", "list"]}] - assert _flatten_text_blocks(blocks) is None - - def test_drops_cache_control_from_inner_blocks(self) -> None: - # The whole point: existing cache_control on inner blocks is - # discarded so LiteLLM's ``cache_control_injection_points`` can - # re-attach exactly one breakpoint after flattening. - blocks = [ - {"type": "text", "text": "first"}, - { - "type": "text", - "text": "second", - "cache_control": {"type": "ephemeral"}, - }, - ] - flattened = _flatten_text_blocks(blocks) - assert flattened == "first\n\nsecond" - assert "cache_control" not in flattened # type: ignore[operator] - - -# --------------------------------------------------------------------------- -# _flattened_request — decides when to override and when to no-op. -# --------------------------------------------------------------------------- - - -def _make_request(system_message: SystemMessage | None) -> Any: - """Build a minimal ModelRequest stub. We only need .system_message - and .override(system_message=...) — the middleware never touches - other fields. - """ - request = MagicMock() - request.system_message = system_message - - def override(**kwargs: Any) -> Any: - new_request = MagicMock() - new_request.system_message = kwargs.get( - "system_message", request.system_message - ) - new_request.messages = kwargs.get("messages", getattr(request, "messages", [])) - new_request.tools = kwargs.get("tools", getattr(request, "tools", [])) - return new_request - - request.override = override - return request - - -class TestFlattenedRequest: - def test_collapses_multi_block_system_to_string(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - ] - ) - request = _make_request(sys) - flattened = _flattened_request(request) - - assert flattened is not None - assert isinstance(flattened.system_message, SystemMessage) - assert flattened.system_message.content == ( - "\n\n\n\n\n\n\n\n" - ) - - def test_no_op_for_string_content(self) -> None: - sys = SystemMessage(content="already a string") - request = _make_request(sys) - assert _flattened_request(request) is None - - def test_no_op_for_single_block_list(self) -> None: - # One block already produces one breakpoint — no need to flatten. - sys = SystemMessage(content=[{"type": "text", "text": "single"}]) - request = _make_request(sys) - assert _flattened_request(request) is None - - def test_no_op_when_system_message_missing(self) -> None: - request = _make_request(None) - assert _flattened_request(request) is None - - def test_no_op_when_list_contains_non_text_block(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "look"}, - {"type": "image_url", "image_url": {"url": "data:..."}}, - ] - ) - request = _make_request(sys) - assert _flattened_request(request) is None - - def test_preserves_additional_kwargs_and_metadata(self) -> None: - # Defensive: nothing in the current chain sets these on a system - # message, but losing them silently when something does in the - # future would be a regression. ``name`` in particular is the only - # ``additional_kwargs`` field that ChatLiteLLM's - # ``_convert_message_to_dict`` propagates onto the wire. - sys = SystemMessage( - content=[ - {"type": "text", "text": "a"}, - {"type": "text", "text": "b"}, - ], - additional_kwargs={"name": "surfsense_system", "x": 1}, - response_metadata={"tokens": 42}, - ) - sys.id = "sys-msg-1" - request = _make_request(sys) - - flattened = _flattened_request(request) - assert flattened is not None - assert flattened.system_message.content == "a\n\nb" - assert flattened.system_message.additional_kwargs == { - "name": "surfsense_system", - "x": 1, - } - assert flattened.system_message.response_metadata == {"tokens": 42} - assert flattened.system_message.id == "sys-msg-1" - - def test_idempotent_when_run_twice(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "a"}, - {"type": "text", "text": "b"}, - ] - ) - request = _make_request(sys) - first = _flattened_request(request) - assert first is not None - - # Second pass on the already-flattened request should be a no-op. - # We re-wrap in a request stub since the helper inspects - # ``request.system_message.content``. - second_request = _make_request(first.system_message) - assert _flattened_request(second_request) is None - - -# --------------------------------------------------------------------------- -# Middleware integration — verify the handler sees a flattened request. -# --------------------------------------------------------------------------- - - -class TestMiddlewareWrap: - @pytest.mark.asyncio - async def test_async_passes_flattened_request_to_handler(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "alpha"}, - {"type": "text", "text": "beta"}, - ] - ) - request = _make_request(sys) - captured: dict[str, Any] = {} - - async def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - result = await mw.awrap_model_call(request, handler) - - assert result == "ok" - assert isinstance(captured["request"].system_message, SystemMessage) - assert captured["request"].system_message.content == "alpha\n\nbeta" - - @pytest.mark.asyncio - async def test_async_passes_through_when_already_string(self) -> None: - sys = SystemMessage(content="just a string") - request = _make_request(sys) - captured: dict[str, Any] = {} - - async def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - await mw.awrap_model_call(request, handler) - - # Same request object: no override happened. - assert captured["request"] is request - - def test_sync_passes_flattened_request_to_handler(self) -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": "alpha"}, - {"type": "text", "text": "beta"}, - ] - ) - request = _make_request(sys) - captured: dict[str, Any] = {} - - def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - result = mw.wrap_model_call(request, handler) - - assert result == "ok" - assert captured["request"].system_message.content == "alpha\n\nbeta" - - def test_sync_passes_through_when_no_system_message(self) -> None: - request = _make_request(None) - captured: dict[str, Any] = {} - - def handler(req: Any) -> str: - captured["request"] = req - return "ok" - - mw = FlattenSystemMessageMiddleware() - mw.wrap_model_call(request, handler) - assert captured["request"] is request - - -# --------------------------------------------------------------------------- -# Regression guard — pin the worst-case shape that triggered the -# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the -# downstream cache_control_injection_points can only place 1 breakpoint -# on the system message regardless of provider redistribution quirks. -# --------------------------------------------------------------------------- - - -def test_regression_five_block_system_collapses_to_one_block() -> None: - sys = SystemMessage( - content=[ - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - {"type": "text", "text": ""}, - ] - ) - request = _make_request(sys) - flattened = _flattened_request(request) - - assert flattened is not None - assert isinstance(flattened.system_message.content, str) - # The exact join doesn't matter for the cache_control accounting — - # only that there is exactly ONE content block when LiteLLM's - # AnthropicCacheControlHook later targets ``role: system``. - assert " None: - # Sanity: the middleware MUST NOT touch user messages — only the - # system message. Multi-block user content is the path that carries - # image attachments and would lose its image_url block on - # accidental flatten. - sys = SystemMessage( - content=[ - {"type": "text", "text": "a"}, - {"type": "text", "text": "b"}, - ] - ) - user = HumanMessage( - content=[ - {"type": "text", "text": "look at this"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, - ] - ) - request = _make_request(sys) - request.messages = [user] - - flattened = _flattened_request(request) - assert flattened is not None - # System flattened to string … - assert isinstance(flattened.system_message.content, str) - # … user message is untouched (the helper does not even look at it). - assert flattened.messages == [user] - assert isinstance(user.content, list) - assert len(user.content) == 2 diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py deleted file mode 100644 index e1d522201..000000000 --- a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py +++ /dev/null @@ -1,214 +0,0 @@ -import pytest -from langchain_core.messages import AIMessage, HumanMessage - -from app.agents.shared.middleware.file_intent import ( - FileIntentMiddleware, - FileOperationIntent, - _fallback_path, -) - -pytestmark = pytest.mark.unit - - -class _FakeLLM: - def __init__(self, response_text: str): - self._response_text = response_text - - async def ainvoke(self, *_args, **_kwargs): - return AIMessage(content=self._response_text) - - -@pytest.mark.asyncio -async def test_file_write_intent_injects_contract_message(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="Create another random note for me")], - "turn_id": "123:456", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/ideas.md" - assert contract["turn_id"] == "123:456" - assert any( - "file_operation_contract" in str(msg.content) - for msg in result["messages"] - if hasattr(msg, "content") - ) - - -@pytest.mark.asyncio -async def test_non_write_intent_does_not_inject_contract_message(): - llm = _FakeLLM('{"intent":"file_read","confidence":0.88,"suggested_filename":null}') - middleware = FileIntentMiddleware(llm=llm) - original_messages = [HumanMessage(content="Read /notes.md")] - state = {"messages": original_messages, "turn_id": "abc:def"} - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - assert ( - result["file_operation_contract"]["intent"] - == FileOperationIntent.FILE_READ.value - ) - assert "messages" not in result - - -@pytest.mark.asyncio -async def test_file_write_null_filename_uses_semantic_default_path(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.74,"suggested_filename":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random markdown file")], - "turn_id": "turn:1", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/notes.md" - - -@pytest.mark.asyncio -async def test_file_write_null_filename_defaults_to_markdown_path(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.71,"suggested_filename":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a sample json config file")], - "turn_id": "turn:2", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/notes.md" - - -@pytest.mark.asyncio -async def test_file_write_txt_suggestion_is_normalized_to_markdown(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random file")], - "turn_id": "turn:3", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/random.md" - - -@pytest.mark.asyncio -async def test_file_write_with_suggested_directory_preserves_folder(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random file in pc backups folder")], - "turn_id": "turn:4", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/pc_backups/random.md" - - -@pytest.mark.asyncio -async def test_file_write_with_suggested_path_takes_precedence(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create report")], - "turn_id": "turn:5", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/reports/q2/summary.md" - - -@pytest.mark.asyncio -async def test_file_write_infers_directory_from_user_text_when_missing(): - llm = _FakeLLM( - '{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}' - ) - middleware = FileIntentMiddleware(llm=llm) - state = { - "messages": [HumanMessage(content="create a random file in pc backups folder")], - "turn_id": "turn:6", - } - - result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] - - assert result is not None - contract = result["file_operation_contract"] - assert contract["intent"] == FileOperationIntent.FILE_WRITE.value - assert contract["suggested_path"] == "/pc_backups/random.md" - - -def test_fallback_path_normalizes_windows_slashes() -> None: - resolved = _fallback_path( - suggested_filename="summary.md", - suggested_path=r"\reports\q2\summary.md", - user_text="create report", - ) - - assert resolved == "/reports/q2/summary.md" - - -def test_fallback_path_normalizes_windows_drive_path() -> None: - resolved = _fallback_path( - suggested_filename=None, - suggested_path=r"C:\Users\anish\notes\todo.md", - user_text="create note", - ) - - assert resolved == "/C/Users/anish/notes/todo.md" - - -def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None: - resolved = _fallback_path( - suggested_filename="summary.md", - suggested_path=r"\\reports\\q2//summary.md", - user_text="create report", - ) - - assert resolved == "/reports/q2/summary.md" - - -def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None: - resolved = _fallback_path( - suggested_filename=None, - suggested_path="/var/log/surfsense/notes.md", - user_text="create note", - ) - - assert resolved == "/var/log/surfsense/notes.md" diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index c97bcde0a..00304794b 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -10,7 +10,7 @@ from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.document_ ) from app.agents.shared.middleware.knowledge_search import ( KBSearchPlan, - KnowledgeBaseSearchMiddleware, + KnowledgePriorityMiddleware, _normalize_optional_date_range, _parse_kb_search_plan_response, _render_recent_conversation, @@ -203,7 +203,7 @@ class FakeBudgetLLM: return sum(len(msg.get("content", "")) for msg in messages) -class TestKnowledgeBaseSearchMiddlewarePlanner: +class TestKnowledgePriorityMiddlewarePlanner: @pytest.fixture(autouse=True) def _disable_planner_runnable(self, monkeypatch): # ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the @@ -273,7 +273,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: } ) ) - middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=37) + middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37) result = await middleware.abefore_agent( { @@ -307,7 +307,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=FakeLLM("not json"), search_space_id=37, ) @@ -336,7 +336,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=FakeLLM( json.dumps( { @@ -395,7 +395,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: } ) ) - middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42) result = await middleware.abefore_agent( {"messages": [HumanMessage(content="what's my latest file?")]}, @@ -442,7 +442,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: } ) ) - middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42) + middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42) await middleware.abefore_agent( {"messages": [HumanMessage(content="find the quarterly revenue report")]}, @@ -559,7 +559,7 @@ class TestKnowledgePriorityMentionDrain: fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=self._planner_llm(), search_space_id=42, mentioned_document_ids=[1, 2, 3], @@ -609,7 +609,7 @@ class TestKnowledgePriorityMentionDrain: # Simulate a cached middleware instance whose closure was seeded # by a previous turn's cache-miss build (mentions=[1,2,3]). - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=self._planner_llm(), search_space_id=42, mentioned_document_ids=[1, 2, 3], @@ -652,7 +652,7 @@ class TestKnowledgePriorityMentionDrain: fake_search_knowledge_base, ) - middleware = KnowledgeBaseSearchMiddleware( + middleware = KnowledgePriorityMiddleware( llm=self._planner_llm(), search_space_id=42, mentioned_document_ids=[7, 8],