mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): split dedup_tool_calls; move HITL middleware to main_agent
DedupHITLToolCallsMiddleware is only wired by the main_agent stack, but its module also exports dedup-key resolvers consumed by the shared MCP tool layer. Splitting keeps the resolvers (dedup_key_full_args, wrap_dedup_key_by_arg_name, DedupResolver) in shared and moves the middleware class verbatim into main_agent/middleware/dedup_hitl.py (merged with its builder), eliminating the shared->main_agent dependency that a flat move would create. No behavior change.
This commit is contained in:
parent
afa51e97cf
commit
fbd5ccc35a
5 changed files with 127 additions and 111 deletions
|
|
@ -1,12 +1,127 @@
|
|||
"""Drop duplicate HITL tool calls before execution."""
|
||||
"""Drop duplicate HITL tool calls before execution.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
|
||||
Dedup-key resolution order (read from each tool's own ``metadata``):
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
||||
stable signature string. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg;
|
||||
used by MCP / Composio tools that only expose a single key field.
|
||||
|
||||
A tool with no resolver from either path simply opts out of dedup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.shared.middleware import DedupHITLToolCallsMiddleware
|
||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||
DedupResolver,
|
||||
wrap_dedup_key_by_arg_name,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
|
||||
The dedup-resolver map is built from two sources, in priority order:
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict
|
||||
and returns a string signature. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||
name; primarily used by MCP / Composio tools.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||
self._resolvers: dict[str, DedupResolver] = {}
|
||||
|
||||
for t in agent_tools or []:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
callable_key = meta.get("dedup_key")
|
||||
if callable(callable_key):
|
||||
self._resolvers[t.name] = callable_key
|
||||
continue
|
||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
||||
meta["hitl_dedup_key"]
|
||||
)
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(
|
||||
state: AgentState,
|
||||
resolvers: dict[str, DedupResolver],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||
return None
|
||||
|
||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped: list[dict[str, Any]] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
resolver = resolvers.get(name)
|
||||
if resolver is not None:
|
||||
try:
|
||||
arg_val = resolver(tc.get("args", {}) or {})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Dedup resolver for tool %s raised; keeping call", name
|
||||
)
|
||||
deduped.append(tc)
|
||||
continue
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||
name,
|
||||
arg_val,
|
||||
)
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(tc)
|
||||
|
||||
if len(deduped) == len(tool_calls):
|
||||
return None
|
||||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
|
|
|
|||
|
|
@ -17,9 +17,6 @@ from app.agents.shared.middleware.context_editing import (
|
|||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
||||
from app.agents.shared.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
|
|
@ -47,7 +44,6 @@ __all__ = [
|
|||
"AnonymousDocumentMiddleware",
|
||||
"BusyMutexMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"DoomLoopMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
||||
"""Dedup-key resolvers for tool-call deduplication.
|
||||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
A *resolver* maps a tool's ``args`` dict to a stable signature string used to
|
||||
collapse duplicate calls. These helpers are shared: the MCP tool layer uses
|
||||
:func:`dedup_key_full_args` as a safe default, and the main-agent
|
||||
``DedupHITLToolCallsMiddleware`` builds its resolver map from them.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
|
||||
Dedup-key resolution order (read from each tool's own ``metadata``):
|
||||
Resolver resolution order (read from each tool's own ``metadata``):
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
||||
stable signature string. This is the canonical mechanism.
|
||||
|
|
@ -22,15 +18,9 @@ A tool with no resolver from either path simply opts out of dedup.
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Resolver type — given the tool ``args`` dict returns a stable
|
||||
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||
DedupResolver = Callable[[dict[str, Any]], str]
|
||||
|
|
@ -67,90 +57,3 @@ def dedup_key_full_args(args: dict[str, Any]) -> str:
|
|||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
|
||||
The dedup-resolver map is built from two sources, in priority order:
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict
|
||||
and returns a string signature. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||
name; primarily used by MCP / Composio tools.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||
self._resolvers: dict[str, DedupResolver] = {}
|
||||
|
||||
for t in agent_tools or []:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
callable_key = meta.get("dedup_key")
|
||||
if callable(callable_key):
|
||||
self._resolvers[t.name] = callable_key
|
||||
continue
|
||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
||||
meta["hitl_dedup_key"]
|
||||
)
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(
|
||||
state: AgentState,
|
||||
resolvers: dict[str, DedupResolver],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||
return None
|
||||
|
||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||
seen: set[tuple[str, str]] = set()
|
||||
deduped: list[dict[str, Any]] = []
|
||||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
resolver = resolvers.get(name)
|
||||
if resolver is not None:
|
||||
try:
|
||||
arg_val = resolver(tc.get("args", {}) or {})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Dedup resolver for tool %s raised; keeping call", name
|
||||
)
|
||||
deduped.append(tc)
|
||||
continue
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||
name,
|
||||
arg_val,
|
||||
)
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(tc)
|
||||
|
||||
if len(deduped) == len(tool_calls):
|
||||
return None
|
||||
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||
return {"messages": [updated_msg]}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||
from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ import pytest
|
|||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||
from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||
wrap_dedup_key_by_arg_name,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue