From fbd5ccc35aa6b2f538fb8c55c49f508c54f63c53 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 5 Jun 2026 11:17:44 +0200 Subject: [PATCH] refactor(agents): split dedup_tool_calls; move HITL middleware to main_agent DedupHITLToolCallsMiddleware is only wired by the main_agent stack, but its module also exports dedup-key resolvers consumed by the shared MCP tool layer. Splitting keeps the resolvers (dedup_key_full_args, wrap_dedup_key_by_arg_name, DedupResolver) in shared and moves the middleware class verbatim into main_agent/middleware/dedup_hitl.py (merged with its builder), eliminating the shared->main_agent dependency that a flat move would create. No behavior change. --- .../main_agent/middleware/dedup_hitl.py | 119 +++++++++++++++++- .../app/agents/shared/middleware/__init__.py | 4 - .../shared/middleware/dedup_tool_calls.py | 109 +--------------- .../agents/new_chat/test_dedup_tool_calls.py | 2 +- .../middleware/test_dedup_hitl_tool_calls.py | 4 +- 5 files changed, 127 insertions(+), 111 deletions(-) diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/middleware/dedup_hitl.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/middleware/dedup_hitl.py index f5536bca9..61af45a22 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/middleware/dedup_hitl.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/middleware/dedup_hitl.py @@ -1,12 +1,127 @@ -"""Drop duplicate HITL tool calls before execution.""" +"""Drop duplicate HITL tool calls before execution. + +When the LLM emits multiple calls to the same HITL tool with the same +primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``), +only the first call is kept. Non-HITL tools are never touched. + +This runs in the ``after_model`` hook — **before** any tool executes — so +the duplicate call is stripped from the AIMessage that gets checkpointed. +That means it is also safe across LangGraph ``interrupt()`` boundaries: +the removed call will never appear on graph resume. + +Dedup-key resolution order (read from each tool's own ``metadata``): + +1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a + stable signature string. This is the canonical mechanism. +2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg; + used by MCP / Composio tools that only expose a single key field. + +A tool with no resolver from either path simply opts out of dedup. +""" from __future__ import annotations +import logging from collections.abc import Sequence +from typing import Any +from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.tools import BaseTool +from langgraph.runtime import Runtime -from app.agents.shared.middleware import DedupHITLToolCallsMiddleware +from app.agents.shared.middleware.dedup_tool_calls import ( + DedupResolver, + wrap_dedup_key_by_arg_name, +) + +logger = logging.getLogger(__name__) + + +class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Remove duplicate HITL tool calls from a single LLM response. + + Only the **first** occurrence of each ``(tool-name, dedup_key)`` + pair is kept; subsequent duplicates are silently dropped. + + The dedup-resolver map is built from two sources, in priority order: + + 1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict + and returns a string signature. This is the canonical mechanism. + 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg + name; primarily used by MCP / Composio tools. + """ + + tools = () + + def __init__(self, *, agent_tools: list[Any] | None = None) -> None: + self._resolvers: dict[str, DedupResolver] = {} + + for t in agent_tools or []: + meta = getattr(t, "metadata", None) or {} + callable_key = meta.get("dedup_key") + if callable(callable_key): + self._resolvers[t.name] = callable_key + continue + if meta.get("hitl") and meta.get("hitl_dedup_key"): + self._resolvers[t.name] = wrap_dedup_key_by_arg_name( + meta["hitl_dedup_key"] + ) + + def after_model( + self, state: AgentState, runtime: Runtime[Any] + ) -> dict[str, Any] | None: + return self._dedup(state, self._resolvers) + + async def aafter_model( + self, state: AgentState, runtime: Runtime[Any] + ) -> dict[str, Any] | None: + return self._dedup(state, self._resolvers) + + @staticmethod + def _dedup( + state: AgentState, + resolvers: dict[str, DedupResolver], + ) -> dict[str, Any] | None: + messages = state.get("messages") + if not messages: + return None + + last_msg = messages[-1] + if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None): + return None + + tool_calls: list[dict[str, Any]] = last_msg.tool_calls + seen: set[tuple[str, str]] = set() + deduped: list[dict[str, Any]] = [] + + for tc in tool_calls: + name = tc.get("name", "") + resolver = resolvers.get(name) + if resolver is not None: + try: + arg_val = resolver(tc.get("args", {}) or {}) + except Exception: + logger.exception( + "Dedup resolver for tool %s raised; keeping call", name + ) + deduped.append(tc) + continue + key = (name, arg_val) + if key in seen: + logger.info( + "Dedup: dropped duplicate HITL tool call %s(%s)", + name, + arg_val, + ) + continue + seen.add(key) + deduped.append(tc) + + if len(deduped) == len(tool_calls): + return None + + updated_msg = last_msg.model_copy(update={"tool_calls": deduped}) + return {"messages": [updated_msg]} def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware: diff --git a/surfsense_backend/app/agents/shared/middleware/__init__.py b/surfsense_backend/app/agents/shared/middleware/__init__.py index 7aaeb2713..e9652325c 100644 --- a/surfsense_backend/app/agents/shared/middleware/__init__.py +++ b/surfsense_backend/app/agents/shared/middleware/__init__.py @@ -17,9 +17,6 @@ from app.agents.shared.middleware.context_editing import ( SpillingContextEditingMiddleware, SpillToBackendEdit, ) -from app.agents.shared.middleware.dedup_tool_calls import ( - DedupHITLToolCallsMiddleware, -) from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware from app.agents.shared.middleware.kb_persistence import ( KnowledgeBasePersistenceMiddleware, @@ -47,7 +44,6 @@ __all__ = [ "AnonymousDocumentMiddleware", "BusyMutexMiddleware", "ClearToolUsesEdit", - "DedupHITLToolCallsMiddleware", "DoomLoopMiddleware", "KnowledgeBasePersistenceMiddleware", "KnowledgePriorityMiddleware", diff --git a/surfsense_backend/app/agents/shared/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/shared/middleware/dedup_tool_calls.py index 69b107dbe..087a69ae6 100644 --- a/surfsense_backend/app/agents/shared/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/shared/middleware/dedup_tool_calls.py @@ -1,15 +1,11 @@ -"""Middleware that deduplicates HITL tool calls within a single LLM response. +"""Dedup-key resolvers for tool-call deduplication. -When the LLM emits multiple calls to the same HITL tool with the same -primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``), -only the first call is kept. Non-HITL tools are never touched. +A *resolver* maps a tool's ``args`` dict to a stable signature string used to +collapse duplicate calls. These helpers are shared: the MCP tool layer uses +:func:`dedup_key_full_args` as a safe default, and the main-agent +``DedupHITLToolCallsMiddleware`` builds its resolver map from them. -This runs in the ``after_model`` hook — **before** any tool executes — so -the duplicate call is stripped from the AIMessage that gets checkpointed. -That means it is also safe across LangGraph ``interrupt()`` boundaries: -the removed call will never appear on graph resume. - -Dedup-key resolution order (read from each tool's own ``metadata``): +Resolver resolution order (read from each tool's own ``metadata``): 1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a stable signature string. This is the canonical mechanism. @@ -22,15 +18,9 @@ A tool with no resolver from either path simply opts out of dedup. from __future__ import annotations import json -import logging from collections.abc import Callable from typing import Any -from langchain.agents.middleware import AgentMiddleware, AgentState -from langgraph.runtime import Runtime - -logger = logging.getLogger(__name__) - # Resolver type — given the tool ``args`` dict returns a stable # string used to dedupe consecutive calls. ``None`` means no dedup. DedupResolver = Callable[[dict[str, Any]], str] @@ -67,90 +57,3 @@ def dedup_key_full_args(args: dict[str, Any]) -> str: # Backwards-compatible alias for code that imported the original # private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. _wrap_string_key = wrap_dedup_key_by_arg_name - - -class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Remove duplicate HITL tool calls from a single LLM response. - - Only the **first** occurrence of each ``(tool-name, dedup_key)`` - pair is kept; subsequent duplicates are silently dropped. - - The dedup-resolver map is built from two sources, in priority order: - - 1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict - and returns a string signature. This is the canonical mechanism. - 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg - name; primarily used by MCP / Composio tools. - """ - - tools = () - - def __init__(self, *, agent_tools: list[Any] | None = None) -> None: - self._resolvers: dict[str, DedupResolver] = {} - - for t in agent_tools or []: - meta = getattr(t, "metadata", None) or {} - callable_key = meta.get("dedup_key") - if callable(callable_key): - self._resolvers[t.name] = callable_key - continue - if meta.get("hitl") and meta.get("hitl_dedup_key"): - self._resolvers[t.name] = wrap_dedup_key_by_arg_name( - meta["hitl_dedup_key"] - ) - - def after_model( - self, state: AgentState, runtime: Runtime[Any] - ) -> dict[str, Any] | None: - return self._dedup(state, self._resolvers) - - async def aafter_model( - self, state: AgentState, runtime: Runtime[Any] - ) -> dict[str, Any] | None: - return self._dedup(state, self._resolvers) - - @staticmethod - def _dedup( - state: AgentState, - resolvers: dict[str, DedupResolver], - ) -> dict[str, Any] | None: - messages = state.get("messages") - if not messages: - return None - - last_msg = messages[-1] - if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None): - return None - - tool_calls: list[dict[str, Any]] = last_msg.tool_calls - seen: set[tuple[str, str]] = set() - deduped: list[dict[str, Any]] = [] - - for tc in tool_calls: - name = tc.get("name", "") - resolver = resolvers.get(name) - if resolver is not None: - try: - arg_val = resolver(tc.get("args", {}) or {}) - except Exception: - logger.exception( - "Dedup resolver for tool %s raised; keeping call", name - ) - deduped.append(tc) - continue - key = (name, arg_val) - if key in seen: - logger.info( - "Dedup: dropped duplicate HITL tool call %s(%s)", - name, - arg_val, - ) - continue - seen.add(key) - deduped.append(tc) - - if len(deduped) == len(tool_calls): - return None - - updated_msg = last_msg.model_copy(update={"tool_calls": deduped}) - return {"messages": [updated_msg]} diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py index 6996a717f..c64ebc630 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py @@ -6,7 +6,7 @@ import pytest from langchain_core.messages import AIMessage from langchain_core.tools import StructuredTool -from app.agents.shared.middleware.dedup_tool_calls import ( +from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import ( DedupHITLToolCallsMiddleware, ) diff --git a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py index aa4bab204..4646a9590 100644 --- a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py +++ b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py @@ -2,8 +2,10 @@ import pytest from langchain_core.messages import AIMessage from langchain_core.tools import StructuredTool -from app.agents.shared.middleware.dedup_tool_calls import ( +from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import ( DedupHITLToolCallsMiddleware, +) +from app.agents.shared.middleware.dedup_tool_calls import ( wrap_dedup_key_by_arg_name, )