SurfSense/surfsense_backend/app/agents/shared/middleware/doom_loop.py
CREDO23 8faa03889d docs(agents): refresh comments that referenced the deleted single-agent stack (bucket B6)
After deleting app/agents/new_chat/, several shared-kernel comments still cited
new_chat paths/cycles. Update the two lazy-import comments in middleware to state
the real reason (tools.registry <-> shared.middleware cycle), and repoint dangling
``new_chat/tools/hitl.py`` / ``chat_deepagent`` doc references to their shared
locations. Comment-only; suite unaffected.
2026-06-04 13:47:10 +02:00

238 lines
8.4 KiB
Python

"""
DoomLoopMiddleware — pattern-based detector for repeated identical tool calls.
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
of tool calls per turn — but it can't tell apart "10 distinct, useful
calls" from "the same call 10 times in a row". This middleware fills that
gap with a sliding-window check on tool-call signatures, ported from
OpenCode's ``packages/opencode/src/session/processor.ts``.
When the same tool with the same arguments is called N times in a row,
the agent has likely entered an infinite loop. We surface this to the
user as an interrupt with ``permission="doom_loop"`` so the UI can
render an "Are you stuck? Continue / cancel?" affordance.
This ships **OFF by default** until the frontend explicitly handles
``context.permission == "doom_loop"`` interrupts.
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
(see ``app/agents/shared/tools/hitl.py``):
{
"type": "permission_ask",
"action": {"tool": <name>, "params": <args>},
"context": {"permission": "doom_loop", "recent_signatures": [...]},
}
so the frontend that already handles HITL prompts can render this with
no changes beyond a string check.
"""
from __future__ import annotations
import hashlib
import json
import logging
from collections import deque
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__)
def _signature(name: str, args: Any) -> str:
"""Hash a tool call ``(name, args)`` to a short signature."""
try:
canonical = json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
canonical = repr(args)
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
return digest[:16]
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Detect repeated identical tool calls and prompt the user.
Tracks a sliding window of the most-recent ``threshold`` tool-call
signatures across the live request. When all entries match, raise
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
Args:
threshold: How many consecutive identical signatures count as a
doom loop. Default 3 (matches OpenCode's processor.ts).
"""
def __init__(self, *, threshold: int = 3) -> None:
super().__init__()
if threshold < 2:
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
self._threshold = threshold
self.tools = []
# Per-thread sliding windows. We can't put this in graph state
# without state-schema gymnastics; for one process-lifetime it's
# fine to keep an in-memory map keyed by thread_id.
self._windows: dict[str, deque[str]] = {}
@staticmethod
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
"""Resolve the thread id for sliding-window keying.
Prefer LangGraph's ``get_config()`` (the only way to read
``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry
a ``config`` attribute). Fall back to ``runtime.config`` for unit
tests that synthesize a config-bearing stub. Default
``"no_thread"`` is intentionally only used when both lookups fail
— it would collapse all threads into one window so we keep the
debug log loud.
"""
def _from_dict(cfg: Any) -> str | None:
if not isinstance(cfg, dict):
return None
tid = (cfg.get("configurable") or {}).get("thread_id")
return str(tid) if tid is not None else None
try:
tid = _from_dict(get_config())
except Exception:
tid = None
if tid is not None:
return tid
tid = _from_dict(getattr(runtime, "config", None))
if tid is not None:
return tid
logger.debug(
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
"falling back to shared 'no_thread' window."
)
return "no_thread"
def _window(self, thread_id: str) -> deque[str]:
win = self._windows.get(thread_id)
if win is None:
win = deque(maxlen=self._threshold)
self._windows[thread_id] = win
return win
def _detect(
self, message: AIMessage, runtime: Runtime[ContextT]
) -> tuple[bool, list[str], dict[str, Any] | None]:
if not message.tool_calls:
return False, [], None
thread_id = self._thread_id_from_runtime(runtime)
window = self._window(thread_id)
triggered_call: dict[str, Any] | None = None
for call in message.tool_calls:
name = (
call.get("name")
if isinstance(call, dict)
else getattr(call, "name", None)
)
args = (
call.get("args")
if isinstance(call, dict)
else getattr(call, "args", {})
)
if not isinstance(name, str):
continue
sig = _signature(name, args)
window.append(sig)
if len(window) >= self._threshold and len(set(window)) == 1:
triggered_call = {"name": name, "params": args or {}}
break
if triggered_call is None:
return False, list(window), None
return True, list(window), triggered_call
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
triggered, signatures, action = self._detect(last, runtime)
if not triggered:
return None
logger.warning(
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
action["name"] if action else "<unknown>",
self._threshold,
signatures[-1] if signatures else "<empty>",
)
# Open an interrupt.raised span with permission=doom_loop attribute
# so dashboards can break out doom-loop interrupts from regular
# permission asks via the ``interrupt.permission`` attribute.
with ot.interrupt_span(
interrupt_type="permission_ask",
extra={
"interrupt.permission": "doom_loop",
"interrupt.threshold": self._threshold,
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
},
):
ot_metrics.record_interrupt(interrupt_type="permission_ask")
decision = interrupt(
{
"type": "permission_ask",
"action": action or {"tool": "<unknown>", "params": {}},
"context": {
"permission": "doom_loop",
"recent_signatures": signatures,
"threshold": self._threshold,
},
}
)
# Reset window so the next decision (continue/cancel) starts fresh.
thread_id = self._thread_id_from_runtime(runtime)
self._windows.pop(thread_id, None)
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
# If the user cancelled, jump to end. Otherwise return ``None`` so the
# tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict):
kind = str(
decision.get("decision_type") or decision.get("type") or ""
).lower()
if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"}
return None
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"DoomLoopMiddleware",
"_signature",
]