refactor(agents): split dedup_tool_calls; move HITL middleware to main_agent

DedupHITLToolCallsMiddleware is only wired by the main_agent stack, but
its module also exports dedup-key resolvers consumed by the shared MCP
tool layer. Splitting keeps the resolvers (dedup_key_full_args,
wrap_dedup_key_by_arg_name, DedupResolver) in shared and moves the
middleware class verbatim into main_agent/middleware/dedup_hitl.py
(merged with its builder), eliminating the shared->main_agent dependency
that a flat move would create. No behavior change.
This commit is contained in:
CREDO23 2026-06-05 11:17:44 +02:00
parent afa51e97cf
commit fbd5ccc35a
5 changed files with 127 additions and 111 deletions

View file

@ -1,12 +1,127 @@
"""Drop duplicate HITL tool calls before execution."""
"""Drop duplicate HITL tool calls before execution.
When the LLM emits multiple calls to the same HITL tool with the same
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
only the first call is kept. Non-HITL tools are never touched.
This runs in the ``after_model`` hook **before** any tool executes so
the duplicate call is stripped from the AIMessage that gets checkpointed.
That means it is also safe across LangGraph ``interrupt()`` boundaries:
the removed call will never appear on graph resume.
Dedup-key resolution order (read from each tool's own ``metadata``):
1. ``tool.metadata["dedup_key"]`` callable mapping the args dict to a
stable signature string. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string naming a primary arg;
used by MCP / Composio tools that only expose a single key field.
A tool with no resolver from either path simply opts out of dedup.
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from app.agents.shared.middleware import DedupHITLToolCallsMiddleware
from app.agents.shared.middleware.dedup_tool_calls import (
DedupResolver,
wrap_dedup_key_by_arg_name,
)
logger = logging.getLogger(__name__)
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Remove duplicate HITL tool calls from a single LLM response.
Only the **first** occurrence of each ``(tool-name, dedup_key)``
pair is kept; subsequent duplicates are silently dropped.
The dedup-resolver map is built from two sources, in priority order:
1. ``tool.metadata["dedup_key"]`` callable that receives the args dict
and returns a string signature. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string with a primary arg
name; primarily used by MCP / Composio tools.
"""
tools = ()
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
self._resolvers: dict[str, DedupResolver] = {}
for t in agent_tools or []:
meta = getattr(t, "metadata", None) or {}
callable_key = meta.get("dedup_key")
if callable(callable_key):
self._resolvers[t.name] = callable_key
continue
if meta.get("hitl") and meta.get("hitl_dedup_key"):
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
meta["hitl_dedup_key"]
)
def after_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
async def aafter_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
@staticmethod
def _dedup(
state: AgentState,
resolvers: dict[str, DedupResolver],
) -> dict[str, Any] | None:
messages = state.get("messages")
if not messages:
return None
last_msg = messages[-1]
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
return None
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
seen: set[tuple[str, str]] = set()
deduped: list[dict[str, Any]] = []
for tc in tool_calls:
name = tc.get("name", "")
resolver = resolvers.get(name)
if resolver is not None:
try:
arg_val = resolver(tc.get("args", {}) or {})
except Exception:
logger.exception(
"Dedup resolver for tool %s raised; keeping call", name
)
deduped.append(tc)
continue
key = (name, arg_val)
if key in seen:
logger.info(
"Dedup: dropped duplicate HITL tool call %s(%s)",
name,
arg_val,
)
continue
seen.add(key)
deduped.append(tc)
if len(deduped) == len(tool_calls):
return None
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
return {"messages": [updated_msg]}
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:

View file

