mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): split dedup_tool_calls; move HITL middleware to main_agent
DedupHITLToolCallsMiddleware is only wired by the main_agent stack, but its module also exports dedup-key resolvers consumed by the shared MCP tool layer. Splitting keeps the resolvers (dedup_key_full_args, wrap_dedup_key_by_arg_name, DedupResolver) in shared and moves the middleware class verbatim into main_agent/middleware/dedup_hitl.py (merged with its builder), eliminating the shared->main_agent dependency that a flat move would create. No behavior change.
This commit is contained in:
parent
afa51e97cf
commit
fbd5ccc35a
5 changed files with 127 additions and 111 deletions
|
|
@ -1,12 +1,127 @@
|
||||||
"""Drop duplicate HITL tool calls before execution."""
|
"""Drop duplicate HITL tool calls before execution.
|
||||||
|
|
||||||
|
When the LLM emits multiple calls to the same HITL tool with the same
|
||||||
|
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||||
|
only the first call is kept. Non-HITL tools are never touched.
|
||||||
|
|
||||||
|
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||||
|
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||||
|
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||||
|
the removed call will never appear on graph resume.
|
||||||
|
|
||||||
|
Dedup-key resolution order (read from each tool's own ``metadata``):
|
||||||
|
|
||||||
|
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
||||||
|
stable signature string. This is the canonical mechanism.
|
||||||
|
2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg;
|
||||||
|
used by MCP / Composio tools that only expose a single key field.
|
||||||
|
|
||||||
|
A tool with no resolver from either path simply opts out of dedup.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from app.agents.shared.middleware import DedupHITLToolCallsMiddleware
|
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||||
|
DedupResolver,
|
||||||
|
wrap_dedup_key_by_arg_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||||
|
|
||||||
|
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
||||||
|
pair is kept; subsequent duplicates are silently dropped.
|
||||||
|
|
||||||
|
The dedup-resolver map is built from two sources, in priority order:
|
||||||
|
|
||||||
|
1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict
|
||||||
|
and returns a string signature. This is the canonical mechanism.
|
||||||
|
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||||
|
name; primarily used by MCP / Composio tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||||
|
self._resolvers: dict[str, DedupResolver] = {}
|
||||||
|
|
||||||
|
for t in agent_tools or []:
|
||||||
|
meta = getattr(t, "metadata", None) or {}
|
||||||
|
callable_key = meta.get("dedup_key")
|
||||||
|
if callable(callable_key):
|
||||||
|
self._resolvers[t.name] = callable_key
|
||||||
|
continue
|
||||||
|
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||||
|
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
||||||
|
meta["hitl_dedup_key"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def after_model(
|
||||||
|
self, state: AgentState, runtime: Runtime[Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self._dedup(state, self._resolvers)
|
||||||
|
|
||||||
|
async def aafter_model(
|
||||||
|
self, state: AgentState, runtime: Runtime[Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self._dedup(state, self._resolvers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dedup(
|
||||||
|
state: AgentState,
|
||||||
|
resolvers: dict[str, DedupResolver],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
messages = state.get("messages")
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
last_msg = messages[-1]
|
||||||
|
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
||||||
|
seen: set[tuple[str, str]] = set()
|
||||||
|
deduped: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for tc in tool_calls:
|
||||||
|
name = tc.get("name", "")
|
||||||
|
resolver = resolvers.get(name)
|
||||||
|
if resolver is not None:
|
||||||
|
try:
|
||||||
|
arg_val = resolver(tc.get("args", {}) or {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Dedup resolver for tool %s raised; keeping call", name
|
||||||
|
)
|
||||||
|
deduped.append(tc)
|
||||||
|
continue
|
||||||
|
key = (name, arg_val)
|
||||||
|
if key in seen:
|
||||||
|
logger.info(
|
||||||
|
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
||||||
|
name,
|
||||||
|
arg_val,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
deduped.append(tc)
|
||||||
|
|
||||||
|
if len(deduped) == len(tool_calls):
|
||||||
|
return None
|
||||||
|
|
||||||
|
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
||||||
|
return {"messages": [updated_msg]}
|
||||||
|
|
||||||
|
|
||||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,6 @@ from app.agents.shared.middleware.context_editing import (
|
||||||
SpillingContextEditingMiddleware,
|
SpillingContextEditingMiddleware,
|
||||||
SpillToBackendEdit,
|
SpillToBackendEdit,
|
||||||
)
|
)
|
||||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
|
||||||
DedupHITLToolCallsMiddleware,
|
|
||||||
)
|
|
||||||
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
from app.agents.shared.middleware.doom_loop import DoomLoopMiddleware
|
||||||
from app.agents.shared.middleware.kb_persistence import (
|
from app.agents.shared.middleware.kb_persistence import (
|
||||||
KnowledgeBasePersistenceMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
|
|
@ -47,7 +44,6 @@ __all__ = [
|
||||||
"AnonymousDocumentMiddleware",
|
"AnonymousDocumentMiddleware",
|
||||||
"BusyMutexMiddleware",
|
"BusyMutexMiddleware",
|
||||||
"ClearToolUsesEdit",
|
"ClearToolUsesEdit",
|
||||||
"DedupHITLToolCallsMiddleware",
|
|
||||||
"DoomLoopMiddleware",
|
"DoomLoopMiddleware",
|
||||||
"KnowledgeBasePersistenceMiddleware",
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
"KnowledgePriorityMiddleware",
|
"KnowledgePriorityMiddleware",
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,11 @@
|
||||||
"""Middleware that deduplicates HITL tool calls within a single LLM response.
|
"""Dedup-key resolvers for tool-call deduplication.
|
||||||
|
|
||||||
When the LLM emits multiple calls to the same HITL tool with the same
|
A *resolver* maps a tool's ``args`` dict to a stable signature string used to
|
||||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
collapse duplicate calls. These helpers are shared: the MCP tool layer uses
|
||||||
only the first call is kept. Non-HITL tools are never touched.
|
:func:`dedup_key_full_args` as a safe default, and the main-agent
|
||||||
|
``DedupHITLToolCallsMiddleware`` builds its resolver map from them.
|
||||||
|
|
||||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
Resolver resolution order (read from each tool's own ``metadata``):
|
||||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
|
||||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
|
||||||
the removed call will never appear on graph resume.
|
|
||||||
|
|
||||||
Dedup-key resolution order (read from each tool's own ``metadata``):
|
|
||||||
|
|
||||||
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
||||||
stable signature string. This is the canonical mechanism.
|
stable signature string. This is the canonical mechanism.
|
||||||
|
|
@ -22,15 +18,9 @@ A tool with no resolver from either path simply opts out of dedup.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Resolver type — given the tool ``args`` dict returns a stable
|
# Resolver type — given the tool ``args`` dict returns a stable
|
||||||
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||||
DedupResolver = Callable[[dict[str, Any]], str]
|
DedupResolver = Callable[[dict[str, Any]], str]
|
||||||
|
|
@ -67,90 +57,3 @@ def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||||
# Backwards-compatible alias for code that imported the original
|
# Backwards-compatible alias for code that imported the original
|
||||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||||
|
|
||||||
|
|
||||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|
||||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
|
||||||
|
|
||||||
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
|
||||||
pair is kept; subsequent duplicates are silently dropped.
|
|
||||||
|
|
||||||
The dedup-resolver map is built from two sources, in priority order:
|
|
||||||
|
|
||||||
1. ``tool.metadata["dedup_key"]`` — callable that receives the args dict
|
|
||||||
and returns a string signature. This is the canonical mechanism.
|
|
||||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
|
||||||
name; primarily used by MCP / Composio tools.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tools = ()
|
|
||||||
|
|
||||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
|
||||||
self._resolvers: dict[str, DedupResolver] = {}
|
|
||||||
|
|
||||||
for t in agent_tools or []:
|
|
||||||
meta = getattr(t, "metadata", None) or {}
|
|
||||||
callable_key = meta.get("dedup_key")
|
|
||||||
if callable(callable_key):
|
|
||||||
self._resolvers[t.name] = callable_key
|
|
||||||
continue
|
|
||||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
|
||||||
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
|
||||||
meta["hitl_dedup_key"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def after_model(
|
|
||||||
self, state: AgentState, runtime: Runtime[Any]
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
return self._dedup(state, self._resolvers)
|
|
||||||
|
|
||||||
async def aafter_model(
|
|
||||||
self, state: AgentState, runtime: Runtime[Any]
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
return self._dedup(state, self._resolvers)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dedup(
|
|
||||||
state: AgentState,
|
|
||||||
resolvers: dict[str, DedupResolver],
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
messages = state.get("messages")
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
last_msg = messages[-1]
|
|
||||||
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
|
|
||||||
return None
|
|
||||||
|
|
||||||
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
|
|
||||||
seen: set[tuple[str, str]] = set()
|
|
||||||
deduped: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for tc in tool_calls:
|
|
||||||
name = tc.get("name", "")
|
|
||||||
resolver = resolvers.get(name)
|
|
||||||
if resolver is not None:
|
|
||||||
try:
|
|
||||||
arg_val = resolver(tc.get("args", {}) or {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Dedup resolver for tool %s raised; keeping call", name
|
|
||||||
)
|
|
||||||
deduped.append(tc)
|
|
||||||
continue
|
|
||||||
key = (name, arg_val)
|
|
||||||
if key in seen:
|
|
||||||
logger.info(
|
|
||||||
"Dedup: dropped duplicate HITL tool call %s(%s)",
|
|
||||||
name,
|
|
||||||
arg_val,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
seen.add(key)
|
|
||||||
deduped.append(tc)
|
|
||||||
|
|
||||||
if len(deduped) == len(tool_calls):
|
|
||||||
return None
|
|
||||||
|
|
||||||
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
|
|
||||||
return {"messages": [updated_msg]}
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
|
|
||||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,10 @@ import pytest
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
|
|
||||||
from app.agents.shared.middleware.dedup_tool_calls import (
|
from app.agents.multi_agent_chat.main_agent.middleware.dedup_hitl import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
|
)
|
||||||
|
from app.agents.shared.middleware.dedup_tool_calls import (
|
||||||
wrap_dedup_key_by_arg_name,
|
wrap_dedup_key_by_arg_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue