mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move middleware package to app/agents/shared (slice 5c)
Relocate the entire new_chat/middleware/ package to the shared kernel as one cohesive unit (it is live shared infrastructure: the multi-agent stack wraps nearly every middleware via multi_agent_chat/middleware/main_agent/*, and anonymous_agent consumes it too). Flip 69 live importers across both the package-path and submodule-path forms. Shims left for the frozen single-agent stack: a package __init__ re-export plus submodule shims for permission, skills_backends, and scoped_model_fallback (the three imported via submodule path by chat_deepagent/subagents). Cycle break: importing shared.middleware previously reached back into new_chat.tools at module load, which dragged in new_chat.__init__ -> chat_deepagent -> the middleware shim -> half-initialized shared.middleware. Made action_log's ToolDefinition import TYPE_CHECKING-only and tool_call_repair's INVALID_TOOL_NAME import function-local. These tools-package back-edges fully resolve in slice 6. Asset note: skills_backends._default_builtin_root now walks to app/agents/new_chat/skills/builtin (the skills/ tree migrates in slice 7).
This commit is contained in:
parent
6f488d9564
commit
227983a104
98 changed files with 1131 additions and 999 deletions
238
surfsense_backend/app/agents/shared/middleware/doom_loop.py
Normal file
238
surfsense_backend/app/agents/shared/middleware/doom_loop.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""
|
||||
DoomLoopMiddleware — pattern-based detector for repeated identical tool calls.
|
||||
|
||||
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
|
||||
of tool calls per turn — but it can't tell apart "10 distinct, useful
|
||||
calls" from "the same call 10 times in a row". This middleware fills that
|
||||
gap with a sliding-window check on tool-call signatures, ported from
|
||||
OpenCode's ``packages/opencode/src/session/processor.ts``.
|
||||
|
||||
When the same tool with the same arguments is called N times in a row,
|
||||
the agent has likely entered an infinite loop. We surface this to the
|
||||
user as an interrupt with ``permission="doom_loop"`` so the UI can
|
||||
render an "Are you stuck? Continue / cancel?" affordance.
|
||||
|
||||
This ships **OFF by default** until the frontend explicitly handles
|
||||
``context.permission == "doom_loop"`` interrupts.
|
||||
|
||||
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
|
||||
(see ``app/agents/new_chat/tools/hitl.py``):
|
||||
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": <name>, "params": <args>},
|
||||
"context": {"permission": "doom_loop", "recent_signatures": [...]},
|
||||
}
|
||||
|
||||
so the frontend that already handles HITL prompts can render this with
|
||||
no changes beyond a string check.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _signature(name: str, args: Any) -> str:
|
||||
"""Hash a tool call ``(name, args)`` to a short signature."""
|
||||
try:
|
||||
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = repr(args)
|
||||
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
|
||||
return digest[:16]
|
||||
|
||||
|
||||
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Detect repeated identical tool calls and prompt the user.
|
||||
|
||||
Tracks a sliding window of the most-recent ``threshold`` tool-call
|
||||
signatures across the live request. When all entries match, raise
|
||||
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
|
||||
|
||||
Args:
|
||||
threshold: How many consecutive identical signatures count as a
|
||||
doom loop. Default 3 (matches OpenCode's processor.ts).
|
||||
"""
|
||||
|
||||
def __init__(self, *, threshold: int = 3) -> None:
|
||||
super().__init__()
|
||||
if threshold < 2:
|
||||
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
|
||||
self._threshold = threshold
|
||||
self.tools = []
|
||||
# Per-thread sliding windows. We can't put this in graph state
|
||||
# without state-schema gymnastics; for one process-lifetime it's
|
||||
# fine to keep an in-memory map keyed by thread_id.
|
||||
self._windows: dict[str, deque[str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
|
||||
"""Resolve the thread id for sliding-window keying.
|
||||
|
||||
Prefer LangGraph's ``get_config()`` (the only way to read
|
||||
``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry
|
||||
a ``config`` attribute). Fall back to ``runtime.config`` for unit
|
||||
tests that synthesize a config-bearing stub. Default
|
||||
``"no_thread"`` is intentionally only used when both lookups fail
|
||||
— it would collapse all threads into one window so we keep the
|
||||
debug log loud.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
tid = _from_dict(getattr(runtime, "config", None))
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
logger.debug(
|
||||
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"falling back to shared 'no_thread' window."
|
||||
)
|
||||
return "no_thread"
|
||||
|
||||
def _window(self, thread_id: str) -> deque[str]:
|
||||
win = self._windows.get(thread_id)
|
||||
if win is None:
|
||||
win = deque(maxlen=self._threshold)
|
||||
self._windows[thread_id] = win
|
||||
return win
|
||||
|
||||
def _detect(
|
||||
self, message: AIMessage, runtime: Runtime[ContextT]
|
||||
) -> tuple[bool, list[str], dict[str, Any] | None]:
|
||||
if not message.tool_calls:
|
||||
return False, [], None
|
||||
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
window = self._window(thread_id)
|
||||
|
||||
triggered_call: dict[str, Any] | None = None
|
||||
for call in message.tool_calls:
|
||||
name = (
|
||||
call.get("name")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "name", None)
|
||||
)
|
||||
args = (
|
||||
call.get("args")
|
||||
if isinstance(call, dict)
|
||||
else getattr(call, "args", {})
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
sig = _signature(name, args)
|
||||
window.append(sig)
|
||||
if len(window) >= self._threshold and len(set(window)) == 1:
|
||||
triggered_call = {"name": name, "params": args or {}}
|
||||
break
|
||||
|
||||
if triggered_call is None:
|
||||
return False, list(window), None
|
||||
return True, list(window), triggered_call
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
triggered, signatures, action = self._detect(last, runtime)
|
||||
if not triggered:
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
|
||||
action["name"] if action else "<unknown>",
|
||||
self._threshold,
|
||||
signatures[-1] if signatures else "<empty>",
|
||||
)
|
||||
|
||||
# Open an interrupt.raised span with permission=doom_loop attribute
|
||||
# so dashboards can break out doom-loop interrupts from regular
|
||||
# permission asks via the ``interrupt.permission`` attribute.
|
||||
with ot.interrupt_span(
|
||||
interrupt_type="permission_ask",
|
||||
extra={
|
||||
"interrupt.permission": "doom_loop",
|
||||
"interrupt.threshold": self._threshold,
|
||||
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
|
||||
},
|
||||
):
|
||||
ot_metrics.record_interrupt(interrupt_type="permission_ask")
|
||||
decision = interrupt(
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": action or {"tool": "<unknown>", "params": {}},
|
||||
"context": {
|
||||
"permission": "doom_loop",
|
||||
"recent_signatures": signatures,
|
||||
"threshold": self._threshold,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Reset window so the next decision (continue/cancel) starts fresh.
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
self._windows.pop(thread_id, None)
|
||||
|
||||
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
|
||||
# If the user cancelled, jump to end. Otherwise return ``None`` so the
|
||||
# tool call proceeds. The frontend's exact reply names may differ —
|
||||
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||
if isinstance(decision, dict):
|
||||
kind = str(
|
||||
decision.get("decision_type") or decision.get("type") or ""
|
||||
).lower()
|
||||
if "reject" in kind or "cancel" in kind:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DoomLoopMiddleware",
|
||||
"_signature",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue