mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): delete dead single-agent-only middleware
file_intent (FileIntentMiddleware) and flatten_system
(FlattenSystemMessageMiddleware) were only ever instantiated in the
single-agent chat_deepagent stack, which was removed in 14bbea085. They
have no production consumer in multi_agent_chat. Delete both modules and
their unit tests.
Also drop the vestigial KnowledgeBaseSearchMiddleware alias (= the live
KnowledgePriorityMiddleware); its tests now target the real class so the
behavior coverage is preserved. Trim the three barrel/__all__ entries and
strip the now-dead class names from comments.
This commit is contained in:
parent
21509e7eca
commit
afa51e97cf
10 changed files with 19 additions and 1161 deletions
|
|
@ -50,8 +50,8 @@ class SurfSenseContextSchema:
|
|||
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
|
||||
entries in ``<priority_documents>`` so the agent prioritises
|
||||
walking those folders with ``ls`` / ``find_documents``.
|
||||
file_operation_contract: One-shot file operation contract emitted
|
||||
by ``FileIntentMiddleware`` for the upcoming turn.
|
||||
file_operation_contract: One-shot file operation contract for the
|
||||
upcoming turn (reserved; not currently populated).
|
||||
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||
task; populated for telemetry.
|
||||
|
||||
|
|
|
|||
|
|
@ -21,18 +21,11 @@ from app.agents.shared.middleware.dedup_tool_calls import (
|
|||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
||||
from app.agents.shared.middleware.file_intent import (
|
||||
FileIntentMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.agents.shared.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.knowledge_tree import (
|
||||
|
|
@ -56,10 +49,7 @@ __all__ = [
|
|||
"ClearToolUsesEdit",
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"DoomLoopMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"KnowledgeTreeMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ Respond ONLY with the structured summary. Do not include any text before or afte
|
|||
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
||||
"<priority_documents>", # KnowledgePriorityMiddleware
|
||||
"<workspace_tree>", # KnowledgeTreeMiddleware
|
||||
"<file_operation_contract>", # FileIntentMiddleware
|
||||
"<file_operation_contract>", # reserved file-operation contract prefix
|
||||
"<user_memory>", # MemoryInjectionMiddleware
|
||||
"<team_memory>", # MemoryInjectionMiddleware
|
||||
"<user_name>", # MemoryInjectionMiddleware
|
||||
|
|
|
|||
|
|
@ -1,334 +0,0 @@
|
|||
"""Semantic file-intent routing middleware for new chat turns.
|
||||
|
||||
This middleware classifies the latest human turn into a small intent set:
|
||||
- chat_only
|
||||
- file_write
|
||||
- file_read
|
||||
|
||||
For ``file_write`` turns it injects a strict system contract so the model
|
||||
uses filesystem tools before claiming success, and provides a deterministic
|
||||
fallback path when no filename is specified by the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileOperationIntent(StrEnum):
|
||||
CHAT_ONLY = "chat_only"
|
||||
FILE_WRITE = "file_write"
|
||||
FILE_READ = "file_read"
|
||||
|
||||
|
||||
class FileIntentPlan(BaseModel):
|
||||
intent: FileOperationIntent = Field(
|
||||
description="Primary user intent for this turn."
|
||||
)
|
||||
confidence: float = Field(
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.5,
|
||||
description="Model confidence in the selected intent.",
|
||||
)
|
||||
suggested_filename: str | None = Field(
|
||||
default=None,
|
||||
description="Optional filename (e.g. notes.md) inferred from user request.",
|
||||
)
|
||||
suggested_directory: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
|
||||
"user request."
|
||||
),
|
||||
)
|
||||
suggested_path: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
|
||||
"takes precedence over suggested_directory + suggested_filename."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_json_payload(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||
if fenced:
|
||||
return fenced.group(1)
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start : end + 1]
|
||||
return stripped
|
||||
|
||||
|
||||
def _sanitize_filename(value: str) -> str:
|
||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
name = re.sub(r"\s+", "-", name)
|
||||
name = name.strip("._-")
|
||||
if not name:
|
||||
name = "note"
|
||||
if len(name) > 80:
|
||||
name = name[:80].rstrip("-_.")
|
||||
return name
|
||||
|
||||
|
||||
def _sanitize_path_segment(value: str) -> str:
|
||||
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
segment = re.sub(r"\s+", "_", segment)
|
||||
segment = segment.strip("._-")
|
||||
return segment
|
||||
|
||||
|
||||
def _normalize_directory(value: str) -> str:
|
||||
raw = value.strip().replace("\\", "/")
|
||||
raw = raw.strip("/")
|
||||
if not raw:
|
||||
return ""
|
||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||
parts = [part for part in parts if part]
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def _normalize_file_path(value: str) -> str:
|
||||
raw = value.strip().replace("\\", "/").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
had_trailing_slash = raw.endswith("/")
|
||||
raw = raw.strip("/")
|
||||
if not raw:
|
||||
return ""
|
||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||
parts = [part for part in parts if part]
|
||||
if not parts:
|
||||
return ""
|
||||
if had_trailing_slash:
|
||||
return f"/{'/'.join(parts)}/"
|
||||
return f"/{'/'.join(parts)}"
|
||||
|
||||
|
||||
def _infer_directory_from_user_text(user_text: str) -> str | None:
|
||||
patterns = (
|
||||
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
|
||||
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
|
||||
)
|
||||
lowered = user_text.lower()
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, lowered, flags=re.IGNORECASE)
|
||||
if not match:
|
||||
continue
|
||||
candidate = match.group(1).strip()
|
||||
if candidate in {"the", "a", "an"}:
|
||||
continue
|
||||
normalized = _normalize_directory(candidate)
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def _fallback_path(
|
||||
suggested_filename: str | None,
|
||||
*,
|
||||
suggested_directory: str | None = None,
|
||||
suggested_path: str | None = None,
|
||||
user_text: str,
|
||||
) -> str:
|
||||
inferred_dir = _infer_directory_from_user_text(user_text)
|
||||
|
||||
sanitized_filename = ""
|
||||
if suggested_filename:
|
||||
sanitized_filename = _sanitize_filename(suggested_filename)
|
||||
if sanitized_filename.lower().endswith(".txt"):
|
||||
sanitized_filename = f"{sanitized_filename[:-4]}.md"
|
||||
if not sanitized_filename:
|
||||
sanitized_filename = "notes.md"
|
||||
elif "." not in sanitized_filename:
|
||||
sanitized_filename = f"{sanitized_filename}.md"
|
||||
|
||||
normalized_suggested_path = (
|
||||
_normalize_file_path(suggested_path) if suggested_path else ""
|
||||
)
|
||||
if normalized_suggested_path:
|
||||
if normalized_suggested_path.endswith("/"):
|
||||
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
|
||||
return normalized_suggested_path
|
||||
|
||||
directory = _normalize_directory(suggested_directory or "")
|
||||
if not directory and inferred_dir:
|
||||
directory = inferred_dir
|
||||
if directory:
|
||||
return f"/{directory}/{sanitized_filename}"
|
||||
|
||||
return f"/{sanitized_filename}"
|
||||
|
||||
|
||||
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
|
||||
return (
|
||||
"Classify the latest user request into a filesystem intent for an AI agent.\n"
|
||||
"Return JSON only with this exact schema:\n"
|
||||
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
|
||||
"Rules:\n"
|
||||
"- Use semantic intent, not literal keywords.\n"
|
||||
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
|
||||
"- file_read: user asks to open/read/list/search existing files.\n"
|
||||
"- chat_only: conversational/analysis responses without required file operations.\n"
|
||||
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
|
||||
"- If the user mentions a folder/directory, populate suggested_directory.\n"
|
||||
"- If user specifies an explicit full path, populate suggested_path.\n"
|
||||
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
|
||||
"- Do not use .txt; prefer .md for generic text notes.\n"
|
||||
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
|
||||
"- Never include markdown or explanation.\n\n"
|
||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||
f"Latest user message:\n{user_text}"
|
||||
)
|
||||
|
||||
|
||||
def _build_recent_conversation(
|
||||
messages: list[BaseMessage], *, max_messages: int = 6
|
||||
) -> str:
|
||||
rows: list[str] = []
|
||||
filtered: list[tuple[str, BaseMessage]] = []
|
||||
for msg in messages:
|
||||
role: str | None = None
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(msg, AIMessage):
|
||||
if getattr(msg, "tool_calls", None):
|
||||
continue
|
||||
role = "assistant"
|
||||
else:
|
||||
continue
|
||||
filtered.append((role, msg))
|
||||
for role, msg in filtered[-max_messages:]:
|
||||
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||
if text:
|
||||
rows.append(f"{role}: {text[:280]}")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Classify file intent and inject a strict file-write contract."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
async def _classify_intent(
|
||||
self, *, messages: list[BaseMessage], user_text: str
|
||||
) -> FileIntentPlan:
|
||||
if self.llm is None:
|
||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||
|
||||
prompt = _build_classifier_prompt(
|
||||
recent_conversation=_build_recent_conversation(messages),
|
||||
user_text=user_text,
|
||||
)
|
||||
try:
|
||||
response = await self.llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
payload = json.loads(
|
||||
_extract_json_payload(_extract_text_from_message(response))
|
||||
)
|
||||
plan = FileIntentPlan.model_validate(payload)
|
||||
return plan
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||
logger.warning("File intent classifier returned invalid output: %s", exc)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
logger.warning("File intent classifier failed: %s", exc)
|
||||
|
||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_human: HumanMessage | None = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
last_human = msg
|
||||
break
|
||||
if last_human is None:
|
||||
return None
|
||||
|
||||
user_text = _extract_text_from_message(last_human).strip()
|
||||
if not user_text:
|
||||
return None
|
||||
|
||||
plan = await self._classify_intent(messages=messages, user_text=user_text)
|
||||
suggested_path = _fallback_path(
|
||||
plan.suggested_filename,
|
||||
suggested_directory=plan.suggested_directory,
|
||||
suggested_path=plan.suggested_path,
|
||||
user_text=user_text,
|
||||
)
|
||||
contract = {
|
||||
"intent": plan.intent.value,
|
||||
"confidence": plan.confidence,
|
||||
"suggested_path": suggested_path,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"turn_id": state.get("turn_id", ""),
|
||||
}
|
||||
|
||||
if plan.intent != FileOperationIntent.FILE_WRITE:
|
||||
return {"file_operation_contract": contract}
|
||||
|
||||
contract_msg = SystemMessage(
|
||||
content=(
|
||||
"<file_operation_contract>\n"
|
||||
"This turn intent is file_write.\n"
|
||||
f"Suggested default path: {suggested_path}\n"
|
||||
"Rules:\n"
|
||||
"- You MUST call write_file or edit_file before claiming success.\n"
|
||||
"- If no path is provided by the user, use the suggested default path.\n"
|
||||
"- Do not claim a file was created/updated unless tool output confirms it.\n"
|
||||
"- If the write/edit fails, clearly report failure instead of success.\n"
|
||||
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
|
||||
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
|
||||
"</file_operation_contract>"
|
||||
)
|
||||
)
|
||||
|
||||
# Insert just before the latest human turn so it applies to this request.
|
||||
new_messages = list(messages)
|
||||
insert_at = max(len(new_messages) - 1, 0)
|
||||
new_messages.insert(insert_at, contract_msg)
|
||||
return {"messages": new_messages, "file_operation_contract": contract}
|
||||
|
|
@ -1,233 +0,0 @@
|
|||
r"""Coalesce multi-block system messages into a single text block.
|
||||
|
||||
Several middlewares in our deepagent stack each call
|
||||
``append_to_system_message`` on the way down to the model
|
||||
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
||||
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
||||
request reaches the LLM, the system message has 5+ separate text blocks.
|
||||
|
||||
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
||||
request**, and we configure 2 injection points
|
||||
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
||||
the prepended ``request.system_message``, this middleware is the
|
||||
defensive partner: it guarantees that "the system block" is *one*
|
||||
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
||||
OpenRouter→Anthropic transformer can never multiply our budget into
|
||||
several breakpoints by spreading ``cache_control`` across multiple
|
||||
text blocks of a multi-block system content.
|
||||
|
||||
Without flattening we used to see::
|
||||
|
||||
OpenrouterException - {"error":{"message":"Provider returned error",
|
||||
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
||||
cache_control may be provided. Found 5."}}}
|
||||
|
||||
(Same error class documented in
|
||||
https://github.com/BerriAI/litellm/issues/15696 and
|
||||
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||
in PR #15395 covers the litellm transformer but does not protect us
|
||||
when the OpenRouter SaaS itself does the redistribution.)
|
||||
|
||||
A separate fix in :mod:`app.agents.shared.prompt_caching` (switching
|
||||
the first injection point from ``role: system`` to ``index: 0``)
|
||||
neutralises the *primary* cause of the same 400 — multiple
|
||||
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
||||
turns, each tagged with ``cache_control`` by the ``role: system``
|
||||
matcher. This middleware remains useful as defence-in-depth against
|
||||
the multi-block redistribution path.
|
||||
|
||||
Placement: innermost on the system-message-mutation chain, after every
|
||||
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
||||
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
||||
attempt sees a flattened payload.
|
||||
|
||||
Idempotent: a string-content system message is left untouched. A list
|
||||
that contains anything other than plain text blocks (e.g. an image) is
|
||||
also left untouched — those are rare on system messages and we'd lose
|
||||
the non-text payload by joining.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
||||
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
||||
|
||||
Returns ``None`` when the list contains anything that isn't a text
|
||||
block we can safely concatenate (image, audio, file, non-standard
|
||||
blocks, dicts with extra non-cache_control fields). The caller
|
||||
leaves the original content untouched in that case rather than
|
||||
silently dropping payload.
|
||||
|
||||
``cache_control`` on individual blocks is intentionally discarded —
|
||||
the whole point of flattening is to let LiteLLM's
|
||||
``cache_control_injection_points`` re-place a single breakpoint on
|
||||
the resulting one-block system content.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
return None
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
chunks.append(text)
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def _flattened_request(
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT] | None:
|
||||
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
return None
|
||||
content = sys_msg.content
|
||||
if not isinstance(content, list) or len(content) <= 1:
|
||||
return None
|
||||
|
||||
flattened = _flatten_text_blocks(content)
|
||||
if flattened is None:
|
||||
return None
|
||||
|
||||
new_sys = SystemMessage(
|
||||
content=flattened,
|
||||
additional_kwargs=dict(sys_msg.additional_kwargs),
|
||||
response_metadata=dict(sys_msg.response_metadata),
|
||||
)
|
||||
if sys_msg.id is not None:
|
||||
new_sys.id = sys_msg.id
|
||||
return request.override(system_message=new_sys)
|
||||
|
||||
|
||||
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
||||
"""One-line dump of cache_control-relevant request shape.
|
||||
|
||||
Temporary diagnostic to prove where the ``Found N`` cache_control
|
||||
breakpoints are coming from when Anthropic 400s. Removed once the
|
||||
root cause is confirmed and a fix is in place.
|
||||
"""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
sys_shape = "none"
|
||||
elif isinstance(sys_msg.content, str):
|
||||
sys_shape = f"str(len={len(sys_msg.content)})"
|
||||
elif isinstance(sys_msg.content, list):
|
||||
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
||||
else:
|
||||
sys_shape = f"other({type(sys_msg.content).__name__})"
|
||||
|
||||
role_hist: list[str] = []
|
||||
multi_block_msgs = 0
|
||||
msgs_with_cc = 0
|
||||
sys_msgs_in_history = 0
|
||||
for m in request.messages:
|
||||
mtype = getattr(m, "type", type(m).__name__)
|
||||
role_hist.append(mtype)
|
||||
if isinstance(m, SystemMessage):
|
||||
sys_msgs_in_history += 1
|
||||
c = getattr(m, "content", None)
|
||||
if isinstance(c, list):
|
||||
multi_block_msgs += 1
|
||||
for blk in c:
|
||||
if isinstance(blk, dict) and "cache_control" in blk:
|
||||
msgs_with_cc += 1
|
||||
break
|
||||
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
||||
msgs_with_cc += 1
|
||||
|
||||
tools = request.tools or []
|
||||
tools_with_cc = 0
|
||||
for t in tools:
|
||||
if isinstance(t, dict) and (
|
||||
"cache_control" in t or "cache_control" in t.get("function", {})
|
||||
):
|
||||
tools_with_cc += 1
|
||||
|
||||
return (
|
||||
f"sys={sys_shape} msgs={len(request.messages)} "
|
||||
f"sys_msgs_in_history={sys_msgs_in_history} "
|
||||
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
||||
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
||||
f"roles={role_hist[-8:]}"
|
||||
)
|
||||
|
||||
|
||||
class FlattenSystemMessageMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Collapse a multi-text-block system message to a single string.
|
||||
|
||||
Sits innermost on the system-message-mutation chain so it observes
|
||||
every middleware's contribution. Has no other side effect — the
|
||||
body of every block is preserved, just joined with ``"\\n\\n"``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tools = []
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return handler(flattened)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return await handler(flattened)
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"_flatten_text_blocks",
|
||||
"_flattened_request",
|
||||
]
|
||||
|
|
@ -1049,12 +1049,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
return priority, matched_chunk_ids
|
||||
|
||||
|
||||
# Backwards-compatible alias for any external imports.
|
||||
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"browse_recent_documents",
|
||||
"fetch_mentioned_documents",
|
||||
|
|
|
|||
|
|
@ -78,14 +78,12 @@ logger = logging.getLogger(__name__)
|
|||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||
#
|
||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
||||
# anonymous-doc) insert ``SystemMessage`` instances into
|
||||
# ``state["messages"]`` that accumulate across turns. With
|
||||
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
||||
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
||||
# always targets the langchain-prepended ``request.system_message``
|
||||
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
||||
# block), giving us exactly one stable cache breakpoint.
|
||||
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
|
||||
# insert ``SystemMessage`` instances into ``state["messages"]`` that
|
||||
# accumulate across turns. With ``role: system`` the LiteLLM hook would
|
||||
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
|
||||
# 4-block limit. ``index: 0`` always targets the langchain-prepended
|
||||
# ``request.system_message``, giving us exactly one stable cache breakpoint.
|
||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||
{"location": "message", "index": 0},
|
||||
{"location": "message", "index": -1},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue