mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move middleware package to app/agents/shared (slice 5c)
Relocate the entire new_chat/middleware/ package to the shared kernel as one cohesive unit (it is live shared infrastructure: the multi-agent stack wraps nearly every middleware via multi_agent_chat/middleware/main_agent/*, and anonymous_agent consumes it too). Flip 69 live importers across both the package-path and submodule-path forms. Shims left for the frozen single-agent stack: a package __init__ re-export plus submodule shims for permission, skills_backends, and scoped_model_fallback (the three imported via submodule path by chat_deepagent/subagents). Cycle break: importing shared.middleware previously reached back into new_chat.tools at module load, which dragged in new_chat.__init__ -> chat_deepagent -> the middleware shim -> half-initialized shared.middleware. Made action_log's ToolDefinition import TYPE_CHECKING-only and tool_call_repair's INVALID_TOOL_NAME import function-local. These tools-package back-edges fully resolve in slice 6. Asset note: skills_backends._default_builtin_root now walks to app/agents/new_chat/skills/builtin (the skills/ tree migrates in slice 7).
This commit is contained in:
parent
6f488d9564
commit
227983a104
98 changed files with 1131 additions and 999 deletions
87
surfsense_backend/app/agents/shared/middleware/__init__.py
Normal file
87
surfsense_backend/app/agents/shared/middleware/__init__.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Middleware components for the SurfSense new chat agent."""
|
||||
|
||||
from app.agents.shared.middleware.action_log import ActionLogMiddleware
|
||||
from app.agents.shared.middleware.anonymous_document import (
|
||||
AnonymousDocumentMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.busy_mutex import BusyMutexMiddleware
|
||||
from app.agents.shared.middleware.compaction import (
|
||||
SurfSenseCompactionMiddleware,
|
||||
create_surfsense_compaction_middleware,
|
||||
)
|
||||
from app.agents.shared.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
||||
from app.agents.shared.middleware.file_intent import (
|
||||
FileIntentMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.filesystem import (
|
||||
SurfSenseFilesystemMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.agents.shared.middleware.knowledge_search import (
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.knowledge_tree import (
|
||||
KnowledgeTreeMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.noop_injection import NoopInjectionMiddleware
|
||||
from app.agents.shared.middleware.otel_span import OtelSpanMiddleware
|
||||
from app.agents.shared.middleware.permission import PermissionMiddleware
|
||||
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
|
||||
from app.agents.shared.middleware.skills_backends import (
|
||||
BuiltinSkillsBackend,
|
||||
SearchSpaceSkillsBackend,
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.shared.middleware.tool_call_repair import (
|
||||
ToolCallNameRepairMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionLogMiddleware",
|
||||
"AnonymousDocumentMiddleware",
|
||||
"BuiltinSkillsBackend",
|
||||
"BusyMutexMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"DoomLoopMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"KnowledgeTreeMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"NoopInjectionMiddleware",
|
||||
"OtelSpanMiddleware",
|
||||
"PermissionMiddleware",
|
||||
"RetryAfterMiddleware",
|
||||
"SearchSpaceSkillsBackend",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
"ToolCallNameRepairMiddleware",
|
||||
"build_skills_backend_factory",
|
||||
"commit_staged_filesystem_state",
|
||||
"create_surfsense_compaction_middleware",
|
||||
"default_skills_sources",
|
||||
]
|
||||
367
surfsense_backend/app/agents/shared/middleware/action_log.py
Normal file
367
surfsense_backend/app/agents/shared/middleware/action_log.py
Normal file
|
|
@ -0,0 +1,367 @@
|
|||
"""Append-only action-log middleware for the SurfSense agent.
|
||||
|
||||
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
||||
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
||||
into reversibility by declaring a ``reverse`` callable on their
|
||||
:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered
|
||||
descriptor is persisted in ``reverse_descriptor`` for use by
|
||||
``/api/threads/{thread_id}/revert/{action_id}``.
|
||||
|
||||
Design points:
|
||||
|
||||
* **Defensive.** Logging never blocks the agent. We catch every exception
|
||||
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
||||
result is always returned untouched.
|
||||
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
||||
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
||||
remains in the LangGraph checkpoint / spilled tool-output files.
|
||||
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
||||
with the parsed JSON result when the tool's content is a JSON object;
|
||||
otherwise the raw text is passed. Exceptions in the reverse callable
|
||||
are swallowed and logged — a failed descriptor render simply means the
|
||||
action is NOT marked reversible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.shared.feature_flags import get_flags
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
# Type-only import: keeping it lazy avoids a module-load cycle through the
|
||||
# frozen single-agent package (new_chat.__init__ -> chat_deepagent ->
|
||||
# middleware shim). Resolves to app.agents.shared.tools once tools migrate.
|
||||
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
|
||||
# accidentally-huge inputs. Values are truncated and a flag is set in the
|
||||
# stored payload so consumers can detect truncation.
|
||||
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
|
||||
|
||||
|
||||
class ActionLogMiddleware(AgentMiddleware):
|
||||
"""Persist a row in :class:`AgentActionLog` after every tool call.
|
||||
|
||||
Should be placed near the OUTERMOST end of the tool-call wrapping stack
|
||||
so that it sees the *final* :class:`ToolMessage` after all retries,
|
||||
permission checks, and dedup logic have run. In practice that means
|
||||
placing it just inside :class:`PermissionMiddleware` and outside
|
||||
:class:`DedupHITLToolCallsMiddleware`.
|
||||
|
||||
The middleware is fully a no-op when:
|
||||
|
||||
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
|
||||
(checked via :func:`get_flags`),
|
||||
* the per-feature flag ``enable_action_log`` is off, or
|
||||
* persistence raises (defensive: tool-call dispatch always succeeds).
|
||||
|
||||
Args:
|
||||
thread_id: The current chat thread's primary-key id. Required to
|
||||
persist a row; if ``None`` the middleware silently no-ops.
|
||||
search_space_id: Search-space id for cascade-on-delete safety.
|
||||
user_id: UUID string of the user driving this turn (nullable in
|
||||
anonymous mode).
|
||||
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
|
||||
so the middleware can look up the tool's ``reverse`` callable.
|
||||
When omitted, no actions are marked reversible.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
tool_definitions: dict[str, ToolDefinition] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._thread_id = thread_id
|
||||
self._search_space_id = search_space_id
|
||||
self._user_id = user_id
|
||||
self._tool_definitions = dict(tool_definitions or {})
|
||||
|
||||
def _enabled(self) -> bool:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack:
|
||||
return False
|
||||
return bool(flags.enable_action_log) and self._thread_id is not None
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not self._enabled():
|
||||
return await handler(request)
|
||||
|
||||
result: ToolMessage | Command[Any]
|
||||
error_payload: dict[str, Any] | None = None
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception as exc:
|
||||
# Persist the failure too so revert/audit can see it, then
|
||||
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
|
||||
error_payload = {"type": type(exc).__name__, "message": str(exc)}
|
||||
await self._record(
|
||||
request=request,
|
||||
result=None,
|
||||
error_payload=error_payload,
|
||||
)
|
||||
raise
|
||||
|
||||
await self._record(request=request, result=result, error_payload=None)
|
||||
return result
|
||||
|
||||
async def _record(
|
||||
self,
|
||||
*,
|
||||
request: ToolCallRequest,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
error_payload: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
|
||||
try:
|
||||
from app.db import AgentActionLog, shielded_async_session
|
||||
|
||||
tool_name = _resolve_tool_name(request)
|
||||
args_payload = _resolve_args_payload(request)
|
||||
result_id = _resolve_result_id(result)
|
||||
reverse_descriptor, reversible = self._render_reverse(
|
||||
tool_name=tool_name,
|
||||
args=_resolve_args_dict(request),
|
||||
result=result,
|
||||
)
|
||||
|
||||
tool_call_id = _resolve_tool_call_id(request)
|
||||
chat_turn_id = _resolve_chat_turn_id(request)
|
||||
|
||||
row = AgentActionLog(
|
||||
thread_id=self._thread_id,
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||
# kept for one release for safe rollback. New consumers
|
||||
# should read ``tool_call_id`` directly.
|
||||
turn_id=tool_call_id,
|
||||
tool_call_id=tool_call_id,
|
||||
chat_turn_id=chat_turn_id,
|
||||
message_id=_resolve_message_id(request),
|
||||
tool_name=tool_name,
|
||||
args=args_payload,
|
||||
result_id=result_id,
|
||||
reversible=reversible,
|
||||
reverse_descriptor=reverse_descriptor,
|
||||
error=error_payload,
|
||||
)
|
||||
async with shielded_async_session() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
row_id = int(row.id) if row.id is not None else None
|
||||
row_created_at = row.created_at
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"ActionLogMiddleware failed to persist action log row",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Surface a side-channel SSE event so the chat tool card can
|
||||
# render a Revert button immediately after the row is durable.
|
||||
# ``stream_new_chat`` translates this into a
|
||||
# ``data-action-log`` SSE event. We DO NOT include the
|
||||
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"action_log",
|
||||
{
|
||||
"id": row_id,
|
||||
"lc_tool_call_id": tool_call_id,
|
||||
"chat_turn_id": chat_turn_id,
|
||||
"tool_name": tool_name,
|
||||
"reversible": bool(reversible),
|
||||
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||
"created_at": row_created_at.isoformat()
|
||||
if row_created_at
|
||||
else None,
|
||||
"error": error_payload is not None,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ActionLogMiddleware failed to dispatch action_log event",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _render_reverse(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any] | None,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
) -> tuple[dict[str, Any] | None, bool]:
|
||||
"""Run the tool's ``reverse`` callable and return its descriptor.
|
||||
|
||||
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
|
||||
the tool has no ``reverse`` callable, or when the callable raises,
|
||||
the action is marked non-reversible.
|
||||
"""
|
||||
if not result or not isinstance(result, ToolMessage):
|
||||
return None, False
|
||||
if args is None:
|
||||
return None, False
|
||||
tool_def = self._tool_definitions.get(tool_name)
|
||||
if tool_def is None or tool_def.reverse is None:
|
||||
return None, False
|
||||
try:
|
||||
parsed_result = _parse_tool_result_content(result)
|
||||
descriptor = tool_def.reverse(args, parsed_result)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Reverse descriptor render failed for tool %s",
|
||||
tool_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, False
|
||||
if not isinstance(descriptor, dict):
|
||||
return None, False
|
||||
return descriptor, True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolution helpers — defensive against tool_call request shape variation.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_tool_name(request: Any) -> str:
|
||||
try:
|
||||
tool = getattr(request, "tool", None)
|
||||
if tool is not None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
name = call.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None)
|
||||
if not isinstance(call, dict):
|
||||
return None
|
||||
args = call.get("args")
|
||||
if isinstance(args, dict):
|
||||
return args
|
||||
return None
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
||||
"""Return a JSON-serializable args dict, truncated if too big."""
|
||||
args = _resolve_args_dict(request)
|
||||
if args is None:
|
||||
return None
|
||||
try:
|
||||
encoded = json.dumps(args, default=str)
|
||||
except Exception:
|
||||
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
|
||||
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
|
||||
return args
|
||||
return {
|
||||
"_truncated": True,
|
||||
"_size": len(encoded),
|
||||
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||
try:
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
tid = call.get("id")
|
||||
if isinstance(tid, str):
|
||||
return tid
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# Deprecated alias kept for one release. Old callers and tests treated
|
||||
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||
_resolve_turn_id = _resolve_tool_call_id
|
||||
|
||||
|
||||
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||
|
||||
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||
"""
|
||||
try:
|
||||
runtime = getattr(request, "runtime", None)
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
configurable = config.get("configurable")
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
value = configurable.get("turn_id")
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_message_id(request: Any) -> str | None:
|
||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||
return _resolve_tool_call_id(request)
|
||||
|
||||
|
||||
def _resolve_result_id(result: Any) -> str | None:
|
||||
if isinstance(result, ToolMessage):
|
||||
msg_id = getattr(result, "id", None)
|
||||
if isinstance(msg_id, str):
|
||||
return msg_id
|
||||
return None
|
||||
|
||||
|
||||
def _parse_tool_result_content(result: ToolMessage) -> Any:
|
||||
content = result.content
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
return json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return content
|
||||
return content
|
||||
|
||||
|
||||
__all__ = ["ActionLogMiddleware"]
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""Lightweight middleware that loads the anonymous-session document into state.
|
||||
|
||||
Anonymous chats receive a single uploaded document via Redis (no DB row,
|
||||
read-only). This middleware loads it once on the first turn into
|
||||
``state['kb_anon_doc']`` so:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
|
||||
view without touching the DB.
|
||||
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
|
||||
degenerate priority list.
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
|
||||
recognises the synthetic path.
|
||||
|
||||
The middleware is a no-op when ``anon_session_id`` is not provided or when
|
||||
the document is already cached in state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.shared.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Load the anonymous user's uploaded document from Redis into state."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(self, *, anon_session_id: str | None) -> None:
|
||||
self.anon_session_id = anon_session_id
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if not self.anon_session_id:
|
||||
return None
|
||||
if state.get("kb_anon_doc"):
|
||||
return None
|
||||
|
||||
anon_doc = await self._load_anon_document()
|
||||
if anon_doc is None:
|
||||
return None
|
||||
return {"kb_anon_doc": anon_doc}
|
||||
|
||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||
"""Read ``anon:doc:<session_id>`` from Redis."""
|
||||
try:
|
||||
import redis.asyncio as aioredis # local import to keep cold paths cheap
|
||||
|
||||
from app.config import config
|
||||
|
||||
redis_client = aioredis.from_url(
|
||||
config.REDIS_APP_URL, decode_responses=True
|
||||
)
|
||||
try:
|
||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||
data = await redis_client.get(redis_key)
|
||||
if not data:
|
||||
return None
|
||||
payload = json.loads(data)
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||
return None
|
||||
|
||||
title = str(payload.get("filename") or "uploaded_document")
|
||||
content = str(payload.get("content") or "")
|
||||
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
|
||||
return {
|
||||
"path": path,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["AnonymousDocumentMiddleware"]
|
||||
328
surfsense_backend/app/agents/shared/middleware/busy_mutex.py
Normal file
328
surfsense_backend/app/agents/shared/middleware/busy_mutex.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
BusyMutexMiddleware — per-thread asyncio lock + cancel token.
|
||||
|
||||
LangChain has no built-in concept of "this thread is already running a
|
||||
turn — refuse the second concurrent request". Without it, a user
|
||||
double-clicking "send" or refreshing the page mid-stream can spawn two
|
||||
turns racing on the same checkpoint, producing duplicated tool calls
|
||||
and mangled state.
|
||||
|
||||
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a
|
||||
single-process, in-memory lock + cooperative cancellation token keyed by
|
||||
``thread_id``. For multi-worker deployments a distributed lock backend
|
||||
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
|
||||
|
||||
What this provides:
|
||||
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
|
||||
acquiring the lock during ``before_agent`` blocks any concurrent
|
||||
prompt on the same thread until release.
|
||||
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
|
||||
tools can poll to abort cooperatively. The event is reset between
|
||||
turns. Tools should check ``runtime.context.cancel_event.is_set()``
|
||||
in tight inner loops.
|
||||
- A typed :class:`~app.agents.shared.errors.BusyError` raised when a
|
||||
second turn arrives while the lock is held.
|
||||
|
||||
Note: SurfSense's ``stream_new_chat`` is the call site that should
|
||||
acquire/release. Wiring this as middleware means the contract is
|
||||
explicit and the lock manager is shared with subagents that compile
|
||||
their own ``create_agent`` runnables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.shared.errors import BusyError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ThreadLockManager:
|
||||
"""Process-local registry of per-thread asyncio locks + cancel events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
# Monotonic per-thread epoch used to prevent stale middleware
|
||||
# teardown from releasing a newer turn's lock.
|
||||
self._turn_epoch: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._locks[thread_id] = lock
|
||||
return lock
|
||||
|
||||
def cancel_event(self, thread_id: str) -> asyncio.Event:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
return event
|
||||
|
||||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
event.set()
|
||||
now_ms = int(time.time() * 1000)
|
||||
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||
self._cancel_attempt_count[thread_id] = (
|
||||
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||
)
|
||||
return True
|
||||
|
||||
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
return bool(event and event.is_set())
|
||||
|
||||
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||
if not self.is_cancel_requested(thread_id):
|
||||
return None
|
||||
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||
return attempts, requested_at_ms
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||
self._turn_epoch[thread_id] = epoch
|
||||
return epoch
|
||||
|
||||
def current_turn_epoch(self, thread_id: str) -> int:
|
||||
return self._turn_epoch.get(thread_id, 0)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
This is intentionally idempotent and safe to call from outer stream
|
||||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||
# retry that already acquired the lock for the same thread.
|
||||
self.bump_turn_epoch(thread_id)
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
self.reset(thread_id)
|
||||
|
||||
def release(self, thread_id: str) -> bool:
|
||||
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
|
||||
|
||||
``BusyMutexMiddleware.aafter_agent`` only releases on graph completion, so
|
||||
an ``interrupt()`` pause or an early streaming bail-out would otherwise
|
||||
leak the lock and block the next request with :class:`BusyError`. Returns
|
||||
``True`` when a held lock was released, ``False`` otherwise.
|
||||
"""
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None or not lock.locked():
|
||||
return False
|
||||
try:
|
||||
lock.release()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
# instances built in this process. Subagents created in nested
|
||||
# ``create_agent`` calls also get this so locks are coherent.
|
||||
manager = _ThreadLockManager()
|
||||
|
||||
|
||||
def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||
"""Public accessor used by long-running tools to poll cancellation."""
|
||||
return manager.cancel_event(thread_id)
|
||||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def is_cancel_requested(thread_id: str) -> bool:
|
||||
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||
return manager.is_cancel_requested(thread_id)
|
||||
|
||||
|
||||
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||
return manager.cancel_state(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
def end_turn(thread_id: str) -> None:
|
||||
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||
manager.end_turn(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Block concurrent prompts on the same thread.
|
||||
|
||||
Acquires the thread's lock in ``abefore_agent`` and releases in
|
||||
``aafter_agent``. If the lock is held, raises :class:`BusyError`
|
||||
so the caller can emit a ``surfsense.busy`` SSE event with the
|
||||
in-flight request id.
|
||||
|
||||
Args:
|
||||
require_thread_id: When True, raise :class:`BusyError` if no
|
||||
``thread_id`` can be resolved from the active
|
||||
``RunnableConfig``. Default is False — we treat a missing
|
||||
thread_id as "this turn has nothing to lock against" and
|
||||
no-op the mutex. Set True only when you trust the call
|
||||
site to always provide ``configurable.thread_id`` (e.g.
|
||||
in production where ``stream_new_chat`` always does).
|
||||
"""
|
||||
|
||||
def __init__(self, *, require_thread_id: bool = False) -> None:
|
||||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||
# only releases when its epoch still matches the manager's current
|
||||
# epoch for the thread, preventing stale unlock races.
|
||||
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
"""Extract ``thread_id`` from the active LangGraph ``RunnableConfig``.
|
||||
|
||||
``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``.
|
||||
The runnable config (where ``configurable.thread_id`` lives) must be
|
||||
fetched via :func:`langgraph.config.get_config` from inside a node /
|
||||
middleware. We fall back to ``getattr(runtime, "config", None)`` for
|
||||
unit tests / legacy runtimes that synthesize a config-bearing stub.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
# Preferred path: real LangGraph runtime context.
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
# Fallback for tests and any runtime that surfaces a config dict
|
||||
# directly on the runtime instance.
|
||||
return _from_dict(getattr(runtime, "config", None))
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
del state
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
raise BusyError("no thread_id configured")
|
||||
logger.debug(
|
||||
"BusyMutexMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"skipping per-thread lock for this turn."
|
||||
)
|
||||
return None
|
||||
|
||||
lock = manager.lock_for(thread_id)
|
||||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
async def aafter_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
del state
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
held = self._held_locks.pop(thread_id, None)
|
||||
if held is None:
|
||||
return None
|
||||
lock, held_epoch = held
|
||||
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||
# already advanced epoch). Do not touch current lock/cancel state.
|
||||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
# Provide sync no-ops because the middleware base class allows them
|
||||
def before_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
|
||||
# if anyone else is in flight.
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
raise BusyError("no thread_id configured")
|
||||
return None
|
||||
lock = manager.lock_for(thread_id)
|
||||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
return None
|
||||
|
||||
def after_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
]
|
||||
255
surfsense_backend/app/agents/shared/middleware/compaction.py
Normal file
255
surfsense_backend/app/agents/shared/middleware/compaction.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""
|
||||
SurfSense compaction middleware.
|
||||
|
||||
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
|
||||
to add SurfSense-specific behavior:
|
||||
|
||||
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
|
||||
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``)
|
||||
— see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base
|
||||
``SummarizationMiddleware`` only ships a freeform "summarize this"
|
||||
prompt; the structured template is ported from OpenCode's
|
||||
``compaction.ts``.
|
||||
2. **Protect SurfSense-specific SystemMessages** so injected hints
|
||||
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
|
||||
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
|
||||
are *not* summarized away and are kept verbatim in the post-summary
|
||||
message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||
(some message types are part of the agent's contract and must survive
|
||||
compaction unchanged).
|
||||
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
|
||||
(Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage
|
||||
containing only tool_calls and no text, ``content`` can be ``None`` and
|
||||
``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from deepagents.middleware.summarization import (
|
||||
SummarizationMiddleware,
|
||||
compute_summarization_defaults,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BACKEND_TYPES
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Structured summary template ported from OpenCode's
|
||||
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
|
||||
# module-level constant so unit tests can assert on its sections.
|
||||
SURFSENSE_SUMMARY_PROMPT = """<role>
|
||||
SurfSense Conversation Compaction Assistant
|
||||
</role>
|
||||
|
||||
<primary_objective>
|
||||
Extract the most important context from the conversation history below into a structured summary that will replace the older messages.
|
||||
</primary_objective>
|
||||
|
||||
<instructions>
|
||||
You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report.
|
||||
|
||||
## Goal
|
||||
What is the user's primary goal or request? State it in one or two sentences.
|
||||
|
||||
## Constraints
|
||||
What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)?
|
||||
|
||||
## Progress
|
||||
What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion.
|
||||
|
||||
## Key Decisions
|
||||
What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path.
|
||||
|
||||
## Next Steps
|
||||
What specific tasks remain to achieve the goal? Order them by dependency.
|
||||
|
||||
## Critical Context
|
||||
What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name).
|
||||
|
||||
## Relevant Files
|
||||
What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree.
|
||||
</instructions>
|
||||
|
||||
<messages>
|
||||
Messages to summarize:
|
||||
{messages}
|
||||
</messages>
|
||||
|
||||
Respond ONLY with the structured summary. Do not include any text before or after.
|
||||
"""
|
||||
|
||||
# SystemMessage prefixes that must NOT be summarized away. They are
|
||||
# re-injected on every turn by the corresponding middleware, but the
|
||||
# compaction step happens *before* re-injection in some paths, so we
|
||||
# must preserve them verbatim across the cutoff.
|
||||
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
||||
"<priority_documents>", # KnowledgePriorityMiddleware
|
||||
"<workspace_tree>", # KnowledgeTreeMiddleware
|
||||
"<file_operation_contract>", # FileIntentMiddleware
|
||||
"<user_memory>", # MemoryInjectionMiddleware
|
||||
"<team_memory>", # MemoryInjectionMiddleware
|
||||
"<user_name>", # MemoryInjectionMiddleware
|
||||
"<memory_warning>", # MemoryInjectionMiddleware
|
||||
)
|
||||
|
||||
|
||||
def _is_protected_system_message(msg: AnyMessage) -> bool:
|
||||
"""Return True if ``msg`` is a SystemMessage we must not summarize."""
|
||||
if not isinstance(msg, SystemMessage):
|
||||
return False
|
||||
content = msg.content
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
stripped = content.lstrip()
|
||||
return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES)
|
||||
|
||||
|
||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||
"""Return ``msg`` with ``content=None`` coerced to ``""``.
|
||||
|
||||
Folds in the historical defense from ``safe_summarization.py`` —
|
||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
|
||||
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
|
||||
AIMessage) explodes. We return a copy with empty string content so
|
||||
downstream consumers see an empty body without mutating the original.
|
||||
"""
|
||||
if getattr(msg, "content", "not-missing") is not None:
|
||||
return msg
|
||||
try:
|
||||
return msg.model_copy(update={"content": ""})
|
||||
except AttributeError:
|
||||
import copy
|
||||
|
||||
new_msg = copy.copy(msg)
|
||||
try:
|
||||
new_msg.content = ""
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not sanitize content=None on message of type %s",
|
||||
type(msg).__name__,
|
||||
)
|
||||
return msg
|
||||
return new_msg
|
||||
|
||||
|
||||
class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||
"""SummarizationMiddleware tuned for SurfSense.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Overrides :meth:`_partition_messages` so protected SystemMessages
|
||||
survive into the ``preserved_messages`` half regardless of cutoff.
|
||||
- Overrides :meth:`_filter_summary_messages` so the buffer-string path
|
||||
never iterates ``None`` content.
|
||||
- Inherits everything else (auto-trigger, backend offload,
|
||||
``_summarization_event`` plumbing, ``ContextOverflowError`` fallback).
|
||||
"""
|
||||
|
||||
def _partition_messages( # type: ignore[override]
|
||||
self,
|
||||
conversation_messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Split messages but always preserve SurfSense protected SystemMessages.
|
||||
|
||||
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||
(``opencode/packages/opencode/src/session/compaction.ts``): some
|
||||
message types are always kept verbatim because they are part of the
|
||||
agent's working contract, not transient output.
|
||||
|
||||
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
|
||||
so dashboards can count compaction events and message-volume
|
||||
without having to instrument upstream callers.
|
||||
"""
|
||||
# Opening a span here is appropriate because partitioning is the
|
||||
# first call SummarizationMiddleware makes when it has decided to
|
||||
# summarize; we record the volume and then close as a normal span.
|
||||
with ot.compaction_span(
|
||||
reason="auto",
|
||||
messages_in=len(conversation_messages),
|
||||
extra={"compaction.cutoff_index": int(cutoff_index)},
|
||||
):
|
||||
ot_metrics.record_compaction_run(reason="auto")
|
||||
messages_to_summarize, preserved_messages = super()._partition_messages(
|
||||
conversation_messages, cutoff_index
|
||||
)
|
||||
|
||||
protected: list[AnyMessage] = []
|
||||
kept_for_summary: list[AnyMessage] = []
|
||||
for msg in messages_to_summarize:
|
||||
if _is_protected_system_message(msg):
|
||||
protected.append(msg)
|
||||
else:
|
||||
kept_for_summary.append(msg)
|
||||
|
||||
# Place protected blocks at the *front* of preserved_messages so
|
||||
# they keep their original ordering relative to the summary
|
||||
# HumanMessage that precedes the rest of the preserved tail.
|
||||
return kept_for_summary, [*protected, *preserved_messages]
|
||||
|
||||
def _filter_summary_messages( # type: ignore[override]
|
||||
self, messages: list[AnyMessage]
|
||||
) -> list[AnyMessage]:
|
||||
"""Filter previous summaries AND sanitize ``content=None``.
|
||||
|
||||
Folds the ``safe_summarization.py`` defense in: when the buffer
|
||||
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
|
||||
here covers both the sync and async offload paths.
|
||||
"""
|
||||
filtered = super()._filter_summary_messages(messages)
|
||||
return [_sanitize_message_content(m) for m in filtered]
|
||||
|
||||
|
||||
def create_surfsense_compaction_middleware(
|
||||
model: BaseChatModel,
|
||||
backend: BACKEND_TYPES,
|
||||
*,
|
||||
summary_prompt: str | None = None,
|
||||
history_path_prefix: str = "/conversation_history",
|
||||
**overrides: Any,
|
||||
) -> SurfSenseCompactionMiddleware:
|
||||
"""Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults.
|
||||
|
||||
Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings``
|
||||
via :func:`deepagents.middleware.summarization.compute_summarization_defaults`
|
||||
so callers get the same behavior as ``create_summarization_middleware``
|
||||
plus our overrides.
|
||||
|
||||
Args:
|
||||
model: Chat model to call for summary generation.
|
||||
backend: Backend instance or factory for offloading conversation history.
|
||||
summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`.
|
||||
history_path_prefix: Path prefix for offloaded conversation history.
|
||||
**overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`.
|
||||
"""
|
||||
defaults = compute_summarization_defaults(model)
|
||||
return SurfSenseCompactionMiddleware(
|
||||
model=model,
|
||||
backend=backend,
|
||||
trigger=overrides.pop("trigger", defaults["trigger"]),
|
||||
keep=overrides.pop("keep", defaults["keep"]),
|
||||
trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None),
|
||||
truncate_args_settings=overrides.pop(
|
||||
"truncate_args_settings", defaults["truncate_args_settings"]
|
||||
),
|
||||
summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT,
|
||||
history_path_prefix=history_path_prefix,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROTECTED_SYSTEM_PREFIXES",
|
||||
"SURFSENSE_SUMMARY_PROMPT",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"create_surfsense_compaction_middleware",
|
||||
]
|
||||
|
|
@ -0,0 +1,350 @@
|
|||
"""
|
||||
SpillToBackendEdit + SpillingContextEditingMiddleware.
|
||||
|
||||
LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content``
|
||||
when the context-editing budget triggers, replacing the body with a fixed
|
||||
placeholder. That's lossy: anything the agent might want to revisit is
|
||||
gone. The spill-to-disk pattern (originally from OpenCode's
|
||||
``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune
|
||||
behavior but writes the full original payload to the runtime backend
|
||||
under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The
|
||||
placeholder is then upgraded to point at the spill path so the agent
|
||||
(or a subagent) can read it back on demand.
|
||||
|
||||
Why this is a middleware subclass instead of a plain ``ContextEdit``:
|
||||
``ContextEdit.apply`` is sync, but writing to the backend is async. We
|
||||
capture the spill payloads inside ``apply`` and flush them via
|
||||
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
|
||||
delegating to the handler, so the explore subagent can always read what
|
||||
the placeholder advertises.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEdit,
|
||||
ContextEditingMiddleware,
|
||||
TokenCounter,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from langgraph.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BackendProtocol
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
||||
|
||||
|
||||
def _build_spill_placeholder(spill_path: str) -> str:
|
||||
"""Build the user-facing placeholder text shown to the model."""
|
||||
return (
|
||||
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
|
||||
)
|
||||
|
||||
|
||||
def _get_thread_id_or_session() -> str:
|
||||
"""Best-effort thread_id discovery for the spill path.
|
||||
|
||||
Falls back to a process-stable string if no LangGraph config is
|
||||
available (e.g. unit tests). The exact value doesn't matter as long
|
||||
as it's stable within one stream so the placeholder paths line up
|
||||
with the actual upload path.
|
||||
"""
|
||||
try:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
if thread_id is not None:
|
||||
return str(thread_id)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return "no_thread"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SpillToBackendEdit(ContextEdit):
|
||||
"""Capture-and-replace context edit that spills full tool output to the backend.
|
||||
|
||||
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
|
||||
semantics) **and** records the original ``ToolMessage.content`` in
|
||||
:attr:`pending_spills` so the wrapping middleware can flush them
|
||||
before the model call.
|
||||
|
||||
Args:
|
||||
trigger: Token threshold above which the edit fires.
|
||||
clear_at_least: Minimum number of tokens to reclaim (best effort).
|
||||
keep: Number of most-recent ``ToolMessage`` instances to leave
|
||||
untouched.
|
||||
exclude_tools: Names of tools whose output is NOT spilled.
|
||||
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
|
||||
args when their pair is cleared.
|
||||
path_prefix: Path under the backend where spills are written.
|
||||
Default ``"/tool_outputs"``.
|
||||
"""
|
||||
|
||||
trigger: int = 100_000
|
||||
clear_at_least: int = 0
|
||||
keep: int = 3
|
||||
clear_tool_inputs: bool = False
|
||||
exclude_tools: Sequence[str] = ()
|
||||
path_prefix: str = DEFAULT_SPILL_PREFIX
|
||||
|
||||
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def drain_pending(self) -> list[tuple[str, bytes]]:
|
||||
"""Return and clear the pending-spill list atomically."""
|
||||
with self._lock:
|
||||
out = list(self.pending_spills)
|
||||
self.pending_spills.clear()
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
|
||||
tokens = count_tokens(messages)
|
||||
if tokens <= self.trigger:
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg)
|
||||
for idx, msg in enumerate(messages)
|
||||
if isinstance(msg, ToolMessage)
|
||||
]
|
||||
if self.keep >= len(candidates):
|
||||
return
|
||||
if self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
|
||||
thread_id = _get_thread_id_or_session()
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
|
||||
for idx, tool_message in candidates:
|
||||
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
||||
continue
|
||||
|
||||
ai_message = next(
|
||||
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if ai_message is None:
|
||||
continue
|
||||
|
||||
tool_call = next(
|
||||
(
|
||||
call
|
||||
for call in ai_message.tool_calls
|
||||
if call.get("id") == tool_message.tool_call_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_call is None:
|
||||
continue
|
||||
|
||||
tool_name = tool_message.name or tool_call["name"]
|
||||
if tool_name in excluded_tools:
|
||||
continue
|
||||
|
||||
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
|
||||
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
|
||||
|
||||
original = tool_message.content
|
||||
payload = self._encode_payload(original)
|
||||
with self._lock:
|
||||
self.pending_spills.append((spill_path, payload))
|
||||
|
||||
messages[idx] = tool_message.model_copy(
|
||||
update={
|
||||
"artifact": None,
|
||||
"content": _build_spill_placeholder(spill_path),
|
||||
"response_metadata": {
|
||||
**tool_message.response_metadata,
|
||||
"context_editing": {
|
||||
"cleared": True,
|
||||
"strategy": "spill_to_backend",
|
||||
"spill_path": spill_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if self.clear_tool_inputs:
|
||||
ai_idx = messages.index(ai_message)
|
||||
messages[ai_idx] = self._clear_input_args(
|
||||
ai_message, tool_message.tool_call_id or ""
|
||||
)
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _encode_payload(content: Any) -> bytes:
|
||||
"""Serialize ``ToolMessage.content`` to bytes for upload."""
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
return content.encode("utf-8")
|
||||
try:
|
||||
import json
|
||||
|
||||
return json.dumps(content, default=str).encode("utf-8")
|
||||
except Exception:
|
||||
return str(content).encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
|
||||
updated_tool_calls: list[dict[str, Any]] = []
|
||||
cleared_any = False
|
||||
for tool_call in message.tool_calls:
|
||||
updated = dict(tool_call)
|
||||
if updated.get("id") == tool_call_id:
|
||||
updated["args"] = {}
|
||||
cleared_any = True
|
||||
updated_tool_calls.append(updated)
|
||||
|
||||
metadata = dict(getattr(message, "response_metadata", {}))
|
||||
if cleared_any:
|
||||
ctx = dict(metadata.get("context_editing", {}))
|
||||
ids = set(ctx.get("cleared_tool_inputs", []))
|
||||
ids.add(tool_call_id)
|
||||
ctx["cleared_tool_inputs"] = sorted(ids)
|
||||
metadata["context_editing"] = ctx
|
||||
return message.model_copy(
|
||||
update={
|
||||
"tool_calls": updated_tool_calls,
|
||||
"response_metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
|
||||
|
||||
|
||||
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
|
||||
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
|
||||
|
||||
Runs the configured edits as the parent does, then flushes any
|
||||
pending spills via the supplied backend resolver before delegating
|
||||
to the model handler. Spill failures are logged but never abort the
|
||||
model call — the placeholder text is already in the message, so the
|
||||
worst case is the agent gets a placeholder it cannot follow up on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
edits: Sequence[ContextEdit],
|
||||
backend_resolver: BackendResolver | None = None,
|
||||
token_count_method: str = "approximate",
|
||||
) -> None:
|
||||
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
|
||||
self._backend_resolver = backend_resolver
|
||||
|
||||
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
|
||||
if self._backend_resolver is None:
|
||||
return None
|
||||
if callable(self._backend_resolver):
|
||||
try:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
tool_runtime = ToolRuntime(
|
||||
state=getattr(request, "state", {}),
|
||||
context=getattr(request.runtime, "context", None),
|
||||
stream_writer=getattr(request.runtime, "stream_writer", None),
|
||||
store=getattr(request.runtime, "store", None),
|
||||
config=getattr(request.runtime, "config", None) or {},
|
||||
tool_call_id=None,
|
||||
)
|
||||
return self._backend_resolver(tool_runtime)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve spill backend")
|
||||
return None
|
||||
return self._backend_resolver # type: ignore[return-value]
|
||||
|
||||
def _collect_pending(self) -> list[tuple[str, bytes]]:
|
||||
out: list[tuple[str, bytes]] = []
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, SpillToBackendEdit):
|
||||
out.extend(edit.drain_pending())
|
||||
return out
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> Any:
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate":
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
else:
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
pending = self._collect_pending()
|
||||
if pending:
|
||||
backend = self._resolve_backend(request)
|
||||
if backend is not None:
|
||||
try:
|
||||
await backend.aupload_files(pending)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Spill-to-backend upload failed (%d files); placeholders "
|
||||
"remain in messages but content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"SpillToBackendEdit produced %d pending spills but no backend "
|
||||
"resolver was configured; content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SPILL_PREFIX",
|
||||
"ClearToolUsesEdit",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"_build_spill_placeholder",
|
||||
]
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
|
||||
Dedup-key resolution order:
|
||||
|
||||
1. :class:`ToolDefinition.dedup_key` — callable provided by the registry
|
||||
entry. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name;
|
||||
used by MCP / Composio tools whose schemas the registry doesn't see.
|
||||
|
||||
A tool with no resolver from either path simply opts out of dedup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Resolver type — given the tool ``args`` dict returns a stable
|
||||
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||
DedupResolver = Callable[[dict[str, Any]], str]
|
||||
|
||||
|
||||
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
||||
"""Adapt a string-arg name into a :data:`DedupResolver`.
|
||||
|
||||
Convenience helper used by registry entries that just want to dedupe
|
||||
on a single arg's lowercased value (the most common case for native
|
||||
HITL tools like ``send_gmail_email`` keyed on ``subject``).
|
||||
|
||||
Example::
|
||||
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
...,
|
||||
dedup_key=wrap_dedup_key_by_arg_name("subject"),
|
||||
)
|
||||
"""
|
||||
|
||||
def _resolver(args: dict[str, Any]) -> str:
|
||||
return str(args.get(arg_name, "")).lower()
|
||||
|
||||
return _resolver
|
||||
|
||||
|
||||
def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||
"""Resolver that collapses calls only when **every** argument is identical.
|
||||
|
||||
Safe default for tools where no single field uniquely identifies a call
|
||||
(e.g. MCP tools whose first required field is a shared workspace id).
|
||||
"""
|
||||
|
||||
try:
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
|
||||
|
||||
|
||||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
|
||||
The dedup-resolver map is built from two sources, in priority order:
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's
|
||||
``ToolDefinition.dedup_key``. Receives the args dict and returns
|
||||
a string signature. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||
name; primarily used by MCP / Composio tools.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||
self._resolvers: dict[str, DedupResolver] = {}
|
||||
|
||||
for t in agent_tools or []:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
callable_key = meta.get("dedup_key")
|
||||
if callable(callable_key):
|
||||
self._resolvers[t.name] = callable_key
|
||||
continue
|
||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
||||
meta["hitl_dedup_key"]
|
||||
)
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(
|
||||
state: AgentState,
|
||||
resolvers: dict[str, DedupResolver],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||
return None
|
||||
|
||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped: list[dict[str, Any]] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
resolver = resolvers.get(name)
|
||||
if resolver is not None:
|
||||
try:
|
||||
arg_val = resolver(tc.get("args", {}) or {})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Dedup resolver for tool %s raised; keeping call", name
|
||||
)
|
||||
deduped.append(tc)
|
||||
continue
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||
name,
|
||||
arg_val,
|
||||
)
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(tc)
|
||||
|
||||
if len(deduped) == len(tool_calls):
|
||||
return None
|
||||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
238
surfsense_backend/app/agents/shared/middleware/doom_loop.py
Normal file
238
surfsense_backend/app/agents/shared/middleware/doom_loop.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""
|
||||
DoomLoopMiddleware — pattern-based detector for repeated identical tool calls.
|
||||
|
||||
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
|
||||
of tool calls per turn — but it can't tell apart "10 distinct, useful
|
||||
calls" from "the same call 10 times in a row". This middleware fills that
|
||||
gap with a sliding-window check on tool-call signatures, ported from
|
||||
OpenCode's ``packages/opencode/src/session/processor.ts``.
|
||||
|
||||
When the same tool with the same arguments is called N times in a row,
|
||||
the agent has likely entered an infinite loop. We surface this to the
|
||||
user as an interrupt with ``permission="doom_loop"`` so the UI can
|
||||
render an "Are you stuck? Continue / cancel?" affordance.
|
||||
|
||||
This ships **OFF by default** until the frontend explicitly handles
|
||||
``context.permission == "doom_loop"`` interrupts.
|
||||
|
||||
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
|
||||
(see ``app/agents/new_chat/tools/hitl.py``):
|
||||
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": <name>, "params": <args>},
|
||||
"context": {"permission": "doom_loop", "recent_signatures": [...]},
|
||||
}
|
||||
|
||||
so the frontend that already handles HITL prompts can render this with
|
||||
no changes beyond a string check.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _signature(name: str, args: Any) -> str:
|
||||
"""Hash a tool call ``(name, args)`` to a short signature."""
|
||||
try:
|
||||
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = repr(args)
|
||||
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
|
||||
return digest[:16]
|
||||
|
||||
|
||||
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Detect repeated identical tool calls and prompt the user.
|
||||
|
||||
Tracks a sliding window of the most-recent ``threshold`` tool-call
|
||||
signatures across the live request. When all entries match, raise
|
||||
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
|
||||
|
||||
Args:
|
||||
threshold: How many consecutive identical signatures count as a
|
||||
doom loop. Default 3 (matches OpenCode's processor.ts).
|
||||
"""
|
||||
|
||||
def __init__(self, *, threshold: int = 3) -> None:
|
||||
super().__init__()
|
||||
if threshold < 2:
|
||||
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
|
||||
self._threshold = threshold
|
||||
self.tools = []
|
||||
# Per-thread sliding windows. We can't put this in graph state
|
||||
# without state-schema gymnastics; for one process-lifetime it's
|
||||
# fine to keep an in-memory map keyed by thread_id.
|
||||
self._windows: dict[str, deque[str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
|
||||
"""Resolve the thread id for sliding-window keying.
|
||||
|
||||
Prefer LangGraph's ``get_config()`` (the only way to read
|
||||
``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry
|
||||
a ``config`` attribute). Fall back to ``runtime.config`` for unit
|
||||
tests that synthesize a config-bearing stub. Default
|
||||
``"no_thread"`` is intentionally only used when both lookups fail
|
||||
— it would collapse all threads into one window so we keep the
|
||||
debug log loud.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
tid = _from_dict(getattr(runtime, "config", None))
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
logger.debug(
|
||||
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"falling back to shared 'no_thread' window."
|
||||
)
|
||||
return "no_thread"
|
||||
|
||||
def _window(self, thread_id: str) -> deque[str]:
|
||||
win = self._windows.get(thread_id)
|
||||
if win is None:
|
||||
win = deque(maxlen=self._threshold)
|
||||
self._windows[thread_id] = win
|
||||
return win
|
||||
|
||||
def _detect(
|
||||
self, message: AIMessage, runtime: Runtime[ContextT]
|
||||
) -> tuple[bool, list[str], dict[str, Any] | None]:
|
||||
if not message.tool_calls:
|
||||
return False, [], None
|
||||
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
window = self._window(thread_id)
|
||||
|
||||
triggered_call: dict[str, Any] | None = None
|
||||
for call in message.tool_calls:
|
||||
name = (
|
||||
call.get("name")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "name", None)
|
||||
)
|
||||
args = (
|
||||
call.get("args")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "args", {})
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
sig = _signature(name, args)
|
||||
window.append(sig)
|
||||
if len(window) >= self._threshold and len(set(window)) == 1:
|
||||
triggered_call = {"name": name, "params": args or {}}
|
||||
break
|
||||
|
||||
if triggered_call is None:
|
||||
return False, list(window), None
|
||||
return True, list(window), triggered_call
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
triggered, signatures, action = self._detect(last, runtime)
|
||||
if not triggered:
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
|
||||
action["name"] if action else "<unknown>",
|
||||
self._threshold,
|
||||
signatures[-1] if signatures else "<empty>",
|
||||
)
|
||||
|
||||
# Open an interrupt.raised span with permission=doom_loop attribute
|
||||
# so dashboards can break out doom-loop interrupts from regular
|
||||
# permission asks via the ``interrupt.permission`` attribute.
|
||||
with ot.interrupt_span(
|
||||
interrupt_type="permission_ask",
|
||||
extra={
|
||||
"interrupt.permission": "doom_loop",
|
||||
"interrupt.threshold": self._threshold,
|
||||
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
|
||||
},
|
||||
):
|
||||
ot_metrics.record_interrupt(interrupt_type="permission_ask")
|
||||
decision = interrupt(
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": action or {"tool": "<unknown>", "params": {}},
|
||||
"context": {
|
||||
"permission": "doom_loop",
|
||||
"recent_signatures": signatures,
|
||||
"threshold": self._threshold,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Reset window so the next decision (continue/cancel) starts fresh.
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
self._windows.pop(thread_id, None)
|
||||
|
||||
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
|
||||
# If the user cancelled, jump to end. Otherwise return ``None`` so the
|
||||
# tool call proceeds. The frontend's exact reply names may differ —
|
||||
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||
if isinstance(decision, dict):
|
||||
kind = str(
|
||||
decision.get("decision_type") or decision.get("type") or ""
|
||||
).lower()
|
||||
if "reject" in kind or "cancel" in kind:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DoomLoopMiddleware",
|
||||
"_signature",
|
||||
]
|
||||
334
surfsense_backend/app/agents/shared/middleware/file_intent.py
Normal file
334
surfsense_backend/app/agents/shared/middleware/file_intent.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
"""Semantic file-intent routing middleware for new chat turns.
|
||||
|
||||
This middleware classifies the latest human turn into a small intent set:
|
||||
- chat_only
|
||||
- file_write
|
||||
- file_read
|
||||
|
||||
For ``file_write`` turns it injects a strict system contract so the model
|
||||
uses filesystem tools before claiming success, and provides a deterministic
|
||||
fallback path when no filename is specified by the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileOperationIntent(StrEnum):
|
||||
CHAT_ONLY = "chat_only"
|
||||
FILE_WRITE = "file_write"
|
||||
FILE_READ = "file_read"
|
||||
|
||||
|
||||
class FileIntentPlan(BaseModel):
|
||||
intent: FileOperationIntent = Field(
|
||||
description="Primary user intent for this turn."
|
||||
)
|
||||
confidence: float = Field(
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.5,
|
||||
description="Model confidence in the selected intent.",
|
||||
)
|
||||
suggested_filename: str | None = Field(
|
||||
default=None,
|
||||
description="Optional filename (e.g. notes.md) inferred from user request.",
|
||||
)
|
||||
suggested_directory: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional directory path (e.g. /reports/q2 or reports/q2) inferred from "
|
||||
"user request."
|
||||
),
|
||||
)
|
||||
suggested_path: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional full file path (e.g. /reports/q2/summary.md). If present, this "
|
||||
"takes precedence over suggested_directory + suggested_filename."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_json_payload(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||
if fenced:
|
||||
return fenced.group(1)
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start : end + 1]
|
||||
return stripped
|
||||
|
||||
|
||||
def _sanitize_filename(value: str) -> str:
|
||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
name = re.sub(r"\s+", "-", name)
|
||||
name = name.strip("._-")
|
||||
if not name:
|
||||
name = "note"
|
||||
if len(name) > 80:
|
||||
name = name[:80].rstrip("-_.")
|
||||
return name
|
||||
|
||||
|
||||
def _sanitize_path_segment(value: str) -> str:
|
||||
segment = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||
segment = re.sub(r"\s+", "_", segment)
|
||||
segment = segment.strip("._-")
|
||||
return segment
|
||||
|
||||
|
||||
def _normalize_directory(value: str) -> str:
|
||||
raw = value.strip().replace("\\", "/")
|
||||
raw = raw.strip("/")
|
||||
if not raw:
|
||||
return ""
|
||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||
parts = [part for part in parts if part]
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def _normalize_file_path(value: str) -> str:
|
||||
raw = value.strip().replace("\\", "/").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
had_trailing_slash = raw.endswith("/")
|
||||
raw = raw.strip("/")
|
||||
if not raw:
|
||||
return ""
|
||||
parts = [_sanitize_path_segment(part) for part in raw.split("/") if part.strip()]
|
||||
parts = [part for part in parts if part]
|
||||
if not parts:
|
||||
return ""
|
||||
if had_trailing_slash:
|
||||
return f"/{'/'.join(parts)}/"
|
||||
return f"/{'/'.join(parts)}"
|
||||
|
||||
|
||||
def _infer_directory_from_user_text(user_text: str) -> str | None:
|
||||
patterns = (
|
||||
r"\b(?:in|inside|under)\s+(?:the\s+)?([a-zA-Z0-9 _\-/]+?)\s+folder\b",
|
||||
r"\b(?:in|inside|under)\s+([a-zA-Z0-9 _\-/]+?)\b",
|
||||
)
|
||||
lowered = user_text.lower()
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, lowered, flags=re.IGNORECASE)
|
||||
if not match:
|
||||
continue
|
||||
candidate = match.group(1).strip()
|
||||
if candidate in {"the", "a", "an"}:
|
||||
continue
|
||||
normalized = _normalize_directory(candidate)
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def _fallback_path(
|
||||
suggested_filename: str | None,
|
||||
*,
|
||||
suggested_directory: str | None = None,
|
||||
suggested_path: str | None = None,
|
||||
user_text: str,
|
||||
) -> str:
|
||||
inferred_dir = _infer_directory_from_user_text(user_text)
|
||||
|
||||
sanitized_filename = ""
|
||||
if suggested_filename:
|
||||
sanitized_filename = _sanitize_filename(suggested_filename)
|
||||
if sanitized_filename.lower().endswith(".txt"):
|
||||
sanitized_filename = f"{sanitized_filename[:-4]}.md"
|
||||
if not sanitized_filename:
|
||||
sanitized_filename = "notes.md"
|
||||
elif "." not in sanitized_filename:
|
||||
sanitized_filename = f"{sanitized_filename}.md"
|
||||
|
||||
normalized_suggested_path = (
|
||||
_normalize_file_path(suggested_path) if suggested_path else ""
|
||||
)
|
||||
if normalized_suggested_path:
|
||||
if normalized_suggested_path.endswith("/"):
|
||||
return f"{normalized_suggested_path.rstrip('/')}/{sanitized_filename}"
|
||||
return normalized_suggested_path
|
||||
|
||||
directory = _normalize_directory(suggested_directory or "")
|
||||
if not directory and inferred_dir:
|
||||
directory = inferred_dir
|
||||
if directory:
|
||||
return f"/{directory}/{sanitized_filename}"
|
||||
|
||||
return f"/{sanitized_filename}"
|
||||
|
||||
|
||||
def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str:
|
||||
return (
|
||||
"Classify the latest user request into a filesystem intent for an AI agent.\n"
|
||||
"Return JSON only with this exact schema:\n"
|
||||
'{"intent":"chat_only|file_write|file_read","confidence":0.0,"suggested_filename":"string or null","suggested_directory":"string or null","suggested_path":"string or null"}\n\n'
|
||||
"Rules:\n"
|
||||
"- Use semantic intent, not literal keywords.\n"
|
||||
"- file_write: user asks to create/save/write/update/edit content as a file.\n"
|
||||
"- file_read: user asks to open/read/list/search existing files.\n"
|
||||
"- chat_only: conversational/analysis responses without required file operations.\n"
|
||||
"- For file_write, choose a concise semantic suggested_filename and match the requested format.\n"
|
||||
"- If the user mentions a folder/directory, populate suggested_directory.\n"
|
||||
"- If user specifies an explicit full path, populate suggested_path.\n"
|
||||
"- Use extensions that match user intent (e.g. .md, .json, .yaml, .csv, .py, .ts, .js, .html, .css, .sql).\n"
|
||||
"- Do not use .txt; prefer .md for generic text notes.\n"
|
||||
"- Do not include dates or timestamps in suggested_filename unless explicitly requested.\n"
|
||||
"- Never include markdown or explanation.\n\n"
|
||||
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
|
||||
f"Latest user message:\n{user_text}"
|
||||
)
|
||||
|
||||
|
||||
def _build_recent_conversation(
|
||||
messages: list[BaseMessage], *, max_messages: int = 6
|
||||
) -> str:
|
||||
rows: list[str] = []
|
||||
filtered: list[tuple[str, BaseMessage]] = []
|
||||
for msg in messages:
|
||||
role: str | None = None
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(msg, AIMessage):
|
||||
if getattr(msg, "tool_calls", None):
|
||||
continue
|
||||
role = "assistant"
|
||||
else:
|
||||
continue
|
||||
filtered.append((role, msg))
|
||||
for role, msg in filtered[-max_messages:]:
|
||||
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||
if text:
|
||||
rows.append(f"{role}: {text[:280]}")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Classify file intent and inject a strict file-write contract."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, llm: BaseChatModel | None = None) -> None:
|
||||
self.llm = llm
|
||||
|
||||
async def _classify_intent(
|
||||
self, *, messages: list[BaseMessage], user_text: str
|
||||
) -> FileIntentPlan:
|
||||
if self.llm is None:
|
||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||
|
||||
prompt = _build_classifier_prompt(
|
||||
recent_conversation=_build_recent_conversation(messages),
|
||||
user_text=user_text,
|
||||
)
|
||||
try:
|
||||
response = await self.llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
payload = json.loads(
|
||||
_extract_json_payload(_extract_text_from_message(response))
|
||||
)
|
||||
plan = FileIntentPlan.model_validate(payload)
|
||||
return plan
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
|
||||
logger.warning("File intent classifier returned invalid output: %s", exc)
|
||||
except Exception as exc: # pragma: no cover - defensive fallback
|
||||
logger.warning("File intent classifier failed: %s", exc)
|
||||
|
||||
return FileIntentPlan(intent=FileOperationIntent.CHAT_ONLY, confidence=0.0)
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_human: HumanMessage | None = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
last_human = msg
|
||||
break
|
||||
if last_human is None:
|
||||
return None
|
||||
|
||||
user_text = _extract_text_from_message(last_human).strip()
|
||||
if not user_text:
|
||||
return None
|
||||
|
||||
plan = await self._classify_intent(messages=messages, user_text=user_text)
|
||||
suggested_path = _fallback_path(
|
||||
plan.suggested_filename,
|
||||
suggested_directory=plan.suggested_directory,
|
||||
suggested_path=plan.suggested_path,
|
||||
user_text=user_text,
|
||||
)
|
||||
contract = {
|
||||
"intent": plan.intent.value,
|
||||
"confidence": plan.confidence,
|
||||
"suggested_path": suggested_path,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"turn_id": state.get("turn_id", ""),
|
||||
}
|
||||
|
||||
if plan.intent != FileOperationIntent.FILE_WRITE:
|
||||
return {"file_operation_contract": contract}
|
||||
|
||||
contract_msg = SystemMessage(
|
||||
content=(
|
||||
"<file_operation_contract>\n"
|
||||
"This turn intent is file_write.\n"
|
||||
f"Suggested default path: {suggested_path}\n"
|
||||
"Rules:\n"
|
||||
"- You MUST call write_file or edit_file before claiming success.\n"
|
||||
"- If no path is provided by the user, use the suggested default path.\n"
|
||||
"- Do not claim a file was created/updated unless tool output confirms it.\n"
|
||||
"- If the write/edit fails, clearly report failure instead of success.\n"
|
||||
"- Do not include timestamps or dates in generated file content unless the user explicitly asks for them.\n"
|
||||
"- For open-ended requests (e.g., random note), generate useful concrete content, not placeholders.\n"
|
||||
"</file_operation_contract>"
|
||||
)
|
||||
)
|
||||
|
||||
# Insert just before the latest human turn so it applies to this request.
|
||||
new_messages = list(messages)
|
||||
insert_at = max(len(new_messages) - 1, 0)
|
||||
new_messages.insert(insert_at, contract_msg)
|
||||
return {"messages": new_messages, "file_operation_contract": contract}
|
||||
1998
surfsense_backend/app/agents/shared/middleware/filesystem.py
Normal file
1998
surfsense_backend/app/agents/shared/middleware/filesystem.py
Normal file
File diff suppressed because it is too large
Load diff
233
surfsense_backend/app/agents/shared/middleware/flatten_system.py
Normal file
233
surfsense_backend/app/agents/shared/middleware/flatten_system.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
r"""Coalesce multi-block system messages into a single text block.
|
||||
|
||||
Several middlewares in our deepagent stack each call
|
||||
``append_to_system_message`` on the way down to the model
|
||||
(``TodoListMiddleware``, ``SurfSenseFilesystemMiddleware``,
|
||||
``SkillsMiddleware``, ``SubAgentMiddleware`` …). By the time the
|
||||
request reaches the LLM, the system message has 5+ separate text blocks.
|
||||
|
||||
Anthropic enforces a hard cap of **4 ``cache_control`` blocks per
|
||||
request**, and we configure 2 injection points
|
||||
(``index: 0`` + ``index: -1``). With ``index: 0`` always targeting
|
||||
the prepended ``request.system_message``, this middleware is the
|
||||
defensive partner: it guarantees that "the system block" is *one*
|
||||
content block, so LiteLLM's ``AnthropicCacheControlHook`` and any
|
||||
OpenRouter→Anthropic transformer can never multiply our budget into
|
||||
several breakpoints by spreading ``cache_control`` across multiple
|
||||
text blocks of a multi-block system content.
|
||||
|
||||
Without flattening we used to see::
|
||||
|
||||
OpenrouterException - {"error":{"message":"Provider returned error",
|
||||
"code":400,"metadata":{"raw":"...A maximum of 4 blocks with
|
||||
cache_control may be provided. Found 5."}}}
|
||||
|
||||
(Same error class documented in
|
||||
https://github.com/BerriAI/litellm/issues/15696 and
|
||||
https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||
in PR #15395 covers the litellm transformer but does not protect us
|
||||
when the OpenRouter SaaS itself does the redistribution.)
|
||||
|
||||
A separate fix in :mod:`app.agents.shared.prompt_caching` (switching
|
||||
the first injection point from ``role: system`` to ``index: 0``)
|
||||
neutralises the *primary* cause of the same 400 — multiple
|
||||
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||
(priority/tree/memory/file-intent/anonymous-doc) accumulating across
|
||||
turns, each tagged with ``cache_control`` by the ``role: system``
|
||||
matcher. This middleware remains useful as defence-in-depth against
|
||||
the multi-block redistribution path.
|
||||
|
||||
Placement: innermost on the system-message-mutation chain, after every
|
||||
appender (``todo``/``filesystem``/``skills``/``subagents``) and after
|
||||
summarization, but before ``noop``/``retry``/``fallback`` so each retry
|
||||
attempt sees a flattened payload. See ``chat_deepagent.py``.
|
||||
|
||||
Idempotent: a string-content system message is left untouched. A list
|
||||
that contains anything other than plain text blocks (e.g. an image) is
|
||||
also left untouched — those are rare on system messages and we'd lose
|
||||
the non-text payload by joining.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flatten_text_blocks(content: list[Any]) -> str | None:
|
||||
"""Return joined text if every block is a plain ``{"type": "text"}``.
|
||||
|
||||
Returns ``None`` when the list contains anything that isn't a text
|
||||
block we can safely concatenate (image, audio, file, non-standard
|
||||
blocks, dicts with extra non-cache_control fields). The caller
|
||||
leaves the original content untouched in that case rather than
|
||||
silently dropping payload.
|
||||
|
||||
``cache_control`` on individual blocks is intentionally discarded —
|
||||
the whole point of flattening is to let LiteLLM's
|
||||
``cache_control_injection_points`` re-place a single breakpoint on
|
||||
the resulting one-block system content.
|
||||
"""
|
||||
chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
return None
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
chunks.append(text)
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def _flattened_request(
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT] | None:
|
||||
"""Return a request with system_message flattened, or ``None`` for no-op."""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
return None
|
||||
content = sys_msg.content
|
||||
if not isinstance(content, list) or len(content) <= 1:
|
||||
return None
|
||||
|
||||
flattened = _flatten_text_blocks(content)
|
||||
if flattened is None:
|
||||
return None
|
||||
|
||||
new_sys = SystemMessage(
|
||||
content=flattened,
|
||||
additional_kwargs=dict(sys_msg.additional_kwargs),
|
||||
response_metadata=dict(sys_msg.response_metadata),
|
||||
)
|
||||
if sys_msg.id is not None:
|
||||
new_sys.id = sys_msg.id
|
||||
return request.override(system_message=new_sys)
|
||||
|
||||
|
||||
def _diagnostic_summary(request: ModelRequest[Any]) -> str:
|
||||
"""One-line dump of cache_control-relevant request shape.
|
||||
|
||||
Temporary diagnostic to prove where the ``Found N`` cache_control
|
||||
breakpoints are coming from when Anthropic 400s. Removed once the
|
||||
root cause is confirmed and a fix is in place.
|
||||
"""
|
||||
sys_msg = request.system_message
|
||||
if sys_msg is None:
|
||||
sys_shape = "none"
|
||||
elif isinstance(sys_msg.content, str):
|
||||
sys_shape = f"str(len={len(sys_msg.content)})"
|
||||
elif isinstance(sys_msg.content, list):
|
||||
sys_shape = f"list(blocks={len(sys_msg.content)})"
|
||||
else:
|
||||
sys_shape = f"other({type(sys_msg.content).__name__})"
|
||||
|
||||
role_hist: list[str] = []
|
||||
multi_block_msgs = 0
|
||||
msgs_with_cc = 0
|
||||
sys_msgs_in_history = 0
|
||||
for m in request.messages:
|
||||
mtype = getattr(m, "type", type(m).__name__)
|
||||
role_hist.append(mtype)
|
||||
if isinstance(m, SystemMessage):
|
||||
sys_msgs_in_history += 1
|
||||
c = getattr(m, "content", None)
|
||||
if isinstance(c, list):
|
||||
multi_block_msgs += 1
|
||||
for blk in c:
|
||||
if isinstance(blk, dict) and "cache_control" in blk:
|
||||
msgs_with_cc += 1
|
||||
break
|
||||
if "cache_control" in getattr(m, "additional_kwargs", {}) or {}:
|
||||
msgs_with_cc += 1
|
||||
|
||||
tools = request.tools or []
|
||||
tools_with_cc = 0
|
||||
for t in tools:
|
||||
if isinstance(t, dict) and (
|
||||
"cache_control" in t or "cache_control" in t.get("function", {})
|
||||
):
|
||||
tools_with_cc += 1
|
||||
|
||||
return (
|
||||
f"sys={sys_shape} msgs={len(request.messages)} "
|
||||
f"sys_msgs_in_history={sys_msgs_in_history} "
|
||||
f"multi_block_msgs={multi_block_msgs} pre_existing_msg_cc={msgs_with_cc} "
|
||||
f"tools={len(tools)} pre_existing_tool_cc={tools_with_cc} "
|
||||
f"roles={role_hist[-8:]}"
|
||||
)
|
||||
|
||||
|
||||
class FlattenSystemMessageMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Collapse a multi-text-block system message to a single string.
|
||||
|
||||
Sits innermost on the system-message-mutation chain so it observes
|
||||
every middleware's contribution. Has no other side effect — the
|
||||
body of every block is preserved, just joined with ``"\\n\\n"``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tools = []
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return handler(flattened)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("[flatten_system_diag] %s", _diagnostic_summary(request))
|
||||
flattened = _flattened_request(request)
|
||||
if flattened is not None:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"[flatten_system] collapsed %d system blocks to one",
|
||||
len(request.system_message.content), # type: ignore[arg-type, union-attr]
|
||||
)
|
||||
return await handler(flattened)
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FlattenSystemMessageMiddleware",
|
||||
"_flatten_text_blocks",
|
||||
"_flattened_request",
|
||||
]
|
||||
1546
surfsense_backend/app/agents/shared/middleware/kb_persistence.py
Normal file
1546
surfsense_backend/app/agents/shared/middleware/kb_persistence.py
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
1057
surfsense_backend/app/agents/shared/middleware/knowledge_search.py
Normal file
1057
surfsense_backend/app/agents/shared/middleware/knowledge_search.py
Normal file
File diff suppressed because it is too large
Load diff
332
surfsense_backend/app/agents/shared/middleware/knowledge_tree.py
Normal file
332
surfsense_backend/app/agents/shared/middleware/knowledge_tree.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
"""Workspace-tree middleware for the SurfSense agent.
|
||||
|
||||
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
|
||||
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
|
||||
injects the result as a ``<workspace_tree>`` system message immediately
|
||||
before the latest human turn.
|
||||
|
||||
The render is bounded by two truncation layers:
|
||||
|
||||
1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is
|
||||
replaced with a "use ls" hint.
|
||||
2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
|
||||
token-count profile when available). If the entry-truncated tree still
|
||||
exceeds the token cap we fall back to a root-only summary.
|
||||
|
||||
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
|
||||
|
||||
This middleware also performs a one-time initialization of ``state['cwd']``
|
||||
to ``"/documents"`` so subsequent middlewares and tools always see a valid
|
||||
cwd in cloud mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, shielded_async_session
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
try:
|
||||
from litellm import token_counter
|
||||
except Exception: # pragma: no cover - optional dep
|
||||
token_counter = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_TREE_ENTRIES = 500
|
||||
MAX_TREE_TOKENS = 4000
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
|
||||
return max(1, (len(text) + 3) // 4)
|
||||
|
||||
|
||||
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
|
||||
if llm is None:
|
||||
return _approx_tokens(text)
|
||||
count_fn = getattr(llm, "_count_tokens", None)
|
||||
if callable(count_fn):
|
||||
try:
|
||||
return int(count_fn([{"role": "user", "content": text}]))
|
||||
except Exception:
|
||||
pass
|
||||
profile = getattr(llm, "profile", None)
|
||||
model_names: list[str] = []
|
||||
if isinstance(profile, dict):
|
||||
tcms = profile.get("token_count_models")
|
||||
if isinstance(tcms, list):
|
||||
model_names.extend(name for name in tcms if isinstance(name, str) and name)
|
||||
tcm = profile.get("token_count_model")
|
||||
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||
model_names.append(tcm)
|
||||
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||
if not isinstance(model_name, str) or not model_name or token_counter is None:
|
||||
return _approx_tokens(text)
|
||||
try:
|
||||
return int(
|
||||
token_counter(
|
||||
messages=[{"role": "user", "content": text}],
|
||||
model=model_name,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return _approx_tokens(text)
|
||||
|
||||
|
||||
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Inject the workspace folder/document tree into the agent's context."""
|
||||
|
||||
tools = ()
|
||||
state_schema = SurfSenseFilesystemState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
llm: BaseChatModel | None = None,
|
||||
max_entries: int = MAX_TREE_ENTRIES,
|
||||
max_tokens: int = MAX_TREE_TOKENS,
|
||||
inject_system_message: bool = True, # For backwards compatibility
|
||||
) -> None:
|
||||
self.search_space_id = search_space_id
|
||||
self.filesystem_mode = filesystem_mode
|
||||
self.llm = llm
|
||||
self.max_entries = max_entries
|
||||
self.max_tokens = max_tokens
|
||||
self.inject_system_message = inject_system_message
|
||||
self._cache: dict[tuple[int, int, bool], str] = {}
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
|
||||
start = time.perf_counter()
|
||||
update: dict[str, Any] = {}
|
||||
if not state.get("cwd"):
|
||||
update["cwd"] = DOCUMENTS_ROOT
|
||||
|
||||
anon_doc = state.get("kb_anon_doc")
|
||||
if anon_doc:
|
||||
tree_msg = self._render_anon_tree(anon_doc)
|
||||
cache_outcome = "anon"
|
||||
else:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cache_outcome = "hit" if cache_key in self._cache else "miss"
|
||||
tree_msg = await self._render_kb_tree(state)
|
||||
|
||||
update["workspace_tree_text"] = tree_msg
|
||||
|
||||
if self.inject_system_message:
|
||||
messages = list(state.get("messages") or [])
|
||||
insert_at = max(len(messages) - 1, 0)
|
||||
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||
update["messages"] = messages
|
||||
|
||||
_perf_log.info(
|
||||
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
|
||||
cache_outcome,
|
||||
len(tree_msg),
|
||||
time.perf_counter() - start,
|
||||
self.search_space_id,
|
||||
)
|
||||
return update
|
||||
|
||||
def before_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return None
|
||||
except RuntimeError:
|
||||
pass
|
||||
return asyncio.run(self.abefore_agent(state, runtime))
|
||||
|
||||
# ------------------------------------------------------------------ render
|
||||
|
||||
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
|
||||
path = str(anon_doc.get("path") or "")
|
||||
title = str(anon_doc.get("title") or "uploaded_document")
|
||||
return (
|
||||
"<workspace_tree>\n"
|
||||
"Anonymous session — only one read-only document is available.\n"
|
||||
f"{DOCUMENTS_ROOT}/\n"
|
||||
f" {path} — {title}\n"
|
||||
"</workspace_tree>"
|
||||
)
|
||||
|
||||
async def _render_kb_tree(self, state: AgentState) -> str:
|
||||
version = int(state.get("tree_version") or 0)
|
||||
cache_key = (self.search_space_id, version, False)
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
index = await build_path_index(session, self.search_space_id)
|
||||
doc_rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == self.search_space_id
|
||||
)
|
||||
)
|
||||
docs = list(doc_rows.all())
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("knowledge_tree: DB error %s", exc)
|
||||
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
|
||||
|
||||
rendered = self._format_tree(index, docs)
|
||||
self._cache[cache_key] = rendered
|
||||
return rendered
|
||||
|
||||
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
|
||||
folder_paths = sorted(set(index.folder_paths.values()))
|
||||
doc_paths = sorted(
|
||||
doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
for row in docs
|
||||
)
|
||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||
|
||||
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||
# infer emptiness from indentation alone.
|
||||
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||
|
||||
lines: list[str] = []
|
||||
for path in all_paths:
|
||||
depth = (
|
||||
0
|
||||
if path == DOCUMENTS_ROOT
|
||||
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
|
||||
)
|
||||
indent = " " * depth
|
||||
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
|
||||
display = (
|
||||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||
)
|
||||
if is_dir:
|
||||
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||
lines.append(f"{indent}{display}/ (empty)")
|
||||
else:
|
||||
lines.append(f"{indent}{display}/")
|
||||
else:
|
||||
lines.append(f"{indent}{display}")
|
||||
if len(lines) >= self.max_entries:
|
||||
remaining = len(all_paths) - len(lines)
|
||||
if remaining > 0:
|
||||
lines.append(
|
||||
f"... {remaining} more entries — use "
|
||||
"ls('/documents/<folder>', offset, limit) to expand"
|
||||
)
|
||||
break
|
||||
|
||||
body = "\n".join(lines)
|
||||
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
|
||||
|
||||
token_count = _count_tokens(rendered, llm=self.llm)
|
||||
if token_count <= self.max_tokens:
|
||||
return rendered
|
||||
|
||||
return self._format_root_summary(folder_paths, doc_paths)
|
||||
|
||||
@staticmethod
|
||||
def _compute_non_empty_folders(
|
||||
folder_paths: list[str], doc_paths: list[str]
|
||||
) -> set[str]:
|
||||
"""Return the set of folder paths that contain at least one descendant.
|
||||
|
||||
A folder is "non-empty" if any document path or any other folder path
|
||||
is strictly under it. Documents propagate emptiness up to every
|
||||
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||
"""
|
||||
non_empty: set[str] = set()
|
||||
folder_set = set(folder_paths)
|
||||
|
||||
for doc_path in doc_paths:
|
||||
parent = doc_path.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT:
|
||||
if parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
for child in folder_paths:
|
||||
parent = child.rsplit("/", 1)[0]
|
||||
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||
non_empty.add(parent)
|
||||
parent = parent.rsplit("/", 1)[0]
|
||||
|
||||
return non_empty
|
||||
|
||||
def _format_root_summary(
|
||||
self, folder_paths: list[str], doc_paths: list[str]
|
||||
) -> str:
|
||||
top_level: dict[str, int] = {}
|
||||
loose_docs = 0
|
||||
for path in doc_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if "/" in rel:
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level[top] = top_level.get(top, 0) + 1
|
||||
else:
|
||||
loose_docs += 1
|
||||
for path in folder_paths:
|
||||
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
continue
|
||||
top = rel.split("/", 1)[0]
|
||||
top_level.setdefault(top, 0)
|
||||
|
||||
lines = [DOCUMENTS_ROOT + "/"]
|
||||
for name in sorted(top_level):
|
||||
count = top_level[name]
|
||||
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
|
||||
if loose_docs:
|
||||
lines.append(
|
||||
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
|
||||
)
|
||||
lines.append(
|
||||
"Tree is large; use list_tree('/documents/<folder>') to drill in "
|
||||
"or ls('/documents/<folder>', offset, limit) for paginated listings."
|
||||
)
|
||||
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
|
||||
|
||||
|
||||
__all__ = ["KnowledgeTreeMiddleware"]
|
||||
|
|
@ -0,0 +1,613 @@
|
|||
"""Desktop local-folder filesystem backend for deepagents tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import os
|
||||
import threading
|
||||
from collections import deque
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends.protocol import (
|
||||
EditResult,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
FileUploadResponse,
|
||||
GrepMatch,
|
||||
WriteResult,
|
||||
)
|
||||
from deepagents.backends.utils import (
|
||||
create_file_data,
|
||||
format_read_response,
|
||||
perform_string_replacement,
|
||||
)
|
||||
|
||||
_INVALID_PATH = "invalid_path"
|
||||
_FILE_NOT_FOUND = "file_not_found"
|
||||
_IS_DIRECTORY = "is_directory"
|
||||
|
||||
|
||||
class LocalFolderBackend:
|
||||
"""Filesystem backend rooted to a single local folder."""
|
||||
|
||||
def __init__(self, root_path: str) -> None:
|
||||
root = Path(root_path).expanduser().resolve()
|
||||
if not root.exists() or not root.is_dir():
|
||||
msg = f"Local filesystem root does not exist or is not a directory: {root_path}"
|
||||
raise ValueError(msg)
|
||||
self._root = root
|
||||
self._locks: dict[str, threading.Lock] = {}
|
||||
self._locks_mu = threading.Lock()
|
||||
|
||||
def _lock_for(self, path: str) -> threading.Lock:
|
||||
with self._locks_mu:
|
||||
if path not in self._locks:
|
||||
self._locks[path] = threading.Lock()
|
||||
return self._locks[path]
|
||||
|
||||
def _resolve_virtual(self, virtual_path: str, *, allow_root: bool = False) -> Path:
|
||||
if not virtual_path.startswith("/"):
|
||||
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||
raise ValueError(msg)
|
||||
rel = virtual_path.lstrip("/")
|
||||
candidate = self._root if rel == "" else (self._root / rel)
|
||||
resolved = candidate.resolve()
|
||||
if not allow_root and resolved == self._root:
|
||||
msg = "Path must refer to a file or child directory under root"
|
||||
raise ValueError(msg)
|
||||
if not resolved.is_relative_to(self._root):
|
||||
msg = f"Path escapes local filesystem root: {virtual_path}"
|
||||
raise ValueError(msg)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _to_virtual(path: Path, root: Path) -> str:
|
||||
rel = path.relative_to(root).as_posix()
|
||||
return "/" if rel == "." else f"/{rel}"
|
||||
|
||||
def _write_text_atomic(self, path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = path.with_suffix(f"{path.suffix}.tmp")
|
||||
temp_path.write_text(content, encoding="utf-8")
|
||||
os.replace(temp_path, path)
|
||||
|
||||
def _acquire_path_locks(self, *paths: str) -> ExitStack:
|
||||
ordered_paths = sorted(set(paths))
|
||||
stack = ExitStack()
|
||||
for path in ordered_paths:
|
||||
stack.enter_context(self._lock_for(path))
|
||||
return stack
|
||||
|
||||
@staticmethod
|
||||
def _clamp_page_size(page_size: int) -> int:
|
||||
return max(1, min(page_size, 1000))
|
||||
|
||||
def _read_dir_entries(self, directory_path: str) -> list[dict[str, Any]]:
|
||||
directory = Path(directory_path)
|
||||
try:
|
||||
children = sorted(
|
||||
directory.iterdir(),
|
||||
key=lambda p: (not p.is_dir(), p.name.lower()),
|
||||
)
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
for child in children:
|
||||
try:
|
||||
stat_result = child.stat()
|
||||
except OSError:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"path": self._to_virtual(child, self._root),
|
||||
"is_dir": child.is_dir(),
|
||||
"size": stat_result.st_size if child.is_file() else 0,
|
||||
"modified_at": str(stat_result.st_mtime),
|
||||
"absolute_path": str(child),
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
try:
|
||||
target = self._resolve_virtual(path, allow_root=True)
|
||||
except ValueError:
|
||||
return []
|
||||
if not target.exists() or not target.is_dir():
|
||||
return []
|
||||
infos: list[FileInfo] = []
|
||||
for child in sorted(
|
||||
target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())
|
||||
):
|
||||
infos.append(
|
||||
FileInfo(
|
||||
path=self._to_virtual(child, self._root),
|
||||
is_dir=child.is_dir(),
|
||||
size=child.stat().st_size if child.is_file() else 0,
|
||||
modified_at=str(child.stat().st_mtime),
|
||||
)
|
||||
)
|
||||
return infos
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.ls_info, path)
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return f"Error: Invalid path '{file_path}'"
|
||||
if not path.exists():
|
||||
return f"Error: File '{file_path}' not found"
|
||||
if not path.is_file():
|
||||
return f"Error: Path '{file_path}' is not a file"
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
file_data = create_file_data(content)
|
||||
return format_read_response(file_data, offset, limit)
|
||||
|
||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||
|
||||
def read_raw(self, file_path: str) -> str:
|
||||
"""Read raw file text without line-number formatting."""
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return f"Error: Invalid path '{file_path}'"
|
||||
if not path.exists():
|
||||
return f"Error: File '{file_path}' not found"
|
||||
if not path.is_file():
|
||||
return f"Error: Path '{file_path}' is not a file"
|
||||
return path.read_text(encoding="utf-8", errors="replace")
|
||||
|
||||
async def aread_raw(self, file_path: str) -> str:
|
||||
"""Async variant of read_raw."""
|
||||
return await asyncio.to_thread(self.read_raw, file_path)
|
||||
|
||||
def write(self, file_path: str, content: str) -> WriteResult:
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||
lock = self._lock_for(file_path)
|
||||
with lock:
|
||||
if path.exists():
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Cannot write to {file_path} because it already exists. "
|
||||
"Read and then make an edit, or write to a new path."
|
||||
)
|
||||
)
|
||||
parent = path.parent
|
||||
if not parent.exists() or not parent.is_dir():
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: parent directory for '{file_path}' does not exist. "
|
||||
"Create the folder first or write to an existing directory."
|
||||
)
|
||||
)
|
||||
self._write_text_atomic(path, content)
|
||||
return WriteResult(path=file_path, files_update=None)
|
||||
|
||||
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.write, file_path, content)
|
||||
|
||||
def list_tree(
|
||||
self,
|
||||
path: str = "/",
|
||||
*,
|
||||
max_depth: int | None = 8,
|
||||
page_size: int = 500,
|
||||
include_files: bool = True,
|
||||
include_dirs: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
if not include_files and not include_dirs:
|
||||
return {
|
||||
"entries": [],
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
normalized_depth = None if max_depth is None else max(0, int(max_depth))
|
||||
page_limit = self._clamp_page_size(int(page_size))
|
||||
try:
|
||||
start = self._resolve_virtual(path, allow_root=True)
|
||||
except ValueError:
|
||||
return {"error": f"Error: invalid path '{path}'"}
|
||||
if not start.exists():
|
||||
return {"error": f"Error: path '{path}' not found"}
|
||||
if start.is_file():
|
||||
stat_result = start.stat()
|
||||
if include_files:
|
||||
return {
|
||||
"entries": [
|
||||
{
|
||||
"path": self._to_virtual(start, self._root),
|
||||
"is_dir": False,
|
||||
"size": stat_result.st_size,
|
||||
"modified_at": str(stat_result.st_mtime),
|
||||
"depth": 0,
|
||||
}
|
||||
],
|
||||
"truncated": False,
|
||||
}
|
||||
return {
|
||||
"entries": [],
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
pending_dirs: deque[tuple[str, int]] = deque([(str(start), 0)])
|
||||
entries: list[dict[str, Any]] = []
|
||||
truncated = False
|
||||
while pending_dirs and not truncated:
|
||||
next_dir_path, next_depth = pending_dirs.popleft()
|
||||
active_entries = self._read_dir_entries(next_dir_path)
|
||||
for item in active_entries:
|
||||
item_depth = next_depth + 1
|
||||
if normalized_depth is not None and item_depth > normalized_depth:
|
||||
continue
|
||||
if item["is_dir"]:
|
||||
if normalized_depth is None or item_depth <= normalized_depth:
|
||||
pending_dirs.append((item["absolute_path"], item_depth))
|
||||
if include_dirs:
|
||||
entries.append(
|
||||
{
|
||||
"path": item["path"],
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"modified_at": item["modified_at"],
|
||||
"depth": item_depth,
|
||||
}
|
||||
)
|
||||
elif include_files:
|
||||
entries.append(
|
||||
{
|
||||
"path": item["path"],
|
||||
"is_dir": False,
|
||||
"size": item["size"],
|
||||
"modified_at": item["modified_at"],
|
||||
"depth": item_depth,
|
||||
}
|
||||
)
|
||||
if len(entries) >= page_limit:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
return {
|
||||
"entries": entries,
|
||||
"truncated": truncated,
|
||||
}
|
||||
|
||||
async def alist_tree(
|
||||
self,
|
||||
path: str = "/",
|
||||
*,
|
||||
max_depth: int | None = 8,
|
||||
page_size: int = 500,
|
||||
include_files: bool = True,
|
||||
include_dirs: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
return await asyncio.to_thread(
|
||||
self.list_tree,
|
||||
path,
|
||||
max_depth=max_depth,
|
||||
page_size=page_size,
|
||||
include_files=include_files,
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
|
||||
def move(
|
||||
self,
|
||||
source_path: str,
|
||||
destination_path: str,
|
||||
overwrite: bool = False,
|
||||
) -> WriteResult:
|
||||
try:
|
||||
source = self._resolve_virtual(source_path)
|
||||
destination = self._resolve_virtual(destination_path)
|
||||
except ValueError:
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: invalid source '{source_path}' or destination "
|
||||
f"'{destination_path}' path"
|
||||
)
|
||||
)
|
||||
if source == destination:
|
||||
return WriteResult(error="Error: source and destination paths are the same")
|
||||
with self._acquire_path_locks(source_path, destination_path):
|
||||
if not source.exists():
|
||||
return WriteResult(
|
||||
error=f"Error: source path '{source_path}' not found"
|
||||
)
|
||||
if destination.exists():
|
||||
if not overwrite:
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: destination path '{destination_path}' already exists. "
|
||||
"Set overwrite=True to replace files."
|
||||
)
|
||||
)
|
||||
if source.is_dir() or destination.is_dir():
|
||||
return WriteResult(
|
||||
error=(
|
||||
"Error: overwrite=True is only supported for file-to-file moves."
|
||||
)
|
||||
)
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
if overwrite:
|
||||
os.replace(source, destination)
|
||||
else:
|
||||
source.rename(destination)
|
||||
except OSError as exc:
|
||||
return WriteResult(
|
||||
error=f"Error: failed to move '{source_path}': {exc}"
|
||||
)
|
||||
return WriteResult(
|
||||
path=self._to_virtual(destination, self._root), files_update=None
|
||||
)
|
||||
|
||||
async def amove(
|
||||
self,
|
||||
source_path: str,
|
||||
destination_path: str,
|
||||
overwrite: bool = False,
|
||||
) -> WriteResult:
|
||||
return await asyncio.to_thread(
|
||||
self.move, source_path, destination_path, overwrite
|
||||
)
|
||||
|
||||
def delete_file(self, file_path: str) -> WriteResult:
|
||||
"""Hard-delete a single file under root.
|
||||
|
||||
Refuses directories, root, and missing paths. Roughly mirrors POSIX
|
||||
``rm path``; ``-r`` recursion and glob expansion are explicitly
|
||||
out of scope.
|
||||
"""
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||
with self._lock_for(file_path):
|
||||
if not path.exists():
|
||||
return WriteResult(error=f"Error: File '{file_path}' not found")
|
||||
if path.is_dir():
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: '{file_path}' is a directory. "
|
||||
"Use rmdir for empty directories."
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError as exc:
|
||||
return WriteResult(
|
||||
error=f"Error: failed to delete '{file_path}': {exc}"
|
||||
)
|
||||
return WriteResult(path=file_path, files_update=None)
|
||||
|
||||
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.delete_file, file_path)
|
||||
|
||||
def rmdir(self, dir_path: str) -> WriteResult:
|
||||
"""Hard-delete an empty directory under root.
|
||||
|
||||
Refuses files, root, missing paths, and non-empty directories.
|
||||
``os.rmdir`` is naturally empty-only; we pre-check so the error is
|
||||
clearer for the agent.
|
||||
"""
|
||||
try:
|
||||
path = self._resolve_virtual(dir_path)
|
||||
except ValueError:
|
||||
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
|
||||
with self._lock_for(dir_path):
|
||||
if not path.exists():
|
||||
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
|
||||
if not path.is_dir():
|
||||
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
|
||||
try:
|
||||
next(path.iterdir())
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
return WriteResult(
|
||||
error=(
|
||||
f"Error: directory '{dir_path}' is not empty. "
|
||||
"Remove its contents first."
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.rmdir(path)
|
||||
except OSError as exc:
|
||||
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
|
||||
return WriteResult(path=dir_path, files_update=None)
|
||||
|
||||
async def armdir(self, dir_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
try:
|
||||
path = self._resolve_virtual(file_path)
|
||||
except ValueError:
|
||||
return EditResult(error=f"Error: Invalid path '{file_path}'")
|
||||
lock = self._lock_for(file_path)
|
||||
with lock:
|
||||
if not path.exists() or not path.is_file():
|
||||
return EditResult(error=f"Error: File '{file_path}' not found")
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
result = perform_string_replacement(
|
||||
content, old_string, new_string, replace_all
|
||||
)
|
||||
if isinstance(result, str):
|
||||
return EditResult(error=result)
|
||||
updated_content, occurrences = result
|
||||
self._write_text_atomic(path, updated_content)
|
||||
return EditResult(
|
||||
path=file_path, files_update=None, occurrences=int(occurrences)
|
||||
)
|
||||
|
||||
async def aedit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
return await asyncio.to_thread(
|
||||
self.edit, file_path, old_string, new_string, replace_all
|
||||
)
|
||||
|
||||
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
try:
|
||||
base = self._resolve_virtual(path, allow_root=True)
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
if pattern.startswith("/"):
|
||||
search_base = self._root
|
||||
normalized_pattern = pattern.lstrip("/")
|
||||
else:
|
||||
search_base = base
|
||||
normalized_pattern = pattern
|
||||
|
||||
matches: list[FileInfo] = []
|
||||
for hit in search_base.glob(normalized_pattern):
|
||||
try:
|
||||
resolved = hit.resolve()
|
||||
if not resolved.is_relative_to(self._root):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
matches.append(
|
||||
FileInfo(
|
||||
path=self._to_virtual(resolved, self._root),
|
||||
is_dir=resolved.is_dir(),
|
||||
size=resolved.stat().st_size if resolved.is_file() else 0,
|
||||
modified_at=str(resolved.stat().st_mtime),
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||
|
||||
def _iter_candidate_files(self, path: str | None, glob: str | None) -> list[Path]:
|
||||
base_virtual = path or "/"
|
||||
try:
|
||||
base = self._resolve_virtual(base_virtual, allow_root=True)
|
||||
except ValueError:
|
||||
return []
|
||||
if not base.exists():
|
||||
return []
|
||||
|
||||
candidates = [p for p in base.rglob("*") if p.is_file()]
|
||||
if glob:
|
||||
candidates = [
|
||||
p
|
||||
for p in candidates
|
||||
if fnmatch.fnmatch(self._to_virtual(p, self._root), glob)
|
||||
or fnmatch.fnmatch(p.name, glob)
|
||||
]
|
||||
return candidates
|
||||
|
||||
def grep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
if not pattern:
|
||||
return "Error: pattern cannot be empty"
|
||||
matches: list[GrepMatch] = []
|
||||
for file_path in self._iter_candidate_files(path, glob):
|
||||
try:
|
||||
lines = file_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
except Exception:
|
||||
continue
|
||||
for idx, line in enumerate(lines, start=1):
|
||||
if pattern in line:
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=self._to_virtual(file_path, self._root),
|
||||
line=idx,
|
||||
text=line,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
async def agrep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||
|
||||
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
responses: list[FileUploadResponse] = []
|
||||
for virtual_path, content in files:
|
||||
try:
|
||||
target = self._resolve_virtual(virtual_path)
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = target.with_suffix(f"{target.suffix}.tmp")
|
||||
temp_path.write_bytes(content)
|
||||
os.replace(temp_path, target)
|
||||
responses.append(FileUploadResponse(path=virtual_path, error=None))
|
||||
except FileNotFoundError:
|
||||
responses.append(
|
||||
FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND)
|
||||
)
|
||||
except IsADirectoryError:
|
||||
responses.append(
|
||||
FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY)
|
||||
)
|
||||
except Exception:
|
||||
responses.append(
|
||||
FileUploadResponse(path=virtual_path, error=_INVALID_PATH)
|
||||
)
|
||||
return responses
|
||||
|
||||
async def aupload_files(
|
||||
self, files: list[tuple[str, bytes]]
|
||||
) -> list[FileUploadResponse]:
|
||||
return await asyncio.to_thread(self.upload_files, files)
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
responses: list[FileDownloadResponse] = []
|
||||
for virtual_path in paths:
|
||||
try:
|
||||
target = self._resolve_virtual(virtual_path)
|
||||
if not target.exists():
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=None, error=_FILE_NOT_FOUND
|
||||
)
|
||||
)
|
||||
continue
|
||||
if target.is_dir():
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=None, error=_IS_DIRECTORY
|
||||
)
|
||||
)
|
||||
continue
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=target.read_bytes(), error=None
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
responses.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=None, error=_INVALID_PATH
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
return await asyncio.to_thread(self.download_files, paths)
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
"""Memory injection middleware for the SurfSense agent.
|
||||
|
||||
Injects memory markdown into the system prompt on every turn:
|
||||
- Private threads: only personal memory (<user_memory>)
|
||||
- Shared threads: only team memory (<team_memory>)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import ChatVisibility, SearchSpace, User, shielded_async_session
|
||||
from app.services.memory import MEMORY_HARD_LIMIT, MEMORY_SOFT_LIMIT
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class MemoryInjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Injects memory markdown into the conversation on every turn."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_id: str | UUID | None,
|
||||
search_space_id: int,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
) -> None:
|
||||
self.user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
self.search_space_id = search_space_id
|
||||
self.visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
start = time.perf_counter()
|
||||
db_elapsed = 0.0
|
||||
memory_blocks: list[str] = []
|
||||
scope = "team" if self.visibility == ChatVisibility.SEARCH_SPACE else "user"
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
db_start = time.perf_counter()
|
||||
if self.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
team_memory = await self._load_team_memory(session)
|
||||
if team_memory:
|
||||
chars = len(team_memory)
|
||||
memory_blocks.append(
|
||||
f'<team_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{team_memory}\n"
|
||||
f"</team_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Team memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
elif self.user_id is not None:
|
||||
user_memory, display_name = await self._load_user_memory(session)
|
||||
if display_name and display_name.strip():
|
||||
first_name = display_name.strip().split()[0]
|
||||
memory_blocks.append(f"<user_name>{first_name}</user_name>")
|
||||
if user_memory:
|
||||
chars = len(user_memory)
|
||||
memory_blocks.append(
|
||||
f'<user_memory chars="{chars}" limit="{MEMORY_HARD_LIMIT}">\n'
|
||||
f"{user_memory}\n"
|
||||
f"</user_memory>"
|
||||
)
|
||||
if chars > MEMORY_SOFT_LIMIT:
|
||||
memory_blocks.append(
|
||||
f"<memory_warning>Your personal memory is at "
|
||||
f"{chars:,}/{MEMORY_HARD_LIMIT:,} characters and approaching "
|
||||
f"the hard limit. On your next update_memory call, consolidate "
|
||||
f"by merging duplicates, removing outdated entries, and "
|
||||
f"shortening descriptions before adding anything new."
|
||||
f"</memory_warning>"
|
||||
)
|
||||
|
||||
db_elapsed = time.perf_counter() - db_start
|
||||
|
||||
if not memory_blocks:
|
||||
_perf_log.info(
|
||||
"[memory_injection] scope=%s injected=0 db=%.3fs total=%.3fs",
|
||||
scope,
|
||||
db_elapsed,
|
||||
time.perf_counter() - start,
|
||||
)
|
||||
return None
|
||||
|
||||
memory_text = "\n\n".join(memory_blocks)
|
||||
memory_msg = SystemMessage(content=memory_text)
|
||||
|
||||
new_messages = list(messages)
|
||||
insert_idx = 1 if len(new_messages) > 1 else 0
|
||||
new_messages.insert(insert_idx, memory_msg)
|
||||
|
||||
_perf_log.info(
|
||||
"[memory_injection] scope=%s injected=1 chars=%d db=%.3fs total=%.3fs",
|
||||
scope,
|
||||
len(memory_text),
|
||||
db_elapsed,
|
||||
time.perf_counter() - start,
|
||||
)
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def _load_user_memory(
|
||||
self, session: AsyncSession
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Return (memory_content, display_name)."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(User.memory_md, User.display_name).where(User.id == self.user_id)
|
||||
)
|
||||
row = result.one_or_none()
|
||||
if row is None:
|
||||
return None, None
|
||||
return row.memory_md or None, row.display_name
|
||||
except Exception:
|
||||
logger.exception("Failed to load user memory")
|
||||
return None, None
|
||||
|
||||
async def _load_team_memory(self, session: AsyncSession) -> str | None:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace.shared_memory_md).where(
|
||||
SearchSpace.id == self.search_space_id
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
logger.exception("Failed to load team memory")
|
||||
return None
|
||||
|
|
@ -0,0 +1,489 @@
|
|||
"""Aggregate multiple LocalFolderBackend roots behind mount-prefixed virtual paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends.protocol import (
|
||||
EditResult,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
FileUploadResponse,
|
||||
GrepMatch,
|
||||
WriteResult,
|
||||
)
|
||||
|
||||
from app.agents.shared.middleware.local_folder_backend import LocalFolderBackend
|
||||
|
||||
_INVALID_PATH = "invalid_path"
|
||||
_FILE_NOT_FOUND = "file_not_found"
|
||||
_IS_DIRECTORY = "is_directory"
|
||||
|
||||
|
||||
class MultiRootLocalFolderBackend:
|
||||
"""Route filesystem operations to one of several mounted local roots.
|
||||
|
||||
Virtual paths are namespaced as:
|
||||
- `/<mount>/...`
|
||||
where `<mount>` is derived from each selected root folder name.
|
||||
"""
|
||||
|
||||
def __init__(self, mounts: tuple[tuple[str, str], ...]) -> None:
|
||||
if not mounts:
|
||||
msg = "At least one local mount is required"
|
||||
raise ValueError(msg)
|
||||
self._mount_to_backend: dict[str, LocalFolderBackend] = {}
|
||||
for raw_mount, raw_root in mounts:
|
||||
mount = raw_mount.strip()
|
||||
if not mount:
|
||||
msg = "Mount id cannot be empty"
|
||||
raise ValueError(msg)
|
||||
if mount in self._mount_to_backend:
|
||||
msg = f"Duplicate mount id: {mount}"
|
||||
raise ValueError(msg)
|
||||
normalized_root = str(Path(raw_root).expanduser().resolve())
|
||||
self._mount_to_backend[mount] = LocalFolderBackend(normalized_root)
|
||||
self._mount_order = tuple(self._mount_to_backend.keys())
|
||||
|
||||
def list_mounts(self) -> tuple[str, ...]:
|
||||
return self._mount_order
|
||||
|
||||
def default_mount(self) -> str:
|
||||
return self._mount_order[0]
|
||||
|
||||
def _mount_error(self) -> str:
|
||||
mounts = ", ".join(f"/{mount}" for mount in self._mount_order)
|
||||
return (
|
||||
"Path must start with one of the selected folders: "
|
||||
f"{mounts}. Example: /{self._mount_order[0]}/file.txt"
|
||||
)
|
||||
|
||||
def _split_mount_path(self, virtual_path: str) -> tuple[str, str]:
|
||||
if not virtual_path.startswith("/"):
|
||||
msg = f"Invalid path (must be absolute): {virtual_path}"
|
||||
raise ValueError(msg)
|
||||
rel = virtual_path.lstrip("/")
|
||||
if not rel:
|
||||
raise ValueError(self._mount_error())
|
||||
mount, _, remainder = rel.partition("/")
|
||||
backend = self._mount_to_backend.get(mount)
|
||||
if backend is None:
|
||||
raise ValueError(self._mount_error())
|
||||
local_path = f"/{remainder}" if remainder else "/"
|
||||
return mount, local_path
|
||||
|
||||
@staticmethod
|
||||
def _prefix_mount_path(mount: str, local_path: str) -> str:
|
||||
if local_path == "/":
|
||||
return f"/{mount}"
|
||||
return f"/{mount}{local_path}"
|
||||
|
||||
@staticmethod
|
||||
def _get_value(item: Any, key: str) -> Any:
|
||||
if isinstance(item, dict):
|
||||
return item.get(key)
|
||||
return getattr(item, key, None)
|
||||
|
||||
@classmethod
|
||||
def _get_str(cls, item: Any, key: str) -> str:
|
||||
value = cls._get_value(item, key)
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
@classmethod
|
||||
def _get_int(cls, item: Any, key: str) -> int:
|
||||
value = cls._get_value(item, key)
|
||||
return int(value) if isinstance(value, int | float) else 0
|
||||
|
||||
@classmethod
|
||||
def _get_bool(cls, item: Any, key: str) -> bool:
|
||||
value = cls._get_value(item, key)
|
||||
return bool(value)
|
||||
|
||||
def _list_mount_roots(self) -> list[FileInfo]:
|
||||
return [
|
||||
FileInfo(path=f"/{mount}", is_dir=True, size=0, modified_at="0")
|
||||
for mount in self._mount_order
|
||||
]
|
||||
|
||||
def _transform_infos(self, mount: str, infos: list[FileInfo]) -> list[FileInfo]:
|
||||
transformed: list[FileInfo] = []
|
||||
for info in infos:
|
||||
transformed.append(
|
||||
FileInfo(
|
||||
path=self._prefix_mount_path(mount, self._get_str(info, "path")),
|
||||
is_dir=self._get_bool(info, "is_dir"),
|
||||
size=self._get_int(info, "size"),
|
||||
modified_at=self._get_str(info, "modified_at"),
|
||||
)
|
||||
)
|
||||
return transformed
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
if path == "/":
|
||||
return self._list_mount_roots()
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError:
|
||||
return []
|
||||
return self._transform_infos(
|
||||
mount, self._mount_to_backend[mount].ls_info(local_path)
|
||||
)
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.ls_info, path)
|
||||
|
||||
def list_tree(
|
||||
self,
|
||||
path: str = "/",
|
||||
*,
|
||||
max_depth: int | None = 8,
|
||||
page_size: int = 500,
|
||||
include_files: bool = True,
|
||||
include_dirs: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
if path == "/":
|
||||
entries = [
|
||||
{
|
||||
"path": f"/{mount}",
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"modified_at": "0",
|
||||
"depth": 0,
|
||||
}
|
||||
for mount in self._mount_order
|
||||
]
|
||||
return {
|
||||
"entries": entries if include_dirs else [],
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError as exc:
|
||||
return {"error": f"Error: {exc}"}
|
||||
|
||||
result = self._mount_to_backend[mount].list_tree(
|
||||
local_path,
|
||||
max_depth=max_depth,
|
||||
page_size=page_size,
|
||||
include_files=include_files,
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
if result.get("error"):
|
||||
return result
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
for entry in result.get("entries", []):
|
||||
raw_path = self._get_str(entry, "path")
|
||||
entries.append(
|
||||
{
|
||||
"path": self._prefix_mount_path(mount, raw_path),
|
||||
"is_dir": self._get_bool(entry, "is_dir"),
|
||||
"size": self._get_int(entry, "size"),
|
||||
"modified_at": self._get_str(entry, "modified_at"),
|
||||
"depth": self._get_int(entry, "depth"),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"entries": entries,
|
||||
"truncated": self._get_bool(result, "truncated"),
|
||||
}
|
||||
|
||||
async def alist_tree(
|
||||
self,
|
||||
path: str = "/",
|
||||
*,
|
||||
max_depth: int | None = 8,
|
||||
page_size: int = 500,
|
||||
include_files: bool = True,
|
||||
include_dirs: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
return await asyncio.to_thread(
|
||||
self.list_tree,
|
||||
path,
|
||||
max_depth=max_depth,
|
||||
page_size=page_size,
|
||||
include_files=include_files,
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
|
||||
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
return self._mount_to_backend[mount].read(local_path, offset, limit)
|
||||
|
||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
|
||||
return await asyncio.to_thread(self.read, file_path, offset, limit)
|
||||
|
||||
def read_raw(self, file_path: str) -> str:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
return self._mount_to_backend[mount].read_raw(local_path)
|
||||
|
||||
async def aread_raw(self, file_path: str) -> str:
|
||||
return await asyncio.to_thread(self.read_raw, file_path)
|
||||
|
||||
def write(self, file_path: str, content: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].write(local_path, content)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def awrite(self, file_path: str, content: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.write, file_path, content)
|
||||
|
||||
def move(
|
||||
self,
|
||||
source_path: str,
|
||||
destination_path: str,
|
||||
overwrite: bool = False,
|
||||
) -> WriteResult:
|
||||
try:
|
||||
source_mount, source_local_path = self._split_mount_path(source_path)
|
||||
destination_mount, destination_local_path = self._split_mount_path(
|
||||
destination_path
|
||||
)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
if source_mount != destination_mount:
|
||||
return WriteResult(
|
||||
error=(
|
||||
"Error: cross-mount moves are not supported. "
|
||||
"Source and destination must be under the same mounted root."
|
||||
)
|
||||
)
|
||||
result = self._mount_to_backend[source_mount].move(
|
||||
source_local_path,
|
||||
destination_local_path,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(source_mount, result.path)
|
||||
return result
|
||||
|
||||
async def amove(
|
||||
self,
|
||||
source_path: str,
|
||||
destination_path: str,
|
||||
overwrite: bool = False,
|
||||
) -> WriteResult:
|
||||
return await asyncio.to_thread(
|
||||
self.move,
|
||||
source_path,
|
||||
destination_path,
|
||||
overwrite,
|
||||
)
|
||||
|
||||
def delete_file(self, file_path: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].delete_file(local_path)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.delete_file, file_path)
|
||||
|
||||
def rmdir(self, dir_path: str) -> WriteResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(dir_path)
|
||||
except ValueError as exc:
|
||||
return WriteResult(error=f"Error: {exc}")
|
||||
if local_path == "/":
|
||||
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
|
||||
result = self._mount_to_backend[mount].rmdir(local_path)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def armdir(self, dir_path: str) -> WriteResult:
|
||||
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(file_path)
|
||||
except ValueError as exc:
|
||||
return EditResult(error=f"Error: {exc}")
|
||||
result = self._mount_to_backend[mount].edit(
|
||||
local_path, old_string, new_string, replace_all
|
||||
)
|
||||
if result.path:
|
||||
result.path = self._prefix_mount_path(mount, result.path)
|
||||
return result
|
||||
|
||||
async def aedit(
|
||||
self,
|
||||
file_path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
) -> EditResult:
|
||||
return await asyncio.to_thread(
|
||||
self.edit, file_path, old_string, new_string, replace_all
|
||||
)
|
||||
|
||||
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
if path == "/":
|
||||
prefixed_results: list[FileInfo] = []
|
||||
if pattern.startswith("/"):
|
||||
mount, _, remainder = pattern.lstrip("/").partition("/")
|
||||
backend = self._mount_to_backend.get(mount)
|
||||
if not backend:
|
||||
return []
|
||||
local_pattern = f"/{remainder}" if remainder else "/"
|
||||
return self._transform_infos(
|
||||
mount, backend.glob_info(local_pattern, path="/")
|
||||
)
|
||||
for mount, backend in self._mount_to_backend.items():
|
||||
prefixed_results.extend(
|
||||
self._transform_infos(mount, backend.glob_info(pattern, path="/"))
|
||||
)
|
||||
return prefixed_results
|
||||
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError:
|
||||
return []
|
||||
return self._transform_infos(
|
||||
mount, self._mount_to_backend[mount].glob_info(pattern, path=local_path)
|
||||
)
|
||||
|
||||
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]:
|
||||
return await asyncio.to_thread(self.glob_info, pattern, path)
|
||||
|
||||
def grep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
if not pattern:
|
||||
return "Error: pattern cannot be empty"
|
||||
if path is None or path == "/":
|
||||
all_matches: list[GrepMatch] = []
|
||||
for mount, backend in self._mount_to_backend.items():
|
||||
result = backend.grep_raw(pattern, path="/", glob=glob)
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
all_matches.extend(
|
||||
[
|
||||
GrepMatch(
|
||||
path=self._prefix_mount_path(
|
||||
mount, self._get_str(match, "path")
|
||||
),
|
||||
line=self._get_int(match, "line"),
|
||||
text=self._get_str(match, "text"),
|
||||
)
|
||||
for match in result
|
||||
]
|
||||
)
|
||||
return all_matches
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(path)
|
||||
except ValueError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
result = self._mount_to_backend[mount].grep_raw(
|
||||
pattern, path=local_path, glob=glob
|
||||
)
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
return [
|
||||
GrepMatch(
|
||||
path=self._prefix_mount_path(mount, self._get_str(match, "path")),
|
||||
line=self._get_int(match, "line"),
|
||||
text=self._get_str(match, "text"),
|
||||
)
|
||||
for match in result
|
||||
]
|
||||
|
||||
async def agrep_raw(
|
||||
self, pattern: str, path: str | None = None, glob: str | None = None
|
||||
) -> list[GrepMatch] | str:
|
||||
return await asyncio.to_thread(self.grep_raw, pattern, path, glob)
|
||||
|
||||
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
grouped: dict[str, list[tuple[str, bytes]]] = {}
|
||||
invalid: list[FileUploadResponse] = []
|
||||
for virtual_path, content in files:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(virtual_path)
|
||||
except ValueError:
|
||||
invalid.append(
|
||||
FileUploadResponse(path=virtual_path, error=_INVALID_PATH)
|
||||
)
|
||||
continue
|
||||
grouped.setdefault(mount, []).append((local_path, content))
|
||||
|
||||
responses = list(invalid)
|
||||
for mount, mount_files in grouped.items():
|
||||
result = self._mount_to_backend[mount].upload_files(mount_files)
|
||||
responses.extend(
|
||||
[
|
||||
FileUploadResponse(
|
||||
path=self._prefix_mount_path(
|
||||
mount, self._get_str(item, "path")
|
||||
),
|
||||
error=self._get_str(item, "error") or None,
|
||||
)
|
||||
for item in result
|
||||
]
|
||||
)
|
||||
return responses
|
||||
|
||||
async def aupload_files(
|
||||
self, files: list[tuple[str, bytes]]
|
||||
) -> list[FileUploadResponse]:
|
||||
return await asyncio.to_thread(self.upload_files, files)
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
grouped: dict[str, list[str]] = {}
|
||||
invalid: list[FileDownloadResponse] = []
|
||||
for virtual_path in paths:
|
||||
try:
|
||||
mount, local_path = self._split_mount_path(virtual_path)
|
||||
except ValueError:
|
||||
invalid.append(
|
||||
FileDownloadResponse(
|
||||
path=virtual_path, content=None, error=_INVALID_PATH
|
||||
)
|
||||
)
|
||||
continue
|
||||
grouped.setdefault(mount, []).append(local_path)
|
||||
|
||||
responses = list(invalid)
|
||||
for mount, mount_paths in grouped.items():
|
||||
result = self._mount_to_backend[mount].download_files(mount_paths)
|
||||
responses.extend(
|
||||
[
|
||||
FileDownloadResponse(
|
||||
path=self._prefix_mount_path(
|
||||
mount, self._get_str(item, "path")
|
||||
),
|
||||
content=self._get_value(item, "content"),
|
||||
error=self._get_str(item, "error") or None,
|
||||
)
|
||||
for item in result
|
||||
]
|
||||
)
|
||||
return responses
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
return await asyncio.to_thread(self.download_files, paths)
|
||||
141
surfsense_backend/app/agents/shared/middleware/noop_injection.py
Normal file
141
surfsense_backend/app/agents/shared/middleware/noop_injection.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
``_noop`` provider-compatibility tool + injection middleware.
|
||||
|
||||
Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has
|
||||
empty ``tools`` but the message history includes prior ``tool_calls`` —
|
||||
they treat that shape as malformed even though it's perfectly valid
|
||||
LangChain. SurfSense hits this on the compaction summarize call (no
|
||||
tools, history full of tool calls).
|
||||
|
||||
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``,
|
||||
which discovered and codified the workaround: inject a no-op tool *only*
|
||||
on those provider shapes so the request validates without ever being
|
||||
called.
|
||||
|
||||
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
|
||||
if the request has zero tools but the last AI message in history includes
|
||||
``tool_calls``. If yes, it injects the ``_noop`` tool only — never
|
||||
globally — mirroring OpenCode's gating exactly. The :func:`noop_tool`
|
||||
returns empty content when called (which it should never be in
|
||||
practice).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOOP_TOOL_NAME = "_noop"
|
||||
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
|
||||
|
||||
|
||||
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
||||
def noop_tool() -> str:
|
||||
"""Return empty content. Never expected to be called."""
|
||||
return ""
|
||||
|
||||
|
||||
# Provider markers that benefit from ``_noop`` injection. These match
|
||||
# OpenCode's gating list (``llm.ts:209-228``). We also accept any string
|
||||
# containing one of these substrings so e.g. ``litellm`` matches
|
||||
# ``ChatLiteLLM``.
|
||||
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
|
||||
"litellm",
|
||||
"bedrock",
|
||||
"copilot",
|
||||
)
|
||||
|
||||
|
||||
def _provider_needs_noop(model: Any) -> bool:
|
||||
"""Heuristic: does this model's provider need the _noop injection?"""
|
||||
try:
|
||||
ls_params = model._get_ls_params()
|
||||
provider = str(ls_params.get("ls_provider", "")).lower()
|
||||
except Exception:
|
||||
provider = ""
|
||||
|
||||
if not provider:
|
||||
cls_name = type(model).__name__.lower()
|
||||
provider = cls_name
|
||||
|
||||
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
|
||||
|
||||
|
||||
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return bool(msg.tool_calls)
|
||||
return False
|
||||
|
||||
|
||||
class NoopInjectionMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
||||
|
||||
The check fires per model call, not at agent build time, because the
|
||||
summarization path generates a no-tool subcall at runtime. The
|
||||
extra tool is appended to ``request.tools`` as an instance — the
|
||||
actual ``langchain_core.tools.BaseTool`` is bound on every call site
|
||||
that creates the agent.
|
||||
"""
|
||||
|
||||
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
|
||||
super().__init__()
|
||||
self._noop_tool = noop_tool_instance or noop_tool
|
||||
self.tools = []
|
||||
|
||||
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
|
||||
if request.tools:
|
||||
return False
|
||||
if not _last_ai_has_tool_calls(request.messages):
|
||||
return False
|
||||
return _provider_needs_noop(request.model)
|
||||
|
||||
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
return request.override(tools=[self._noop_tool])
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
return handler(self._augmented(request))
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
return await handler(self._augmented(request))
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NOOP_TOOL_DESCRIPTION",
|
||||
"NOOP_TOOL_NAME",
|
||||
"NoopInjectionMiddleware",
|
||||
"_provider_needs_noop",
|
||||
"noop_tool",
|
||||
]
|
||||
285
surfsense_backend/app/agents/shared/middleware/otel_span.py
Normal file
285
surfsense_backend/app/agents/shared/middleware/otel_span.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""
|
||||
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
|
||||
|
||||
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
|
||||
executions) with OTel spans, attaching low-cardinality span names and
|
||||
high-cardinality identifiers as attributes.
|
||||
|
||||
This middleware is intentionally a thin adapter over
|
||||
:mod:`app.observability.otel`; when OTel is not configured all spans
|
||||
collapse to no-ops and the wrapper adds <1µs overhead per call. When
|
||||
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
|
||||
model and tool call gets a span with the standard attributes our
|
||||
dashboards expect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover — type-only
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
)
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OtelSpanMiddleware(AgentMiddleware):
|
||||
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
|
||||
|
||||
Should be placed near the **outer** end of the middleware list so
|
||||
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
|
||||
model.call spans for ``N`` retry attempts) but inside any concurrency/
|
||||
auth gate. Empirically this means **between** ``BusyMutex`` and
|
||||
``RetryAfter``.
|
||||
"""
|
||||
|
||||
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
|
||||
super().__init__()
|
||||
self._instrumentation_name = instrumentation_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model call spans
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
|
||||
) -> ModelResponse | AIMessage | Any:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
||||
model_id, provider = _resolve_model_attrs(request)
|
||||
t0 = time.perf_counter()
|
||||
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
|
||||
_annotate_model_request(sp, model_id=model_id, provider=provider)
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
ot_metrics.record_model_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
# span context manager records + re-raises
|
||||
raise
|
||||
else:
|
||||
input_tokens, output_tokens = _annotate_model_response(
|
||||
sp,
|
||||
result,
|
||||
model_id=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
ot_metrics.record_model_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
ot_metrics.record_model_token_usage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=model_id,
|
||||
provider=provider,
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool call spans
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
||||
tool_name = _resolve_tool_name(request)
|
||||
input_size = _resolve_input_size(request)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
ot_metrics.record_tool_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
ot_metrics.record_tool_call_error(tool_name=tool_name)
|
||||
raise
|
||||
errored = _annotate_tool_result(sp, result)
|
||||
ot_metrics.record_tool_call_duration(
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
if errored:
|
||||
ot_metrics.record_tool_call_error(tool_name=tool_name)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
|
||||
# a real model/tool call).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
|
||||
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
|
||||
model_id: str | None = None
|
||||
provider: str | None = None
|
||||
try:
|
||||
model = getattr(request, "model", None)
|
||||
if model is None:
|
||||
return None, None
|
||||
# langchain BaseChatModel exposes a few different identifiers
|
||||
for attr in ("model_name", "model", "model_id"):
|
||||
value = getattr(model, attr, None)
|
||||
if value:
|
||||
model_id = str(value)
|
||||
break
|
||||
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
|
||||
for attr in ("provider", "_llm_type"):
|
||||
value = getattr(model, attr, None)
|
||||
if value:
|
||||
provider = str(value)
|
||||
break
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return model_id, provider
|
||||
|
||||
|
||||
def _resolve_tool_name(request: Any) -> str:
|
||||
try:
|
||||
tool = getattr(request, "tool", None)
|
||||
if tool is not None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
# Fall back to the tool_call dict
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
name = call.get("name") if isinstance(call, dict) else None
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _resolve_input_size(request: Any) -> int | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None)
|
||||
if not isinstance(call, dict) or not call:
|
||||
return None
|
||||
args = call.get("args")
|
||||
if args is None:
|
||||
return None
|
||||
return len(repr(args))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
return None
|
||||
|
||||
|
||||
def _annotate_model_request(
|
||||
span: Any, *, model_id: str | None, provider: str | None
|
||||
) -> None:
|
||||
try:
|
||||
span.set_attribute("gen_ai.operation.name", "chat")
|
||||
if model_id:
|
||||
span.set_attribute("gen_ai.request.model", model_id)
|
||||
if provider:
|
||||
span.set_attribute("gen_ai.provider.name", provider)
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
|
||||
|
||||
def _annotate_model_response(
|
||||
span: Any,
|
||||
result: Any,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""Best-effort: attach prompt/completion token counts when available."""
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
try:
|
||||
# ModelResponse may be a dataclass with .result containing AIMessage
|
||||
msg: Any
|
||||
if isinstance(result, AIMessage):
|
||||
msg = result
|
||||
else:
|
||||
inner = getattr(result, "result", None)
|
||||
msg = inner[-1] if isinstance(inner, list) and inner else inner
|
||||
if msg is None:
|
||||
return None, None
|
||||
if provider:
|
||||
span.set_attribute("gen_ai.provider.name", provider)
|
||||
if model_id:
|
||||
span.set_attribute("gen_ai.request.model", model_id)
|
||||
response_model = getattr(msg, "response_metadata", {}) or {}
|
||||
if isinstance(response_model, dict):
|
||||
response_model = (
|
||||
response_model.get("model_name")
|
||||
or response_model.get("model")
|
||||
or response_model.get("model_id")
|
||||
)
|
||||
if not response_model:
|
||||
response_model = model_id
|
||||
if response_model:
|
||||
span.set_attribute("gen_ai.response.model", str(response_model))
|
||||
span.set_attribute("gen_ai.operation.name", "chat")
|
||||
usage = getattr(msg, "usage_metadata", None) or {}
|
||||
if isinstance(usage, dict):
|
||||
if (n := usage.get("input_tokens")) is not None:
|
||||
input_tokens = int(n)
|
||||
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
|
||||
if (n := usage.get("output_tokens")) is not None:
|
||||
output_tokens = int(n)
|
||||
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
|
||||
if (n := usage.get("total_tokens")) is not None:
|
||||
span.set_attribute("gen_ai.usage.total_tokens", int(n))
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
span.set_attribute("model.tool_calls", len(tool_calls))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return input_tokens, output_tokens
|
||||
|
||||
|
||||
def _annotate_tool_result(span: Any, result: Any) -> bool:
|
||||
errored = False
|
||||
try:
|
||||
if isinstance(result, ToolMessage):
|
||||
content = (
|
||||
result.content
|
||||
if isinstance(result.content, str)
|
||||
else repr(result.content)
|
||||
)
|
||||
span.set_attribute("tool.output.size", len(content))
|
||||
status = getattr(result, "status", None)
|
||||
if isinstance(status, str):
|
||||
span.set_attribute("tool.status", status)
|
||||
errored = status.lower() == "error"
|
||||
kwargs = getattr(result, "additional_kwargs", None) or {}
|
||||
if isinstance(kwargs, dict) and kwargs.get("error"):
|
||||
span.set_attribute("tool.error", True)
|
||||
errored = True
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return errored
|
||||
|
||||
|
||||
__all__ = ["OtelSpanMiddleware"]
|
||||
427
surfsense_backend/app/agents/shared/middleware/permission.py
Normal file
427
surfsense_backend/app/agents/shared/middleware/permission.py
Normal file
|
|
@ -0,0 +1,427 @@
|
|||
"""
|
||||
PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback.
|
||||
|
||||
LangChain's :class:`HumanInTheLoopMiddleware` only supports a static
|
||||
"this tool always asks" decision per tool. There's no rule-based
|
||||
allow/deny/ask layered ruleset, no glob patterns, no per-search-space or
|
||||
per-thread overrides, and no auto-deny synthesis.
|
||||
|
||||
This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts``
|
||||
ruleset model on top of SurfSense's existing ``interrupt({type, action,
|
||||
context})`` payload shape (see ``app/agents/new_chat/tools/hitl.py``) so
|
||||
the frontend keeps working unchanged.
|
||||
|
||||
Operation:
|
||||
1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``.
|
||||
2. For each call, the middleware builds a list of ``patterns`` (the
|
||||
tool name plus any tool-specific patterns from the resolver). It
|
||||
evaluates each pattern against the layered rulesets and aggregates
|
||||
the results: ``deny`` > ``ask`` > ``allow``.
|
||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||
containing a :class:`StreamingError`.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
|
||||
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
|
||||
replies are accepted via :func:`_normalize_permission_decision`.
|
||||
- ``once``: proceed.
|
||||
- ``approve_always``: also persist allow rules for ``request.always`` patterns.
|
||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||
- ``reject`` w/ feedback: raise :class:`CorrectedError`.
|
||||
5. On ``allow``: proceed unchanged.
|
||||
|
||||
The middleware also performs a *pre-model* tool-filter step (the
|
||||
``before_model`` hook) so globally denied tools are stripped from the
|
||||
exposed tool list before the model gets to see them. This mirrors
|
||||
OpenCode's ``Permission.disabled`` and dramatically reduces the chance
|
||||
the model emits a deny-only call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.agents.shared.errors import (
|
||||
CorrectedError,
|
||||
RejectedError,
|
||||
StreamingError,
|
||||
)
|
||||
from app.agents.shared.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of
|
||||
# patterns to evaluate. The first pattern is conventionally the bare
|
||||
# tool name; later entries narrow down to specific resources.
|
||||
PatternResolver = Callable[[dict[str, Any]], list[str]]
|
||||
|
||||
|
||||
def _default_pattern_resolver(name: str) -> PatternResolver:
|
||||
def _resolve(args: dict[str, Any]) -> list[str]:
|
||||
# Bare name covers the default catch-all; primary-arg fallbacks
|
||||
# are best added per-tool by callers.
|
||||
del args
|
||||
return [name]
|
||||
|
||||
return _resolve
|
||||
|
||||
|
||||
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
|
||||
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
|
||||
# original tool args — tools needing argument edits should use
|
||||
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
|
||||
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
|
||||
"approve": "once",
|
||||
"reject": "reject",
|
||||
"edit": "once",
|
||||
"approve_always": "approve_always",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
|
||||
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
|
||||
|
||||
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
|
||||
middleware fails closed.
|
||||
"""
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
if not isinstance(decision, dict):
|
||||
logger.warning(
|
||||
"Unrecognized permission resume value (%s); treating as reject",
|
||||
type(decision).__name__,
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
if decision.get("decision_type"):
|
||||
return decision
|
||||
|
||||
payload: dict[str, Any] = decision
|
||||
decisions = decision.get("decisions")
|
||||
if isinstance(decisions, list) and decisions:
|
||||
first = decisions[0]
|
||||
if isinstance(first, dict):
|
||||
payload = first
|
||||
|
||||
raw_type = payload.get("type") or payload.get("decision_type")
|
||||
if not raw_type:
|
||||
logger.warning(
|
||||
"Permission resume missing decision type (keys=%s); treating as reject",
|
||||
list(payload.keys()),
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
raw_type = str(raw_type).lower()
|
||||
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
|
||||
if mapped is None:
|
||||
# Tolerate legacy values arriving without ``decision_type`` wrapping.
|
||||
if raw_type in {"once", "approve_always", "reject"}:
|
||||
mapped = raw_type
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision type %r; treating as reject", raw_type
|
||||
)
|
||||
mapped = "reject"
|
||||
|
||||
if raw_type == "edit":
|
||||
logger.warning(
|
||||
"Permission middleware received an 'edit' decision; original args "
|
||||
"kept (edits not merged here)."
|
||||
)
|
||||
|
||||
out: dict[str, Any] = {"decision_type": mapped}
|
||||
feedback = payload.get("feedback") or payload.get("message")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
out["feedback"] = feedback
|
||||
return out
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
Args:
|
||||
rulesets: Layered rulesets to evaluate. Earlier entries are
|
||||
overridden by later ones (last-match-wins). Typical layering:
|
||||
``defaults < global < space < thread < runtime_approved``.
|
||||
pattern_resolvers: Optional per-tool callables that return a list
|
||||
of patterns to evaluate. When a tool isn't listed, the bare
|
||||
tool name is used as the only pattern.
|
||||
runtime_ruleset: Mutable :class:`Ruleset` that the middleware
|
||||
extends in-place when the user replies ``"approve_always"`` to
|
||||
an ask interrupt. Reused across all calls in the same agent
|
||||
instance so newly-allowed rules apply to subsequent calls.
|
||||
always_emit_interrupt_payload: If True, every ask uses the
|
||||
SurfSense interrupt wire format (default). Set False to
|
||||
disable interrupts and treat ``ask`` as ``deny`` for
|
||||
non-interactive deployments.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rulesets: list[Ruleset] | None = None,
|
||||
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
self._pattern_resolvers: dict[str, PatternResolver] = dict(
|
||||
pattern_resolvers or {}
|
||||
)
|
||||
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
|
||||
origin="runtime_approved"
|
||||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-filter step (mirrors OpenCode's ``Permission.disabled``)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _globally_denied(self, tool_name: str) -> bool:
|
||||
"""Return True if a deny rule with no narrowing pattern matches."""
|
||||
rules = evaluate_many(tool_name, ["*"], *self._all_rulesets())
|
||||
return aggregate_action(rules) == "deny"
|
||||
|
||||
def _all_rulesets(self) -> list[Ruleset]:
|
||||
return [*self._static_rulesets, self._runtime_ruleset]
|
||||
|
||||
# NOTE: ``before_model`` filtering of the tools list is left to the
|
||||
# agent factory. This middleware only blocks at execution time — and
|
||||
# only via the rule-evaluator path, not by mutating ``request.tools``.
|
||||
# Mutating ``request.tools`` per-call would invalidate provider
|
||||
# prompt-cache prefixes (see Operational risks: prompt-cache regression).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call evaluation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]:
|
||||
resolver = self._pattern_resolvers.get(
|
||||
tool_name, _default_pattern_resolver(tool_name)
|
||||
)
|
||||
try:
|
||||
patterns = resolver(args or {})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Pattern resolver for %s raised; using bare name", tool_name
|
||||
)
|
||||
patterns = [tool_name]
|
||||
if not patterns:
|
||||
patterns = [tool_name]
|
||||
return patterns
|
||||
|
||||
def _evaluate(
|
||||
self, tool_name: str, args: dict[str, Any]
|
||||
) -> tuple[str, list[str], list[Rule]]:
|
||||
patterns = self._resolve_patterns(tool_name, args)
|
||||
rules = evaluate_many(tool_name, patterns, *self._all_rulesets())
|
||||
action = aggregate_action(rules)
|
||||
return action, patterns, rules
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HITL ask flow — SurfSense wire format
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _raise_interrupt(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
patterns: list[str],
|
||||
rules: list[Rule],
|
||||
) -> dict[str, Any]:
|
||||
"""Block on user approval via SurfSense's ``interrupt`` shape."""
|
||||
if not self._emit_interrupt:
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
# ``params`` (NOT ``args``) is what SurfSense's streaming
|
||||
# normalizer forwards. Other fields move into ``context``.
|
||||
payload = {
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": tool_name, "params": args or {}},
|
||||
"context": {
|
||||
"patterns": patterns,
|
||||
"rules": [
|
||||
{
|
||||
"permission": r.permission,
|
||||
"pattern": r.pattern,
|
||||
"action": r.action,
|
||||
}
|
||||
for r in rules
|
||||
],
|
||||
# Rules of thumb for the frontend: surface the patterns
|
||||
# the user can promote to "approve_always" with a single reply.
|
||||
"always": patterns,
|
||||
},
|
||||
}
|
||||
# Open ``permission.asked`` + ``interrupt.raised`` OTel spans
|
||||
# (no-op when OTel is disabled) so dashboards can correlate
|
||||
# "we asked X" with "interrupt was actually delivered".
|
||||
with (
|
||||
ot.permission_asked_span(
|
||||
permission=tool_name,
|
||||
pattern=patterns[0] if patterns else None,
|
||||
extra={"permission.patterns": list(patterns)},
|
||||
),
|
||||
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||
):
|
||||
ot_metrics.record_permission_ask(permission=tool_name)
|
||||
ot_metrics.record_interrupt(interrupt_type="permission_ask")
|
||||
decision = interrupt(payload)
|
||||
return _normalize_permission_decision(decision)
|
||||
|
||||
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||
"""Promote ``approve_always`` reply into runtime allow rules.
|
||||
|
||||
Persistence to ``agent_permission_rules`` is done by the
|
||||
streaming layer (``stream_new_chat``) once it observes the
|
||||
``approve_always`` reply — the middleware just keeps an
|
||||
in-memory copy so subsequent calls in the same stream see the rule.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
self._runtime_ruleset.rules.append(
|
||||
Rule(permission=tool_name, pattern=pattern, action="allow")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synthesizing deny -> ToolMessage
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _deny_message(
|
||||
tool_call: dict[str, Any],
|
||||
rule: Rule,
|
||||
) -> ToolMessage:
|
||||
err = StreamingError(
|
||||
code="permission_denied",
|
||||
retryable=False,
|
||||
suggestion=(
|
||||
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
|
||||
f"blocked this call"
|
||||
),
|
||||
)
|
||||
return ToolMessage(
|
||||
content=(
|
||||
f"Permission denied: rule {rule.permission}/{rule.pattern} "
|
||||
f"blocked tool {tool_call.get('name')!r}."
|
||||
),
|
||||
tool_call_id=tool_call.get("id") or "",
|
||||
name=tool_call.get("name"),
|
||||
status="error",
|
||||
additional_kwargs={"error": err.model_dump()},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The hook: aafter_model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _process(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime # unused
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||
return None
|
||||
|
||||
deny_messages: list[ToolMessage] = []
|
||||
kept_calls: list[dict[str, Any]] = []
|
||||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
call = (
|
||||
dict(raw)
|
||||
if isinstance(raw, dict)
|
||||
else {
|
||||
"name": getattr(raw, "name", None),
|
||||
"args": getattr(raw, "args", {}),
|
||||
"id": getattr(raw, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
name = call.get("name") or ""
|
||||
args = call.get("args") or {}
|
||||
action, patterns, rules = self._evaluate(name, args)
|
||||
|
||||
if action == "deny":
|
||||
# Find the deny rule for the suggestion text
|
||||
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
|
||||
deny_messages.append(self._deny_message(call, deny_rule))
|
||||
any_change = True
|
||||
continue
|
||||
|
||||
if action == "ask":
|
||||
decision = self._raise_interrupt(
|
||||
tool_name=name, args=args, patterns=patterns, rules=rules
|
||||
)
|
||||
kind = str(decision.get("decision_type") or "reject").lower()
|
||||
if kind == "once":
|
||||
kept_calls.append(call)
|
||||
elif kind == "approve_always":
|
||||
self._persist_always(name, patterns)
|
||||
kept_calls.append(call)
|
||||
elif kind == "reject":
|
||||
feedback = decision.get("feedback")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
raise CorrectedError(feedback, tool=name)
|
||||
raise RejectedError(
|
||||
tool=name, pattern=patterns[0] if patterns else None
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision %r; treating as reject", kind
|
||||
)
|
||||
raise RejectedError(tool=name)
|
||||
continue
|
||||
|
||||
# allow
|
||||
kept_calls.append(call)
|
||||
|
||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||
return None
|
||||
|
||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||
result_messages: list[Any] = [updated]
|
||||
if deny_messages:
|
||||
result_messages.extend(deny_messages)
|
||||
return {"messages": result_messages}
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PatternResolver",
|
||||
"PermissionMiddleware",
|
||||
"_normalize_permission_decision",
|
||||
]
|
||||
277
surfsense_backend/app/agents/shared/middleware/retry_after.py
Normal file
277
surfsense_backend/app/agents/shared/middleware/retry_after.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
"""
|
||||
RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing.
|
||||
|
||||
LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores
|
||||
the ``Retry-After`` HTTP header — it just runs its own exponential backoff.
|
||||
That wastes time when a provider has explicitly told us how long to wait.
|
||||
This middleware honors the header (mirroring OpenCode's
|
||||
``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE
|
||||
event so the UI can show "rate-limited, retrying in Ns".
|
||||
|
||||
We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a
|
||||
module-level ``calculate_delay`` inline (no overridable
|
||||
``_calculate_delay`` hook), so this is a standalone implementation.
|
||||
|
||||
Behaviour:
|
||||
- Extracts ``Retry-After`` / ``retry-after-ms`` from
|
||||
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
|
||||
exposing a similar shape).
|
||||
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
|
||||
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
|
||||
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
|
||||
the LangChain summarization fallback path) handles those instead.
|
||||
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
|
||||
so ``stream_new_chat`` can forward it to clients as an SSE event.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Names of exception classes for which a retry would not help — context
|
||||
# overflow needs compaction, auth needs human intervention, etc. Detected
|
||||
# by class-name substring so we don't have to import LiteLLM/Anthropic
|
||||
# here (which would tie this module to optional deps).
|
||||
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
|
||||
"ContextWindowExceeded",
|
||||
"ContextOverflow",
|
||||
"AuthenticationError",
|
||||
"InvalidRequestError",
|
||||
"PermissionDenied",
|
||||
"InvalidApiKey",
|
||||
"ContextLimit",
|
||||
)
|
||||
|
||||
|
||||
def _is_non_retryable(exc: BaseException) -> bool:
|
||||
name = type(exc).__name__
|
||||
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
|
||||
|
||||
|
||||
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
|
||||
"""Return seconds-to-wait suggested by the provider, if any.
|
||||
|
||||
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
|
||||
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
|
||||
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
|
||||
to a regex on the exception message for shapes like
|
||||
``"Please retry after 30s"``.
|
||||
"""
|
||||
headers: dict[str, Any] | None = None
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
headers = getattr(exc, "headers", None)
|
||||
|
||||
if isinstance(headers, dict):
|
||||
# Normalize keys to lowercase for case-insensitive matching
|
||||
norm = {str(k).lower(): v for k, v in headers.items()}
|
||||
ms = norm.get("retry-after-ms")
|
||||
if ms is not None:
|
||||
try:
|
||||
return float(ms) / 1000.0
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
seconds = norm.get("retry-after")
|
||||
if seconds is not None:
|
||||
try:
|
||||
return float(seconds)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Last resort: scan the message for "retry after Xs" or "X seconds"
|
||||
msg = str(exc)
|
||||
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(1))
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _exponential_delay(
|
||||
attempt: int,
|
||||
*,
|
||||
initial_delay: float,
|
||||
backoff_factor: float,
|
||||
max_delay: float,
|
||||
jitter: bool,
|
||||
) -> float:
|
||||
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
||||
delay = (
|
||||
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||
)
|
||||
delay = min(delay, max_delay)
|
||||
if jitter and delay > 0:
|
||||
delay *= 1 + random.uniform(-0.25, 0.25)
|
||||
return max(delay, 0.0)
|
||||
|
||||
|
||||
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Retry middleware that honors provider-issued Retry-After hints.
|
||||
|
||||
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
|
||||
when working with LiteLLM/Anthropic/OpenAI providers that surface
|
||||
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
|
||||
events so the UI can show a friendly "rate limited, retrying in Xs"
|
||||
indicator.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum retries after the initial attempt (default 3).
|
||||
initial_delay: Initial backoff delay in seconds.
|
||||
backoff_factor: Exponential growth factor for backoff.
|
||||
max_delay: Cap on per-attempt delay in seconds.
|
||||
jitter: Whether to add ±25% jitter.
|
||||
retry_on: Optional callable that returns True for retryable
|
||||
exceptions. The default retries everything except known
|
||||
non-retryable classes (context overflow, auth, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
backoff_factor: float = 2.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
retry_on: Callable[[BaseException], bool] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.initial_delay = initial_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
self._retry_on: Callable[[BaseException], bool] = retry_on or (
|
||||
lambda exc: not _is_non_retryable(exc)
|
||||
)
|
||||
|
||||
def _should_retry(self, exc: BaseException) -> bool:
|
||||
try:
|
||||
return bool(self._retry_on(exc))
|
||||
except Exception:
|
||||
logger.exception("retry_on callable raised; defaulting to False")
|
||||
return False
|
||||
|
||||
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
|
||||
backoff = _exponential_delay(
|
||||
attempt,
|
||||
initial_delay=self.initial_delay,
|
||||
backoff_factor=self.backoff_factor,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
header = _extract_retry_after_seconds(exc) or 0.0
|
||||
return max(backoff, header)
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc:
|
||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||
raise
|
||||
delay = self._delay_for_attempt(attempt, exc)
|
||||
ot.add_event(
|
||||
"model.retry.scheduled",
|
||||
{
|
||||
"retry.attempt": attempt + 1,
|
||||
"retry.max": self.max_retries,
|
||||
"retry.delay_ms": int(delay * 1000),
|
||||
"retry.reason": ot_metrics.categorize_exception(exc),
|
||||
},
|
||||
)
|
||||
try:
|
||||
dispatch_custom_event(
|
||||
"surfsense.retrying",
|
||||
{
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self.max_retries,
|
||||
"delay_ms": int(delay * 1000),
|
||||
"reason": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"dispatch_custom_event failed; suppressed", exc_info=True
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Unreachable
|
||||
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc:
|
||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||
raise
|
||||
delay = self._delay_for_attempt(attempt, exc)
|
||||
ot.add_event(
|
||||
"model.retry.scheduled",
|
||||
{
|
||||
"retry.attempt": attempt + 1,
|
||||
"retry.max": self.max_retries,
|
||||
"retry.delay_ms": int(delay * 1000),
|
||||
"retry.reason": ot_metrics.categorize_exception(exc),
|
||||
},
|
||||
)
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"surfsense.retrying",
|
||||
{
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self.max_retries,
|
||||
"delay_ms": int(delay * 1000),
|
||||
"reason": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"adispatch_custom_event failed; suppressed", exc_info=True
|
||||
)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RetryAfterMiddleware",
|
||||
"_extract_retry_after_seconds",
|
||||
"_is_non_retryable",
|
||||
]
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
# Matched by class name across the MRO so we don't have to import every
|
||||
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"RateLimitError",
|
||||
"APIStatusError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"BadGatewayError",
|
||||
"GatewayTimeoutError",
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"TimeoutError",
|
||||
"TimeoutException",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for attempt, fallback_model in enumerate(self.models, start=1):
|
||||
ot.add_event(
|
||||
"model.fallback",
|
||||
{
|
||||
"fallback.attempt": attempt,
|
||||
"fallback.from": attempt - 1,
|
||||
"fallback.to": attempt,
|
||||
"fallback.reason": ot_metrics.categorize_exception(last_exception),
|
||||
},
|
||||
)
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for attempt, fallback_model in enumerate(self.models, start=1):
|
||||
ot.add_event(
|
||||
"model.fallback",
|
||||
{
|
||||
"fallback.attempt": attempt,
|
||||
"fallback.from": attempt - 1,
|
||||
"fallback.to": attempt,
|
||||
"fallback.reason": ot_metrics.categorize_exception(last_exception),
|
||||
},
|
||||
)
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
|
@ -0,0 +1,344 @@
|
|||
"""Skills backends for SurfSense.
|
||||
|
||||
Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol`
|
||||
subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`.
|
||||
|
||||
The middleware only needs four methods to load skills from a backend:
|
||||
|
||||
* ``ls_info`` / ``als_info`` — list directories under a source path.
|
||||
* ``download_files`` / ``adownload_files`` — fetch ``SKILL.md`` bytes.
|
||||
|
||||
Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` …)
|
||||
default to ``NotImplementedError`` from the base class. They are never reached
|
||||
by the skills middleware because skill content is rendered into the system
|
||||
prompt at agent build time, not edited at runtime.
|
||||
|
||||
Two backends are provided:
|
||||
|
||||
* :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from
|
||||
``app/agents/new_chat/skills/builtin/``.
|
||||
* :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over
|
||||
:class:`KBPostgresBackend` that filters notes under the privileged folder
|
||||
``/documents/_skills/``.
|
||||
|
||||
Both backends are intentionally read-only: skill authoring happens out of band
|
||||
(via filesystem or a search-space-admin route), so we never expose
|
||||
``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError``
|
||||
gives a clean failure mode if anything tries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deepagents.backends.composite import CompositeBackend
|
||||
from deepagents.backends.protocol import (
|
||||
BackendProtocol,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
)
|
||||
from deepagents.backends.state import StateBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.shared.middleware.kb_postgres_backend import KBPostgresBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE.
|
||||
_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
|
||||
def _default_builtin_root() -> Path:
|
||||
"""Return the absolute path to the bundled builtin skills directory.
|
||||
|
||||
The skill assets still live at ``app/agents/new_chat/skills/builtin/`` (the
|
||||
``skills/`` tree migrates to the shared kernel in a later slice). This module
|
||||
now lives under ``app/agents/shared/middleware/``, so we walk up to
|
||||
``app/agents/`` and back into ``new_chat/skills/builtin``. Once skills move,
|
||||
this becomes ``Path(__file__).resolve().parent.parent / "skills" / "builtin"``.
|
||||
"""
|
||||
agents_dir = Path(__file__).resolve().parent.parent.parent
|
||||
return (agents_dir / "new_chat" / "skills" / "builtin").resolve()
|
||||
|
||||
|
||||
class BuiltinSkillsBackend(BackendProtocol):
|
||||
"""Read-only disk-backed skills source.
|
||||
|
||||
Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk,
|
||||
where each skill is its own subdirectory containing a ``SKILL.md`` file::
|
||||
|
||||
<root>/<skill-name>/SKILL.md
|
||||
|
||||
The middleware calls :meth:`als_info` with the source path and expects a
|
||||
``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it
|
||||
calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and
|
||||
parses YAML frontmatter from the returned ``content`` bytes.
|
||||
|
||||
Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at
|
||||
prefix ``/skills/builtin/`` means the middleware can issue paths like
|
||||
``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to
|
||||
``/kb-research/SKILL.md`` before forwarding here. We treat any leading
|
||||
slash as anchoring at :attr:`root`.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path | str | None = None) -> None:
|
||||
self.root: Path = Path(root).resolve() if root else _default_builtin_root()
|
||||
if not self.root.exists():
|
||||
logger.info(
|
||||
"BuiltinSkillsBackend root %s does not exist; skills will be empty.",
|
||||
self.root,
|
||||
)
|
||||
|
||||
def _resolve(self, path: str) -> Path:
|
||||
"""Resolve a virtual posix path under :attr:`root`, refusing escapes."""
|
||||
bare = path.lstrip("/")
|
||||
candidate = (self.root / bare).resolve() if bare else self.root
|
||||
# Refuse symlink/.. traversal that escapes the root.
|
||||
try:
|
||||
candidate.relative_to(self.root)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"path {path!r} escapes builtin skills root") from exc
|
||||
return candidate
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
try:
|
||||
target = self._resolve(path)
|
||||
except ValueError as exc:
|
||||
logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc)
|
||||
return []
|
||||
if not target.exists() or not target.is_dir():
|
||||
return []
|
||||
|
||||
infos: list[FileInfo] = []
|
||||
# Build virtual paths anchored at "/" because CompositeBackend already
|
||||
# stripped the route prefix before calling us.
|
||||
target_virtual = (
|
||||
"/"
|
||||
if target == self.root
|
||||
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
|
||||
)
|
||||
for child in sorted(target.iterdir()):
|
||||
if child.name == "__pycache__" or child.name.startswith("."):
|
||||
continue
|
||||
child_virtual = (
|
||||
target_virtual.rstrip("/") + "/" + child.name
|
||||
if target_virtual != "/"
|
||||
else "/" + child.name
|
||||
)
|
||||
info: FileInfo = {
|
||||
"path": child_virtual,
|
||||
"is_dir": child.is_dir(),
|
||||
}
|
||||
if child.is_file():
|
||||
with contextlib.suppress(OSError): # pragma: no cover - defensive
|
||||
info["size"] = child.stat().st_size
|
||||
infos.append(info)
|
||||
return infos
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
responses: list[FileDownloadResponse] = []
|
||||
for p in paths:
|
||||
try:
|
||||
target = self._resolve(p)
|
||||
except ValueError:
|
||||
responses.append(FileDownloadResponse(path=p, error="invalid_path"))
|
||||
continue
|
||||
if not target.exists():
|
||||
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||
continue
|
||||
if target.is_dir():
|
||||
responses.append(FileDownloadResponse(path=p, error="is_directory"))
|
||||
continue
|
||||
try:
|
||||
# Hard cap to avoid loading rogue mega-files into memory.
|
||||
size = target.stat().st_size
|
||||
if size > _MAX_SKILL_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Builtin skill file %s exceeds %d bytes; truncating.",
|
||||
target,
|
||||
_MAX_SKILL_FILE_SIZE,
|
||||
)
|
||||
with target.open("rb") as fh:
|
||||
content = fh.read(_MAX_SKILL_FILE_SIZE)
|
||||
else:
|
||||
content = target.read_bytes()
|
||||
except PermissionError:
|
||||
responses.append(
|
||||
FileDownloadResponse(path=p, error="permission_denied")
|
||||
)
|
||||
continue
|
||||
except OSError as exc: # pragma: no cover - defensive
|
||||
logger.warning("Builtin skill read failed %s: %s", target, exc)
|
||||
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||
continue
|
||||
responses.append(FileDownloadResponse(path=p, content=content, error=None))
|
||||
return responses
|
||||
|
||||
|
||||
class SearchSpaceSkillsBackend(BackendProtocol):
|
||||
"""Read-only view of search-space-authored skills.
|
||||
|
||||
Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged
|
||||
folder ``/documents/_skills/`` (configurable). The folder is intended to be
|
||||
writable only by search-space admins; this backend never writes.
|
||||
|
||||
The skills middleware expects a layout like::
|
||||
|
||||
/<source_root>/<skill-name>/SKILL.md
|
||||
|
||||
But the KB stores documents like ``/documents/_skills/<name>/SKILL.md``.
|
||||
We expose the inner namespace by remapping each path. When mounted under
|
||||
:class:`CompositeBackend` at prefix ``/skills/space/`` the paths the
|
||||
middleware sees become ``/skills/space/<name>/SKILL.md``; the composite
|
||||
strips ``/skills/space/`` and hands us ``/<name>/SKILL.md``, which we
|
||||
rewrite to ``/documents/_skills/<name>/SKILL.md`` before forwarding to the
|
||||
KB.
|
||||
|
||||
No new database table is needed: the privileged folder convention is
|
||||
enforced server-side outside of this class. We intentionally swallow any
|
||||
write/edit attempts (the base class raises ``NotImplementedError``).
|
||||
"""
|
||||
|
||||
DEFAULT_KB_ROOT: str = "/documents/_skills"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_backend: KBPostgresBackend,
|
||||
*,
|
||||
kb_root: str = DEFAULT_KB_ROOT,
|
||||
) -> None:
|
||||
self._kb = kb_backend
|
||||
# Normalize trailing slash off so we can join cleanly.
|
||||
self._kb_root = kb_root.rstrip("/") or "/"
|
||||
|
||||
def _to_kb(self, path: str) -> str:
|
||||
"""Rewrite a virtual path into the underlying KB namespace."""
|
||||
bare = path.lstrip("/")
|
||||
if not bare:
|
||||
return self._kb_root
|
||||
return f"{self._kb_root}/{bare}"
|
||||
|
||||
def _from_kb(self, kb_path: str) -> str:
|
||||
"""Rewrite a KB path back into our virtual namespace."""
|
||||
if not kb_path.startswith(self._kb_root):
|
||||
return kb_path # pragma: no cover - defensive
|
||||
rel = kb_path[len(self._kb_root) :]
|
||||
return rel if rel.startswith("/") else "/" + rel
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
# KBPostgresBackend exposes only the async API meaningfully; the sync
|
||||
# path falls back to ``asyncio.to_thread(...)`` in the base class. We
|
||||
# keep this stub to satisfy abstract resolution; the middleware calls
|
||||
# ``als_info``.
|
||||
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]:
|
||||
kb_path = self._to_kb(path)
|
||||
try:
|
||||
infos = await self._kb.als_info(kb_path)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc)
|
||||
return []
|
||||
remapped: list[FileInfo] = []
|
||||
for info in infos:
|
||||
kb_p = info.get("path", "")
|
||||
if not kb_p.startswith(self._kb_root):
|
||||
continue
|
||||
remapped.append({**info, "path": self._from_kb(kb_p)})
|
||||
return remapped
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
kb_paths = [self._to_kb(p) for p in paths]
|
||||
responses = await self._kb.adownload_files(kb_paths)
|
||||
# Re-map response paths back to the virtual namespace so the middleware
|
||||
# correlates them to the input list correctly.
|
||||
remapped: list[FileDownloadResponse] = []
|
||||
for original, resp in zip(paths, responses, strict=True):
|
||||
remapped.append(replace(resp, path=original))
|
||||
return remapped
|
||||
|
||||
|
||||
SKILLS_BUILTIN_PREFIX = "/skills/builtin/"
|
||||
SKILLS_SPACE_PREFIX = "/skills/space/"
|
||||
|
||||
|
||||
def build_skills_backend_factory(
|
||||
*,
|
||||
builtin_root: Path | str | None = None,
|
||||
search_space_id: int | None = None,
|
||||
) -> Callable[[ToolRuntime], BackendProtocol]:
|
||||
"""Return a runtime-aware factory for the skills :class:`CompositeBackend`.
|
||||
|
||||
When ``search_space_id`` is provided the composite includes a
|
||||
:class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh
|
||||
per-runtime :class:`KBPostgresBackend`, mirroring how
|
||||
:func:`build_backend_resolver` constructs the main filesystem backend.
|
||||
|
||||
When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit
|
||||
tests) only the bundled :class:`BuiltinSkillsBackend` is exposed.
|
||||
|
||||
Returning a factory rather than a fixed instance is intentional: the
|
||||
underlying KB backend depends on per-call ``ToolRuntime`` state
|
||||
(``staged_dirs``, ``files`` cache, runtime config), so a single shared
|
||||
instance cannot serve multiple concurrent agent runs.
|
||||
"""
|
||||
builtin = BuiltinSkillsBackend(builtin_root)
|
||||
|
||||
if search_space_id is None:
|
||||
|
||||
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
|
||||
# Default StateBackend is intentionally inert: any path outside the
|
||||
# ``/skills/builtin/`` route resolves to an empty per-runtime state
|
||||
# so the SkillsMiddleware can iterate sources without raising.
|
||||
return CompositeBackend(
|
||||
default=StateBackend(runtime),
|
||||
routes={SKILLS_BUILTIN_PREFIX: builtin},
|
||||
)
|
||||
|
||||
return _factory_builtin_only
|
||||
|
||||
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
|
||||
# Imported lazily to avoid a hard dependency at module import time:
|
||||
# ``KBPostgresBackend`` pulls in DB models, which are unnecessary for
|
||||
# the unit-tested builtin path.
|
||||
from app.agents.shared.middleware.kb_postgres_backend import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
|
||||
kb = KBPostgresBackend(search_space_id, runtime)
|
||||
space = SearchSpaceSkillsBackend(kb)
|
||||
return CompositeBackend(
|
||||
default=StateBackend(runtime),
|
||||
routes={
|
||||
SKILLS_BUILTIN_PREFIX: builtin,
|
||||
SKILLS_SPACE_PREFIX: space,
|
||||
},
|
||||
)
|
||||
|
||||
return _factory_with_space
|
||||
|
||||
|
||||
def default_skills_sources() -> list[str]:
|
||||
"""Return the canonical source list for SkillsMiddleware (built-in then space)."""
|
||||
return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SKILLS_BUILTIN_PREFIX",
|
||||
"SKILLS_SPACE_PREFIX",
|
||||
"BuiltinSkillsBackend",
|
||||
"SearchSpaceSkillsBackend",
|
||||
"build_skills_backend_factory",
|
||||
"default_skills_sources",
|
||||
]
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
"""
|
||||
ToolCallNameRepairMiddleware — two-stage tool-name repair.
|
||||
|
||||
Operation:
|
||||
1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in
|
||||
the registry but ``name.lower()`` is, rewrite in place. Catches
|
||||
models that emit ``Search`` instead of ``search``.
|
||||
2. **Stage 2 — invalid fallback:** if still unmatched, rewrite the call
|
||||
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
|
||||
so the registered :func:`invalid_tool` returns the error to the model
|
||||
for self-correction.
|
||||
|
||||
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358``
|
||||
+ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent:
|
||||
:class:`deepagents.middleware.PatchToolCallsMiddleware` patches
|
||||
*dangling* tool calls (no matching ToolMessage) but does nothing about
|
||||
wrong names, and the model framework's default behavior on an unknown
|
||||
name is to crash the turn rather than route to a self-correction
|
||||
fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
|
||||
"""Normalize a tool call entry to a mutable dict."""
|
||||
if isinstance(call, dict):
|
||||
return dict(call)
|
||||
return {
|
||||
"name": getattr(call, "name", None),
|
||||
"args": getattr(call, "args", {}),
|
||||
"id": getattr(call, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
|
||||
class ToolCallNameRepairMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
||||
|
||||
Args:
|
||||
registered_tool_names: Set of canonically-registered tool names.
|
||||
``invalid`` should be in this set so the fallback dispatches.
|
||||
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
|
||||
fuzzy-match step that runs *between* lowercase and invalid.
|
||||
Set to ``None`` to disable fuzzy matching (default in
|
||||
OpenCode; we mirror that to avoid silent rewrites).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
registered_tool_names: set[str],
|
||||
fuzzy_match_threshold: float | None = 0.85,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._registered = set(registered_tool_names)
|
||||
self._registered_lower = {name.lower(): name for name in self._registered}
|
||||
self._fuzzy_threshold = fuzzy_match_threshold
|
||||
self.tools = []
|
||||
|
||||
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
||||
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
||||
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
||||
if isinstance(ctx_tools, set | frozenset):
|
||||
return self._registered | set(ctx_tools)
|
||||
if isinstance(ctx_tools, list | tuple):
|
||||
return self._registered | set(ctx_tools)
|
||||
return self._registered
|
||||
|
||||
def _repair_one(
|
||||
self,
|
||||
call: dict[str, Any],
|
||||
registered: set[str],
|
||||
) -> dict[str, Any]:
|
||||
name = call.get("name")
|
||||
if not isinstance(name, str):
|
||||
return call
|
||||
|
||||
if name in registered:
|
||||
return call
|
||||
|
||||
# Stage 1 — lowercase
|
||||
lowered = name.lower()
|
||||
if lowered in registered:
|
||||
call["name"] = lowered
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", "lowercase")
|
||||
call["response_metadata"] = metadata
|
||||
return call
|
||||
|
||||
# Optional fuzzy step (off by default — see class docstring)
|
||||
if self._fuzzy_threshold is not None:
|
||||
close = difflib.get_close_matches(
|
||||
name, registered, n=1, cutoff=self._fuzzy_threshold
|
||||
)
|
||||
if close:
|
||||
call["name"] = close[0]
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
|
||||
call["response_metadata"] = metadata
|
||||
return call
|
||||
|
||||
# Stage 2 — invalid fallback
|
||||
# Local import avoids a module-load cycle through the frozen single-agent
|
||||
# package (new_chat.__init__ -> chat_deepagent -> middleware shim).
|
||||
# Resolves to app.agents.shared.tools once tools migrate.
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
|
||||
|
||||
if INVALID_TOOL_NAME in registered:
|
||||
original_args = call.get("args") or {}
|
||||
error_msg = (
|
||||
f"Tool name '{name}' is not registered. "
|
||||
f"Original arguments were: {original_args!r}."
|
||||
)
|
||||
call["name"] = INVALID_TOOL_NAME
|
||||
call["args"] = {"tool": name, "error": error_msg}
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", f"invalid_fallback:{name}")
|
||||
call["response_metadata"] = metadata
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not repair unknown tool call %r; 'invalid' tool not registered",
|
||||
name,
|
||||
)
|
||||
return call
|
||||
|
||||
def _maybe_repair(
|
||||
self,
|
||||
message: AIMessage,
|
||||
registered: set[str],
|
||||
) -> AIMessage | None:
|
||||
if not message.tool_calls:
|
||||
return None
|
||||
|
||||
new_calls: list[dict[str, Any]] = []
|
||||
any_changed = False
|
||||
for raw in message.tool_calls:
|
||||
call = _coerce_existing_tool_call(raw)
|
||||
before = (call.get("name"), call.get("args"))
|
||||
repaired = self._repair_one(call, registered)
|
||||
after = (repaired.get("name"), repaired.get("args"))
|
||||
if before != after:
|
||||
any_changed = True
|
||||
new_calls.append(repaired)
|
||||
|
||||
if not any_changed:
|
||||
return None
|
||||
|
||||
return message.model_copy(update={"tool_calls": new_calls})
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
registered = self._registered_for_runtime(runtime)
|
||||
repaired = self._maybe_repair(last, registered)
|
||||
if repaired is None:
|
||||
return None
|
||||
return {"messages": [repaired]}
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolCallNameRepairMiddleware",
|
||||
]
|
||||
|
|
@ -23,7 +23,7 @@ the receipt into the parent's ``receipts`` state via the append reducer.
|
|||
|
||||
The KB write path is the one exception: file-tool calls cannot emit a
|
||||
durable receipt because the actual DB writes happen end-of-turn inside
|
||||
:class:`app.agents.new_chat.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`.
|
||||
:class:`app.agents.shared.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`.
|
||||
KB tools therefore emit a *provisional* receipt with ``status="pending"``;
|
||||
the persistence middleware flips it to ``"success"`` or ``"failed"``
|
||||
before returning control to the parent.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue