mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): delete dead single-agent-only middleware
file_intent (FileIntentMiddleware) and flatten_system
(FlattenSystemMessageMiddleware) were only ever instantiated in the
single-agent chat_deepagent stack, which was removed in 14bbea085. They
have no production consumer in multi_agent_chat. Delete both modules and
their unit tests.
Also drop the vestigial KnowledgeBaseSearchMiddleware alias (= the live
KnowledgePriorityMiddleware); its tests now target the real class so the
behavior coverage is preserved. Trim the three barrel/__all__ entries and
strip the now-dead class names from comments.
This commit is contained in:
parent
21509e7eca
commit
afa51e97cf
10 changed files with 19 additions and 1161 deletions
|
|
@ -50,8 +50,8 @@ class SurfSenseContextSchema:
|
||||||
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
|
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
|
||||||
entries in ``<priority_documents>`` so the agent prioritises
|
entries in ``<priority_documents>`` so the agent prioritises
|
||||||
walking those folders with ``ls`` / ``find_documents``.
|
walking those folders with ``ls`` / ``find_documents``.
|
||||||
file_operation_contract: One-shot file operation contract emitted
|
file_operation_contract: One-shot file operation contract for the
|
||||||
by ``FileIntentMiddleware`` for the upcoming turn.
|
upcoming turn (reserved; not currently populated).
|
||||||
turn_id / request_id: Correlation IDs surfaced by the streaming
|
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||||
task; populated for telemetry.
|
task; populated for telemetry.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,18 +21,11 @@ from app.agents.shared.middleware.dedup_tool_calls import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
)
|
)
|
||||||
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
||||||
from app.agents.shared.middleware.file_intent import (
|
|
||||||
FileIntentMiddleware,
|
|
||||||
)
|
|
||||||
from app.agents.shared.middleware.flatten_system import (
|
|
||||||
FlattenSystemMessageMiddleware,
|
|
||||||
)
|
|
||||||
from app.agents.shared.middleware.kb_persistence import (
|
from app.agents.shared.middleware.kb_persistence import (
|
||||||
KnowledgeBasePersistenceMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
commit_staged_filesystem_state,
|
commit_staged_filesystem_state,
|
||||||
)
|
)
|
||||||
from app.agents.shared.middleware.knowledge_search import (
|
from app.agents.shared.middleware.knowledge_search import (
|
||||||
KnowledgeBaseSearchMiddleware,
|
|
||||||
KnowledgePriorityMiddleware,
|
KnowledgePriorityMiddleware,
|
||||||
)
|
)
|
||||||
from app.agents.shared.middleware.knowledge_tree import (
|
from app.agents.shared.middleware.knowledge_tree import (
|
||||||
|
|
@ -56,10 +49,7 @@ __all__ = [
|
||||||
"ClearToolUsesEdit",
|
"ClearToolUsesEdit",
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
"DoomLoopMiddleware",
|
"DoomLoopMiddleware",
|
||||||
"FileIntentMiddleware",
|
|
||||||
"FlattenSystemMessageMiddleware",
|
|
||||||
"KnowledgeBasePersistenceMiddleware",
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
|
||||||
"KnowledgePriorityMiddleware",
|
"KnowledgePriorityMiddleware",
|
||||||
"KnowledgeTreeMiddleware",
|
"KnowledgeTreeMiddleware",
|
||||||
"MemoryInjectionMiddleware",
|
"MemoryInjectionMiddleware",
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,7 @@ Respond ONLY with the structured summary. Do not include any text before or afte
|
||||||
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
||||||
"<priority_documents>", # KnowledgePriorityMiddleware
|
"<priority_documents>", # KnowledgePriorityMiddleware
|
||||||
"<workspace_tree>", # KnowledgeTreeMiddleware
|
"<workspace_tree>", # KnowledgeTreeMiddleware
|
||||||
"<file_operation_contract>", # FileIntentMiddleware
|
"<file_operation_contract>", # reserved file-operation contract prefix
|
||||||
"<user_memory>", # MemoryInjectionMiddleware
|
"<user_memory>", # MemoryInjectionMiddleware
|
||||||
"<team_memory>", # MemoryInjectionMiddleware
|
"<team_memory>", # MemoryInjectionMiddleware
|
||||||
"<user_name>", # MemoryInjectionMiddleware
|
"<user_name>", # MemoryInjectionMiddleware
|
||||||
|
|
|
||||||
|
|
@ -1,334 +0,0 @@
|
||||||
"""Semantic file-intent routing middleware for new chat turns.
|
|
||||||
|
|
||||||
This middleware classifies the latest human turn into a small intent set:
|
|
||||||
- chat_only
|
|
||||||
- file_write
|
|
||||||
- file_read
|
|
||||||
|
|
||||||
For ``file_write`` turns it injects a strict system contract so the model
|
|
||||||
uses filesystem tools before claiming success, and provides a deterministic
|
|
||||||
fallback path when no filename is specified by the user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FileOperationIntent(StrEnum):
|
|
||||||
CHAT_ONLY = "chat_only"
|
|
||||||
FILE_WRITE = "file_write"
|
|
||||||
FILE_READ = "file_read"
|
|
||||||
|
|
||||||
|
|
||||||
class FileIntentPlan(BaseModel):
|
|
||||||
intent: FileOperationIntent = Field(
|
|
||||||
description="Primary user intent for this turn."
|
|
||||||
)
|
|
||||||
confidence: float = Field(
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
default=0.5,
|
|
||||||
description="Model confidence in the selected intent.",
|
|
||||||
)
|
|
||||||
suggested_filename: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional filename (e.g. notes.md) inferred from user request.",
|
|
||||||
)
|
|
||||||
suggested_directory: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
|
|
||||||
"user request."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
suggested_path: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
|
|
||||||
"takes precedence over suggested_directory + suggested_filename."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
|
||||||
content = getattr(message, "content", "")
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict) and item.get("type") == "text":
|
|
||||||
parts.append(str(item.get("text", "")))
|
|
||||||
return "\n".join(part for part in parts if part)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_json_payload(text: str) -> str:
|
|
||||||
stripped = text.strip()
|
|
||||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
|
||||||
if fenced:
|
|
||||||
return fenced.group(1)
|
|
||||||
start = stripped.find("{")
|
|
||||||
end = stripped.rfind("}")
|
|
||||||
if start != -1 and end != -1 and end > start:
|
|
||||||
return stripped[start : end + 1]
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_filename(value: str) -> str:
|
|
||||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
|
||||||
name = re.sub(r"\s+", "-", name)
|
|
||||||
name = name.strip("._-")
|
|
||||||
if not name:
|
|
||||||
name = "note"
|
|
||||||
if len(name) > 80:
|
|
||||||
name = name[:80].rstrip("-_.")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_path_segment(value: str) -> str:
|
|
||||||
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
|
||||||
segment = re.sub(r"\s+", "_", segment)
|
|
||||||
segment = segment.strip("._-")
|
|
||||||
return segment
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_directory(value: str) -> str:
|
|
||||||
raw = value.strip().replace("\\", "/")
|
|
||||||
raw = raw.strip("/")
|
|
||||||
if not raw:
|
|
||||||
return ""
|
|
||||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
|
||||||
parts = [part for part in parts if part]
|
|
||||||
return "/".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_file_path(value: str) -> str:
|
|
||||||
raw = value.strip().replace("\\", "/").strip()
|
|
||||||
if not raw:
|
|
||||||
return ""
|
|
||||||
had_trailing_slash = raw.endswith("/")
|
|
||||||
raw = raw.strip("/")
|
|
||||||
if not raw:
|
|
||||||
return ""
|
|
||||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
|
||||||
parts = [part for part in parts if part]
|
|
||||||
if not parts:
|
|
||||||
return ""
|
|
||||||
if had_trailing_slash:
|
|
||||||
return f"/{'/'.join(parts)}/"
|
|
||||||
return f"/{'/'.join(parts)}"
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_directory_from_user_text(user_text: str) -> str | None:
|
|
||||||
patterns = (
|
|
||||||
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
|
|
||||||
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
|
|
||||||
)
|
|
||||||
lowered = user_text.lower()
|
|
||||||
for pattern in patterns:
|
|
||||||
match = re.search(pattern, lowered, flags=re.IGNORECASE)
|
|
||||||
if not match:
|
|
||||||
continue
|
|
||||||
candidate = match.group(1).strip()
|
|
||||||
if candidate in {"the", "a", "an"}:
|
|
||||||
continue
|
|
||||||
normalized = _normalize_directory(candidate)
|
|
||||||
if normalized:
|
|
||||||
return normalized
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _fallback_path(
|
|
||||||
suggested_filename: str | None,
|
|
||||||
*,
|
|
||||||
suggested_directory: str | None = None,
|
|
||||||
suggested_path: str | None = None,
|
|
||||||
user_text: str,
|
|
||||||
) -> str:
|
|
||||||
inferred_dir = _infer_directory_from_user_text(user_text)
|
|
||||||
|
|
||||||
sanitized_filename = ""
|
|
||||||
if suggested_filename:
|
|
||||||
sanitized_filename = _sanitize_filename(suggested_filename)
|
|
||||||
if sanitized_filename.lower().endswith(".txt"):
|
|
||||||
sanitized_filename = f"{sanitized_filename[:-4]}.md"
|
|
||||||
if not sanitized_filename:
|
|
||||||
sanitized_filename = "notes.md"
|
|
||||||
elif "." not in sanitized_filename:
|
|
||||||
sanitized_filename = f"{sanitized_filename}.md"
|
|
||||||
|
|
||||||
normalized_suggested_path = (
|
|
||||||
_normalize_file_path(suggested_path) if suggested_path else ""
|
|
||||||
)
|
|
||||||
if normalized_suggested_path:
|
|
||||||
if normalized_suggested_path.endswith("/"):
|
|
||||||
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
|
|
||||||
return normalized_suggested_path
|
|
||||||
|
|
||||||
directory = _normalize_directory(suggested_directory or "")
|
|
||||||
if not directory and inferred_dir:
|
|
||||||
directory = inferred_dir
|
|
||||||
if directory:
|
|
||||||
return f"/{directory}/{sanitized_filename}"
|
|
||||||
|
|
||||||
return f"/{sanitized_filename}"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
|
|
||||||
return (
|
|
||||||
"Classify the latest user request into a filesystem intent for an AI agent.\n"
|
|
||||||
"Return JSON only with this exact schema:\n"
|
|
||||||
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
|
|
||||||
"Rules:\n"
|
|
||||||
"- Use semantic intent, not literal keywords.\n"
|
|
||||||
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
|
|
||||||
"- file_read: user asks to open/read/list/search existing files.\n"
|
|
||||||
"- chat_only: conversational/analysis responses without required file operations.\n"
|
|
||||||
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
|
|
||||||
"- If the user mentions a folder/directory, populate suggested_directory.\n"
|
|
||||||
"- If user specifies an explicit full path, populate suggested_path.\n"
|
|
||||||
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
|
|
||||||
"- Do not use .txt; prefer .md for generic text notes.\n"
|
|
||||||
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
|
|
||||||
"- Never include markdown or explanation.\n\n"
|
|
||||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
|
||||||
f"Latest user message:\n{user_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_recent_conversation(
|
|
||||||
messages: list[BaseMessage], *, max_messages: int = 6
|
|
||||||
) -> str:
|
|
||||||
rows: list[str] = []
|
|
||||||
filtered: list[tuple[str, BaseMessage]] = []
|
|
||||||
for msg in messages:
|
|
||||||
role: str | None = None
|
|
||||||
if isinstance(msg, HumanMessage):
|
|
||||||
role = "user"
|
|
||||||
elif isinstance(msg, AIMessage):
|
|
||||||
if getattr(msg, "tool_calls", None):
|
|
||||||
continue
|
|
||||||
role = "assistant"
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
filtered.append((role, msg))
|
|
||||||
for role, msg in filtered[-max_messages:]:
|
|
||||||
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
|
||||||
if text:
|
|
||||||
rows.append(f"{role}: {text[:280]}")
|
|
||||||
return "\n".join(rows)
|
|
||||||
|
|
||||||
|
|
||||||
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|
||||||
"""Classify file intent and inject a strict file-write contract."""
|
|
||||||
|
|
||||||
tools = ()
|
|
||||||
|
|
||||||
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
|
|
||||||
self.llm = llm
|
|
||||||
|
|
||||||
async def _classify_intent(
|
|
||||||
self, *, messages: list[BaseMessage], user_text: str
|
|
||||||
) -> FileIntentPlan:
|
|
||||||
if self.llm is None:
|
|
||||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
|
||||||
|
|
||||||
prompt = _build_classifier_prompt(
|
|
||||||
recent_conversation=_build_recent_conversation(messages),
|
|
||||||
user_text=user_text,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await self.llm.ainvoke(
|
|
||||||
[HumanMessage(content=prompt)],
|
|
||||||
config={"tags": ["surfsense:internal"]},
|
|
||||||
)
|
|
||||||
payload = json.loads(
|
|
||||||
_extract_json_payload(_extract_text_from_message(response))
|
|
||||||
)
|
|
||||||
plan = FileIntentPlan.model_validate(payload)
|
|
||||||
return plan
|
|
||||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
|
||||||
logger.warning("File intent classifier returned invalid output: %s", exc)
|
|
||||||
except Exception as exc: # pragma: no cover - defensive fallback
|
|
||||||
logger.warning("File intent classifier failed: %s", exc)
|
|
||||||
|
|
||||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
|
||||||
|
|
||||||
async def abefore_agent( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
state: AgentState,
|
|
||||||
runtime: Runtime[Any],
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
del runtime
|
|
||||||
messages = state.get("messages") or []
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
last_human: HumanMessage | None = None
|
|
||||||
for msg in reversed(messages):
|
|
||||||
if isinstance(msg, HumanMessage):
|
|
||||||
last_human = msg
|
|
||||||
break
|
|
||||||
if last_human is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_text = _extract_text_from_message(last_human).strip()
|
|
||||||
if not user_text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
plan = await self._classify_intent(messages=messages, user_text=user_text)
|
|
||||||
suggested_path = _fallback_path(
|
|
||||||
plan.suggested_filename,
|
|
||||||
suggested_directory=plan.suggested_directory,
|
|
||||||
suggested_path=plan.suggested_path,
|
|
||||||
user_text=user_text,
|
|
||||||
)
|
|
||||||
contract = {
|
|
||||||
"intent": plan.intent.value,
|
|
||||||
"confidence": plan.confidence,
|
|
||||||
"suggested_path": suggested_path,
|
|
||||||
"timestamp": datetime.now(UTC).isoformat(),
|
|
||||||
"turn_id": state.get("turn_id", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
if plan.intent != FileOperationIntent.FILE_WRITE:
|
|
||||||
return {"file_operation_contract": contract}
|
|
||||||
|
|
||||||
contract_msg = SystemMessage(
|
|
||||||
content=(
|
|
||||||
"<file_operation_contract>\n"
|
|
||||||
"This turn intent is file_write.\n"
|
|
||||||
f"Suggested default path: {suggested_path}\n"
|
|
||||||
"Rules:\n"
|
|
||||||
"- You MUST call write_file or edit_file before claiming success.\n"
|
|
||||||
"- If no path is provided by the user, use the suggested default path.\n"
|
|
||||||
"- Do not claim a file was created/updated unless tool output confirms it.\n"
|
|
||||||
"- If the write/edit fails, clearly report failure instead of success.\n"
|
|
||||||
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
|
|
||||||
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
|
|
||||||
"</file_operation_contract>"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Insert just before the latest human turn so it applies to this request.
|
|
||||||
new_messages = list(messages)
|
|
||||||
insert_at = max(len(new_messages) - 1, 0)
|
|
||||||
new_messages.insert(insert_at, contract_msg)
|
|
||||||
return {"messages": new_messages, "file_operation_contract": contract}
|
|
||||||
|
|
@ -1,233 +0,0 @@
|
||||||
r"""Coalesce multi-block system messages into a single text block.
|
|
||||||
|
|
||||||
Several middlewares in our deepagent stack each call
|
|
||||||
``append_to_system_message`` on the way down to the model
|
|
||||||
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
|
||||||
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
|
||||||
request reaches the LLM, the system message has 5+ separate text blocks.
|
|
||||||
|
|
||||||
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
|
||||||
request**, and we configure 2 injection points
|
|
||||||
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
|
||||||
the prepended ``request.system_message``, this middleware is the
|
|
||||||
defensive partner: it guarantees that "the system block" is *one*
|
|
||||||
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
|
||||||
OpenRouter→Anthropic transformer can never multiply our budget into
|
|
||||||
several breakpoints by spreading ``cache_control`` across multiple
|
|
||||||
text blocks of a multi-block system content.
|
|
||||||
|
|
||||||
Without flattening we used to see::
|
|
||||||
|
|
||||||
OpenrouterException - {"error":{"message":"Provider returned error",
|
|
||||||
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
|
||||||
cache_control may be provided. Found 5."}}}
|
|
||||||
|
|
||||||
(Same error class documented in
|
|
||||||
https://github.com/BerriAI/litellm/issues/15696 and
|
|
||||||
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
|
||||||
in PR #15395 covers the litellm transformer but does not protect us
|
|
||||||
when the OpenRouter SaaS itself does the redistribution.)
|
|
||||||
|
|
||||||
A separate fix in :mod:`app.agents.shared.prompt_caching` (switching
|
|
||||||
the first injection point from ``role: system`` to ``index: 0``)
|
|
||||||
neutralises the *primary* cause of the same 400 — multiple
|
|
||||||
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
|
||||||
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
|
||||||
turns, each tagged with ``cache_control`` by the ``role: system``
|
|
||||||
matcher. This middleware remains useful as defence-in-depth against
|
|
||||||
the multi-block redistribution path.
|
|
||||||
|
|
||||||
Placement: innermost on the system-message-mutation chain, after every
|
|
||||||
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
|
||||||
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
|
||||||
attempt sees a flattened payload.
|
|
||||||
|
|
||||||
Idempotent: a string-content system message is left untouched. A list
|
|
||||||
that contains anything other than plain text blocks (e.g. an image) is
|
|
||||||
also left untouched — those are rare on system messages and we'd lose
|
|
||||||
the non-text payload by joining.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain.agents.middleware.types import (
|
|
||||||
AgentMiddleware,
|
|
||||||
AgentState,
|
|
||||||
ContextT,
|
|
||||||
ModelRequest,
|
|
||||||
ModelResponse,
|
|
||||||
ResponseT,
|
|
||||||
)
|
|
||||||
from langchain_core.messages import SystemMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
|
||||||
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
|
||||||
|
|
||||||
Returns ``None`` when the list contains anything that isn't a text
|
|
||||||
block we can safely concatenate (image, audio, file, non-standard
|
|
||||||
blocks, dicts with extra non-cache_control fields). The caller
|
|
||||||
leaves the original content untouched in that case rather than
|
|
||||||
silently dropping payload.
|
|
||||||
|
|
||||||
``cache_control`` on individual blocks is intentionally discarded —
|
|
||||||
the whole point of flattening is to let LiteLLM's
|
|
||||||
``cache_control_injection_points`` re-place a single breakpoint on
|
|
||||||
the resulting one-block system content.
|
|
||||||
"""
|
|
||||||
chunks: list[str] = []
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, str):
|
|
||||||
chunks.append(block)
|
|
||||||
continue
|
|
||||||
if not isinstance(block, dict):
|
|
||||||
return None
|
|
||||||
if block.get("type") != "text":
|
|
||||||
return None
|
|
||||||
text = block.get("text")
|
|
||||||
if not isinstance(text, str):
|
|
||||||
return None
|
|
||||||
chunks.append(text)
|
|
||||||
return "\n\n".join(chunks)
|
|
||||||
|
|
||||||
|
|
||||||
def _flattened_request(
|
|
||||||
request: ModelRequest[ContextT],
|
|
||||||
) -> ModelRequest[ContextT] | None:
|
|
||||||
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
|
||||||
sys_msg = request.system_message
|
|
||||||
if sys_msg is None:
|
|
||||||
return None
|
|
||||||
content = sys_msg.content
|
|
||||||
if not isinstance(content, list) or len(content) <= 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
flattened = _flatten_text_blocks(content)
|
|
||||||
if flattened is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
new_sys = SystemMessage(
|
|
||||||
content=flattened,
|
|
||||||
additional_kwargs=dict(sys_msg.additional_kwargs),
|
|
||||||
response_metadata=dict(sys_msg.response_metadata),
|
|
||||||
)
|
|
||||||
if sys_msg.id is not None:
|
|
||||||
new_sys.id = sys_msg.id
|
|
||||||
return request.override(system_message=new_sys)
|
|
||||||
|
|
||||||
|
|
||||||
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
|
||||||
"""One-line dump of cache_control-relevant request shape.
|
|
||||||
|
|
||||||
Temporary diagnostic to prove where the ``Found N`` cache_control
|
|
||||||
breakpoints are coming from when Anthropic 400s. Removed once the
|
|
||||||
root cause is confirmed and a fix is in place.
|
|
||||||
"""
|
|
||||||
sys_msg = request.system_message
|
|
||||||
if sys_msg is None:
|
|
||||||
sys_shape = "none"
|
|
||||||
elif isinstance(sys_msg.content, str):
|
|
||||||
sys_shape = f"str(len={len(sys_msg.content)})"
|
|
||||||
elif isinstance(sys_msg.content, list):
|
|
||||||
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
|
||||||
else:
|
|
||||||
sys_shape = f"other({type(sys_msg.content).__name__})"
|
|
||||||
|
|
||||||
role_hist: list[str] = []
|
|
||||||
multi_block_msgs = 0
|
|
||||||
msgs_with_cc = 0
|
|
||||||
sys_msgs_in_history = 0
|
|
||||||
for m in request.messages:
|
|
||||||
mtype = getattr(m, "type", type(m).__name__)
|
|
||||||
role_hist.append(mtype)
|
|
||||||
if isinstance(m, SystemMessage):
|
|
||||||
sys_msgs_in_history += 1
|
|
||||||
c = getattr(m, "content", None)
|
|
||||||
if isinstance(c, list):
|
|
||||||
multi_block_msgs += 1
|
|
||||||
for blk in c:
|
|
||||||
if isinstance(blk, dict) and "cache_control" in blk:
|
|
||||||
msgs_with_cc += 1
|
|
||||||
break
|
|
||||||
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
|
||||||
msgs_with_cc += 1
|
|
||||||
|
|
||||||
tools = request.tools or []
|
|
||||||
tools_with_cc = 0
|
|
||||||
for t in tools:
|
|
||||||
if isinstance(t, dict) and (
|
|
||||||
"cache_control" in t or "cache_control" in t.get("function", {})
|
|
||||||
):
|
|
||||||
tools_with_cc += 1
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"sys={sys_shape} msgs={len(request.messages)} "
|
|
||||||
f"sys_msgs_in_history={sys_msgs_in_history} "
|
|
||||||
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
|
||||||
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
|
||||||
f"roles={role_hist[-8:]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlattenSystemMessageMiddleware(
|
|
||||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
|
||||||
):
|
|
||||||
"""Collapse a multi-text-block system message to a single string.
|
|
||||||
|
|
||||||
Sits innermost on the system-message-mutation chain so it observes
|
|
||||||
every middleware's contribution. Has no other side effect — the
|
|
||||||
body of every block is preserved, just joined with ``"\\n\\n"``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.tools = []
|
|
||||||
|
|
||||||
def wrap_model_call( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
request: ModelRequest[ContextT],
|
|
||||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
|
||||||
) -> Any:
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
|
||||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
|
||||||
flattened = _flattened_request(request)
|
|
||||||
if flattened is not None:
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
|
||||||
logger.debug(
|
|
||||||
"[flatten_system] collapsed %d system blocks to one",
|
|
||||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
|
||||||
)
|
|
||||||
return handler(flattened)
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
async def awrap_model_call( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
request: ModelRequest[ContextT],
|
|
||||||
handler: Callable[
|
|
||||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
|
||||||
],
|
|
||||||
) -> Any:
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
|
||||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
|
||||||
flattened = _flattened_request(request)
|
|
||||||
if flattened is not None:
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
|
||||||
logger.debug(
|
|
||||||
"[flatten_system] collapsed %d system blocks to one",
|
|
||||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
|
||||||
)
|
|
||||||
return await handler(flattened)
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"FlattenSystemMessageMiddleware",
|
|
||||||
"_flatten_text_blocks",
|
|
||||||
"_flattened_request",
|
|
||||||
]
|
|
||||||
|
|
@ -1049,12 +1049,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
return priority, matched_chunk_ids
|
return priority, matched_chunk_ids
|
||||||
|
|
||||||
|
|
||||||
# Backwards-compatible alias for any external imports.
|
|
||||||
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"KnowledgeBaseSearchMiddleware",
|
|
||||||
"KnowledgePriorityMiddleware",
|
"KnowledgePriorityMiddleware",
|
||||||
"browse_recent_documents",
|
"browse_recent_documents",
|
||||||
"fetch_mentioned_documents",
|
"fetch_mentioned_documents",
|
||||||
|
|
|
||||||
|
|
@ -78,14 +78,12 @@ logger = logging.getLogger(__name__)
|
||||||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||||
#
|
#
|
||||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||||
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
|
||||||
# anonymous-doc) insert ``SystemMessage`` instances into
|
# insert ``SystemMessage`` instances into ``state["messages"]`` that
|
||||||
# ``state["messages"]`` that accumulate across turns. With
|
# accumulate across turns. With ``role: system`` the LiteLLM hook would
|
||||||
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
|
||||||
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
# 4-block limit. ``index: 0`` always targets the langchain-prepended
|
||||||
# always targets the langchain-prepended ``request.system_message``
|
# ``request.system_message``, giving us exactly one stable cache breakpoint.
|
||||||
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
|
||||||
# block), giving us exactly one stable cache breakpoint.
|
|
||||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
{"location": "message", "index": 0},
|
{"location": "message", "index": 0},
|
||||||
{"location": "message", "index": -1},
|
{"location": "message", "index": -1},
|
||||||
|
|
|
||||||
|
|
@ -1,344 +0,0 @@
|
||||||
"""Tests for ``FlattenSystemMessageMiddleware``.
|
|
||||||
|
|
||||||
The middleware exists to defend against Anthropic's "Found 5 cache_control
|
|
||||||
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
|
|
||||||
the system message and the OpenRouter→Anthropic adapter redistributes
|
|
||||||
``cache_control`` across all of them. The flattening collapses every
|
|
||||||
all-text system content list to a single string before the LLM call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from app.agents.shared.middleware.flatten_system import (
|
|
||||||
FlattenSystemMessageMiddleware,
|
|
||||||
_flatten_text_blocks,
|
|
||||||
_flattened_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _flatten_text_blocks — pure helper, the heart of the middleware.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestFlattenTextBlocks:
|
|
||||||
def test_joins_text_blocks_with_double_newline(self) -> None:
|
|
||||||
blocks = [
|
|
||||||
{"type": "text", "text": "<surfsense base>"},
|
|
||||||
{"type": "text", "text": "<filesystem section>"},
|
|
||||||
{"type": "text", "text": "<skills section>"},
|
|
||||||
]
|
|
||||||
assert (
|
|
||||||
_flatten_text_blocks(blocks)
|
|
||||||
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_handles_single_text_block(self) -> None:
|
|
||||||
blocks = [{"type": "text", "text": "only one"}]
|
|
||||||
assert _flatten_text_blocks(blocks) == "only one"
|
|
||||||
|
|
||||||
def test_handles_empty_list(self) -> None:
|
|
||||||
assert _flatten_text_blocks([]) == ""
|
|
||||||
|
|
||||||
def test_passes_through_bare_string_blocks(self) -> None:
|
|
||||||
# LangChain content can mix bare strings and dict blocks.
|
|
||||||
blocks = ["raw string", {"type": "text", "text": "dict block"}]
|
|
||||||
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
|
|
||||||
|
|
||||||
def test_returns_none_for_image_block(self) -> None:
|
|
||||||
# System messages with images are rare — but we never want to
|
|
||||||
# silently lose the image payload by joining as text.
|
|
||||||
blocks = [
|
|
||||||
{"type": "text", "text": "look at this"},
|
|
||||||
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
|
|
||||||
]
|
|
||||||
assert _flatten_text_blocks(blocks) is None
|
|
||||||
|
|
||||||
def test_returns_none_for_non_dict_non_str_block(self) -> None:
|
|
||||||
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
|
|
||||||
assert _flatten_text_blocks(blocks) is None
|
|
||||||
|
|
||||||
def test_returns_none_when_text_field_missing(self) -> None:
|
|
||||||
blocks = [{"type": "text"}] # no ``text`` key
|
|
||||||
assert _flatten_text_blocks(blocks) is None
|
|
||||||
|
|
||||||
def test_returns_none_when_text_is_not_string(self) -> None:
|
|
||||||
blocks = [{"type": "text", "text": ["nested", "list"]}]
|
|
||||||
assert _flatten_text_blocks(blocks) is None
|
|
||||||
|
|
||||||
def test_drops_cache_control_from_inner_blocks(self) -> None:
|
|
||||||
# The whole point: existing cache_control on inner blocks is
|
|
||||||
# discarded so LiteLLM's ``cache_control_injection_points`` can
|
|
||||||
# re-attach exactly one breakpoint after flattening.
|
|
||||||
blocks = [
|
|
||||||
{"type": "text", "text": "first"},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "second",
|
|
||||||
"cache_control": {"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
flattened = _flatten_text_blocks(blocks)
|
|
||||||
assert flattened == "first\n\nsecond"
|
|
||||||
assert "cache_control" not in flattened # type: ignore[operator]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _flattened_request — decides when to override and when to no-op.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_request(system_message: SystemMessage | None) -> Any:
|
|
||||||
"""Build a minimal ModelRequest stub. We only need .system_message
|
|
||||||
and .override(system_message=...) — the middleware never touches
|
|
||||||
other fields.
|
|
||||||
"""
|
|
||||||
request = MagicMock()
|
|
||||||
request.system_message = system_message
|
|
||||||
|
|
||||||
def override(**kwargs: Any) -> Any:
|
|
||||||
new_request = MagicMock()
|
|
||||||
new_request.system_message = kwargs.get(
|
|
||||||
"system_message", request.system_message
|
|
||||||
)
|
|
||||||
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
|
|
||||||
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
|
|
||||||
return new_request
|
|
||||||
|
|
||||||
request.override = override
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
class TestFlattenedRequest:
|
|
||||||
def test_collapses_multi_block_system_to_string(self) -> None:
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "<base>"},
|
|
||||||
{"type": "text", "text": "<todo>"},
|
|
||||||
{"type": "text", "text": "<filesystem>"},
|
|
||||||
{"type": "text", "text": "<skills>"},
|
|
||||||
{"type": "text", "text": "<subagents>"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
flattened = _flattened_request(request)
|
|
||||||
|
|
||||||
assert flattened is not None
|
|
||||||
assert isinstance(flattened.system_message, SystemMessage)
|
|
||||||
assert flattened.system_message.content == (
|
|
||||||
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_no_op_for_string_content(self) -> None:
|
|
||||||
sys = SystemMessage(content="already a string")
|
|
||||||
request = _make_request(sys)
|
|
||||||
assert _flattened_request(request) is None
|
|
||||||
|
|
||||||
def test_no_op_for_single_block_list(self) -> None:
|
|
||||||
# One block already produces one breakpoint — no need to flatten.
|
|
||||||
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
|
|
||||||
request = _make_request(sys)
|
|
||||||
assert _flattened_request(request) is None
|
|
||||||
|
|
||||||
def test_no_op_when_system_message_missing(self) -> None:
|
|
||||||
request = _make_request(None)
|
|
||||||
assert _flattened_request(request) is None
|
|
||||||
|
|
||||||
def test_no_op_when_list_contains_non_text_block(self) -> None:
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "look"},
|
|
||||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
assert _flattened_request(request) is None
|
|
||||||
|
|
||||||
def test_preserves_additional_kwargs_and_metadata(self) -> None:
|
|
||||||
# Defensive: nothing in the current chain sets these on a system
|
|
||||||
# message, but losing them silently when something does in the
|
|
||||||
# future would be a regression. ``name`` in particular is the only
|
|
||||||
# ``additional_kwargs`` field that ChatLiteLLM's
|
|
||||||
# ``_convert_message_to_dict`` propagates onto the wire.
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "a"},
|
|
||||||
{"type": "text", "text": "b"},
|
|
||||||
],
|
|
||||||
additional_kwargs={"name": "surfsense_system", "x": 1},
|
|
||||||
response_metadata={"tokens": 42},
|
|
||||||
)
|
|
||||||
sys.id = "sys-msg-1"
|
|
||||||
request = _make_request(sys)
|
|
||||||
|
|
||||||
flattened = _flattened_request(request)
|
|
||||||
assert flattened is not None
|
|
||||||
assert flattened.system_message.content == "a\n\nb"
|
|
||||||
assert flattened.system_message.additional_kwargs == {
|
|
||||||
"name": "surfsense_system",
|
|
||||||
"x": 1,
|
|
||||||
}
|
|
||||||
assert flattened.system_message.response_metadata == {"tokens": 42}
|
|
||||||
assert flattened.system_message.id == "sys-msg-1"
|
|
||||||
|
|
||||||
def test_idempotent_when_run_twice(self) -> None:
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "a"},
|
|
||||||
{"type": "text", "text": "b"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
first = _flattened_request(request)
|
|
||||||
assert first is not None
|
|
||||||
|
|
||||||
# Second pass on the already-flattened request should be a no-op.
|
|
||||||
# We re-wrap in a request stub since the helper inspects
|
|
||||||
# ``request.system_message.content``.
|
|
||||||
second_request = _make_request(first.system_message)
|
|
||||||
assert _flattened_request(second_request) is None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Middleware integration — verify the handler sees a flattened request.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMiddlewareWrap:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_passes_flattened_request_to_handler(self) -> None:
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "alpha"},
|
|
||||||
{"type": "text", "text": "beta"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def handler(req: Any) -> str:
|
|
||||||
captured["request"] = req
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
mw = FlattenSystemMessageMiddleware()
|
|
||||||
result = await mw.awrap_model_call(request, handler)
|
|
||||||
|
|
||||||
assert result == "ok"
|
|
||||||
assert isinstance(captured["request"].system_message, SystemMessage)
|
|
||||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_passes_through_when_already_string(self) -> None:
|
|
||||||
sys = SystemMessage(content="just a string")
|
|
||||||
request = _make_request(sys)
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def handler(req: Any) -> str:
|
|
||||||
captured["request"] = req
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
mw = FlattenSystemMessageMiddleware()
|
|
||||||
await mw.awrap_model_call(request, handler)
|
|
||||||
|
|
||||||
# Same request object: no override happened.
|
|
||||||
assert captured["request"] is request
|
|
||||||
|
|
||||||
def test_sync_passes_flattened_request_to_handler(self) -> None:
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "alpha"},
|
|
||||||
{"type": "text", "text": "beta"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def handler(req: Any) -> str:
|
|
||||||
captured["request"] = req
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
mw = FlattenSystemMessageMiddleware()
|
|
||||||
result = mw.wrap_model_call(request, handler)
|
|
||||||
|
|
||||||
assert result == "ok"
|
|
||||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
|
||||||
|
|
||||||
def test_sync_passes_through_when_no_system_message(self) -> None:
|
|
||||||
request = _make_request(None)
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def handler(req: Any) -> str:
|
|
||||||
captured["request"] = req
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
mw = FlattenSystemMessageMiddleware()
|
|
||||||
mw.wrap_model_call(request, handler)
|
|
||||||
assert captured["request"] is request
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Regression guard — pin the worst-case shape that triggered the
|
|
||||||
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
|
|
||||||
# downstream cache_control_injection_points can only place 1 breakpoint
|
|
||||||
# on the system message regardless of provider redistribution quirks.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_regression_five_block_system_collapses_to_one_block() -> None:
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
|
|
||||||
{"type": "text", "text": "<TodoListMiddleware section>"},
|
|
||||||
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
|
|
||||||
{"type": "text", "text": "<SkillsMiddleware section>"},
|
|
||||||
{"type": "text", "text": "<SubAgentMiddleware section>"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
flattened = _flattened_request(request)
|
|
||||||
|
|
||||||
assert flattened is not None
|
|
||||||
assert isinstance(flattened.system_message.content, str)
|
|
||||||
# The exact join doesn't matter for the cache_control accounting —
|
|
||||||
# only that there is exactly ONE content block when LiteLLM's
|
|
||||||
# AnthropicCacheControlHook later targets ``role: system``.
|
|
||||||
assert "<surfsense base" in flattened.system_message.content
|
|
||||||
assert "<SubAgentMiddleware" in flattened.system_message.content
|
|
||||||
|
|
||||||
|
|
||||||
def test_regression_human_message_not_modified() -> None:
|
|
||||||
# Sanity: the middleware MUST NOT touch user messages — only the
|
|
||||||
# system message. Multi-block user content is the path that carries
|
|
||||||
# image attachments and would lose its image_url block on
|
|
||||||
# accidental flatten.
|
|
||||||
sys = SystemMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "a"},
|
|
||||||
{"type": "text", "text": "b"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
user = HumanMessage(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "look at this"},
|
|
||||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
request = _make_request(sys)
|
|
||||||
request.messages = [user]
|
|
||||||
|
|
||||||
flattened = _flattened_request(request)
|
|
||||||
assert flattened is not None
|
|
||||||
# System flattened to string …
|
|
||||||
assert isinstance(flattened.system_message.content, str)
|
|
||||||
# … user message is untouched (the helper does not even look at it).
|
|
||||||
assert flattened.messages == [user]
|
|
||||||
assert isinstance(user.content, list)
|
|
||||||
assert len(user.content) == 2
|
|
||||||
|
|
@ -1,214 +0,0 @@
|
||||||
import pytest
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
|
||||||
|
|
||||||
from app.agents.shared.middleware.file_intent import (
|
|
||||||
FileIntentMiddleware,
|
|
||||||
FileOperationIntent,
|
|
||||||
_fallback_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeLLM:
|
|
||||||
def __init__(self, response_text: str):
|
|
||||||
self._response_text = response_text
|
|
||||||
|
|
||||||
async def ainvoke(self, *_args, **_kwargs):
|
|
||||||
return AIMessage(content=self._response_text)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_intent_injects_contract_message():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.93,"suggested_filename":"ideas.md"}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="Create another random note for me")],
|
|
||||||
"turn_id": "123:456",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/ideas.md"
|
|
||||||
assert contract["turn_id"] == "123:456"
|
|
||||||
assert any(
|
|
||||||
"file_operation_contract" in str(msg.content)
|
|
||||||
for msg in result["messages"]
|
|
||||||
if hasattr(msg, "content")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_non_write_intent_does_not_inject_contract_message():
|
|
||||||
llm = _FakeLLM('{"intent":"file_read","confidence":0.88,"suggested_filename":null}')
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
original_messages = [HumanMessage(content="Read /notes.md")]
|
|
||||||
state = {"messages": original_messages, "turn_id": "abc:def"}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert (
|
|
||||||
result["file_operation_contract"]["intent"]
|
|
||||||
== FileOperationIntent.FILE_READ.value
|
|
||||||
)
|
|
||||||
assert "messages" not in result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_null_filename_uses_semantic_default_path():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.74,"suggested_filename":null}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="create a random markdown file")],
|
|
||||||
"turn_id": "turn:1",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/notes.md"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_null_filename_defaults_to_markdown_path():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.71,"suggested_filename":null}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="create a sample json config file")],
|
|
||||||
"turn_id": "turn:2",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/notes.md"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_txt_suggestion_is_normalized_to_markdown():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.82,"suggested_filename":"random.txt"}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="create a random file")],
|
|
||||||
"turn_id": "turn:3",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/random.md"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_with_suggested_directory_preserves_folder():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.86,"suggested_filename":"random.md","suggested_directory":"pc backups","suggested_path":null}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="create a random file in pc backups folder")],
|
|
||||||
"turn_id": "turn:4",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/pc_backups/random.md"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_with_suggested_path_takes_precedence():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.9,"suggested_filename":"ignored.md","suggested_directory":"docs","suggested_path":"/reports/q2/summary.md"}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="create report")],
|
|
||||||
"turn_id": "turn:5",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/reports/q2/summary.md"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_write_infers_directory_from_user_text_when_missing():
|
|
||||||
llm = _FakeLLM(
|
|
||||||
'{"intent":"file_write","confidence":0.83,"suggested_filename":"random.md","suggested_directory":null,"suggested_path":null}'
|
|
||||||
)
|
|
||||||
middleware = FileIntentMiddleware(llm=llm)
|
|
||||||
state = {
|
|
||||||
"messages": [HumanMessage(content="create a random file in pc backups folder")],
|
|
||||||
"turn_id": "turn:6",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
contract = result["file_operation_contract"]
|
|
||||||
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
|
|
||||||
assert contract["suggested_path"] == "/pc_backups/random.md"
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_path_normalizes_windows_slashes() -> None:
|
|
||||||
resolved = _fallback_path(
|
|
||||||
suggested_filename="summary.md",
|
|
||||||
suggested_path=r"\reports\q2\summary.md",
|
|
||||||
user_text="create report",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resolved == "/reports/q2/summary.md"
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_path_normalizes_windows_drive_path() -> None:
|
|
||||||
resolved = _fallback_path(
|
|
||||||
suggested_filename=None,
|
|
||||||
suggested_path=r"C:\Users\anish\notes\todo.md",
|
|
||||||
user_text="create note",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resolved == "/C/Users/anish/notes/todo.md"
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_path_normalizes_mixed_separators_and_duplicate_slashes() -> None:
|
|
||||||
resolved = _fallback_path(
|
|
||||||
suggested_filename="summary.md",
|
|
||||||
suggested_path=r"\\reports\\q2//summary.md",
|
|
||||||
user_text="create report",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resolved == "/reports/q2/summary.md"
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> None:
|
|
||||||
resolved = _fallback_path(
|
|
||||||
suggested_filename=None,
|
|
||||||
suggested_path="/var/log/surfsense/notes.md",
|
|
||||||
user_text="create note",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resolved == "/var/log/surfsense/notes.md"
|
|
||||||
|
|
@ -10,7 +10,7 @@ from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.document_
|
||||||
)
|
)
|
||||||
from app.agents.shared.middleware.knowledge_search import (
|
from app.agents.shared.middleware.knowledge_search import (
|
||||||
KBSearchPlan,
|
KBSearchPlan,
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgePriorityMiddleware,
|
||||||
_normalize_optional_date_range,
|
_normalize_optional_date_range,
|
||||||
_parse_kb_search_plan_response,
|
_parse_kb_search_plan_response,
|
||||||
_render_recent_conversation,
|
_render_recent_conversation,
|
||||||
|
|
@ -203,7 +203,7 @@ class FakeBudgetLLM:
|
||||||
return sum(len(msg.get("content", "")) for msg in messages)
|
return sum(len(msg.get("content", "")) for msg in messages)
|
||||||
|
|
||||||
|
|
||||||
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
class TestKnowledgePriorityMiddlewarePlanner:
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _disable_planner_runnable(self, monkeypatch):
|
def _disable_planner_runnable(self, monkeypatch):
|
||||||
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||||
|
|
@ -273,7 +273,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=37)
|
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37)
|
||||||
|
|
||||||
result = await middleware.abefore_agent(
|
result = await middleware.abefore_agent(
|
||||||
{
|
{
|
||||||
|
|
@ -307,7 +307,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgePriorityMiddleware(
|
||||||
llm=FakeLLM("not json"),
|
llm=FakeLLM("not json"),
|
||||||
search_space_id=37,
|
search_space_id=37,
|
||||||
)
|
)
|
||||||
|
|
@ -336,7 +336,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgePriorityMiddleware(
|
||||||
llm=FakeLLM(
|
llm=FakeLLM(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
{
|
{
|
||||||
|
|
@ -395,7 +395,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
|
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
|
||||||
|
|
||||||
result = await middleware.abefore_agent(
|
result = await middleware.abefore_agent(
|
||||||
{"messages": [HumanMessage(content="what's my latest file?")]},
|
{"messages": [HumanMessage(content="what's my latest file?")]},
|
||||||
|
|
@ -442,7 +442,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
|
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
|
||||||
|
|
||||||
await middleware.abefore_agent(
|
await middleware.abefore_agent(
|
||||||
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
|
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
|
||||||
|
|
@ -559,7 +559,7 @@ class TestKnowledgePriorityMentionDrain:
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgePriorityMiddleware(
|
||||||
llm=self._planner_llm(),
|
llm=self._planner_llm(),
|
||||||
search_space_id=42,
|
search_space_id=42,
|
||||||
mentioned_document_ids=[1, 2, 3],
|
mentioned_document_ids=[1, 2, 3],
|
||||||
|
|
@ -609,7 +609,7 @@ class TestKnowledgePriorityMentionDrain:
|
||||||
|
|
||||||
# Simulate a cached middleware instance whose closure was seeded
|
# Simulate a cached middleware instance whose closure was seeded
|
||||||
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgePriorityMiddleware(
|
||||||
llm=self._planner_llm(),
|
llm=self._planner_llm(),
|
||||||
search_space_id=42,
|
search_space_id=42,
|
||||||
mentioned_document_ids=[1, 2, 3],
|
mentioned_document_ids=[1, 2, 3],
|
||||||
|
|
@ -652,7 +652,7 @@ class TestKnowledgePriorityMentionDrain:
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgePriorityMiddleware(
|
||||||
llm=self._planner_llm(),
|
llm=self._planner_llm(),
|
||||||
search_space_id=42,
|
search_space_id=42,
|
||||||
mentioned_document_ids=[7, 8],
|
mentioned_document_ids=[7, 8],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue