mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.
This commit is contained in:
parent
4ac3f0b304
commit
277bd50f37
6 changed files with 442 additions and 65 deletions
|
|
@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
||||||
return _resolver
|
return _resolver
|
||||||
|
|
||||||
|
|
||||||
|
def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||||
|
"""Resolver that collapses calls only when **every** argument is identical.
|
||||||
|
|
||||||
|
Safe default for tools where no single field uniquely identifies a call
|
||||||
|
(e.g. MCP tools whose first required field is a shared workspace id).
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.dumps(args, sort_keys=True, default=str)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
|
||||||
|
|
||||||
|
|
||||||
# Backwards-compatible alias for code that imported the original
|
# Backwards-compatible alias for code that imported the original
|
||||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,9 @@ Operation:
|
||||||
the results: ``deny`` > ``ask`` > ``allow``.
|
the results: ``deny`` > ``ask`` > ``allow``.
|
||||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||||
containing a :class:`StreamingError`.
|
containing a :class:`StreamingError`.
|
||||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
|
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
|
||||||
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
|
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
|
||||||
|
replies are accepted via :func:`_normalize_permission_decision`.
|
||||||
- ``once``: proceed.
|
- ``once``: proceed.
|
||||||
- ``always``: also persist allow rules for ``request.always`` patterns.
|
- ``always``: also persist allow rules for ``request.always`` patterns.
|
||||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||||
|
|
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
|
||||||
return _resolve
|
return _resolve
|
||||||
|
|
||||||
|
|
||||||
|
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
|
||||||
|
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
|
||||||
|
# original tool args — tools needing argument edits should use
|
||||||
|
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
|
||||||
|
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
|
||||||
|
"approve": "once",
|
||||||
|
"reject": "reject",
|
||||||
|
"edit": "once",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
|
||||||
|
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
|
||||||
|
|
||||||
|
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
|
||||||
|
middleware fails closed.
|
||||||
|
"""
|
||||||
|
if isinstance(decision, str):
|
||||||
|
return {"decision_type": decision}
|
||||||
|
if not isinstance(decision, dict):
|
||||||
|
logger.warning(
|
||||||
|
"Unrecognized permission resume value (%s); treating as reject",
|
||||||
|
type(decision).__name__,
|
||||||
|
)
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
if decision.get("decision_type"):
|
||||||
|
return decision
|
||||||
|
|
||||||
|
payload: dict[str, Any] = decision
|
||||||
|
decisions = decision.get("decisions")
|
||||||
|
if isinstance(decisions, list) and decisions:
|
||||||
|
first = decisions[0]
|
||||||
|
if isinstance(first, dict):
|
||||||
|
payload = first
|
||||||
|
|
||||||
|
raw_type = payload.get("type") or payload.get("decision_type")
|
||||||
|
if not raw_type:
|
||||||
|
logger.warning(
|
||||||
|
"Permission resume missing decision type (keys=%s); treating as reject",
|
||||||
|
list(payload.keys()),
|
||||||
|
)
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
raw_type = str(raw_type).lower()
|
||||||
|
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
|
||||||
|
if mapped is None:
|
||||||
|
# Tolerate legacy values arriving without ``decision_type`` wrapping.
|
||||||
|
if raw_type in {"once", "always", "reject"}:
|
||||||
|
mapped = raw_type
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown permission decision type %r; treating as reject", raw_type
|
||||||
|
)
|
||||||
|
mapped = "reject"
|
||||||
|
|
||||||
|
if raw_type == "edit":
|
||||||
|
logger.warning(
|
||||||
|
"Permission middleware received an 'edit' decision; original args "
|
||||||
|
"kept (edits not merged here)."
|
||||||
|
)
|
||||||
|
|
||||||
|
out: dict[str, Any] = {"decision_type": mapped}
|
||||||
|
feedback = payload.get("feedback") or payload.get("message")
|
||||||
|
if isinstance(feedback, str) and feedback.strip():
|
||||||
|
out["feedback"] = feedback
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
"""Allow/deny/ask layer over the agent's tool calls.
|
"""Allow/deny/ask layer over the agent's tool calls.
|
||||||
|
|
||||||
|
|
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
ot.interrupt_span(interrupt_type="permission_ask"),
|
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||||
):
|
):
|
||||||
decision = interrupt(payload)
|
decision = interrupt(payload)
|
||||||
if isinstance(decision, dict):
|
return _normalize_permission_decision(decision)
|
||||||
return decision
|
|
||||||
# Tolerate a plain string reply ("once", "always", "reject")
|
|
||||||
if isinstance(decision, str):
|
|
||||||
return {"decision_type": decision}
|
|
||||||
return {"decision_type": "reject"}
|
|
||||||
|
|
||||||
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||||
"""Promote ``always`` reply into runtime allow rules.
|
"""Promote ``always`` reply into runtime allow rules.
|
||||||
|
|
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PatternResolver",
|
"PatternResolver",
|
||||||
"PermissionMiddleware",
|
"PermissionMiddleware",
|
||||||
|
"_normalize_permission_decision",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from sqlalchemy import cast, select
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||||
from app.agents.new_chat.tools.hitl import request_approval
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||||
from app.db import SearchSourceConnector
|
from app.db import SearchSourceConnector
|
||||||
|
|
@ -45,7 +46,10 @@ _MCP_CACHE_MAX_SIZE = 50
|
||||||
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
_MCP_DISCOVERY_TIMEOUT_SECONDS = 30
|
||||||
_TOOL_CALL_MAX_RETRIES = 3
|
_TOOL_CALL_MAX_RETRIES = 3
|
||||||
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt
|
||||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
# Keyed by ``(search_space_id, bypass_internal_hitl)`` so single-agent and
|
||||||
|
# multi-agent paths cannot share tool closures with different HITL wiring.
|
||||||
|
_MCPCacheKey = tuple[int, bool]
|
||||||
|
_mcp_tools_cache: dict[_MCPCacheKey, tuple[float, list[StructuredTool]]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _evict_expired_mcp_cache() -> None:
|
def _evict_expired_mcp_cache() -> None:
|
||||||
|
|
@ -137,12 +141,13 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
connector_name: str = "",
|
connector_name: str = "",
|
||||||
connector_id: int | None = None,
|
connector_id: int | None = None,
|
||||||
trusted_tools: list[str] | None = None,
|
trusted_tools: list[str] | None = None,
|
||||||
|
bypass_internal_hitl: bool = False,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
||||||
|
|
||||||
All MCP tools are unconditionally wrapped with HITL approval.
|
Set ``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
|
||||||
``request_approval()`` is called OUTSIDE the try/except so that
|
already gates the tool, otherwise the body's ``request_approval()`` is the
|
||||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
sole HITL gate (single-agent path).
|
||||||
"""
|
"""
|
||||||
tool_name = tool_def.get("name", "unnamed_tool")
|
tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
raw_description = tool_def.get("description", "No description provided")
|
raw_description = tool_def.get("description", "No description provided")
|
||||||
|
|
@ -161,24 +166,29 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
"""Execute the MCP tool call via the client with retry support."""
|
"""Execute the MCP tool call via the client with retry support."""
|
||||||
logger.debug("MCP tool '%s' called", tool_name)
|
logger.debug("MCP tool '%s' called", tool_name)
|
||||||
|
|
||||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
if bypass_internal_hitl:
|
||||||
hitl_result = request_approval(
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
action_type="mcp_tool_call",
|
{k: v for k, v in kwargs.items() if v is not None}
|
||||||
tool_name=tool_name,
|
)
|
||||||
params=kwargs,
|
else:
|
||||||
context={
|
# Outside try/except so ``GraphInterrupt`` propagates to LangGraph.
|
||||||
"mcp_server": connector_name,
|
hitl_result = request_approval(
|
||||||
"tool_description": raw_description,
|
action_type="mcp_tool_call",
|
||||||
"mcp_transport": "stdio",
|
tool_name=tool_name,
|
||||||
"mcp_connector_id": connector_id,
|
params=kwargs,
|
||||||
},
|
context={
|
||||||
trusted_tools=trusted_tools,
|
"mcp_server": connector_name,
|
||||||
)
|
"tool_description": raw_description,
|
||||||
if hitl_result.rejected:
|
"mcp_transport": "stdio",
|
||||||
return "Tool call rejected by user."
|
"mcp_connector_id": connector_id,
|
||||||
call_kwargs = _unpack_synthetic_input_data(
|
},
|
||||||
{k: v for k, v in hitl_result.params.items() if v is not None}
|
trusted_tools=trusted_tools,
|
||||||
)
|
)
|
||||||
|
if hitl_result.rejected:
|
||||||
|
return "Tool call rejected by user."
|
||||||
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
|
{k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
)
|
||||||
|
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
for attempt in range(_TOOL_CALL_MAX_RETRIES):
|
||||||
|
|
@ -221,7 +231,9 @@ async def _create_mcp_tool_from_definition_stdio(
|
||||||
"mcp_connector_name": connector_name or None,
|
"mcp_connector_name": connector_name or None,
|
||||||
"mcp_is_generic": True,
|
"mcp_is_generic": True,
|
||||||
"hitl": True,
|
"hitl": True,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
||||||
|
# would otherwise collapse legitimate batches.
|
||||||
|
"dedup_key": dedup_key_full_args,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -240,11 +252,14 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
readonly_tools: frozenset[str] | None = None,
|
readonly_tools: frozenset[str] | None = None,
|
||||||
tool_name_prefix: str | None = None,
|
tool_name_prefix: str | None = None,
|
||||||
is_generic_mcp: bool = False,
|
is_generic_mcp: bool = False,
|
||||||
|
bypass_internal_hitl: bool = False,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||||
|
|
||||||
Write tools are wrapped with HITL approval; read-only tools (listed in
|
Write tools are wrapped with HITL approval; read-only tools (listed in
|
||||||
``readonly_tools``) execute immediately without user confirmation.
|
``readonly_tools``) execute immediately without user confirmation. Set
|
||||||
|
``bypass_internal_hitl=True`` when an outer ``HumanInTheLoopMiddleware``
|
||||||
|
already gates the tool.
|
||||||
|
|
||||||
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
||||||
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
||||||
|
|
@ -302,7 +317,7 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
"""Execute the MCP tool call via HTTP transport."""
|
"""Execute the MCP tool call via HTTP transport."""
|
||||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||||
|
|
||||||
if is_readonly:
|
if is_readonly or bypass_internal_hitl:
|
||||||
call_kwargs = _unpack_synthetic_input_data(
|
call_kwargs = _unpack_synthetic_input_data(
|
||||||
{k: v for k, v in kwargs.items() if v is not None}
|
{k: v for k, v in kwargs.items() if v is not None}
|
||||||
)
|
)
|
||||||
|
|
@ -385,7 +400,9 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
"mcp_connector_name": connector_name or None,
|
"mcp_connector_name": connector_name or None,
|
||||||
"mcp_is_generic": is_generic_mcp,
|
"mcp_is_generic": is_generic_mcp,
|
||||||
"hitl": not is_readonly,
|
"hitl": not is_readonly,
|
||||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
||||||
|
# would otherwise collapse legitimate batches.
|
||||||
|
"dedup_key": dedup_key_full_args,
|
||||||
"mcp_original_tool_name": original_tool_name,
|
"mcp_original_tool_name": original_tool_name,
|
||||||
"mcp_connector_id": connector_id,
|
"mcp_connector_id": connector_id,
|
||||||
},
|
},
|
||||||
|
|
@ -400,6 +417,8 @@ async def _load_stdio_mcp_tools(
|
||||||
connector_name: str,
|
connector_name: str,
|
||||||
server_config: dict[str, Any],
|
server_config: dict[str, Any],
|
||||||
trusted_tools: list[str] | None = None,
|
trusted_tools: list[str] | None = None,
|
||||||
|
*,
|
||||||
|
bypass_internal_hitl: bool = False,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from a stdio-based MCP server."""
|
"""Load tools from a stdio-based MCP server."""
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
|
|
@ -451,6 +470,7 @@ async def _load_stdio_mcp_tools(
|
||||||
connector_name=connector_name,
|
connector_name=connector_name,
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
trusted_tools=trusted_tools,
|
trusted_tools=trusted_tools,
|
||||||
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
)
|
)
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -473,6 +493,8 @@ async def _load_http_mcp_tools(
|
||||||
readonly_tools: frozenset[str] | None = None,
|
readonly_tools: frozenset[str] | None = None,
|
||||||
tool_name_prefix: str | None = None,
|
tool_name_prefix: str | None = None,
|
||||||
is_generic_mcp: bool = False,
|
is_generic_mcp: bool = False,
|
||||||
|
*,
|
||||||
|
bypass_internal_hitl: bool = False,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load tools from an HTTP-based MCP server.
|
"""Load tools from an HTTP-based MCP server.
|
||||||
|
|
||||||
|
|
@ -598,6 +620,7 @@ async def _load_http_mcp_tools(
|
||||||
readonly_tools=readonly_tools,
|
readonly_tools=readonly_tools,
|
||||||
tool_name_prefix=tool_name_prefix,
|
tool_name_prefix=tool_name_prefix,
|
||||||
is_generic_mcp=is_generic_mcp,
|
is_generic_mcp=is_generic_mcp,
|
||||||
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
)
|
)
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -905,14 +928,10 @@ async def _mark_connector_auth_expired(connector_id: int) -> None:
|
||||||
|
|
||||||
|
|
||||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
"""Invalidate cached MCP tools.
|
"""Invalidate cached MCP tools (both ``bypass_internal_hitl`` variants together)."""
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: If provided, only invalidate for this search space.
|
|
||||||
If None, invalidate all cached MCP tools.
|
|
||||||
"""
|
|
||||||
if search_space_id is not None:
|
if search_space_id is not None:
|
||||||
_mcp_tools_cache.pop(search_space_id, None)
|
for key in [k for k in _mcp_tools_cache if k[0] == search_space_id]:
|
||||||
|
_mcp_tools_cache.pop(key, None)
|
||||||
else:
|
else:
|
||||||
_mcp_tools_cache.clear()
|
_mcp_tools_cache.clear()
|
||||||
|
|
||||||
|
|
@ -920,27 +939,29 @@ def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
async def load_mcp_tools(
|
async def load_mcp_tools(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
*,
|
||||||
|
bypass_internal_hitl: bool = False,
|
||||||
) -> list[StructuredTool]:
|
) -> list[StructuredTool]:
|
||||||
"""Load all MCP tools from user's active MCP server connectors.
|
"""Load all MCP tools from the user's active MCP server connectors.
|
||||||
|
|
||||||
This discovers tools dynamically from MCP servers using the protocol.
|
Results are cached per ``(search_space_id, bypass_internal_hitl)`` for up
|
||||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
to 5 minutes; bypass is keyed because each variant builds a different tool
|
||||||
|
closure (with vs. without the in-wrapper ``request_approval`` gate).
|
||||||
Results are cached per search space for up to 5 minutes to avoid
|
|
||||||
re-spawning MCP server processes on every chat message.
|
|
||||||
"""
|
"""
|
||||||
_evict_expired_mcp_cache()
|
_evict_expired_mcp_cache()
|
||||||
|
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
cached = _mcp_tools_cache.get(search_space_id)
|
cache_key: _MCPCacheKey = (search_space_id, bypass_internal_hitl)
|
||||||
|
cached = _mcp_tools_cache.get(cache_key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
cached_at, cached_tools = cached
|
cached_at, cached_tools = cached
|
||||||
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
|
"Using cached MCP tools for search space %s (%d tools, age=%.0fs, bypass_hitl=%s)",
|
||||||
search_space_id,
|
search_space_id,
|
||||||
len(cached_tools),
|
len(cached_tools),
|
||||||
now - cached_at,
|
now - cached_at,
|
||||||
|
bypass_internal_hitl,
|
||||||
)
|
)
|
||||||
return list(cached_tools)
|
return list(cached_tools)
|
||||||
|
|
||||||
|
|
@ -1064,6 +1085,7 @@ async def load_mcp_tools(
|
||||||
readonly_tools=task["readonly_tools"],
|
readonly_tools=task["readonly_tools"],
|
||||||
tool_name_prefix=task["tool_name_prefix"],
|
tool_name_prefix=task["tool_name_prefix"],
|
||||||
is_generic_mcp=task.get("is_generic_mcp", False),
|
is_generic_mcp=task.get("is_generic_mcp", False),
|
||||||
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
@ -1074,6 +1096,7 @@ async def load_mcp_tools(
|
||||||
task["connector_name"],
|
task["connector_name"],
|
||||||
task["server_config"],
|
task["server_config"],
|
||||||
trusted_tools=task["trusted_tools"],
|
trusted_tools=task["trusted_tools"],
|
||||||
|
bypass_internal_hitl=bypass_internal_hitl,
|
||||||
),
|
),
|
||||||
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
@ -1095,14 +1118,17 @@ async def load_mcp_tools(
|
||||||
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks])
|
||||||
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
tools: list[StructuredTool] = [tool for sublist in results for tool in sublist]
|
||||||
|
|
||||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
_mcp_tools_cache[cache_key] = (now, tools)
|
||||||
|
|
||||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||||
del _mcp_tools_cache[oldest_key]
|
del _mcp_tools_cache[oldest_key]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Loaded %d MCP tools for search space %d", len(tools), search_space_id
|
"Loaded %d MCP tools for search space %d (bypass_hitl=%s)",
|
||||||
|
len(tools),
|
||||||
|
search_space_id,
|
||||||
|
bypass_internal_hitl,
|
||||||
)
|
)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,79 @@ def test_registry_propagates_dedup_key_to_tool_metadata() -> None:
|
||||||
assert sample == "plan"
|
assert sample == "plan"
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_args_dedup_keeps_distinct_calls_sharing_a_field() -> None:
|
||||||
|
"""Regression: MCP tools (e.g. ``createJiraIssue``) used to dedup on
|
||||||
|
the schema's first required field, which is often the workspace /
|
||||||
|
cloudId — so 3 distinct issues in the same workspace collapsed to 1.
|
||||||
|
|
||||||
|
With :func:`dedup_key_full_args` only fully identical arg dicts dedup.
|
||||||
|
"""
|
||||||
|
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||||
|
|
||||||
|
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
|
||||||
|
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
_msg(
|
||||||
|
{
|
||||||
|
"name": "createJiraIssue",
|
||||||
|
"args": {
|
||||||
|
"cloudId": "ws.atlassian.net",
|
||||||
|
"projectKey": "PROJ",
|
||||||
|
"summary": "Fix login bug",
|
||||||
|
},
|
||||||
|
"id": "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "createJiraIssue",
|
||||||
|
"args": {
|
||||||
|
"cloudId": "ws.atlassian.net",
|
||||||
|
"projectKey": "PROJ",
|
||||||
|
"summary": "Add dark mode",
|
||||||
|
},
|
||||||
|
"id": "2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "createJiraIssue",
|
||||||
|
"args": {
|
||||||
|
"cloudId": "ws.atlassian.net",
|
||||||
|
"projectKey": "PROJ",
|
||||||
|
"summary": "Improve perf",
|
||||||
|
},
|
||||||
|
"id": "3",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
out = mw.after_model(state, _Runtime())
|
||||||
|
assert out is None # nothing dropped — all three differ in summary
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_args_dedup_drops_only_exact_duplicates() -> None:
|
||||||
|
from app.agents.new_chat.middleware.dedup_tool_calls import dedup_key_full_args
|
||||||
|
|
||||||
|
tool = _make_tool("createJiraIssue", dedup_key=dedup_key_full_args)
|
||||||
|
mw = DedupHITLToolCallsMiddleware(agent_tools=[tool])
|
||||||
|
args = {"cloudId": "ws.atlassian.net", "summary": "Fix bug"}
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
_msg(
|
||||||
|
{"name": "createJiraIssue", "args": args, "id": "1"},
|
||||||
|
{"name": "createJiraIssue", "args": dict(args), "id": "2"},
|
||||||
|
{
|
||||||
|
"name": "createJiraIssue",
|
||||||
|
"args": {**args, "summary": "Different"},
|
||||||
|
"id": "3",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
out = mw.after_model(state, _Runtime())
|
||||||
|
assert out is not None
|
||||||
|
new_calls = out["messages"][0].tool_calls
|
||||||
|
assert {c["id"] for c in new_calls} == {"1", "3"}
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_tool_passes_through() -> None:
|
def test_unknown_tool_passes_through() -> None:
|
||||||
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
|
mw = DedupHITLToolCallsMiddleware(agent_tools=None)
|
||||||
state = {
|
state = {
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,10 @@ import pytest
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
from app.agents.new_chat.middleware.permission import (
|
||||||
|
PermissionMiddleware,
|
||||||
|
_normalize_permission_decision,
|
||||||
|
)
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
@ -112,3 +115,151 @@ class TestAsk:
|
||||||
# Runtime ruleset got the always-allow rule
|
# Runtime ruleset got the always-allow rule
|
||||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||||
assert any(r.permission == "send_email" for r in new_rules)
|
assert any(r.permission == "send_email" for r in new_rules)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeDecision:
|
||||||
|
"""Resume shapes ``_normalize_permission_decision`` must accept."""
|
||||||
|
|
||||||
|
def test_legacy_decision_type_dict_passes_through(self) -> None:
|
||||||
|
decision = {"decision_type": "once"}
|
||||||
|
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||||
|
|
||||||
|
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
|
||||||
|
decision = {"decision_type": "reject", "feedback": "no thanks"}
|
||||||
|
assert _normalize_permission_decision(decision) == decision
|
||||||
|
|
||||||
|
def test_plain_string_wrapped(self) -> None:
|
||||||
|
assert _normalize_permission_decision("once") == {"decision_type": "once"}
|
||||||
|
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
|
||||||
|
|
||||||
|
def test_lc_envelope_approve_maps_to_once(self) -> None:
|
||||||
|
decision = {"decisions": [{"type": "approve"}]}
|
||||||
|
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||||
|
|
||||||
|
def test_lc_envelope_reject_maps_to_reject(self) -> None:
|
||||||
|
decision = {"decisions": [{"type": "reject"}]}
|
||||||
|
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||||
|
|
||||||
|
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
|
||||||
|
decision = {
|
||||||
|
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||||
|
}
|
||||||
|
out = _normalize_permission_decision(decision)
|
||||||
|
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
|
||||||
|
|
||||||
|
def test_lc_envelope_reject_with_feedback_field(self) -> None:
|
||||||
|
decision = {
|
||||||
|
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
|
||||||
|
}
|
||||||
|
out = _normalize_permission_decision(decision)
|
||||||
|
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
|
||||||
|
|
||||||
|
def test_lc_envelope_edit_maps_to_once(self) -> None:
|
||||||
|
# Pins the contract: edited args are NOT merged by permission.
|
||||||
|
decision = {
|
||||||
|
"decisions": [
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"edited_action": {
|
||||||
|
"name": "send_email",
|
||||||
|
"args": {"subject": "edited"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||||
|
|
||||||
|
def test_lc_single_decision_without_envelope(self) -> None:
|
||||||
|
assert _normalize_permission_decision({"type": "approve"}) == {
|
||||||
|
"decision_type": "once"
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_unknown_type_falls_back_to_reject(self) -> None:
|
||||||
|
decision = {"decisions": [{"type": "totally_unknown"}]}
|
||||||
|
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||||
|
|
||||||
|
def test_missing_type_falls_back_to_reject(self) -> None:
|
||||||
|
assert _normalize_permission_decision({"decisions": [{}]}) == {
|
||||||
|
"decision_type": "reject"
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
|
||||||
|
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
|
||||||
|
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
|
||||||
|
|
||||||
|
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
|
||||||
|
# Fail-closed on a malformed reply rather than treat it as approve.
|
||||||
|
assert _normalize_permission_decision({"decisions": []}) == {
|
||||||
|
"decision_type": "reject"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumeShapesEndToEnd:
|
||||||
|
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
|
||||||
|
|
||||||
|
def test_lc_approve_envelope_keeps_call(self) -> None:
|
||||||
|
mw = PermissionMiddleware(rulesets=[])
|
||||||
|
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
|
||||||
|
"decisions": [{"type": "approve"}]
|
||||||
|
}
|
||||||
|
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||||
|
original = mw._raise_interrupt
|
||||||
|
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||||
|
original(**kw)
|
||||||
|
)
|
||||||
|
out = mw.after_model(state, _FakeRuntime())
|
||||||
|
assert out is None
|
||||||
|
|
||||||
|
def test_lc_reject_envelope_raises(self) -> None:
|
||||||
|
mw = PermissionMiddleware(rulesets=[])
|
||||||
|
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
|
||||||
|
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||||
|
original(**kw)
|
||||||
|
)
|
||||||
|
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||||
|
with pytest.raises(RejectedError):
|
||||||
|
mw.after_model(state, _FakeRuntime())
|
||||||
|
|
||||||
|
def test_lc_reject_with_message_raises_corrected(self) -> None:
|
||||||
|
mw = PermissionMiddleware(rulesets=[])
|
||||||
|
original = lambda **kw: { # noqa: E731
|
||||||
|
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||||
|
}
|
||||||
|
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||||
|
original(**kw)
|
||||||
|
)
|
||||||
|
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||||
|
with pytest.raises(CorrectedError) as excinfo:
|
||||||
|
mw.after_model(state, _FakeRuntime())
|
||||||
|
assert excinfo.value.feedback == "wrong recipient"
|
||||||
|
|
||||||
|
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
|
||||||
|
# Pins the "edit -> once, args unchanged" contract.
|
||||||
|
mw = PermissionMiddleware(rulesets=[])
|
||||||
|
original = lambda **kw: { # noqa: E731
|
||||||
|
"decisions": [
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"edited_action": {
|
||||||
|
"name": "send_email",
|
||||||
|
"args": {"to": "edited@example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||||
|
original(**kw)
|
||||||
|
)
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
_msg(
|
||||||
|
{
|
||||||
|
"name": "send_email",
|
||||||
|
"args": {"to": "original@example.com"},
|
||||||
|
"id": "1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
out = mw.after_model(state, _FakeRuntime())
|
||||||
|
assert out is None
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,31 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Most recent pending tool-call card with this name, so a new HITL interrupt
|
||||||
|
* does not overwrite an already-approved card with the same tool name.
|
||||||
|
*/
|
||||||
|
function findHitlTargetToolCallId(
|
||||||
|
toolCallIndices: Map<string, number>,
|
||||||
|
contentParts: Array<{
|
||||||
|
type: string;
|
||||||
|
toolName?: string;
|
||||||
|
result?: unknown;
|
||||||
|
}>,
|
||||||
|
toolName: string
|
||||||
|
): string | null {
|
||||||
|
const entries = Array.from(toolCallIndices.entries());
|
||||||
|
for (let i = entries.length - 1; i >= 0; i--) {
|
||||||
|
const [tcId, idx] = entries[i];
|
||||||
|
const part = contentParts[idx];
|
||||||
|
if (!part || part.type !== "tool-call" || part.toolName !== toolName) continue;
|
||||||
|
const result = part.result as Record<string, unknown> | undefined | null;
|
||||||
|
if (result == null) return tcId;
|
||||||
|
if (result.__interrupt__ === true && !result.__decided__) return tcId;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Zod schema for mentioned document info (for type-safe parsing)
|
* Zod schema for mentioned document info (for type-safe parsing)
|
||||||
*/
|
*/
|
||||||
|
|
@ -949,12 +974,13 @@ export default function NewChatPage() {
|
||||||
args: Record<string, unknown>;
|
args: Record<string, unknown>;
|
||||||
}>;
|
}>;
|
||||||
for (const action of actionRequests) {
|
for (const action of actionRequests) {
|
||||||
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
const targetTcId = findHitlTargetToolCallId(
|
||||||
const part = contentParts[idx];
|
toolCallIndices,
|
||||||
return part?.type === "tool-call" && part.toolName === action.name;
|
contentParts,
|
||||||
});
|
action.name
|
||||||
if (existingIdx) {
|
);
|
||||||
updateToolCall(contentPartsState, existingIdx[0], {
|
if (targetTcId) {
|
||||||
|
updateToolCall(contentPartsState, targetTcId, {
|
||||||
result: { __interrupt__: true, ...interruptData },
|
result: { __interrupt__: true, ...interruptData },
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1265,6 +1291,7 @@ export default function NewChatPage() {
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
search_space_id: searchSpaceId,
|
search_space_id: searchSpaceId,
|
||||||
decisions,
|
decisions,
|
||||||
|
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
||||||
filesystem_mode: selection.filesystem_mode,
|
filesystem_mode: selection.filesystem_mode,
|
||||||
client_platform: selection.client_platform,
|
client_platform: selection.client_platform,
|
||||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
|
|
@ -1388,12 +1415,13 @@ export default function NewChatPage() {
|
||||||
args: Record<string, unknown>;
|
args: Record<string, unknown>;
|
||||||
}>;
|
}>;
|
||||||
for (const action of actionRequests) {
|
for (const action of actionRequests) {
|
||||||
const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
const targetTcId = findHitlTargetToolCallId(
|
||||||
const part = contentParts[idx];
|
toolCallIndices,
|
||||||
return part?.type === "tool-call" && part.toolName === action.name;
|
contentParts,
|
||||||
});
|
action.name
|
||||||
if (existingIdx) {
|
);
|
||||||
updateToolCall(contentPartsState, existingIdx[0], {
|
if (targetTcId) {
|
||||||
|
updateToolCall(contentPartsState, targetTcId, {
|
||||||
result: {
|
result: {
|
||||||
__interrupt__: true,
|
__interrupt__: true,
|
||||||
...interruptData,
|
...interruptData,
|
||||||
|
|
@ -1514,6 +1542,25 @@ export default function NewChatPage() {
|
||||||
const decision = detail.decisions[0];
|
const decision = detail.decisions[0];
|
||||||
const decisionType = decision?.type as "approve" | "reject" | "edit";
|
const decisionType = decision?.type as "approve" | "reject" | "edit";
|
||||||
|
|
||||||
|
// Fan a single click out to N decisions when the backend bundled
|
||||||
|
// N tool calls into one HITLRequest (one Approve/Reject covers
|
||||||
|
// the whole batch until per-card decisions land).
|
||||||
|
const interruptData = pendingInterrupt.interruptData as
|
||||||
|
| { action_requests?: unknown[] }
|
||||||
|
| undefined;
|
||||||
|
const expectedCount = Array.isArray(interruptData?.action_requests)
|
||||||
|
? interruptData.action_requests.length
|
||||||
|
: detail.decisions.length;
|
||||||
|
const submittedDecisions =
|
||||||
|
detail.decisions.length >= expectedCount || expectedCount <= 1
|
||||||
|
? detail.decisions
|
||||||
|
: [
|
||||||
|
...detail.decisions,
|
||||||
|
...Array.from({ length: expectedCount - detail.decisions.length }, () => ({
|
||||||
|
...detail.decisions[detail.decisions.length - 1],
|
||||||
|
})),
|
||||||
|
];
|
||||||
|
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((m) => {
|
prev.map((m) => {
|
||||||
if (m.id !== pendingInterrupt.assistantMsgId) return m;
|
if (m.id !== pendingInterrupt.assistantMsgId) return m;
|
||||||
|
|
@ -1554,7 +1601,7 @@ export default function NewChatPage() {
|
||||||
return { ...m, content: newContent as unknown as ThreadMessageLike["content"] };
|
return { ...m, content: newContent as unknown as ThreadMessageLike["content"] };
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
handleResume(detail.decisions);
|
handleResume(submittedDecisions);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
window.addEventListener("hitl-decision", handler);
|
window.addEventListener("hitl-decision", handler);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue