refactor(agents): split dedup_tool_calls; move HITL middleware to main_agent

DedupHITLToolCallsMiddleware is only wired by the main_agent stack, but
its module also exports dedup-key resolvers consumed by the shared MCP
tool layer. Splitting keeps the resolvers (dedup_key_full_args,
wrap_dedup_key_by_arg_name, DedupResolver) in shared and moves the
middleware class verbatim into main_agent/middleware/dedup_hitl.py
(merged with its builder), eliminating the shared->main_agent dependency
that a flat move would create. No behavior change.
This commit is contained in:
CREDO23 2026-06-05 11:17:44 +02:00
parent afa51e97cf
commit fbd5ccc35a
5 changed files with 127 additions and 111 deletions

View file

@ -17,9 +17,6 @@ from app.agents.shared.middleware.context_editing import (
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from app.agents.shared.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
from app.agents.shared.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
@ -47,7 +44,6 @@ __all__ = [
"AnonymousDocumentMiddleware",
"BusyMutexMiddleware",
"ClearToolUsesEdit",
"DedupHITLToolCallsMiddleware",
"DoomLoopMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgePriorityMiddleware",

View file

@ -1,15 +1,11 @@
"""Middleware that deduplicates HITL tool calls within a single LLM response.
"""Dedup-key resolvers for tool-call deduplication.
When the LLM emits multiple calls to the same HITL tool with the same
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
only the first call is kept. Non-HITL tools are never touched.
A *resolver* maps a tool's ``args`` dict to a stable signature string used to
collapse duplicate calls. These helpers are shared: the MCP tool layer uses
:func:`dedup_key_full_args` as a safe default, and the main-agent
``DedupHITLToolCallsMiddleware`` builds its resolver map from them.
This runs in the ``after_model`` hook **before** any tool executes so
the duplicate call is stripped from the AIMessage that gets checkpointed.
That means it is also safe across LangGraph ``interrupt()`` boundaries:
the removed call will never appear on graph resume.
Dedup-key resolution order (read from each tool's own ``metadata``):
Resolver resolution order (read from each tool's own ``metadata``):
1. ``tool.metadata["dedup_key"]`` callable mapping the args dict to a
stable signature string. This is the canonical mechanism.
@ -22,15 +18,9 @@ A tool with no resolver from either path simply opts out of dedup.
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Resolver type — given the tool ``args`` dict returns a stable
# string used to dedupe consecutive calls. ``None`` means no dedup.
DedupResolver = Callable[[dict[str, Any]], str]
@ -67,90 +57,3 @@ def dedup_key_full_args(args: dict[str, Any]) -> str:
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Remove duplicate HITL tool calls from a single LLM response.
Only the **first** occurrence of each ``(tool-name, dedup_key)``
pair is kept; subsequent duplicates are silently dropped.
The dedup-resolver map is built from two sources, in priority order:
1. ``tool.metadata["dedup_key"]`` callable that receives the args dict
and returns a string signature. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string with a primary arg
name; primarily used by MCP / Composio tools.
"""
tools = ()
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
self._resolvers: dict[str, DedupResolver] = {}
for t in agent_tools or []:
meta = getattr(t, "metadata", None) or {}
callable_key = meta.get("dedup_key")
if callable(callable_key):
self._resolvers[t.name] = callable_key
continue
if meta.get("hitl") and meta.get("hitl_dedup_key"):
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
meta["hitl_dedup_key"]
)
def after_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
async def aafter_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
@staticmethod
def _dedup(
state: AgentState,
resolvers: dict[str, DedupResolver],
) -> dict[str, Any] | None:
messages = state.get("messages")
if not messages:
return None
last_msg = messages[-1]
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
return None
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
seen: set[tuple[str, str]] = set()
deduped: list[dict[str, Any]] = []
for tc in tool_calls:
name = tc.get("name", "")
resolver = resolvers.get(name)
if resolver is not None:
try:
arg_val = resolver(tc.get("args", {}) or {})
except Exception:
logger.exception(
"Dedup resolver for tool %s raised; keeping call", name
)
deduped.append(tc)
continue
key = (name, arg_val)
if key in seen:
logger.info(
"Dedup: dropped duplicate HITL tool call %s(%s)",
name,
arg_val,
)
continue
seen.add(key)
deduped.append(tc)
if len(deduped) == len(tool_calls):
return None
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
return {"messages": [updated_msg]}