feat: updated agent harness

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 09:22:19 -07:00
parent 9ec9b64348
commit 31a372bb84
139 changed files with 12583 additions and 1111 deletions

View file

@ -1,11 +1,23 @@
"""Middleware components for the SurfSense new chat agent."""
from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
from app.agents.new_chat.middleware.anonymous_document import (
AnonymousDocumentMiddleware,
)
from app.agents.new_chat.middleware.busy_mutex import BusyMutexMiddleware
from app.agents.new_chat.middleware.compaction import (
SurfSenseCompactionMiddleware,
create_surfsense_compaction_middleware,
)
from app.agents.new_chat.middleware.context_editing import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware
from app.agents.new_chat.middleware.file_intent import (
FileIntentMiddleware,
)
@ -26,16 +38,46 @@ from app.agents.new_chat.middleware.knowledge_tree import (
from app.agents.new_chat.middleware.memory_injection import (
MemoryInjectionMiddleware,
)
from app.agents.new_chat.middleware.noop_injection import NoopInjectionMiddleware
from app.agents.new_chat.middleware.otel_span import OtelSpanMiddleware
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
from app.agents.new_chat.middleware.skills_backends import (
BuiltinSkillsBackend,
SearchSpaceSkillsBackend,
build_skills_backend_factory,
default_skills_sources,
)
from app.agents.new_chat.middleware.tool_call_repair import (
ToolCallNameRepairMiddleware,
)
__all__ = [
"ActionLogMiddleware",
"AnonymousDocumentMiddleware",
"BuiltinSkillsBackend",
"BusyMutexMiddleware",
"ClearToolUsesEdit",
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"FileIntentMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",
"KnowledgeTreeMiddleware",
"MemoryInjectionMiddleware",
"NoopInjectionMiddleware",
"OtelSpanMiddleware",
"PermissionMiddleware",
"RetryAfterMiddleware",
"SearchSpaceSkillsBackend",
"SpillToBackendEdit",
"SpillingContextEditingMiddleware",
"SurfSenseCompactionMiddleware",
"SurfSenseFilesystemMiddleware",
"ToolCallNameRepairMiddleware",
"build_skills_backend_factory",
"commit_staged_filesystem_state",
"create_surfsense_compaction_middleware",
"default_skills_sources",
]

View file

@ -0,0 +1,294 @@
"""Append-only action-log middleware for the SurfSense agent.
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
into reversibility by declaring a ``reverse`` callable on their
:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered
descriptor is persisted in ``reverse_descriptor`` for use by
``/api/threads/{thread_id}/revert/{action_id}``.
Design points:
* **Defensive.** Logging never blocks the agent. We catch every exception
on the DB write path and emit a warning; the tool's ``ToolMessage``
result is always returned untouched.
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
remains in the LangGraph checkpoint / spilled tool-output files.
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
with the parsed JSON result when the tool's content is a JSON object;
otherwise the raw text is passed. Exceptions in the reverse callable
are swallowed and logged a failed descriptor render simply means the
action is NOT marked reversible.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.tools.registry import ToolDefinition
if TYPE_CHECKING: # pragma: no cover - type-only
from langchain.agents.middleware.types import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
# accidentally-huge inputs. Values are truncated and a flag is set in the
# stored payload so consumers can detect truncation.
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
class ActionLogMiddleware(AgentMiddleware):
"""Persist a row in :class:`AgentActionLog` after every tool call.
Should be placed near the OUTERMOST end of the tool-call wrapping stack
so that it sees the *final* :class:`ToolMessage` after all retries,
permission checks, and dedup logic have run. In practice that means
placing it just inside :class:`PermissionMiddleware` and outside
:class:`DedupHITLToolCallsMiddleware`.
The middleware is fully a no-op when:
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
(checked via :func:`get_flags`),
* the per-feature flag ``enable_action_log`` is off, or
* persistence raises (defensive: tool-call dispatch always succeeds).
Args:
thread_id: The current chat thread's primary-key id. Required to
persist a row; if ``None`` the middleware silently no-ops.
search_space_id: Search-space id for cascade-on-delete safety.
user_id: UUID string of the user driving this turn (nullable in
anonymous mode).
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
so the middleware can look up the tool's ``reverse`` callable.
When omitted, no actions are marked reversible.
"""
tools = ()
def __init__(
self,
*,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
tool_definitions: dict[str, ToolDefinition] | None = None,
) -> None:
super().__init__()
self._thread_id = thread_id
self._search_space_id = search_space_id
self._user_id = user_id
self._tool_definitions = dict(tool_definitions or {})
def _enabled(self) -> bool:
flags = get_flags()
if flags.disable_new_agent_stack:
return False
return bool(flags.enable_action_log) and self._thread_id is not None
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any]:
if not self._enabled():
return await handler(request)
result: ToolMessage | Command[Any]
error_payload: dict[str, Any] | None = None
try:
result = await handler(request)
except Exception as exc:
# Persist the failure too so revert/audit can see it, then
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
error_payload = {"type": type(exc).__name__, "message": str(exc)}
await self._record(
request=request,
result=None,
error_payload=error_payload,
)
raise
await self._record(request=request, result=result, error_payload=None)
return result
async def _record(
self,
*,
request: ToolCallRequest,
result: ToolMessage | Command[Any] | None,
error_payload: dict[str, Any] | None,
) -> None:
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
try:
from app.db import AgentActionLog, shielded_async_session
tool_name = _resolve_tool_name(request)
args_payload = _resolve_args_payload(request)
result_id = _resolve_result_id(result)
reverse_descriptor, reversible = self._render_reverse(
tool_name=tool_name,
args=_resolve_args_dict(request),
result=result,
)
row = AgentActionLog(
thread_id=self._thread_id,
user_id=self._user_id,
search_space_id=self._search_space_id,
turn_id=_resolve_turn_id(request),
message_id=_resolve_message_id(request),
tool_name=tool_name,
args=args_payload,
result_id=result_id,
reversible=reversible,
reverse_descriptor=reverse_descriptor,
error=error_payload,
)
async with shielded_async_session() as session:
session.add(row)
await session.commit()
except Exception:
logger.warning(
"ActionLogMiddleware failed to persist action log row",
exc_info=True,
)
def _render_reverse(
self,
*,
tool_name: str,
args: dict[str, Any] | None,
result: ToolMessage | Command[Any] | None,
) -> tuple[dict[str, Any] | None, bool]:
"""Run the tool's ``reverse`` callable and return its descriptor.
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
the tool has no ``reverse`` callable, or when the callable raises,
the action is marked non-reversible.
"""
if not result or not isinstance(result, ToolMessage):
return None, False
if args is None:
return None, False
tool_def = self._tool_definitions.get(tool_name)
if tool_def is None or tool_def.reverse is None:
return None, False
try:
parsed_result = _parse_tool_result_content(result)
descriptor = tool_def.reverse(args, parsed_result)
except Exception:
logger.warning(
"Reverse descriptor render failed for tool %s",
tool_name,
exc_info=True,
)
return None, False
if not isinstance(descriptor, dict):
return None, False
return descriptor, True
# ---------------------------------------------------------------------------
# Resolution helpers — defensive against tool_call request shape variation.
# ---------------------------------------------------------------------------
def _resolve_tool_name(request: Any) -> str:
try:
tool = getattr(request, "tool", None)
if tool is not None:
name = getattr(tool, "name", None)
if isinstance(name, str) and name:
return name
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
name = call.get("name")
if isinstance(name, str) and name:
return name
except Exception: # pragma: no cover - defensive
pass
return "unknown"
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
try:
call = getattr(request, "tool_call", None)
if not isinstance(call, dict):
return None
args = call.get("args")
if isinstance(args, dict):
return args
return None
except Exception: # pragma: no cover - defensive
return None
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
"""Return a JSON-serializable args dict, truncated if too big."""
args = _resolve_args_dict(request)
if args is None:
return None
try:
encoded = json.dumps(args, default=str)
except Exception:
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
return args
return {
"_truncated": True,
"_size": len(encoded),
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
}
def _resolve_turn_id(request: Any) -> str | None:
try:
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
tid = call.get("id")
if isinstance(tid, str):
return tid
except Exception: # pragma: no cover
pass
return None
def _resolve_message_id(request: Any) -> str | None:
"""Tool-call IDs serve as best-available message correlator at this layer."""
return _resolve_turn_id(request)
def _resolve_result_id(result: Any) -> str | None:
if isinstance(result, ToolMessage):
msg_id = getattr(result, "id", None)
if isinstance(msg_id, str):
return msg_id
return None
def _parse_tool_result_content(result: ToolMessage) -> Any:
content = result.content
if isinstance(content, str):
try:
return json.loads(content)
except (json.JSONDecodeError, ValueError):
return content
return content
__all__ = ["ActionLogMiddleware"]

View file

@ -0,0 +1,231 @@
"""
BusyMutexMiddleware per-thread asyncio lock + cancel token.
Tier 2.2 in the OpenCode-port plan. Mirrors opencode's
``Stream.scoped(AbortController)`` pattern (single-process, in-memory
lock + cooperative cancellation). For multi-worker deployments a
distributed lock backend (Redis or PostgreSQL advisory locks) is a
phase-2 follow-up.
What this provides:
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
acquiring the lock during ``before_agent`` blocks any concurrent
prompt on the same thread until release.
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
tools can poll to abort cooperatively. The event is reset between
turns. Tools should check ``runtime.context.cancel_event.is_set()``
in tight inner loops.
- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a
second turn arrives while the lock is held.
Note: SurfSense's ``stream_new_chat`` is the call site that should
acquire/release. Wiring this as middleware means the contract is
explicit and the lock manager is shared with subagents that compile
their own ``create_agent`` runnables.
"""
from __future__ import annotations
import asyncio
import logging
import weakref
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langgraph.config import get_config
from langgraph.runtime import Runtime
from app.agents.new_chat.errors import BusyError
logger = logging.getLogger(__name__)
class _ThreadLockManager:
"""Process-local registry of per-thread asyncio locks + cancel events."""
def __init__(self) -> None:
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
weakref.WeakValueDictionary()
)
self._cancel_events: dict[str, asyncio.Event] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
if lock is None:
lock = asyncio.Lock()
self._locks[thread_id] = lock
return lock
def cancel_event(self, thread_id: str) -> asyncio.Event:
event = self._cancel_events.get(thread_id)
if event is None:
event = asyncio.Event()
self._cancel_events[thread_id] = event
return event
def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
if event is None:
return False
event.set()
return True
def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id)
if event is not None:
event.clear()
# Module-level singleton — process-local but reused across all agent
# instances built in this process. Subagents created in nested
# ``create_agent`` calls also get this so locks are coherent.
manager = _ThreadLockManager()
def get_cancel_event(thread_id: str) -> asyncio.Event:
"""Public accessor used by long-running tools to poll cancellation."""
return manager.cancel_event(thread_id)
def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found."""
return manager.request_cancel(thread_id)
def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Block concurrent prompts on the same thread.
Acquires the thread's lock in ``abefore_agent`` and releases in
``aafter_agent``. If the lock is held, raises :class:`BusyError`
so the caller can emit a ``surfsense.busy`` SSE event with the
in-flight request id.
Args:
require_thread_id: When True, raise :class:`BusyError` if no
``thread_id`` can be resolved from the active
``RunnableConfig``. Default is False we treat a missing
thread_id as "this turn has nothing to lock against" and
no-op the mutex. Set True only when you trust the call
site to always provide ``configurable.thread_id`` (e.g.
in production where ``stream_new_chat`` always does).
"""
def __init__(self, *, require_thread_id: bool = False) -> None:
super().__init__()
self._require_thread_id = require_thread_id
self.tools = []
# Per-call locks owned by this middleware. We track them as
# an instance attribute so ``aafter_agent`` knows which lock
# to release.
self._held_locks: dict[str, asyncio.Lock] = {}
@staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
"""Extract ``thread_id`` from the active LangGraph ``RunnableConfig``.
``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``.
The runnable config (where ``configurable.thread_id`` lives) must be
fetched via :func:`langgraph.config.get_config` from inside a node /
middleware. We fall back to ``getattr(runtime, "config", None)`` for
unit tests / legacy runtimes that synthesize a config-bearing stub.
"""
def _from_dict(cfg: Any) -> str | None:
if not isinstance(cfg, dict):
return None
tid = (cfg.get("configurable") or {}).get("thread_id")
return str(tid) if tid is not None else None
# Preferred path: real LangGraph runtime context.
try:
tid = _from_dict(get_config())
except Exception:
tid = None
if tid is not None:
return tid
# Fallback for tests and any runtime that surfaces a config dict
# directly on the runtime instance.
return _from_dict(getattr(runtime, "config", None))
async def abefore_agent( # type: ignore[override]
self,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
del state
thread_id = self._thread_id(runtime)
if thread_id is None:
if self._require_thread_id:
raise BusyError("no thread_id configured")
logger.debug(
"BusyMutexMiddleware: no thread_id resolved from RunnableConfig; "
"skipping per-thread lock for this turn."
)
return None
lock = manager.lock_for(thread_id)
if lock.locked():
raise BusyError(request_id=thread_id)
await lock.acquire()
self._held_locks[thread_id] = lock
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
async def aafter_agent( # type: ignore[override]
self,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
del state
thread_id = self._thread_id(runtime)
if thread_id is None:
return None
lock = self._held_locks.pop(thread_id, None)
if lock is not None and lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.
reset_cancel(thread_id)
return None
# Provide sync no-ops because the middleware base class allows them
def before_agent( # type: ignore[override]
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
# if anyone else is in flight.
thread_id = self._thread_id(runtime)
if thread_id is None:
if self._require_thread_id:
raise BusyError("no thread_id configured")
return None
lock = manager.lock_for(thread_id)
if lock.locked():
raise BusyError(request_id=thread_id)
return None
def after_agent( # type: ignore[override]
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return None
__all__ = [
"BusyMutexMiddleware",
"get_cancel_event",
"manager",
"request_cancel",
"reset_cancel",
]

View file

@ -0,0 +1,253 @@
"""
SurfSense compaction middleware.
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
to add SurfSense-specific behavior:
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``).
2. **Protect SurfSense-specific SystemMessages** so injected hints
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
are *not* summarized away and are kept verbatim in the post-summary
message list.
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
(Azure OpenAI / LiteLLM defense when a provider streams an AIMessage
containing only tool_calls and no text, ``content`` can be ``None`` and
``get_buffer_string`` crashes iterating over ``None``). This used to live in
``safe_summarization.py``; folded in here.
This replaces ``app.agents.new_chat.middleware.safe_summarization``.
Tier 1.3 in the OpenCode-port plan.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from deepagents.middleware.summarization import (
SummarizationMiddleware,
compute_summarization_defaults,
)
from langchain_core.messages import SystemMessage
from app.observability import otel as ot
if TYPE_CHECKING:
from deepagents.backends.protocol import BACKEND_TYPES
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage
logger = logging.getLogger(__name__)
# OpenCode-faithful structured summary template. Mirrors
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
# module-level constant so unit tests can assert on its sections.
SURFSENSE_SUMMARY_PROMPT = """<role>
SurfSense Conversation Compaction Assistant
</role>
<primary_objective>
Extract the most important context from the conversation history below into a structured summary that will replace the older messages.
</primary_objective>
<instructions>
You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report.
## Goal
What is the user's primary goal or request? State it in one or two sentences.
## Constraints
What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)?
## Progress
What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion.
## Key Decisions
What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path.
## Next Steps
What specific tasks remain to achieve the goal? Order them by dependency.
## Critical Context
What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name).
## Relevant Files
What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree.
</instructions>
<messages>
Messages to summarize:
{messages}
</messages>
Respond ONLY with the structured summary. Do not include any text before or after.
"""
# SystemMessage prefixes that must NOT be summarized away. They are
# re-injected on every turn by the corresponding middleware, but the
# compaction step happens *before* re-injection in some paths, so we
# must preserve them verbatim across the cutoff.
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
"<priority_documents>", # KnowledgePriorityMiddleware
"<workspace_tree>", # KnowledgeTreeMiddleware
"<file_operation_contract>", # FileIntentMiddleware
"<user_memory>", # MemoryInjectionMiddleware
"<team_memory>", # MemoryInjectionMiddleware
"<user_name>", # MemoryInjectionMiddleware
"<memory_warning>", # MemoryInjectionMiddleware
)
def _is_protected_system_message(msg: AnyMessage) -> bool:
"""Return True if ``msg`` is a SystemMessage we must not summarize."""
if not isinstance(msg, SystemMessage):
return False
content = msg.content
if not isinstance(content, str):
return False
stripped = content.lstrip()
return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES)
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
"""Return ``msg`` with ``content=None`` coerced to ``""``.
Folds in the historical defense from ``safe_summarization.py``
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
AIMessage) explodes. We return a copy with empty string content so
downstream consumers see an empty body without mutating the original.
"""
if getattr(msg, "content", "not-missing") is not None:
return msg
try:
return msg.model_copy(update={"content": ""})
except AttributeError:
import copy
new_msg = copy.copy(msg)
try:
new_msg.content = ""
except Exception:
logger.debug(
"Could not sanitize content=None on message of type %s",
type(msg).__name__,
)
return msg
return new_msg
class SurfSenseCompactionMiddleware(SummarizationMiddleware):
"""SummarizationMiddleware tuned for SurfSense.
Notes
-----
- Overrides :meth:`_partition_messages` so protected SystemMessages
survive into the ``preserved_messages`` half regardless of cutoff.
- Overrides :meth:`_filter_summary_messages` so the buffer-string path
never iterates ``None`` content.
- Inherits everything else (auto-trigger, backend offload,
``_summarization_event`` plumbing, ``ContextOverflowError`` fallback).
"""
def _partition_messages( # type: ignore[override]
self,
conversation_messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Split messages but always preserve SurfSense protected SystemMessages.
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
(``opencode/packages/opencode/src/session/compaction.ts``): some
message types are always kept verbatim because they are part of the
agent's working contract, not transient output.
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
so dashboards can count compaction events and message-volume
without having to instrument upstream callers.
"""
# Opening a span here is appropriate because partitioning is the
# first call SummarizationMiddleware makes when it has decided to
# summarize; we record the volume and then close as a normal span.
with ot.compaction_span(
reason="auto",
messages_in=len(conversation_messages),
extra={"compaction.cutoff_index": int(cutoff_index)},
):
messages_to_summarize, preserved_messages = (
super()._partition_messages(conversation_messages, cutoff_index)
)
protected: list[AnyMessage] = []
kept_for_summary: list[AnyMessage] = []
for msg in messages_to_summarize:
if _is_protected_system_message(msg):
protected.append(msg)
else:
kept_for_summary.append(msg)
# Place protected blocks at the *front* of preserved_messages so
# they keep their original ordering relative to the summary
# HumanMessage that precedes the rest of the preserved tail.
return kept_for_summary, [*protected, *preserved_messages]
def _filter_summary_messages( # type: ignore[override]
self, messages: list[AnyMessage]
) -> list[AnyMessage]:
"""Filter previous summaries AND sanitize ``content=None``.
Folds the ``safe_summarization.py`` defense in: when the buffer
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
here covers both the sync and async offload paths.
"""
filtered = super()._filter_summary_messages(messages)
return [_sanitize_message_content(m) for m in filtered]
def create_surfsense_compaction_middleware(
model: BaseChatModel,
backend: BACKEND_TYPES,
*,
summary_prompt: str | None = None,
history_path_prefix: str = "/conversation_history",
**overrides: Any,
) -> SurfSenseCompactionMiddleware:
"""Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults.
Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings``
via :func:`deepagents.middleware.summarization.compute_summarization_defaults`
so callers get the same behavior as ``create_summarization_middleware``
plus our overrides.
Args:
model: Chat model to call for summary generation.
backend: Backend instance or factory for offloading conversation history.
summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`.
history_path_prefix: Path prefix for offloaded conversation history.
**overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`.
"""
defaults = compute_summarization_defaults(model)
return SurfSenseCompactionMiddleware(
model=model,
backend=backend,
trigger=overrides.pop("trigger", defaults["trigger"]),
keep=overrides.pop("keep", defaults["keep"]),
trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None),
truncate_args_settings=overrides.pop(
"truncate_args_settings", defaults["truncate_args_settings"]
),
summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT,
history_path_prefix=history_path_prefix,
**overrides,
)
__all__ = [
"PROTECTED_SYSTEM_PREFIXES",
"SURFSENSE_SUMMARY_PROMPT",
"SurfSenseCompactionMiddleware",
"create_surfsense_compaction_middleware",
]

View file

@ -0,0 +1,349 @@
"""
SpillToBackendEdit + SpillingContextEditingMiddleware.
Mirrors OpenCode's spill-to-disk behavior in
``opencode/packages/opencode/src/tool/truncate.ts``. Before
``ClearToolUsesEdit`` rewrites old ``ToolMessage.content`` to a placeholder,
we capture the full original content and write it to the runtime backend
under ``/tool_outputs/{thread_id}/{message_id}.txt``. The placeholder is
upgraded to ``"[cleared — full output at /tool_outputs/.../{id}.txt; ask the
explore subagent to read it]"`` so the agent can recover it on demand.
Tier 1.2 in the OpenCode-port plan.
Why this is a middleware subclass instead of a plain ``ContextEdit``:
``ContextEdit.apply`` is sync, but writing to the backend is async. We
capture the spill payloads inside ``apply`` and flush them via
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
delegating to the handler, so the explore subagent can always read what
the placeholder advertises.
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Awaitable, Callable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware.context_editing import (
ClearToolUsesEdit,
ContextEdit,
ContextEditingMiddleware,
TokenCounter,
)
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from langgraph.config import get_config
if TYPE_CHECKING:
from deepagents.backends.protocol import BackendProtocol
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
)
logger = logging.getLogger(__name__)
DEFAULT_SPILL_PREFIX = "/tool_outputs"
def _build_spill_placeholder(spill_path: str) -> str:
"""Build the user-facing placeholder text shown to the model."""
return (
f"[cleared — full output at {spill_path}; "
f"ask the explore subagent to read it]"
)
def _get_thread_id_or_session() -> str:
"""Best-effort thread_id discovery for the spill path.
Falls back to a process-stable string if no LangGraph config is
available (e.g. unit tests). The exact value doesn't matter as long
as it's stable within one stream so the placeholder paths line up
with the actual upload path.
"""
try:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is not None:
return str(thread_id)
except RuntimeError:
pass
return "no_thread"
@dataclass(slots=True)
class SpillToBackendEdit(ContextEdit):
"""Capture-and-replace context edit that spills full tool output to the backend.
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
semantics) **and** records the original ``ToolMessage.content`` in
:attr:`pending_spills` so the wrapping middleware can flush them
before the model call.
Args:
trigger: Token threshold above which the edit fires.
clear_at_least: Minimum number of tokens to reclaim (best effort).
keep: Number of most-recent ``ToolMessage`` instances to leave
untouched.
exclude_tools: Names of tools whose output is NOT spilled.
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
args when their pair is cleared.
path_prefix: Path under the backend where spills are written.
Default ``"/tool_outputs"``.
"""
trigger: int = 100_000
clear_at_least: int = 0
keep: int = 3
clear_tool_inputs: bool = False
exclude_tools: Sequence[str] = ()
path_prefix: str = DEFAULT_SPILL_PREFIX
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
_lock: threading.Lock = field(default_factory=threading.Lock)
def drain_pending(self) -> list[tuple[str, bytes]]:
"""Return and clear the pending-spill list atomically."""
with self._lock:
out = list(self.pending_spills)
self.pending_spills.clear()
return out
def apply(
self,
messages: list[AnyMessage],
*,
count_tokens: TokenCounter,
) -> None:
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
tokens = count_tokens(messages)
if tokens <= self.trigger:
return
candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
return
if self.keep:
candidates = candidates[: -self.keep]
thread_id = _get_thread_id_or_session()
excluded_tools = set(self.exclude_tools)
for idx, tool_message in candidates:
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
continue
ai_message = next(
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
None,
)
if ai_message is None:
continue
tool_call = next(
(
call
for call in ai_message.tool_calls
if call.get("id") == tool_message.tool_call_id
),
None,
)
if tool_call is None:
continue
tool_name = tool_message.name or tool_call["name"]
if tool_name in excluded_tools:
continue
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
original = tool_message.content
payload = self._encode_payload(original)
with self._lock:
self.pending_spills.append((spill_path, payload))
messages[idx] = tool_message.model_copy(
update={
"artifact": None,
"content": _build_spill_placeholder(spill_path),
"response_metadata": {
**tool_message.response_metadata,
"context_editing": {
"cleared": True,
"strategy": "spill_to_backend",
"spill_path": spill_path,
},
},
}
)
if self.clear_tool_inputs:
ai_idx = messages.index(ai_message)
messages[ai_idx] = self._clear_input_args(
ai_message, tool_message.tool_call_id or ""
)
if self.clear_at_least > 0:
new_token_count = count_tokens(messages)
cleared_tokens = max(0, tokens - new_token_count)
if cleared_tokens >= self.clear_at_least:
break
@staticmethod
def _encode_payload(content: Any) -> bytes:
"""Serialize ``ToolMessage.content`` to bytes for upload."""
if isinstance(content, bytes):
return content
if isinstance(content, str):
return content.encode("utf-8")
try:
import json
return json.dumps(content, default=str).encode("utf-8")
except Exception:
return str(content).encode("utf-8")
@staticmethod
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
updated_tool_calls: list[dict[str, Any]] = []
cleared_any = False
for tool_call in message.tool_calls:
updated = dict(tool_call)
if updated.get("id") == tool_call_id:
updated["args"] = {}
cleared_any = True
updated_tool_calls.append(updated)
metadata = dict(getattr(message, "response_metadata", {}))
if cleared_any:
ctx = dict(metadata.get("context_editing", {}))
ids = set(ctx.get("cleared_tool_inputs", []))
ids.add(tool_call_id)
ctx["cleared_tool_inputs"] = sorted(ids)
metadata["context_editing"] = ctx
return message.model_copy(
update={
"tool_calls": updated_tool_calls,
"response_metadata": metadata,
}
)
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
Runs the configured edits as the parent does, then flushes any
pending spills via the supplied backend resolver before delegating
to the model handler. Spill failures are logged but never abort the
model call the placeholder text is already in the message, so the
worst case is the agent gets a placeholder it cannot follow up on.
"""
def __init__(
self,
*,
edits: Sequence[ContextEdit],
backend_resolver: BackendResolver | None = None,
token_count_method: str = "approximate",
) -> None:
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
self._backend_resolver = backend_resolver
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
if self._backend_resolver is None:
return None
if callable(self._backend_resolver):
try:
from langchain.tools import ToolRuntime
tool_runtime = ToolRuntime(
state=getattr(request, "state", {}),
context=getattr(request.runtime, "context", None),
stream_writer=getattr(request.runtime, "stream_writer", None),
store=getattr(request.runtime, "store", None),
config=getattr(request.runtime, "config", None) or {},
tool_call_id=None,
)
return self._backend_resolver(tool_runtime)
except Exception:
logger.exception("Failed to resolve spill backend")
return None
return self._backend_resolver # type: ignore[return-value]
def _collect_pending(self) -> list[tuple[str, bytes]]:
out: list[tuple[str, bytes]] = []
for edit in self.edits:
if isinstance(edit, SpillToBackendEdit):
out.extend(edit.drain_pending())
return out
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> Any:
if not request.messages:
return await handler(request)
if self.token_count_method == "approximate":
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = [request.system_message] if request.system_message else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(edited_messages, count_tokens=count_tokens)
pending = self._collect_pending()
if pending:
backend = self._resolve_backend(request)
if backend is not None:
try:
await backend.aupload_files(pending)
except Exception:
logger.exception(
"Spill-to-backend upload failed (%d files); placeholders "
"remain in messages but content is unrecoverable",
len(pending),
)
else:
logger.warning(
"SpillToBackendEdit produced %d pending spills but no backend "
"resolver was configured; content is unrecoverable",
len(pending),
)
return await handler(request.override(messages=edited_messages))
__all__ = [
"DEFAULT_SPILL_PREFIX",
"ClearToolUsesEdit",
"SpillToBackendEdit",
"SpillingContextEditingMiddleware",
"_build_spill_placeholder",
]

View file

@ -2,17 +2,28 @@
When the LLM emits multiple calls to the same HITL tool with the same
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
only the first call is kept. Non-HITL tools are never touched.
only the first call is kept. Non-HITL tools are never touched.
This runs in the ``after_model`` hook **before** any tool executes so
the duplicate call is stripped from the AIMessage that gets checkpointed.
That means it is also safe across LangGraph ``interrupt()`` boundaries:
the removed call will never appear on graph resume.
Dedup-key resolution order (Tier 2.3 / cleanup in the OpenCode-port plan):
1. :class:`ToolDefinition.dedup_key` callable provided by the registry
entry. This is the canonical mechanism after the cleanup-tier removal
of the legacy ``PRIMARY_ARG`` map.
2. ``tool.metadata["hitl_dedup_key"]`` string with a primary arg name;
used by MCP / Composio tools whose schemas the registry doesn't see.
A tool with no resolver from either path simply opts out of dedup.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
@ -20,81 +31,84 @@ from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
# Gmail
"send_gmail_email": "subject",
"create_gmail_draft": "subject",
"update_gmail_draft": "draft_subject_or_id",
"trash_gmail_email": "email_subject_or_id",
# Google Calendar
"create_calendar_event": "title",
"update_calendar_event": "event_title_or_id",
"delete_calendar_event": "event_title_or_id",
# Google Drive
"create_google_drive_file": "file_name",
"delete_google_drive_file": "file_name",
# OneDrive
"create_onedrive_file": "file_name",
"delete_onedrive_file": "file_name",
# Dropbox
"create_dropbox_file": "file_name",
"delete_dropbox_file": "file_name",
# Notion
"create_notion_page": "title",
"update_notion_page": "page_title",
"delete_notion_page": "page_title",
# Linear
"create_linear_issue": "title",
"update_linear_issue": "issue_ref",
"delete_linear_issue": "issue_ref",
# Jira
"create_jira_issue": "summary",
"update_jira_issue": "issue_title_or_key",
"delete_jira_issue": "issue_title_or_key",
# Confluence
"create_confluence_page": "title",
"update_confluence_page": "page_title_or_id",
"delete_confluence_page": "page_title_or_id",
}
# Resolver type — given the tool ``args`` dict returns a stable
# string used to dedupe consecutive calls. ``None`` means no dedup.
DedupResolver = Callable[[dict[str, Any]], str]
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
"""Adapt a string-arg name into a :data:`DedupResolver`.
Convenience helper used by registry entries that just want to dedupe
on a single arg's lowercased value (the most common case for native
HITL tools like ``send_gmail_email`` keyed on ``subject``).
Example::
ToolDefinition(
name="send_gmail_email",
...,
dedup_key=wrap_dedup_key_by_arg_name("subject"),
)
"""
def _resolver(args: dict[str, Any]) -> str:
return str(args.get(arg_name, "")).lower()
return _resolver
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Remove duplicate HITL tool calls from a single LLM response.
Only the **first** occurrence of each (tool-name, primary-arg-value)
Only the **first** occurrence of each ``(tool-name, dedup_key)``
pair is kept; subsequent duplicates are silently dropped.
The dedup map is built from two sources:
The dedup-resolver map is built from two sources, in priority order:
1. A comprehensive list of native HITL tools (hardcoded above).
2. Any ``StructuredTool`` instances passed via *agent_tools* whose
``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``.
This is how MCP tools automatically get dedup support.
1. ``tool.metadata["dedup_key"]`` callable provided by the registry's
``ToolDefinition.dedup_key`` (Tier 2.3). Receives the args dict
and returns a string signature. This is the canonical mechanism
after the cleanup-tier removal of the legacy ``PRIMARY_ARG`` map.
2. ``tool.metadata["hitl_dedup_key"]`` string with a primary arg
name; primarily used by MCP / Composio tools.
"""
tools = ()
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS)
self._resolvers: dict[str, DedupResolver] = {}
for t in agent_tools or []:
meta = getattr(t, "metadata", None) or {}
callable_key = meta.get("dedup_key")
if callable(callable_key):
self._resolvers[t.name] = callable_key
continue
if meta.get("hitl") and meta.get("hitl_dedup_key"):
self._dedup_keys[t.name] = meta["hitl_dedup_key"]
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
meta["hitl_dedup_key"]
)
def after_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._dedup_keys)
return self._dedup(state, self._resolvers)
async def aafter_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._dedup_keys)
return self._dedup(state, self._resolvers)
@staticmethod
def _dedup(
state: AgentState,
dedup_keys: dict[str, str], # type: ignore[type-arg]
resolvers: dict[str, DedupResolver],
) -> dict[str, Any] | None:
messages = state.get("messages")
if not messages:
@ -110,9 +124,16 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
for tc in tool_calls:
name = tc.get("name", "")
dedup_key_arg = dedup_keys.get(name)
if dedup_key_arg is not None:
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
resolver = resolvers.get(name)
if resolver is not None:
try:
arg_val = resolver(tc.get("args", {}) or {})
except Exception:
logger.exception(
"Dedup resolver for tool %s raised; keeping call", name
)
deduped.append(tc)
continue
key = (name, arg_val)
if key in seen:
logger.info(

View file

@ -0,0 +1,228 @@
"""
DoomLoopMiddleware pattern-based detector for repeated identical tool calls.
Mirrors ``opencode/packages/opencode/src/session/processor.ts`` doom-loop
behavior. When the same tool with the same arguments is called N times
in a row, the agent has likely entered an infinite loop. We surface this
to the user as an interrupt with ``permission="doom_loop"`` so the UI
can render an "Are you stuck? Continue / cancel?" affordance.
Tier 1.11 in the OpenCode-port plan.
This ships **OFF by default** until the frontend explicitly handles
``context.permission == "doom_loop"`` interrupts (the plan flips
``SURFSENSE_ENABLE_DOOM_LOOP=true`` once the UI is ready).
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
(see ``app/agents/new_chat/tools/hitl.py``):
{
"type": "permission_ask",
"action": {"tool": <name>, "params": <args>},
"context": {"permission": "doom_loop", "recent_signatures": [...]},
}
so the frontend that already handles HITL prompts can render this with
no changes beyond a string check.
"""
from __future__ import annotations
import hashlib
import json
import logging
from collections import deque
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from app.observability import otel as ot
logger = logging.getLogger(__name__)
def _signature(name: str, args: Any) -> str:
"""Hash a tool call ``(name, args)`` to a short signature."""
try:
canonical = json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
canonical = repr(args)
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
return digest[:16]
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Detect repeated identical tool calls and prompt the user.
Tracks a sliding window of the most-recent ``threshold`` tool-call
signatures across the live request. When all entries match, raise
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
Args:
threshold: How many consecutive identical signatures count as a
doom loop. Default 3 (opencode parity).
"""
def __init__(self, *, threshold: int = 3) -> None:
super().__init__()
if threshold < 2:
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
self._threshold = threshold
self.tools = []
# Per-thread sliding windows. We can't put this in graph state
# without state-schema gymnastics; for one process-lifetime it's
# fine to keep an in-memory map keyed by thread_id.
self._windows: dict[str, deque[str]] = {}
@staticmethod
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
"""Resolve the thread id for sliding-window keying.
Prefer LangGraph's ``get_config()`` (the only way to read
``RunnableConfig`` inside a node :class:`Runtime` does NOT carry
a ``config`` attribute). Fall back to ``runtime.config`` for unit
tests that synthesize a config-bearing stub. Default
``"no_thread"`` is intentionally only used when both lookups fail
it would collapse all threads into one window so we keep the
debug log loud.
"""
def _from_dict(cfg: Any) -> str | None:
if not isinstance(cfg, dict):
return None
tid = (cfg.get("configurable") or {}).get("thread_id")
return str(tid) if tid is not None else None
try:
tid = _from_dict(get_config())
except Exception:
tid = None
if tid is not None:
return tid
tid = _from_dict(getattr(runtime, "config", None))
if tid is not None:
return tid
logger.debug(
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
"falling back to shared 'no_thread' window."
)
return "no_thread"
def _window(self, thread_id: str) -> deque[str]:
win = self._windows.get(thread_id)
if win is None:
win = deque(maxlen=self._threshold)
self._windows[thread_id] = win
return win
def _detect(
self, message: AIMessage, runtime: Runtime[ContextT]
) -> tuple[bool, list[str], dict[str, Any] | None]:
if not message.tool_calls:
return False, [], None
thread_id = self._thread_id_from_runtime(runtime)
window = self._window(thread_id)
triggered_call: dict[str, Any] | None = None
for call in message.tool_calls:
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None)
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {})
if not isinstance(name, str):
continue
sig = _signature(name, args)
window.append(sig)
if (
len(window) >= self._threshold
and len(set(window)) == 1
):
triggered_call = {"name": name, "params": args or {}}
break
if triggered_call is None:
return False, list(window), None
return True, list(window), triggered_call
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
triggered, signatures, action = self._detect(last, runtime)
if not triggered:
return None
logger.warning(
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
action["name"] if action else "<unknown>",
self._threshold,
signatures[-1] if signatures else "<empty>",
)
# Tier 3b: interrupt.raised span with permission=doom_loop attribute
# so dashboards can break out doom-loop interrupts from regular
# permission asks via the ``interrupt.permission`` attribute.
with ot.interrupt_span(
interrupt_type="permission_ask",
extra={
"interrupt.permission": "doom_loop",
"interrupt.threshold": self._threshold,
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
},
):
decision = interrupt(
{
"type": "permission_ask",
"action": action or {"tool": "<unknown>", "params": {}},
"context": {
"permission": "doom_loop",
"recent_signatures": signatures,
"threshold": self._threshold,
},
}
)
# Reset window so the next decision (continue/cancel) starts fresh.
thread_id = self._thread_id_from_runtime(runtime)
self._windows.pop(thread_id, None)
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
# If the user cancelled, jump to end. Otherwise return ``None`` so the
# tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict):
kind = str(decision.get("decision_type") or decision.get("type") or "").lower()
if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"}
return None
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"DoomLoopMiddleware",
"_signature",
]

View file

@ -31,14 +31,17 @@ from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Any
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
from langgraph.runtime import Runtime
from litellm import token_counter
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.path_resolver import (
@ -589,6 +592,53 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.available_document_types = available_document_types
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
# don't pay the create_agent compile cost (50200ms) on every turn.
# Disabled by default behind ``enable_kb_planner_runnable``; when off
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
self._planner: Runnable | None = None
self._planner_compile_failed = False
def _build_kb_planner_runnable(self) -> Runnable | None:
"""Compile the kb-planner private :class:`Runnable` once.
Returns ``None`` when the feature flag is disabled, when the LLM is
unavailable, or when ``create_agent`` raises (we fall back to the
legacy ``self.llm.ainvoke`` path in that case). Compilation happens
lazily on first call, then memoized via ``self._planner``.
The compiled agent is constructed without tools the planner's
contract is "answer with structured JSON" but with ``RetryAfter``
+ the OpenCode-port retry/limit middleware so it shares the parent
agent's resilience guarantees.
"""
if self._planner is not None or self._planner_compile_failed:
return self._planner
if self.llm is None:
return None
flags = get_flags()
if (
not flags.enable_kb_planner_runnable
or flags.disable_new_agent_stack
):
return None
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
try:
self._planner = create_agent(
self.llm,
tools=[],
middleware=[RetryAfterMiddleware(max_retries=2)],
)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s",
exc,
)
self._planner_compile_failed = True
self._planner = None
return self._planner
async def _plan_search_inputs(
self,
@ -611,11 +661,32 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
loop = asyncio.get_running_loop()
t0 = loop.time()
# Tier 4.2: prefer the compiled-once planner Runnable when enabled;
# otherwise fall back to ``self.llm.ainvoke``. The ``surfsense:internal``
# tag is preserved on both paths so ``_stream_agent_events`` still
# suppresses the planner's intermediate events from the UI.
planner = self._build_kb_planner_runnable()
try:
response = await self.llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
if planner is not None:
planner_state = await planner.ainvoke(
{"messages": [HumanMessage(content=prompt)]},
config={"tags": ["surfsense:internal"]},
)
response_messages = (
planner_state.get("messages", [])
if isinstance(planner_state, dict)
else []
)
response = (
response_messages[-1]
if response_messages
else AIMessage(content="")
)
else:
response = await self.llm.ainvoke(
[HumanMessage(content=prompt)],
config={"tags": ["surfsense:internal"]},
)
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
optimized_query = (
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text

View file

@ -0,0 +1,133 @@
"""
``_noop`` provider-compatibility tool + injection middleware.
OpenCode injects a ``_noop`` tool for LiteLLM/Bedrock/Copilot when the
model call has empty tools but message history includes prior
``tool_calls`` some providers 400 in that shape (see
``opencode/packages/opencode/src/session/llm.ts:209-228``). SurfSense uses
LiteLLM, and the compaction summarize call (no tools, history full of
tool calls) hits this. Tier 1.5 in the OpenCode-port plan.
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
if the request has zero tools but the last AI message in history includes
``tool_calls``. If yes, it injects the ``_noop`` tool only never globally,
mirroring opencode's gating exactly. The :func:`noop_tool` returns empty
content when called (which it should never be in practice).
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import AIMessage
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
NOOP_TOOL_NAME = "_noop"
NOOP_TOOL_DESCRIPTION = (
"Do not call this tool. It exists only for API compatibility."
)
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
def noop_tool() -> str:
"""Return empty content. Never expected to be called."""
return ""
# Provider markers that benefit from ``_noop`` injection. These match
# opencode's gating list. We also accept any string containing one of
# these substrings (so e.g. ``litellm`` matches ``ChatLiteLLM``).
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
"litellm",
"bedrock",
"copilot",
)
def _provider_needs_noop(model: Any) -> bool:
"""Heuristic: does this model's provider need the _noop injection?"""
try:
ls_params = model._get_ls_params()
provider = str(ls_params.get("ls_provider", "")).lower()
except Exception:
provider = ""
if not provider:
cls_name = type(model).__name__.lower()
provider = cls_name
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return bool(msg.tool_calls)
return False
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
The check fires per model call, not at agent build time, because the
summarization path generates a no-tool subcall at runtime. The
extra tool is appended to ``request.tools`` as an instance the
actual ``langchain_core.tools.BaseTool`` is bound on every call site
that creates the agent.
"""
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
super().__init__()
self._noop_tool = noop_tool_instance or noop_tool
self.tools = []
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
if request.tools:
return False
if not _last_ai_has_tool_calls(request.messages):
return False
return _provider_needs_noop(request.model)
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
return request.override(tools=[self._noop_tool])
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")
return handler(self._augmented(request))
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")
return await handler(self._augmented(request))
return await handler(request)
__all__ = [
"NOOP_TOOL_DESCRIPTION",
"NOOP_TOOL_NAME",
"NoopInjectionMiddleware",
"_provider_needs_noop",
"noop_tool",
]

View file

@ -0,0 +1,202 @@
"""
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
executions) with OTel spans, attaching low-cardinality span names and
high-cardinality identifiers as attributes (per the Tier 3b plan).
This middleware is intentionally a thin adapter over
:mod:`app.observability.otel`; when OTel is not configured all spans
collapse to no-ops and the wrapper adds <1µs overhead per call. When
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
model and tool call gets a span with the standard attributes the
plan's dashboards expect.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, ToolMessage
from app.observability import otel as ot
if TYPE_CHECKING: # pragma: no cover — type-only
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
ToolCallRequest,
)
from langgraph.types import Command
logger = logging.getLogger(__name__)
class OtelSpanMiddleware(AgentMiddleware):
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
Should be placed near the **outer** end of the middleware list so
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
model.call spans for ``N`` retry attempts) but inside any concurrency/
auth gate. Empirically this means **between** ``BusyMutex`` and
``RetryAfter``.
"""
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
super().__init__()
self._instrumentation_name = instrumentation_name
# ------------------------------------------------------------------
# Model call spans
# ------------------------------------------------------------------
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
],
) -> ModelResponse | AIMessage | Any:
if not ot.is_enabled():
return await handler(request)
model_id, provider = _resolve_model_attrs(request)
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
try:
result = await handler(request)
except Exception:
# span context manager records + re-raises
raise
else:
_annotate_model_response(sp, result)
return result
# ------------------------------------------------------------------
# Tool call spans
# ------------------------------------------------------------------
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
],
) -> ToolMessage | Command[Any]:
if not ot.is_enabled():
return await handler(request)
tool_name = _resolve_tool_name(request)
input_size = _resolve_input_size(request)
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
result = await handler(request)
_annotate_tool_result(sp, result)
return result
# ---------------------------------------------------------------------------
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
# a real model/tool call).
# ---------------------------------------------------------------------------
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
model_id: str | None = None
provider: str | None = None
try:
model = getattr(request, "model", None)
if model is None:
return None, None
# langchain BaseChatModel exposes a few different identifiers
for attr in ("model_name", "model", "model_id"):
value = getattr(model, attr, None)
if value:
model_id = str(value)
break
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
for attr in ("provider", "_llm_type"):
value = getattr(model, attr, None)
if value:
provider = str(value)
break
except Exception: # pragma: no cover — defensive
pass
return model_id, provider
def _resolve_tool_name(request: Any) -> str:
try:
tool = getattr(request, "tool", None)
if tool is not None:
name = getattr(tool, "name", None)
if isinstance(name, str) and name:
return name
# Fall back to the tool_call dict
call = getattr(request, "tool_call", None) or {}
name = call.get("name") if isinstance(call, dict) else None
if isinstance(name, str) and name:
return name
except Exception: # pragma: no cover — defensive
pass
return "unknown"
def _resolve_input_size(request: Any) -> int | None:
try:
call = getattr(request, "tool_call", None)
if not isinstance(call, dict) or not call:
return None
args = call.get("args")
if args is None:
return None
return len(repr(args))
except Exception: # pragma: no cover — defensive
return None
def _annotate_model_response(span: Any, result: Any) -> None:
"""Best-effort: attach prompt/completion token counts when available."""
try:
# ModelResponse may be a dataclass with .result containing AIMessage
msg: Any
if isinstance(result, AIMessage):
msg = result
else:
inner = getattr(result, "result", None)
msg = inner[-1] if isinstance(inner, list) and inner else inner
if msg is None:
return
usage = getattr(msg, "usage_metadata", None) or {}
if isinstance(usage, dict):
if (n := usage.get("input_tokens")) is not None:
span.set_attribute("tokens.prompt", int(n))
if (n := usage.get("output_tokens")) is not None:
span.set_attribute("tokens.completion", int(n))
if (n := usage.get("total_tokens")) is not None:
span.set_attribute("tokens.total", int(n))
tool_calls = getattr(msg, "tool_calls", None) or []
span.set_attribute("model.tool_calls", len(tool_calls))
except Exception: # pragma: no cover — defensive
pass
def _annotate_tool_result(span: Any, result: Any) -> None:
try:
if isinstance(result, ToolMessage):
content = result.content if isinstance(result.content, str) else repr(result.content)
span.set_attribute("tool.output.size", len(content))
status = getattr(result, "status", None)
if isinstance(status, str):
span.set_attribute("tool.status", status)
kwargs = getattr(result, "additional_kwargs", None) or {}
if isinstance(kwargs, dict) and kwargs.get("error"):
span.set_attribute("tool.error", True)
except Exception: # pragma: no cover — defensive
pass
__all__ = ["OtelSpanMiddleware"]

View file

@ -0,0 +1,344 @@
"""
PermissionMiddleware pattern-based allow/deny/ask with HITL fallback.
Mirrors ``opencode/packages/opencode/src/permission/index.ts`` but uses
SurfSense's existing ``interrupt({type, action, context})`` payload shape
(see ``app/agents/new_chat/tools/hitl.py``) so the frontend keeps
working unchanged. Tier 2.1 in the OpenCode-port plan.
Operation:
1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``.
2. For each call, the middleware builds a list of ``patterns`` (the
tool name plus any tool-specific patterns from the resolver). It
evaluates each pattern against the layered rulesets and aggregates
the results: ``deny`` > ``ask`` > ``allow``.
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
containing a :class:`StreamingError`.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
- ``once``: proceed.
- ``always``: also persist allow rules for ``request.always`` patterns.
- ``reject`` w/o feedback: raise :class:`RejectedError`.
- ``reject`` w/ feedback: raise :class:`CorrectedError`.
5. On ``allow``: proceed unchanged.
The middleware also performs a *pre-model* tool-filter step (the
``before_model`` hook) so globally denied tools are stripped from the
exposed tool list before the model gets to see them. This is
opencode's ``Permission.disabled`` equivalent and dramatically reduces
the chance the model emits a deny-only call.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
)
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from app.agents.new_chat.errors import (
CorrectedError,
RejectedError,
StreamingError,
)
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate_many,
)
from app.observability import otel as ot
logger = logging.getLogger(__name__)
# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of
# patterns to evaluate. The first pattern is conventionally the bare
# tool name; later entries narrow down to specific resources.
PatternResolver = Callable[[dict[str, Any]], list[str]]
def _default_pattern_resolver(name: str) -> PatternResolver:
def _resolve(args: dict[str, Any]) -> list[str]:
# Bare name covers the default catch-all; primary-arg fallbacks
# are best added per-tool by callers.
del args
return [name]
return _resolve
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
Args:
rulesets: Layered rulesets to evaluate. Earlier entries are
overridden by later ones (last-match-wins). Typical layering:
``defaults < global < space < thread < runtime_approved``.
pattern_resolvers: Optional per-tool callables that return a list
of patterns to evaluate. When a tool isn't listed, the bare
tool name is used as the only pattern.
runtime_ruleset: Mutable :class:`Ruleset` that the middleware
extends in-place when the user replies ``"always"`` to an
ask interrupt. Reused across all calls in the same agent
instance so newly-allowed rules apply to subsequent calls.
always_emit_interrupt_payload: If True, every ask uses the
SurfSense interrupt wire format (default). Set False to
disable interrupts and treat ``ask`` as ``deny`` for
non-interactive deployments.
"""
tools = ()
def __init__(
self,
*,
rulesets: list[Ruleset] | None = None,
pattern_resolvers: dict[str, PatternResolver] | None = None,
runtime_ruleset: Ruleset | None = None,
always_emit_interrupt_payload: bool = True,
) -> None:
super().__init__()
self._static_rulesets: list[Ruleset] = list(rulesets or [])
self._pattern_resolvers: dict[str, PatternResolver] = dict(
pattern_resolvers or {}
)
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
origin="runtime_approved"
)
self._emit_interrupt = always_emit_interrupt_payload
# ------------------------------------------------------------------
# Tool-filter step (opencode `Permission.disabled` equivalent)
# ------------------------------------------------------------------
def _globally_denied(self, tool_name: str) -> bool:
"""Return True if a deny rule with no narrowing pattern matches."""
rules = evaluate_many(tool_name, ["*"], *self._all_rulesets())
return aggregate_action(rules) == "deny"
def _all_rulesets(self) -> list[Ruleset]:
return [*self._static_rulesets, self._runtime_ruleset]
# NOTE: ``before_model`` filtering of the tools list is left to the
# agent factory. This middleware only blocks at execution time — and
# only via the rule-evaluator path, not by mutating ``request.tools``.
# Mutating ``request.tools`` per-call would invalidate provider
# prompt-cache prefixes (see Operational risks: prompt-cache regression).
# ------------------------------------------------------------------
# Tool-call evaluation
# ------------------------------------------------------------------
def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]:
resolver = self._pattern_resolvers.get(
tool_name, _default_pattern_resolver(tool_name)
)
try:
patterns = resolver(args or {})
except Exception:
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
patterns = [tool_name]
if not patterns:
patterns = [tool_name]
return patterns
def _evaluate(
self, tool_name: str, args: dict[str, Any]
) -> tuple[str, list[str], list[Rule]]:
patterns = self._resolve_patterns(tool_name, args)
rules = evaluate_many(tool_name, patterns, *self._all_rulesets())
action = aggregate_action(rules)
return action, patterns, rules
# ------------------------------------------------------------------
# HITL ask flow — SurfSense wire format
# ------------------------------------------------------------------
def _raise_interrupt(
self,
*,
tool_name: str,
args: dict[str, Any],
patterns: list[str],
rules: list[Rule],
) -> dict[str, Any]:
"""Block on user approval via SurfSense's ``interrupt`` shape."""
if not self._emit_interrupt:
return {"decision_type": "reject"}
# ``params`` (NOT ``args``) is what SurfSense's streaming
# normalizer forwards. Other fields move into ``context``.
payload = {
"type": "permission_ask",
"action": {"tool": tool_name, "params": args or {}},
"context": {
"patterns": patterns,
"rules": [
{
"permission": r.permission,
"pattern": r.pattern,
"action": r.action,
}
for r in rules
],
# Rules of thumb for the frontend: surface the patterns
# the user can promote to "always" with a single reply.
"always": patterns,
},
}
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
# OTel is disabled). Both fire here so dashboards can correlate
# "we asked X" with "interrupt was actually delivered".
with ot.permission_asked_span(
permission=tool_name,
pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)},
), ot.interrupt_span(interrupt_type="permission_ask"):
decision = interrupt(payload)
if isinstance(decision, dict):
return decision
# Tolerate a plain string reply ("once", "always", "reject")
if isinstance(decision, str):
return {"decision_type": decision}
return {"decision_type": "reject"}
def _persist_always(
self, tool_name: str, patterns: list[str]
) -> None:
"""Promote ``always`` reply into runtime allow rules.
Persistence to ``agent_permission_rules`` is done by the
streaming layer (``stream_new_chat``) once it observes the
``always`` reply the middleware just keeps an in-memory
copy so subsequent calls in the same stream see the rule.
"""
for pattern in patterns:
self._runtime_ruleset.rules.append(
Rule(permission=tool_name, pattern=pattern, action="allow")
)
# ------------------------------------------------------------------
# Synthesizing deny -> ToolMessage
# ------------------------------------------------------------------
@staticmethod
def _deny_message(
tool_call: dict[str, Any],
rule: Rule,
) -> ToolMessage:
err = StreamingError(
code="permission_denied",
retryable=False,
suggestion=(
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
f"blocked this call"
),
)
return ToolMessage(
content=(
f"Permission denied: rule {rule.permission}/{rule.pattern} "
f"blocked tool {tool_call.get('name')!r}."
),
tool_call_id=tool_call.get("id") or "",
name=tool_call.get("name"),
status="error",
additional_kwargs={"error": err.model_dump()},
)
# ------------------------------------------------------------------
# The hook: aafter_model
# ------------------------------------------------------------------
def _process(
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime # unused
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage) or not last.tool_calls:
return None
deny_messages: list[ToolMessage] = []
kept_calls: list[dict[str, Any]] = []
any_change = False
for raw in last.tool_calls:
call = dict(raw) if isinstance(raw, dict) else {
"name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None),
"type": "tool_call",
}
name = call.get("name") or ""
args = call.get("args") or {}
action, patterns, rules = self._evaluate(name, args)
if action == "deny":
# Find the deny rule for the suggestion text
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
deny_messages.append(self._deny_message(call, deny_rule))
any_change = True
continue
if action == "ask":
decision = self._raise_interrupt(
tool_name=name, args=args, patterns=patterns, rules=rules
)
kind = str(decision.get("decision_type") or "reject").lower()
if kind == "once":
kept_calls.append(call)
elif kind == "always":
self._persist_always(name, patterns)
kept_calls.append(call)
elif kind == "reject":
feedback = decision.get("feedback")
if isinstance(feedback, str) and feedback.strip():
raise CorrectedError(feedback, tool=name)
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
else:
logger.warning(
"Unknown permission decision %r; treating as reject", kind
)
raise RejectedError(tool=name)
continue
# allow
kept_calls.append(call)
if not any_change and len(kept_calls) == len(last.tool_calls):
return None
updated = last.model_copy(update={"tool_calls": kept_calls})
result_messages: list[Any] = [updated]
if deny_messages:
result_messages.extend(deny_messages)
return {"messages": result_messages}
def after_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
async def aafter_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
__all__ = [
"PatternResolver",
"PermissionMiddleware",
]

View file

@ -0,0 +1,245 @@
"""
RetryAfterMiddleware Header-aware retry with custom backoff and SSE eventing.
Why standalone instead of subclassing ``ModelRetryMiddleware``: the upstream
class calls module-level ``calculate_delay`` inline (no overridable
``_calculate_delay`` hook), so a subclass cannot inject Retry-After header
delays without rewriting the loop. Tier 1.4 in the OpenCode-port plan.
Behaviour:
- Extracts ``Retry-After`` / ``retry-after-ms`` from
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
exposing a similar shape).
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
the LangChain summarization fallback path) handles those instead.
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
so ``stream_new_chat`` can forward it to clients as an SSE event.
"""
from __future__ import annotations
import asyncio
import logging
import random
import re
import time
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
from langchain_core.messages import AIMessage
logger = logging.getLogger(__name__)
# Names of exception classes for which a retry would not help — context
# overflow needs compaction, auth needs human intervention, etc. Detected
# by class-name substring so we don't have to import LiteLLM/Anthropic
# here (which would tie this module to optional deps).
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
"ContextWindowExceeded",
"ContextOverflow",
"AuthenticationError",
"InvalidRequestError",
"PermissionDenied",
"InvalidApiKey",
"ContextLimit",
)
def _is_non_retryable(exc: BaseException) -> bool:
name = type(exc).__name__
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
"""Return seconds-to-wait suggested by the provider, if any.
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
to a regex on the exception message for shapes like
``"Please retry after 30s"``.
"""
headers: dict[str, Any] | None = None
response = getattr(exc, "response", None)
if response is not None:
headers = getattr(response, "headers", None)
if headers is None:
headers = getattr(exc, "headers", None)
if isinstance(headers, dict):
# Normalize keys to lowercase for case-insensitive matching
norm = {str(k).lower(): v for k, v in headers.items()}
ms = norm.get("retry-after-ms")
if ms is not None:
try:
return float(ms) / 1000.0
except (TypeError, ValueError):
pass
seconds = norm.get("retry-after")
if seconds is not None:
try:
return float(seconds)
except (TypeError, ValueError):
pass
# Last resort: scan the message for "retry after Xs" or "X seconds"
msg = str(exc)
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
if match:
try:
return float(match.group(1))
except ValueError:
return None
return None
def _exponential_delay(
attempt: int,
*,
initial_delay: float,
backoff_factor: float,
max_delay: float,
jitter: bool,
) -> float:
"""Compute an exponential-backoff delay with optional ±25% jitter."""
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
delay = min(delay, max_delay)
if jitter and delay > 0:
delay *= 1 + random.uniform(-0.25, 0.25)
return max(delay, 0.0)
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Retry middleware that honors provider-issued Retry-After hints.
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
when working with LiteLLM/Anthropic/OpenAI providers that surface
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
events so the UI can show a friendly "rate limited, retrying in Xs"
indicator.
Args:
max_retries: Maximum retries after the initial attempt (default 3).
initial_delay: Initial backoff delay in seconds.
backoff_factor: Exponential growth factor for backoff.
max_delay: Cap on per-attempt delay in seconds.
jitter: Whether to add ±25% jitter.
retry_on: Optional callable that returns True for retryable
exceptions. The default retries everything except known
non-retryable classes (context overflow, auth, etc.).
"""
def __init__(
self,
*,
max_retries: int = 3,
initial_delay: float = 1.0,
backoff_factor: float = 2.0,
max_delay: float = 60.0,
jitter: bool = True,
retry_on: Callable[[BaseException], bool] | None = None,
) -> None:
super().__init__()
self.max_retries = max_retries
self.initial_delay = initial_delay
self.backoff_factor = backoff_factor
self.max_delay = max_delay
self.jitter = jitter
self._retry_on: Callable[[BaseException], bool] = retry_on or (
lambda exc: not _is_non_retryable(exc)
)
def _should_retry(self, exc: BaseException) -> bool:
try:
return bool(self._retry_on(exc))
except Exception:
logger.exception("retry_on callable raised; defaulting to False")
return False
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
backoff = _exponential_delay(
attempt,
initial_delay=self.initial_delay,
backoff_factor=self.backoff_factor,
max_delay=self.max_delay,
jitter=self.jitter,
)
header = _extract_retry_after_seconds(exc) or 0.0
return max(backoff, header)
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except Exception as exc:
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
try:
dispatch_custom_event(
"surfsense.retrying",
{
"attempt": attempt + 1,
"max_retries": self.max_retries,
"delay_ms": int(delay * 1000),
"reason": type(exc).__name__,
},
)
except Exception:
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True)
if delay > 0:
time.sleep(delay)
# Unreachable
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
for attempt in range(self.max_retries + 1):
try:
return await handler(request)
except Exception as exc:
if not self._should_retry(exc) or attempt >= self.max_retries:
raise
delay = self._delay_for_attempt(attempt, exc)
try:
await adispatch_custom_event(
"surfsense.retrying",
{
"attempt": attempt + 1,
"max_retries": self.max_retries,
"delay_ms": int(delay * 1000),
"reason": type(exc).__name__,
},
)
except Exception:
logger.debug(
"adispatch_custom_event failed; suppressed", exc_info=True
)
if delay > 0:
await asyncio.sleep(delay)
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
__all__ = [
"RetryAfterMiddleware",
"_extract_retry_after_seconds",
"_is_non_retryable",
]

View file

@ -1,123 +0,0 @@
"""Safe wrapper around deepagents' SummarizationMiddleware.
Upstream issue
--------------
`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend`
(and its sync counterpart) call
``get_buffer_string(filtered_messages)`` before writing the evicted history
to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string``
accesses ``m.text`` which iterates ``self.content`` this raises
``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage``
has ``content=None`` (common when a model returns *only* tool_calls, seen
frequently with Azure OpenAI ``gpt-5.x`` responses streamed through
LiteLLM).
The exception aborts the whole agent turn, so the user just sees "Error during
chat" with no assistant response.
Fix
---
We subclass ``SummarizationMiddleware`` and override
``_filter_summary_messages`` the only call site that feeds messages into
``get_buffer_string`` to return *copies* of messages whose ``content`` is
``None`` with ``content=""``. The originals flowing through the rest of the
agent state are untouched.
We also expose a drop-in ``create_safe_summarization_middleware`` factory
that mirrors ``deepagents.middleware.summarization.create_summarization_middleware``
but instantiates our safe subclass.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from deepagents.middleware.summarization import (
SummarizationMiddleware,
compute_summarization_defaults,
)
if TYPE_CHECKING:
from deepagents.backends.protocol import BACKEND_TYPES
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage
logger = logging.getLogger(__name__)
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
"""Return ``msg`` with ``content`` coerced to a non-``None`` value.
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``;
when a provider streams back an ``AIMessage`` with only tool_calls and
no text, ``content`` can be ``None`` and the iteration explodes. We
replace ``None`` with an empty string so downstream consumers that only
care about text see an empty body.
The original message is left untouched we return a copy via
pydantic's ``model_copy`` when available, otherwise we fall back to
re-setting the attribute on a shallow copy.
"""
if getattr(msg, "content", "not-missing") is not None:
return msg
try:
return msg.model_copy(update={"content": ""})
except AttributeError:
import copy
new_msg = copy.copy(msg)
try:
new_msg.content = ""
except Exception: # pragma: no cover - defensive
logger.debug(
"Could not sanitize content=None on message of type %s",
type(msg).__name__,
)
return msg
return new_msg
class SafeSummarizationMiddleware(SummarizationMiddleware):
"""`SummarizationMiddleware` that tolerates messages with ``content=None``.
Only ``_filter_summary_messages`` is overridden this is the single
helper invoked by both the sync and async offload paths immediately
before ``get_buffer_string``. Normalising here means we get coverage
for both without having to copy the (long, rapidly-changing) offload
implementations from upstream.
"""
def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
filtered = super()._filter_summary_messages(messages)
return [_sanitize_message_content(m) for m in filtered]
def create_safe_summarization_middleware(
model: BaseChatModel,
backend: BACKEND_TYPES,
) -> SafeSummarizationMiddleware:
"""Drop-in replacement for ``create_summarization_middleware``.
Mirrors the defaults computed by ``deepagents`` but returns our
``SafeSummarizationMiddleware`` subclass so the
``content=None`` crash in ``get_buffer_string`` is avoided.
"""
defaults = compute_summarization_defaults(model)
return SafeSummarizationMiddleware(
model=model,
backend=backend,
trigger=defaults["trigger"],
keep=defaults["keep"],
trim_tokens_to_summarize=None,
truncate_args_settings=defaults["truncate_args_settings"],
)
__all__ = [
"SafeSummarizationMiddleware",
"create_safe_summarization_middleware",
]

