Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.

This commit is contained in:
CREDO23 2026-05-04 19:25:27 +02:00
parent 4ac3f0b304
commit 277bd50f37
6 changed files with 442 additions and 65 deletions

View file

@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import Any
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
return _resolver
def dedup_key_full_args(args: dict[str, Any]) -> str:
"""Resolver that collapses calls only when **every** argument is identical.
Safe default for tools where no single field uniquely identifies a call
(e.g. MCP tools whose first required field is a shared workspace id).
"""
try:
return json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name

View file

@ -19,8 +19,9 @@ Operation:
the results: ``deny`` > ``ask`` > ``allow``.
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
containing a :class:`StreamingError`.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
replies are accepted via :func:`_normalize_permission_decision`.
- ``once``: proceed.
- ``always``: also persist allow rules for ``request.always`` patterns.
- ``reject`` w/o feedback: raise :class:`RejectedError`.
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
return _resolve
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
# original tool args — tools needing argument edits should use
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
"approve": "once",
"reject": "reject",
"edit": "once",
}
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
middleware fails closed.
"""
if isinstance(decision, str):
return {"decision_type": decision}
if not isinstance(decision, dict):
logger.warning(
"Unrecognized permission resume value (%s); treating as reject",
type(decision).__name__,
)
return {"decision_type": "reject"}
if decision.get("decision_type"):
return decision
payload: dict[str, Any] = decision
decisions = decision.get("decisions")
if isinstance(decisions, list) and decisions:
first = decisions[0]
if isinstance(first, dict):
payload = first
raw_type = payload.get("type") or payload.get("decision_type")
if not raw_type:
logger.warning(
"Permission resume missing decision type (keys=%s); treating as reject",
list(payload.keys()),
)
return {"decision_type": "reject"}
raw_type = str(raw_type).lower()
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
if mapped is None:
# Tolerate legacy values arriving without ``decision_type`` wrapping.
if raw_type in {"once", "always", "reject"}:
mapped = raw_type
else:
logger.warning(
"Unknown permission decision type %r; treating as reject", raw_type
)
mapped = "reject"
if raw_type == "edit":
logger.warning(
"Permission middleware received an 'edit' decision; original args "
"kept (edits not merged here)."
)
out: dict[str, Any] = {"decision_type": mapped}
feedback = payload.get("feedback") or payload.get("message")
if isinstance(feedback, str) and feedback.strip():
out["feedback"] = feedback
return out
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload)
if isinstance(decision, dict):
return decision
# Tolerate a plain string reply ("once", "always", "reject")
if isinstance(decision, str):
return {"decision_type": decision}
return {"decision_type": "reject"}
return _normalize_permission_decision(decision)
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
"""Promote ``always`` reply into runtime allow rules.
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
__all__ = [
"PatternResolver",
"PermissionMiddleware",
"_normalize_permission_decision",
]

View file