@ -17,9 +17,6 @@ from app.agents.shared.middleware.context_editing import (
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from app.agents.shared.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
from app.agents.shared.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
@ -47,7 +44,6 @@ __all__ = [
"AnonymousDocumentMiddleware",
"BusyMutexMiddleware",
"ClearToolUsesEdit",
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgePriorityMiddleware",

View file

@ -1,15 +1,11 @@
"""Middleware that deduplicates HITL tool calls within a single LLM response.
"""Dedup-key resolvers for tool-call deduplication.
When the LLM emits multiple calls to the same HITL tool with the same
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
only the first call is kept. Non-HITL tools are never touched.
A *resolver* maps a tool's ``args`` dict to a stable signature string used to
collapse duplicate calls. These helpers are shared: the MCP tool layer uses
:func:`dedup_key_full_args` as a safe default, and the main-agent
``DedupHITLToolCallsMiddleware`` builds its resolver map from them.
This runs in the ``after_model`` hook **before** any tool executes so
the duplicate call is stripped from the AIMessage that gets checkpointed.
That means it is also safe across LangGraph ``interrupt()`` boundaries:
the removed call will never appear on graph resume.
Dedup-key resolution order (read from each tool's own ``metadata``):
Resolver resolution order (read from each tool's own ``metadata``):
1. ``tool.metadata["dedup_key"]`` callable mapping the args dict to a
stable signature string. This is the canonical mechanism.
@ -22,15 +18,9 @@ A tool with no resolver from either path simply opts out of dedup.
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Resolver type — given the tool ``args`` dict returns a stable
# string used to dedupe consecutive calls. ``None`` means no dedup.
DedupResolver = Callable[[dict[str, Any]], str]
@ -67,90 +57,3 @@ def dedup_key_full_args(args: dict[str, Any]) -> str:
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Remove duplicate HITL tool calls from a single LLM response.
Only the **first** occurrence of each ``(tool-name, dedup_key)``
pair is kept; subsequent duplicates are silently dropped.
The dedup-resolver map is built from two sources, in priority order:
1. ``tool.metadata["dedup_key"]`` callable that receives the args dict
and returns a string signature. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string with a primary arg
name; primarily used by MCP / Composio tools.
"""
tools = ()
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
self._resolvers: dict[str, DedupResolver] = {}
for t in agent_tools or []:
meta = getattr(t, "metadata", None) or {}
callable_key = meta.get("dedup_key")
if callable(callable_key):
self._resolvers[t.name] = callable_key
continue
if meta.get("hitl") and meta.get("hitl_dedup_key"):
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
meta["hitl_dedup_key"]
)
def after_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
async def aafter_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
@staticmethod
def _dedup(
state: AgentState,
resolvers: dict[str, DedupResolver],
) -> dict[str, Any] | None:
messages = state.get("messages")
if not messages:
return None
last_msg = messages[-1]
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
return None
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
seen: set[tuple[str, str]] = set()
deduped: list[dict[str, Any]] = []
for tc in tool_calls:
name = tc.get("name", "")
resolver = resolvers.get(name)
if resolver is not None:
try:
arg_val = resolver(tc.get("args", {}) or {})
except Exception:
logger.exception(
"Dedup resolver for tool %s raised; keeping call", name
)
deduped.append(tc)
continue
key = (name, arg_val)
if key in seen:
logger.info(
"Dedup: dropped duplicate HITL tool call %s(%s)",
name,
arg_val,
)
continue
seen.add(key)
deduped.append(tc)
if len(deduped) == len(tool_calls):
return None
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
return {"messages": [updated_msg]}

View file

@ -6,7 +6,7 @@ import pytest
from langchain_core.messages import AIMessage
from langchain_core.tools import StructuredTool
from app.agents.shared.middleware.dedup_tool_calls import (
from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import (
DedupHITLToolCallsMiddleware,
)

View file

@ -2,8 +2,10 @@ import pytest
from langchain_core.messages import AIMessage
from langchain_core.tools import StructuredTool
from app.agents.shared.middleware.dedup_tool_calls import (
from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import (
DedupHITLToolCallsMiddleware,
)
from app.agents.shared.middleware.dedup_tool_calls import (
wrap_dedup_key_by_arg_name,
)