View file

@ -0,0 +1,332 @@
"""Skills backends for SurfSense.
Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol`
subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`.
The middleware only needs four methods to load skills from a backend:
* ``ls_info`` / ``als_info`` list directories under a source path.
* ``download_files`` / ``adownload_files`` fetch ``SKILL.md`` bytes.
Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` )
default to ``NotImplementedError`` from the base class. They are never reached
by the skills middleware because skill content is rendered into the system
prompt at agent build time, not edited at runtime.
Two backends are provided:
* :class:`BuiltinSkillsBackend` disk-backed read of bundled skills from
``app/agents/new_chat/skills/builtin/``.
* :class:`SearchSpaceSkillsBackend` a thin read-only wrapper over
:class:`KBPostgresBackend` that filters notes under the privileged folder
``/documents/_skills/``.
Both backends are intentionally read-only: skill authoring happens out of band
(via filesystem or a search-space-admin route), so we never expose
``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError``
gives a clean failure mode if anything tries.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import replace
from pathlib import Path
from typing import TYPE_CHECKING
from deepagents.backends.composite import CompositeBackend
from deepagents.backends.protocol import (
BackendProtocol,
FileDownloadResponse,
FileInfo,
)
from deepagents.backends.state import StateBackend
if TYPE_CHECKING:
from langchain.tools import ToolRuntime
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
logger = logging.getLogger(__name__)
# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE.
_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
def _default_builtin_root() -> Path:
"""Return the absolute path to the bundled builtin skills directory.
Located at ``app/agents/new_chat/skills/builtin/`` relative to this module.
"""
return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve()
class BuiltinSkillsBackend(BackendProtocol):
"""Read-only disk-backed skills source.
Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk,
where each skill is its own subdirectory containing a ``SKILL.md`` file::
<root>/<skill-name>/SKILL.md
The middleware calls :meth:`als_info` with the source path and expects a
``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it
calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and
parses YAML frontmatter from the returned ``content`` bytes.
Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at
prefix ``/skills/builtin/`` means the middleware can issue paths like
``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to
``/kb-research/SKILL.md`` before forwarding here. We treat any leading
slash as anchoring at :attr:`root`.
"""
def __init__(self, root: Path | str | None = None) -> None:
self.root: Path = Path(root).resolve() if root else _default_builtin_root()
if not self.root.exists():
logger.info(
"BuiltinSkillsBackend root %s does not exist; skills will be empty.",
self.root,
)
def _resolve(self, path: str) -> Path:
"""Resolve a virtual posix path under :attr:`root`, refusing escapes."""
bare = path.lstrip("/")
candidate = (self.root / bare).resolve() if bare else self.root
# Refuse symlink/.. traversal that escapes the root.
try:
candidate.relative_to(self.root)
except ValueError as exc:
raise ValueError(f"path {path!r} escapes builtin skills root") from exc
return candidate
def ls_info(self, path: str) -> list[FileInfo]:
try:
target = self._resolve(path)
except ValueError as exc:
logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc)
return []
if not target.exists() or not target.is_dir():
return []
infos: list[FileInfo] = []
# Build virtual paths anchored at "/" because CompositeBackend already
# stripped the route prefix before calling us.
target_virtual = "/" if target == self.root else (
"/" + str(target.relative_to(self.root)).replace("\\", "/")
)
for child in sorted(target.iterdir()):
child_virtual = (
target_virtual.rstrip("/") + "/" + child.name
if target_virtual != "/"
else "/" + child.name
)
info: FileInfo = {
"path": child_virtual,
"is_dir": child.is_dir(),
}
if child.is_file():
try:
info["size"] = child.stat().st_size
except OSError: # pragma: no cover - defensive
pass
infos.append(info)
return infos
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
responses: list[FileDownloadResponse] = []
for p in paths:
try:
target = self._resolve(p)
except ValueError:
responses.append(FileDownloadResponse(path=p, error="invalid_path"))
continue
if not target.exists():
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
continue
if target.is_dir():
responses.append(FileDownloadResponse(path=p, error="is_directory"))
continue
try:
# Hard cap to avoid loading rogue mega-files into memory.
size = target.stat().st_size
if size > _MAX_SKILL_FILE_SIZE:
logger.warning(
"Builtin skill file %s exceeds %d bytes; truncating.",
target,
_MAX_SKILL_FILE_SIZE,
)
with target.open("rb") as fh:
content = fh.read(_MAX_SKILL_FILE_SIZE)
else:
content = target.read_bytes()
except PermissionError:
responses.append(FileDownloadResponse(path=p, error="permission_denied"))
continue
except OSError as exc: # pragma: no cover - defensive
logger.warning("Builtin skill read failed %s: %s", target, exc)
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
continue
responses.append(FileDownloadResponse(path=p, content=content, error=None))
return responses
class SearchSpaceSkillsBackend(BackendProtocol):
"""Read-only view of search-space-authored skills.
Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged
folder ``/documents/_skills/`` (configurable). The folder is intended to be
writable only by search-space admins; this backend never writes.
The skills middleware expects a layout like::
/<source_root>/<skill-name>/SKILL.md
But the KB stores documents like ``/documents/_skills/<name>/SKILL.md``.
We expose the inner namespace by remapping each path. When mounted under
:class:`CompositeBackend` at prefix ``/skills/space/`` the paths the
middleware sees become ``/skills/space/<name>/SKILL.md``; the composite
strips ``/skills/space/`` and hands us ``/<name>/SKILL.md``, which we
rewrite to ``/documents/_skills/<name>/SKILL.md`` before forwarding to the
KB.
No new database table is needed: the privileged folder convention is
enforced server-side outside of this class. We intentionally swallow any
write/edit attempts (the base class raises ``NotImplementedError``).
"""
DEFAULT_KB_ROOT: str = "/documents/_skills"
def __init__(
self,
kb_backend: KBPostgresBackend,
*,
kb_root: str = DEFAULT_KB_ROOT,
) -> None:
self._kb = kb_backend
# Normalize trailing slash off so we can join cleanly.
self._kb_root = kb_root.rstrip("/") or "/"
def _to_kb(self, path: str) -> str:
"""Rewrite a virtual path into the underlying KB namespace."""
bare = path.lstrip("/")
if not bare:
return self._kb_root
return f"{self._kb_root}/{bare}"
def _from_kb(self, kb_path: str) -> str:
"""Rewrite a KB path back into our virtual namespace."""
if not kb_path.startswith(self._kb_root):
return kb_path # pragma: no cover - defensive
rel = kb_path[len(self._kb_root) :]
return rel if rel.startswith("/") else "/" + rel
def ls_info(self, path: str) -> list[FileInfo]:
# KBPostgresBackend exposes only the async API meaningfully; the sync
# path falls back to ``asyncio.to_thread(...)`` in the base class. We
# keep this stub to satisfy abstract resolution; the middleware calls
# ``als_info``.
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
async def als_info(self, path: str) -> list[FileInfo]:
kb_path = self._to_kb(path)
try:
infos = await self._kb.als_info(kb_path)
except Exception as exc: # pragma: no cover - defensive
logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc)
return []
remapped: list[FileInfo] = []
for info in infos:
kb_p = info.get("path", "")
if not kb_p.startswith(self._kb_root):
continue
remapped.append({**info, "path": self._from_kb(kb_p)})
return remapped
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
kb_paths = [self._to_kb(p) for p in paths]
responses = await self._kb.adownload_files(kb_paths)
# Re-map response paths back to the virtual namespace so the middleware
# correlates them to the input list correctly.
remapped: list[FileDownloadResponse] = []
for original, resp in zip(paths, responses, strict=True):
remapped.append(replace(resp, path=original))
return remapped
SKILLS_BUILTIN_PREFIX = "/skills/builtin/"
SKILLS_SPACE_PREFIX = "/skills/space/"
def build_skills_backend_factory(
*,
builtin_root: Path | str | None = None,
search_space_id: int | None = None,
) -> Callable[[ToolRuntime], BackendProtocol]:
"""Return a runtime-aware factory for the skills :class:`CompositeBackend`.
When ``search_space_id`` is provided the composite includes a
:class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh
per-runtime :class:`KBPostgresBackend`, mirroring how
:func:`build_backend_resolver` constructs the main filesystem backend.
When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit
tests) only the bundled :class:`BuiltinSkillsBackend` is exposed.
Returning a factory rather than a fixed instance is intentional: the
underlying KB backend depends on per-call ``ToolRuntime`` state
(``staged_dirs``, ``files`` cache, runtime config), so a single shared
instance cannot serve multiple concurrent agent runs.
"""
builtin = BuiltinSkillsBackend(builtin_root)
if search_space_id is None:
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
# Default StateBackend is intentionally inert: any path outside the
# ``/skills/builtin/`` route resolves to an empty per-runtime state
# so the SkillsMiddleware can iterate sources without raising.
return CompositeBackend(
default=StateBackend(runtime),
routes={SKILLS_BUILTIN_PREFIX: builtin},
)
return _factory_builtin_only
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
# Imported lazily to avoid a hard dependency at module import time:
# ``KBPostgresBackend`` pulls in DB models, which are unnecessary for
# the unit-tested builtin path.
from app.agents.new_chat.middleware.kb_postgres_backend import (
KBPostgresBackend,
)
kb = KBPostgresBackend(search_space_id, runtime)
space = SearchSpaceSkillsBackend(kb)
return CompositeBackend(
default=StateBackend(runtime),
routes={
SKILLS_BUILTIN_PREFIX: builtin,
SKILLS_SPACE_PREFIX: space,
},
)
return _factory_with_space
def default_skills_sources() -> list[str]:
"""Return the canonical source list for SkillsMiddleware (built-in then space)."""
return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]
__all__ = [
"SKILLS_BUILTIN_PREFIX",
"SKILLS_SPACE_PREFIX",
"BuiltinSkillsBackend",
"SearchSpaceSkillsBackend",
"build_skills_backend_factory",
"default_skills_sources",
]

View file

@ -0,0 +1,190 @@
"""
ToolCallNameRepairMiddleware two-stage tool-name repair.
Mirrors ``opencode/packages/opencode/src/session/llm.ts:339-358`` plus
``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.7 in the
OpenCode-port plan.
Operation:
1. **Stage 1 lowercase repair:** if a tool call's ``name`` is not in
the registry but ``name.lower()`` is, rewrite in place. Catches
models that emit ``Search`` instead of ``search``.
2. **Stage 2 invalid fallback:** if still unmatched, rewrite the call
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
so the registered :func:`invalid_tool` returns the error to the model
for self-correction.
Distinct from :class:`deepagents.middleware.PatchToolCallsMiddleware`,
which patches *dangling* tool calls (no matching ToolMessage) that
class does not handle the wrong-name case at all.
"""
from __future__ import annotations
import difflib
import logging
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
logger = logging.getLogger(__name__)
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
"""Normalize a tool call entry to a mutable dict."""
if isinstance(call, dict):
return dict(call)
return {
"name": getattr(call, "name", None),
"args": getattr(call, "args", {}),
"id": getattr(call, "id", None),
"type": "tool_call",
}
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Two-stage tool-name repair on the most recent ``AIMessage``.
Args:
registered_tool_names: Set of canonically-registered tool names.
``invalid`` should be in this set so the fallback dispatches.
fuzzy_match_threshold: Optional ``difflib`` ratio (01) for the
fuzzy-match step that runs *between* lowercase and invalid.
Set to ``None`` to disable fuzzy matching (opencode parity).
"""
def __init__(
self,
*,
registered_tool_names: set[str],
fuzzy_match_threshold: float | None = 0.85,
) -> None:
super().__init__()
self._registered = set(registered_tool_names)
self._registered_lower = {name.lower(): name for name in self._registered}
self._fuzzy_threshold = fuzzy_match_threshold
self.tools = []
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
if isinstance(ctx_tools, (set, frozenset)):
return self._registered | set(ctx_tools)
if isinstance(ctx_tools, (list, tuple)):
return self._registered | set(ctx_tools)
return self._registered
def _repair_one(
self,
call: dict[str, Any],
registered: set[str],
) -> dict[str, Any]:
name = call.get("name")
if not isinstance(name, str):
return call
if name in registered:
return call
# Stage 1 — lowercase
lowered = name.lower()
if lowered in registered:
call["name"] = lowered
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", "lowercase")
call["response_metadata"] = metadata
return call
# Optional fuzzy step (off by default for opencode parity)
if self._fuzzy_threshold is not None:
close = difflib.get_close_matches(
name, registered, n=1, cutoff=self._fuzzy_threshold
)
if close:
call["name"] = close[0]
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
call["response_metadata"] = metadata
return call
# Stage 2 — invalid fallback
if INVALID_TOOL_NAME in registered:
original_args = call.get("args") or {}
error_msg = (
f"Tool name '{name}' is not registered. "
f"Original arguments were: {original_args!r}."
)
call["name"] = INVALID_TOOL_NAME
call["args"] = {"tool": name, "error": error_msg}
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", f"invalid_fallback:{name}")
call["response_metadata"] = metadata
else:
logger.warning(
"Could not repair unknown tool call %r; 'invalid' tool not registered",
name,
)
return call
def _maybe_repair(
self,
message: AIMessage,
registered: set[str],
) -> AIMessage | None:
if not message.tool_calls:
return None
new_calls: list[dict[str, Any]] = []
any_changed = False
for raw in message.tool_calls:
call = _coerce_existing_tool_call(raw)
before = (call.get("name"), call.get("args"))
repaired = self._repair_one(call, registered)
after = (repaired.get("name"), repaired.get("args"))
if before != after:
any_changed = True
new_calls.append(repaired)
if not any_changed:
return None
return message.model_copy(update={"tool_calls": new_calls})
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
registered = self._registered_for_runtime(runtime)
repaired = self._maybe_repair(last, registered)
if repaired is None:
return None
return {"messages": [repaired]}
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"ToolCallNameRepairMiddleware",
]