mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/permissions: surface MCP tool metadata into ask interrupts
The FE permission card needs mcp_connector_id, mcp_server, and tool_description in the interrupt context to render "Always Allow" against the right connected account. Thread the tool through the ask pipeline: - pack_subagent → build_permission_mw(tools=...) → PermissionMiddleware (tools_by_name) → request_permission_decision(tool=...) → build_permission_ask_payload(tool=...) projects card fields out of BaseTool. - mcp_tool.py: stdio path now stashes mcp_connector_id in metadata for parity with the HTTP path.
This commit is contained in:
parent
ef1152b80e
commit
a97d1548a6
7 changed files with 236 additions and 31 deletions
|
|
@ -1,20 +1,11 @@
|
|||
"""Build the permission-ask interrupt payload (LC HITL wire + SurfSense context).
|
||||
|
||||
The FE's PermissionCard renders from:
|
||||
|
||||
- Standard langchain fields (``action_requests``, ``review_configs``) — drive
|
||||
the action chrome and the parallel-HITL routing layer (``task_tool``,
|
||||
``resume_routing``) that batches concurrent approvals.
|
||||
- ``interrupt_type="permission_ask"`` — selects the permission card variant.
|
||||
- ``context.patterns`` / ``context.rules`` — explain *why* the ask fired.
|
||||
- ``context.always`` — the patterns the user can promote to a permanent
|
||||
allow rule with a single ``"always"`` reply.
|
||||
"""
|
||||
"""Build the permission-ask interrupt payload (LC HITL wire + SurfSense context)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.wire import (
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_EDIT,
|
||||
|
|
@ -26,8 +17,6 @@ from app.agents.new_chat.permissions import Rule
|
|||
|
||||
PERMISSION_ASK_INTERRUPT_TYPE = "permission_ask"
|
||||
|
||||
# The full palette a permission card may surface: approve once, edit-then-
|
||||
# approve, reject, or "always" to promote the matched pattern.
|
||||
_PERMISSION_ASK_DECISIONS: list[str] = [
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_REJECT,
|
||||
|
|
@ -36,36 +25,45 @@ _PERMISSION_ASK_DECISIONS: list[str] = [
|
|||
]
|
||||
|
||||
|
||||
def _card_fields_from_tool(tool: BaseTool | None) -> dict[str, Any]:
|
||||
"""Project the FE card's tool-scoped fields out of a BaseTool."""
|
||||
if tool is None:
|
||||
return {}
|
||||
metadata = getattr(tool, "metadata", None) or {}
|
||||
fields: dict[str, Any] = {}
|
||||
connector_id = metadata.get("mcp_connector_id")
|
||||
if connector_id is not None:
|
||||
fields["mcp_connector_id"] = connector_id
|
||||
connector_name = metadata.get("mcp_connector_name")
|
||||
if connector_name:
|
||||
fields["mcp_server"] = connector_name
|
||||
if tool.description:
|
||||
fields["tool_description"] = tool.description
|
||||
return fields
|
||||
|
||||
|
||||
def build_permission_ask_payload(
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
patterns: list[str],
|
||||
rules: list[Rule],
|
||||
tool: BaseTool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the permission-ask interrupt payload.
|
||||
|
||||
Args:
|
||||
tool_name: The tool whose call is being reviewed.
|
||||
args: The tool call arguments shown in the card.
|
||||
patterns: Wildcard patterns the call matched (drives ``always``).
|
||||
rules: Matched ruleset entries surfaced for explainability.
|
||||
|
||||
Returns:
|
||||
A dict suitable for ``langgraph.types.interrupt(...)`` carrying both
|
||||
the LC HITL standard fields and SurfSense-specific context.
|
||||
``tool`` carries the FE card's tool-scoped fields (description, MCP
|
||||
connector). When omitted the card still renders, just without the
|
||||
"Always Allow against this connected account" surface.
|
||||
"""
|
||||
context: dict[str, Any] = {
|
||||
"patterns": patterns,
|
||||
"rules": [
|
||||
{
|
||||
"permission": r.permission,
|
||||
"pattern": r.pattern,
|
||||
"action": r.action,
|
||||
}
|
||||
{"permission": r.permission, "pattern": r.pattern, "action": r.action}
|
||||
for r in rules
|
||||
],
|
||||
"always": patterns,
|
||||
**_card_fields_from_tool(tool),
|
||||
}
|
||||
return build_lc_hitl_payload(
|
||||
tool_name=tool_name,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
|
@ -29,13 +30,18 @@ def request_permission_decision(
|
|||
patterns: list[str],
|
||||
rules: list[Rule],
|
||||
emit_interrupt: bool,
|
||||
tool: BaseTool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Pause for an ``ask`` decision; return the canonical permission decision dict."""
|
||||
if not emit_interrupt:
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
payload = build_permission_ask_payload(
|
||||
tool_name=tool_name, args=args, patterns=patterns, rules=rules
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
patterns=patterns,
|
||||
rules=rules,
|
||||
tool=tool,
|
||||
)
|
||||
|
||||
with (
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from langchain.agents.middleware.types import (
|
|||
ContextT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||
|
|
@ -72,6 +73,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
same agent instance so newly-allowed rules apply downstream.
|
||||
always_emit_interrupt_payload: Set ``False`` to make ``ask``
|
||||
collapse to ``deny`` (for non-interactive deployments).
|
||||
tools_by_name: Map from tool name to :class:`BaseTool`, used to
|
||||
decorate ``ask`` interrupts with the tool's description and
|
||||
MCP metadata for the FE card.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
|
@ -83,6 +87,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
tools_by_name: dict[str, BaseTool] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
|
|
@ -93,6 +98,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
origin="runtime_approved"
|
||||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
|
||||
|
||||
def _process(
|
||||
self,
|
||||
|
|
@ -142,6 +148,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
patterns=patterns,
|
||||
rules=rules,
|
||||
emit_interrupt=self._emit_interrupt,
|
||||
tool=self._tools_by_name.get(name),
|
||||
)
|
||||
kind = str(decision.get("decision_type") or "reject").lower()
|
||||
edited_args = decision.get("edited_args")
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ redundant here.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
|
@ -38,6 +42,7 @@ def build_permission_mw(
|
|||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_rulesets: list[Ruleset] | None = None,
|
||||
tools: Sequence[BaseTool] | None = None,
|
||||
) -> PermissionMiddleware | None:
|
||||
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
|
||||
|
||||
|
|
@ -51,6 +56,8 @@ def build_permission_mw(
|
|||
aliasing a shared engine. Presence of any subagent ruleset
|
||||
forces the middleware on regardless of ``enable_permission`` —
|
||||
an explicit ``ask`` rule always asks.
|
||||
tools: Subagent tools used to decorate ``ask`` interrupts with
|
||||
FE-card metadata (description, MCP connector). Optional.
|
||||
|
||||
Returns:
|
||||
``None`` when the engine has no rules to enforce
|
||||
|
|
@ -65,7 +72,8 @@ def build_permission_mw(
|
|||
rulesets: list[Ruleset] = [_SURFSENSE_DEFAULTS]
|
||||
if subagent_rulesets:
|
||||
rulesets.extend(subagent_rulesets)
|
||||
return PermissionMiddleware(rulesets=rulesets)
|
||||
tools_by_name = {t.name: t for t in (tools or [])}
|
||||
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
|
||||
|
||||
|
||||
__all__ = ["build_permission_mw"]
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def pack_subagent(
|
|||
if user_allowlist is not None:
|
||||
subagent_rulesets.append(user_allowlist)
|
||||
per_subagent_perm = build_permission_mw(
|
||||
flags=flags, subagent_rulesets=subagent_rulesets
|
||||
flags=flags, subagent_rulesets=subagent_rulesets, tools=tools
|
||||
)
|
||||
|
||||
prepended: list[Any] = []
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "stdio",
|
||||
"mcp_connector_name": connector_name or None,
|
||||
"mcp_connector_id": connector_id,
|
||||
"mcp_is_generic": True,
|
||||
"hitl": True,
|
||||
# Full-args hash: shared identifiers (cloudId, workspaceId, …)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,185 @@
|
|||
"""Permission-ask payload surfaces tool metadata for the FE card."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.payload import (
|
||||
build_permission_ask_payload,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
class _NoArgs(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
async def _noop(**_kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _ask_rule(tool_name: str) -> Rule:
|
||||
return Rule(permission=tool_name, pattern="*", action="ask")
|
||||
|
||||
|
||||
def _make_mcp_tool(*, name: str, connector_id: int, connector_name: str):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Run {name} via MCP.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={
|
||||
"mcp_connector_id": connector_id,
|
||||
"mcp_connector_name": connector_name,
|
||||
"mcp_transport": "http",
|
||||
"hitl": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_payload_surfaces_mcp_fields_from_tool():
|
||||
tool = _make_mcp_tool(
|
||||
name="linear_create_issue", connector_id=42, connector_name="Linear (acme)"
|
||||
)
|
||||
payload = build_permission_ask_payload(
|
||||
tool_name=tool.name,
|
||||
args={"title": "bug"},
|
||||
patterns=[tool.name],
|
||||
rules=[_ask_rule(tool.name)],
|
||||
tool=tool,
|
||||
)
|
||||
ctx = payload["context"]
|
||||
assert ctx["mcp_connector_id"] == 42
|
||||
assert ctx["mcp_server"] == "Linear (acme)"
|
||||
assert ctx["tool_description"] == "Run linear_create_issue via MCP."
|
||||
|
||||
|
||||
def test_payload_omits_tool_fields_when_tool_is_none():
|
||||
payload = build_permission_ask_payload(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/x"},
|
||||
patterns=["rm"],
|
||||
rules=[_ask_rule("rm")],
|
||||
tool=None,
|
||||
)
|
||||
ctx = payload["context"]
|
||||
assert "mcp_connector_id" not in ctx
|
||||
assert "mcp_server" not in ctx
|
||||
assert "tool_description" not in ctx
|
||||
|
||||
|
||||
def test_payload_omits_falsy_mcp_metadata_fields():
|
||||
tool = StructuredTool(
|
||||
name="anon_tool",
|
||||
description="",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={"mcp_connector_id": None, "mcp_connector_name": ""},
|
||||
)
|
||||
ctx = build_permission_ask_payload(
|
||||
tool_name=tool.name,
|
||||
args={},
|
||||
patterns=[tool.name],
|
||||
rules=[_ask_rule(tool.name)],
|
||||
tool=tool,
|
||||
)["context"]
|
||||
assert "mcp_connector_id" not in ctx
|
||||
assert "mcp_server" not in ctx
|
||||
assert "tool_description" not in ctx
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
|
||||
def _node(_state: _State) -> dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_name,
|
||||
"args": args,
|
||||
"id": call_id,
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
return _node
|
||||
|
||||
|
||||
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
|
||||
def after(state: _State) -> dict[str, Any] | None:
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
|
||||
g.add_node("permission", after)
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_decorates_interrupt_with_mcp_tool_metadata():
|
||||
tool = _make_mcp_tool(
|
||||
name="linear_create_issue", connector_id=7, connector_name="Linear"
|
||||
)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[
|
||||
Ruleset(origin="linear", rules=[_ask_rule(tool.name)]),
|
||||
],
|
||||
tools=[tool],
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _compile_graph_with(pm, tool.name, {"title": "bug"}, "call-1")
|
||||
config = {"configurable": {"thread_id": "linear-ask"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
ctx = snap.interrupts[0].value["context"]
|
||||
assert ctx["mcp_connector_id"] == 7
|
||||
assert ctx["mcp_server"] == "Linear"
|
||||
assert ctx["tool_description"] == "Run linear_create_issue via MCP."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_without_tool_index_still_asks_without_tool_fields():
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule("rm")])],
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _compile_graph_with(pm, "rm", {"path": "/tmp/foo"}, "call-rm")
|
||||
config = {"configurable": {"thread_id": "kb-rm"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
ctx = snap.interrupts[0].value["context"]
|
||||
assert "mcp_connector_id" not in ctx
|
||||
assert "mcp_server" not in ctx
|
||||
assert "tool_description" not in ctx
|
||||
Loading…
Add table
Add a link
Reference in a new issue