@ -33,6 +33,7 @@ from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector
@ -45,7 +46,10 @@ _MCP_CACHE_MAX_SIZE = 50
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
_TOOL_CALL_MAX_RETRIES = 3
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
# Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and
# multi-agent paths cannot share tool closures with different HITL wiring.
_MCPCacheKey = tuple[int, bool]
_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {}
def _evict_expired_mcp_cache() -> None:
@ -137,12 +141,13 @@ async def _create_mcp_tool_from_definition_stdio(
connector_name: str = "",
connector_id: int | None = None,
trusted_tools: list[str] | None = None,
bypass_internal_hitl: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (stdio transport).
All MCP tools are unconditionally wrapped with HITL approval.
``request_approval()`` is called OUTSIDE the try/except so that
``GraphInterrupt`` propagates cleanly to LangGraph.
Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
already gates the tool, otherwise the body's ``request_approval()`` is the
sole HITL gate (single-agent path).
"""
tool_name = tool_def.get("name", "unnamed_tool")
raw_description = tool_def.get("description", "No description provided")
@ -161,24 +166,29 @@ async def _create_mcp_tool_from_definition_stdio(
"""Execute the MCP tool call via the client with retry support."""
logger.debug("MCP tool '%s' called", tool_name)
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
if bypass_internal_hitl:
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
else:
# Outside try/except so ``GraphInterrupt`` propagates to LangGraph.
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": raw_description,
"mcp_transport": "stdio",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in hitl_result.params.items() if v is not None}
)
last_error: Exception | None = None
for attempt in range(_TOOL_CALL_MAX_RETRIES):
@ -221,7 +231,9 @@ async def _create_mcp_tool_from_definition_stdio(
"mcp_connector_name": connector_name or None,
"mcp_is_generic": True,
"hitl": True,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
# would otherwise collapse legitimate batches.
"dedup_key": dedup_key_full_args,
},
)
@ -240,11 +252,14 @@ async def _create_mcp_tool_from_definition_http(
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
bypass_internal_hitl: bool = False,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
Write tools are wrapped with HITL approval; read-only tools (listed in
``readonly_tools``) execute immediately without user confirmation.
``readonly_tools``) execute immediately without user confirmation. Set
``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
already gates the tool.
When ``tool_name_prefix`` is set (multi-account disambiguation), the
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
@ -302,7 +317,7 @@ async def _create_mcp_tool_from_definition_http(
"""Execute the MCP tool call via HTTP transport."""
logger.debug("MCP HTTP tool '%s' called", exposed_name)
if is_readonly:
if is_readonly or bypass_internal_hitl:
call_kwargs = _unpack_synthetic_input_data(
{k: v for k, v in kwargs.items() if v is not None}
)
@ -385,7 +400,9 @@ async def _create_mcp_tool_from_definition_http(
"mcp_connector_name": connector_name or None,
"mcp_is_generic": is_generic_mcp,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
# would otherwise collapse legitimate batches.
"dedup_key": dedup_key_full_args,
"mcp_original_tool_name": original_tool_name,
"mcp_connector_id": connector_id,
},
@ -400,6 +417,8 @@ async def _load_stdio_mcp_tools(
connector_name: str,
server_config: dict[str, Any],
trusted_tools: list[str] | None = None,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]:
"""Load tools from a stdio-based MCP server."""
tools: list[StructuredTool] = []
@ -451,6 +470,7 @@ async def _load_stdio_mcp_tools(
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
bypass_internal_hitl=bypass_internal_hitl,
)
tools.append(tool)
except Exception as e:
@ -473,6 +493,8 @@ async def _load_http_mcp_tools(
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
is_generic_mcp: bool = False,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server.
@ -598,6 +620,7 @@ async def _load_http_mcp_tools(
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
is_generic_mcp=is_generic_mcp,
bypass_internal_hitl=bypass_internal_hitl,
)
tools.append(tool)
except Exception as e:
@ -905,14 +928,10 @@ async def _mark_connector_auth_expired(connector_id: int) -> None:
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
Args:
search_space_id: If provided, only invalidate for this search space.
If None, invalidate all cached MCP tools.
"""
"""Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together)."""
if search_space_id is not None:
_mcp_tools_cache.pop(search_space_id, None)
for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]:
_mcp_tools_cache.pop(key, None)
else:
_mcp_tools_cache.clear()
@ -920,27 +939,29 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
*,
bypass_internal_hitl: bool = False,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
"""Load all MCP tools from the user's active MCP server connectors.
This discovers tools dynamically from MCP servers using the protocol.
Supports both stdio (local process) and HTTP (remote server) transports.
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up
to 5 minutes; bypass is keyed because each variant builds a different tool
closure (with vs. without the in-wrapper ``request_approval`` gate).
"""
_evict_expired_mcp_cache()
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl)
cached = _mcp_tools_cache.get(cache_key)
if cached is not None:
cached_at, cached_tools = cached
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
logger.info(
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
"Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)",
search_space_id,
len(cached_tools),
now - cached_at,
bypass_internal_hitl,
)
return list(cached_tools)
@ -1064,6 +1085,7 @@ async def load_mcp_tools(
readonly_tools=task["readonly_tools"],
tool_name_prefix=task["tool_name_prefix"],
is_generic_mcp=task.get("is_generic_mcp", False),
bypass_internal_hitl=bypass_internal_hitl,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
@ -1074,6 +1096,7 @@ async def load_mcp_tools(
task["connector_name"],
task["server_config"],
trusted_tools=task["trusted_tools"],
bypass_internal_hitl=bypass_internal_hitl,
),
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
)
@ -1095,14 +1118,17 @@ async def load_mcp_tools(
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
_mcp_tools_cache[search_space_id] = (now, tools)
_mcp_tools_cache[cache_key] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key]
logger.info(
"Loaded %d MCP tools for search space %d", len(tools), search_space_id
"Loaded %d MCP tools for search space %d (bypass_hitl=%s)",
len(tools),
search_space_id,
bypass_internal_hitl,
)
return tools