diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py index 270a3888d..21438813e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py @@ -1,20 +1,11 @@ -"""Build the permission-ask interrupt payload (LC HITL wire + SurfSense context). - -The FE's PermissionCard renders from: - -- Standard langchain fields (``action_requests``, ``review_configs``) — drive - the action chrome and the parallel-HITL routing layer (``task_tool``, - ``resume_routing``) that batches concurrent approvals. -- ``interrupt_type="permission_ask"`` — selects the permission card variant. -- ``context.patterns`` / ``context.rules`` — explain *why* the ask fired. -- ``context.always`` — the patterns the user can promote to a permanent - allow rule with a single ``"always"`` reply. -""" +"""Build the permission-ask interrupt payload (LC HITL wire + SurfSense context).""" from __future__ import annotations from typing import Any +from langchain_core.tools import BaseTool + from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( LC_DECISION_APPROVE, LC_DECISION_EDIT, @@ -26,8 +17,6 @@ from app.agents.new_chat.permissions import Rule PERMISSION_ASK_INTERRUPT_TYPE = "permission_ask" -# The full palette a permission card may surface: approve once, edit-then- -# approve, reject, or "always" to promote the matched pattern. _PERMISSION_ASK_DECISIONS: list[str] = [ LC_DECISION_APPROVE, LC_DECISION_REJECT, @@ -36,36 +25,45 @@ _PERMISSION_ASK_DECISIONS: list[str] = [ ] +def _card_fields_from_tool(tool: BaseTool | None) -> dict[str, Any]: + """Project the FE card's tool-scoped fields out of a BaseTool.""" + if tool is None: + return {} + metadata = getattr(tool, "metadata", None) or {} + fields: dict[str, Any] = {} + connector_id = metadata.get("mcp_connector_id") + if connector_id is not None: + fields["mcp_connector_id"] = connector_id + connector_name = metadata.get("mcp_connector_name") + if connector_name: + fields["mcp_server"] = connector_name + if tool.description: + fields["tool_description"] = tool.description + return fields + + def build_permission_ask_payload( *, tool_name: str, args: dict[str, Any], patterns: list[str], rules: list[Rule], + tool: BaseTool | None = None, ) -> dict[str, Any]: """Build the permission-ask interrupt payload. - Args: - tool_name: The tool whose call is being reviewed. - args: The tool call arguments shown in the card. - patterns: Wildcard patterns the call matched (drives ``always``). - rules: Matched ruleset entries surfaced for explainability. - - Returns: - A dict suitable for ``langgraph.types.interrupt(...)`` carrying both - the LC HITL standard fields and SurfSense-specific context. + ``tool`` carries the FE card's tool-scoped fields (description, MCP + connector). When omitted the card still renders, just without the + "Always Allow against this connected account" surface. """ context: dict[str, Any] = { "patterns": patterns, "rules": [ - { - "permission": r.permission, - "pattern": r.pattern, - "action": r.action, - } + {"permission": r.permission, "pattern": r.pattern, "action": r.action} for r in rules ], "always": patterns, + **_card_fields_from_tool(tool), } return build_lc_hitl_payload( tool_name=tool_name, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py index 42e47ef98..d61d38f34 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py @@ -13,6 +13,7 @@ from __future__ import annotations from typing import Any +from langchain_core.tools import BaseTool from langgraph.types import interrupt from app.agents.new_chat.permissions import Rule @@ -29,13 +30,18 @@ def request_permission_decision( patterns: list[str], rules: list[Rule], emit_interrupt: bool, + tool: BaseTool | None = None, ) -> dict[str, Any]: """Pause for an ``ask`` decision; return the canonical permission decision dict.""" if not emit_interrupt: return {"decision_type": "reject"} payload = build_permission_ask_payload( - tool_name=tool_name, args=args, patterns=patterns, rules=rules + tool_name=tool_name, + args=args, + patterns=patterns, + rules=rules, + tool=tool, ) with ( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py index a8bb24143..a96fca7dd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py @@ -42,6 +42,7 @@ from langchain.agents.middleware.types import ( ContextT, ) from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import BaseTool from langgraph.runtime import Runtime from app.agents.new_chat.errors import CorrectedError, RejectedError @@ -72,6 +73,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] same agent instance so newly-allowed rules apply downstream. always_emit_interrupt_payload: Set ``False`` to make ``ask`` collapse to ``deny`` (for non-interactive deployments). + tools_by_name: Map from tool name to :class:`BaseTool`, used to + decorate ``ask`` interrupts with the tool's description and + MCP metadata for the FE card. """ tools = () @@ -83,6 +87,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] pattern_resolvers: dict[str, PatternResolver] | None = None, runtime_ruleset: Ruleset | None = None, always_emit_interrupt_payload: bool = True, + tools_by_name: dict[str, BaseTool] | None = None, ) -> None: super().__init__() self._static_rulesets: list[Ruleset] = list(rulesets or []) @@ -93,6 +98,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] origin="runtime_approved" ) self._emit_interrupt = always_emit_interrupt_payload + self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {}) def _process( self, @@ -142,6 +148,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] patterns=patterns, rules=rules, emit_interrupt=self._emit_interrupt, + tool=self._tools_by_name.get(name), ) kind = str(decision.get("decision_type") or "reject").lower() edited_args = decision.get("edited_args") diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py index 14d5d8eb7..9642e2664 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py @@ -23,6 +23,10 @@ redundant here. from __future__ import annotations +from collections.abc import Sequence + +from langchain_core.tools import BaseTool + from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.permissions import Rule, Ruleset @@ -38,6 +42,7 @@ def build_permission_mw( *, flags: AgentFeatureFlags, subagent_rulesets: list[Ruleset] | None = None, + tools: Sequence[BaseTool] | None = None, ) -> PermissionMiddleware | None: """Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed. @@ -51,6 +56,8 @@ def build_permission_mw( aliasing a shared engine. Presence of any subagent ruleset forces the middleware on regardless of ``enable_permission`` — an explicit ``ask`` rule always asks. + tools: Subagent tools used to decorate ``ask`` interrupts with + FE-card metadata (description, MCP connector). Optional. Returns: ``None`` when the engine has no rules to enforce @@ -65,7 +72,8 @@ def build_permission_mw( rulesets: list[Ruleset] = [_SURFSENSE_DEFAULTS] if subagent_rulesets: rulesets.extend(subagent_rulesets) - return PermissionMiddleware(rulesets=rulesets) + tools_by_name = {t.name: t for t in (tools or [])} + return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name) __all__ = ["build_permission_mw"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py index c61691405..3d1fa1504 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py @@ -74,7 +74,7 @@ def pack_subagent( if user_allowlist is not None: subagent_rulesets.append(user_allowlist) per_subagent_perm = build_permission_mw( - flags=flags, subagent_rulesets=subagent_rulesets + flags=flags, subagent_rulesets=subagent_rulesets, tools=tools ) prepended: list[Any] = [] diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 92a808a5e..64368a878 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -229,6 +229,7 @@ async def _create_mcp_tool_from_definition_stdio( "mcp_input_schema": input_schema, "mcp_transport": "stdio", "mcp_connector_name": connector_name or None, + "mcp_connector_id": connector_id, "mcp_is_generic": True, "hitl": True, # Full-args hash: shared identifiers (cloudId, workspaceId, …) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py new file mode 100644 index 000000000..b6768c530 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py @@ -0,0 +1,185 @@ +"""Permission-ask payload surfaces tool metadata for the FE card.""" + +from __future__ import annotations + +from typing import Annotated, Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import StructuredTool +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from pydantic import BaseModel +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.multi_agent_chat.middleware.shared.permissions.ask.payload import ( + build_permission_ask_payload, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset + + +class _NoArgs(BaseModel): + pass + + +async def _noop(**_kwargs) -> str: + return "" + + +def _ask_rule(tool_name: str) -> Rule: + return Rule(permission=tool_name, pattern="*", action="ask") + + +def _make_mcp_tool(*, name: str, connector_id: int, connector_name: str): + return StructuredTool( + name=name, + description=f"Run {name} via MCP.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={ + "mcp_connector_id": connector_id, + "mcp_connector_name": connector_name, + "mcp_transport": "http", + "hitl": True, + }, + ) + + +def test_payload_surfaces_mcp_fields_from_tool(): + tool = _make_mcp_tool( + name="linear_create_issue", connector_id=42, connector_name="Linear (acme)" + ) + payload = build_permission_ask_payload( + tool_name=tool.name, + args={"title": "bug"}, + patterns=[tool.name], + rules=[_ask_rule(tool.name)], + tool=tool, + ) + ctx = payload["context"] + assert ctx["mcp_connector_id"] == 42 + assert ctx["mcp_server"] == "Linear (acme)" + assert ctx["tool_description"] == "Run linear_create_issue via MCP." + + +def test_payload_omits_tool_fields_when_tool_is_none(): + payload = build_permission_ask_payload( + tool_name="rm", + args={"path": "/tmp/x"}, + patterns=["rm"], + rules=[_ask_rule("rm")], + tool=None, + ) + ctx = payload["context"] + assert "mcp_connector_id" not in ctx + assert "mcp_server" not in ctx + assert "tool_description" not in ctx + + +def test_payload_omits_falsy_mcp_metadata_fields(): + tool = StructuredTool( + name="anon_tool", + description="", + coroutine=_noop, + args_schema=_NoArgs, + metadata={"mcp_connector_id": None, "mcp_connector_name": ""}, + ) + ctx = build_permission_ask_payload( + tool_name=tool.name, + args={}, + patterns=[tool.name], + rules=[_ask_rule(tool.name)], + tool=tool, + )["context"] + assert "mcp_connector_id" not in ctx + assert "mcp_server" not in ctx + assert "tool_description" not in ctx + + +class _State(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str): + def _node(_state: _State) -> dict[str, Any]: + return { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": args, + "id": call_id, + "type": "tool_call", + } + ], + ) + ] + } + + return _node + + +def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str): + def after(state: _State) -> dict[str, Any] | None: + return pm._process(state, None) # type: ignore[arg-type] + + g = StateGraph(_State) + g.add_node("emit", _emit_tool_call(tool_name, args, call_id)) + g.add_node("permission", after) + g.add_edge(START, "emit") + g.add_edge("emit", "permission") + g.add_edge("permission", END) + return g.compile(checkpointer=InMemorySaver()) + + +@pytest.mark.asyncio +async def test_middleware_decorates_interrupt_with_mcp_tool_metadata(): + tool = _make_mcp_tool( + name="linear_create_issue", connector_id=7, connector_name="Linear" + ) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[ + Ruleset(origin="linear", rules=[_ask_rule(tool.name)]), + ], + tools=[tool], + ) + assert pm is not None + + graph = _compile_graph_with(pm, tool.name, {"title": "bug"}, "call-1") + config = {"configurable": {"thread_id": "linear-ask"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1 + ctx = snap.interrupts[0].value["context"] + assert ctx["mcp_connector_id"] == 7 + assert ctx["mcp_server"] == "Linear" + assert ctx["tool_description"] == "Run linear_create_issue via MCP." + + +@pytest.mark.asyncio +async def test_middleware_without_tool_index_still_asks_without_tool_fields(): + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule("rm")])], + ) + assert pm is not None + + graph = _compile_graph_with(pm, "rm", {"path": "/tmp/foo"}, "call-rm") + config = {"configurable": {"thread_id": "kb-rm"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1 + ctx = snap.interrupts[0].value["context"] + assert "mcp_connector_id" not in ctx + assert "mcp_server" not in ctx + assert "tool_description" not in ctx