diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py index b2b553b83..8451b3b7d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py @@ -29,7 +29,10 @@ from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_to from app.agents.new_chat.tools.registry import build_tools_async from app.db import ChatVisibility from app.services.connector_service import ConnectorService -from app.services.user_tool_allowlist import fetch_user_allowlist_rulesets +from app.services.user_tool_allowlist import ( + fetch_user_allowlist_rulesets, + make_trusted_tool_saver, +) from app.utils.perf import get_perf_logger from ..system_prompt import build_main_agent_system_prompt @@ -153,28 +156,37 @@ async def create_multi_agent_chat_deep_agent( # ``ask`` via last-match-wins. Anonymous turns and read failures both # degrade to "no user rules" rather than blocking the turn. user_allowlist_by_subagent: dict[str, Any] = {} + trusted_tool_saver = None if user_id: - _t0 = time.perf_counter() try: import uuid as _uuid - user_allowlist_by_subagent = await fetch_user_allowlist_rulesets( - db_session, - user_id=_uuid.UUID(user_id), - search_space_id=search_space_id, + user_uuid = _uuid.UUID(user_id) + except (TypeError, ValueError): + user_uuid = None + + if user_uuid is not None: + _t0 = time.perf_counter() + try: + user_allowlist_by_subagent = await fetch_user_allowlist_rulesets( + db_session, + user_id=user_uuid, + search_space_id=search_space_id, + ) + except Exception as e: + logging.warning( + "User allow-list fetch failed; subagents will run without user trust rules this turn: %s", + e, + ) + user_allowlist_by_subagent = {} + _perf_log.info( + "[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)", + time.perf_counter() - _t0, + len(user_allowlist_by_subagent), ) - except Exception as e: - logging.warning( - "User allow-list fetch failed; subagents will run without user trust rules this turn: %s", - e, - ) - user_allowlist_by_subagent = {} - _perf_log.info( - "[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)", - time.perf_counter() - _t0, - len(user_allowlist_by_subagent), - ) + trusted_tool_saver = make_trusted_tool_saver(user_uuid) dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent + dependencies["trusted_tool_saver"] = trusted_tool_saver modified_disabled_tools = list(disabled_tools) if disabled_tools else [] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py index a96fca7dd..f5c8e040c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py @@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single parallel-HITL routing layer in ``task_tool`` + ``resume_routing``. - -Per-tool-call flow inside :meth:`_process`: - -1. Skip when the last message has no tool calls. -2. For each call, evaluate the rules. ``deny`` is replaced with a - synthetic :class:`ToolMessage` carrying a typed - :class:`StreamingError`. ``ask`` raises an interrupt via - :mod:`interrupt.request`; the resulting decision is dispatched here: - - - ``once`` → keep the call as-is. - - ``always`` → also extend the runtime ruleset. - - ``reject`` (with feedback) → :class:`CorrectedError`. - - ``reject`` (no feedback) → :class:`RejectedError`. - - ``allow`` keeps the call unchanged. - -3. Returns an updated ``AIMessage`` (tool calls minus the denied ones) - plus any deny ``ToolMessage`` entries appended after it. Tool-list - filtering at ``before_model`` is intentionally not done here — that - would invalidate provider prompt-cache prefixes. """ from __future__ import annotations import logging +from dataclasses import dataclass from typing import Any from langchain.agents.middleware.types import ( @@ -47,6 +28,7 @@ from langgraph.runtime import Runtime from app.agents.new_chat.errors import CorrectedError, RejectedError from app.agents.new_chat.permissions import Ruleset +from app.services.user_tool_allowlist import TrustedToolSaver from ..ask.edit import merge_edited_args from ..ask.request import request_permission_decision @@ -59,6 +41,14 @@ from .runtime_promote import persist_always logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class _AlwaysPromotion: + """A pending request to save an ``always`` decision to the user's trust list.""" + + connector_id: int + tool_name: str + + class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] """Allow/deny/ask layer over the agent's tool calls. @@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] tools_by_name: Map from tool name to :class:`BaseTool`, used to decorate ``ask`` interrupts with the tool's description and MCP metadata for the FE card. + trusted_tool_saver: Async callback invoked on ``always`` decisions + for MCP tools (those whose ``metadata`` carries an + ``mcp_connector_id``). Without it the promotion only lives + in-memory for the current agent instance. """ tools = () @@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] runtime_ruleset: Ruleset | None = None, always_emit_interrupt_payload: bool = True, tools_by_name: dict[str, BaseTool] | None = None, + trusted_tool_saver: TrustedToolSaver | None = None, ) -> None: super().__init__() self._static_rulesets: list[Ruleset] = list(rulesets or []) @@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] ) self._emit_interrupt = always_emit_interrupt_payload self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {}) + self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver def _process( self, state: AgentState, runtime: Runtime[Any], - ) -> dict[str, Any] | None: + ) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]: + """Pure decision pass: returns ``(state_update, pending_promotions)``. + + Side effects performed here are in-memory only (rule promotion + into ``runtime_ruleset``). DB writes for ``always`` decisions + are queued as ``_AlwaysPromotion`` and flushed by the async hook. + """ del runtime messages = state.get("messages") or [] if not messages: - return None + return None, [] last = messages[-1] if not isinstance(last, AIMessage) or not last.tool_calls: - return None + return None, [] rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset) deny_messages: list[ToolMessage] = [] kept_calls: list[dict[str, Any]] = [] + promotions: list[_AlwaysPromotion] = [] any_change = False for raw in last.tool_calls: @@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] any_change = True if kind == "always": persist_always(self._runtime_ruleset, name, patterns) + promotion = self._build_always_promotion(name) + if promotion is not None: + promotions.append(promotion) kept_calls.append(final_call) elif kind == "reject": feedback = decision.get("feedback") @@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] kept_calls.append(call) if not any_change and len(kept_calls) == len(last.tool_calls): - return None + return None, promotions updated = last.model_copy(update={"tool_calls": kept_calls}) result_messages: list[Any] = [updated] if deny_messages: result_messages.extend(deny_messages) - return {"messages": result_messages} + return {"messages": result_messages}, promotions + + def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None: + """Return a save request iff the tool exposes an ``mcp_connector_id``.""" + tool = self._tools_by_name.get(tool_name) + metadata = getattr(tool, "metadata", None) or {} + connector_id = metadata.get("mcp_connector_id") + if not isinstance(connector_id, int): + return None + return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name) def after_model( # type: ignore[override] self, state: AgentState, runtime: Runtime[ContextT] ) -> dict[str, Any] | None: - return self._process(state, runtime) + update, _ = self._process(state, runtime) + return update async def aafter_model( # type: ignore[override] self, state: AgentState, runtime: Runtime[ContextT] ) -> dict[str, Any] | None: - return self._process(state, runtime) + update, promotions = self._process(state, runtime) + if self._trusted_tool_saver is not None: + for promotion in promotions: + await self._trusted_tool_saver( + promotion.connector_id, promotion.tool_name + ) + return update __all__ = ["PermissionMiddleware"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py index 9642e2664..3c061ded6 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py @@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.permissions import Rule, Ruleset +from app.services.user_tool_allowlist import TrustedToolSaver from .core import PermissionMiddleware @@ -43,6 +44,7 @@ def build_permission_mw( flags: AgentFeatureFlags, subagent_rulesets: list[Ruleset] | None = None, tools: Sequence[BaseTool] | None = None, + trusted_tool_saver: TrustedToolSaver | None = None, ) -> PermissionMiddleware | None: """Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed. @@ -58,6 +60,9 @@ def build_permission_mw( an explicit ``ask`` rule always asks. tools: Subagent tools used to decorate ``ask`` interrupts with FE-card metadata (description, MCP connector). Optional. + trusted_tool_saver: Async callback invoked when an MCP tool's + ``always`` decision lands; persists the user's preference to + ``connector.config['trusted_tools']``. Optional. Returns: ``None`` when the engine has no rules to enforce @@ -73,7 +78,11 @@ def build_permission_mw( if subagent_rulesets: rulesets.extend(subagent_rulesets) tools_by_name = {t.name: t for t in (tools or [])} - return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name) + return PermissionMiddleware( + rulesets=rulesets, + tools_by_name=tools_by_name, + trusted_tool_saver=trusted_tool_saver, + ) __all__ = ["build_permission_mw"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py index e6c969678..778bb250c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py @@ -93,7 +93,11 @@ def build_kb_middleware( user_allowlist = _kb_user_allowlist(dependencies, subagent_name) if user_allowlist is not None: rulesets.append(user_allowlist) - permission_mw = build_permission_mw(flags=flags, subagent_rulesets=rulesets) + permission_mw = build_permission_mw( + flags=flags, + subagent_rulesets=rulesets, + trusted_tool_saver=dependencies.get("trusted_tool_saver"), + ) return [ mws["todos"], build_kb_context_projection_mw(), diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py index 3d1fa1504..7173901f9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py @@ -74,7 +74,10 @@ def pack_subagent( if user_allowlist is not None: subagent_rulesets.append(user_allowlist) per_subagent_perm = build_permission_mw( - flags=flags, subagent_rulesets=subagent_rulesets, tools=tools + flags=flags, + subagent_rulesets=subagent_rulesets, + tools=tools, + trusted_tool_saver=dependencies.get("trusted_tool_saver"), ) prepended: list[Any] = [] diff --git a/surfsense_backend/app/services/user_tool_allowlist.py b/surfsense_backend/app/services/user_tool_allowlist.py index 83b075fb7..fb21a7df2 100644 --- a/surfsense_backend/app/services/user_tool_allowlist.py +++ b/surfsense_backend/app/services/user_tool_allowlist.py @@ -1,33 +1,16 @@ -"""User-scoped tool allow-list backed by ``SearchSourceConnector.config``. +"""User-scoped trusted-tools list backed by ``SearchSourceConnector.config``. -Stores the user's "always allow" preferences as a list of tool names under -``connector.config['trusted_tools']``. Storage is per -``(user_id, search_space_id, connector_id)`` — i.e. tied to a specific -connected account inside a specific workspace, exactly what the UI cares -about. - -Callers split into two roles: - -- **Writers** — the ``/connectors/.../trust-tool`` and ``/untrust-tool`` - HTTP routes, and the chat resume handler when it processes a - ``{type: "always"}`` decision. Both call - :func:`add_user_trust` / :func:`remove_user_trust`. The FE button is - the upstream UI trigger but it talks to the routes, never to this - module directly. -- **Reader** — the subagent compile path, which calls - :func:`fetch_user_allowlist_rulesets` and layers the result after the - subagent's coded ruleset. User ``allow`` rules then override coded - ``ask`` via the rule engine's last-match-wins evaluation. - -Coded ``deny`` rules are intentionally **not** overridable by this -allow-list — only ``ask`` can be promoted to ``allow``. The rule engine -enforces this naturally because user rules only ever emit ``allow``. +Storage is per ``(user_id, search_space_id, connector_id)`` under +``connector.config['trusted_tools']``. The list only ever encodes +``allow`` decisions; coded ``deny`` rules cannot be overridden here. """ from __future__ import annotations +import logging import uuid from collections import defaultdict +from collections.abc import Awaitable, Callable from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -37,10 +20,14 @@ from app.agents.multi_agent_chat.constants import ( CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS, ) from app.agents.new_chat.permissions import Rule, Ruleset -from app.db import SearchSourceConnector +from app.db import SearchSourceConnector, async_session_maker + +logger = logging.getLogger(__name__) _TRUSTED_TOOLS_KEY = "trusted_tools" +TrustedToolSaver = Callable[[int, str], Awaitable[None]] + async def _load_owned_connector( session: AsyncSession, @@ -48,11 +35,7 @@ async def _load_owned_connector( user_id: uuid.UUID, connector_id: int, ) -> SearchSourceConnector | None: - """Return a connector iff it belongs to ``user_id``, else ``None``. - - Ownership scoping is mandatory: the trust list mutates user-private - data, callers must never write across user boundaries. - """ + """Return the connector iff owned by ``user_id``, else ``None``.""" result = await session.execute( select(SearchSourceConnector).where( SearchSourceConnector.id == connector_id, @@ -84,11 +67,7 @@ async def add_user_trust( connector_id: int, tool_name: str, ) -> list[str]: - """Append ``tool_name`` to the connector's trusted list (idempotent). - - Returns the updated trusted-tools list. Raises ``LookupError`` when - the connector does not exist or is not owned by ``user_id``. - """ + """Append ``tool_name`` to the connector's trusted list; raise ``LookupError`` if not owned.""" connector = await _load_owned_connector( session, user_id=user_id, connector_id=connector_id ) @@ -112,11 +91,7 @@ async def remove_user_trust( connector_id: int, tool_name: str, ) -> list[str]: - """Remove ``tool_name`` from the connector's trusted list (idempotent). - - Returns the updated trusted-tools list. Raises ``LookupError`` when - the connector does not exist or is not owned by ``user_id``. - """ + """Remove ``tool_name`` from the connector's trusted list; raise ``LookupError`` if not owned.""" connector = await _load_owned_connector( session, user_id=user_id, connector_id=connector_id ) @@ -139,20 +114,10 @@ async def fetch_user_allowlist_rulesets( user_id: uuid.UUID, search_space_id: int, ) -> dict[str, Ruleset]: - """Project the user's trusted-tool lists into per-subagent rulesets. + """Project the user's trusted tools into per-subagent ``allow`` rulesets. - Walks every connector the user owns in this workspace, maps each - ``connector_type`` to its consuming subagent via - :data:`CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS`, and emits one - ``Rule(permission=tool_name, pattern="*", action="allow")`` per - trusted entry. Rules from different connector accounts feeding the - same subagent (e.g. two Linear workspaces) are merged into one - ruleset; duplicates are harmless under last-match-wins. - - Connectors whose type is not mapped (search APIs, Github, etc.) and - connectors with empty trust lists contribute nothing. Subagents - with no trusted tools are absent from the returned dict — callers - should treat ``missing == empty``. + Subagents with no trusted tools are absent from the result — + callers must treat ``missing == empty``. """ result = await session.execute( select( @@ -189,8 +154,35 @@ async def fetch_user_allowlist_rulesets( } +def make_trusted_tool_saver(user_id: uuid.UUID) -> TrustedToolSaver: + """Bind ``user_id`` into a saver closure; failures are logged, never raised.""" + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + try: + async with async_session_maker() as session: + await add_user_trust( + session, + user_id=user_id, + connector_id=connector_id, + tool_name=tool_name, + ) + await session.commit() + except LookupError as exc: + logger.warning("trusted-tool save skipped: %s", exc) + except Exception: + logger.exception( + "trusted-tool save failed for connector=%s tool=%s", + connector_id, + tool_name, + ) + + return trusted_tool_saver + + __all__ = [ + "TrustedToolSaver", "add_user_trust", "fetch_user_allowlist_rulesets", + "make_trusted_tool_saver", "remove_user_trust", ] diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py index b6768c530..f70f027a9 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py @@ -128,7 +128,7 @@ def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str): def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str): def after(state: _State) -> dict[str, Any] | None: - return pm._process(state, None) # type: ignore[arg-type] + return pm.after_model(state, None) # type: ignore[arg-type] g = StateGraph(_State) g.add_node("emit", _emit_tool_call(tool_name, args, call_id)) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py index 6f3f34536..6406fb09a 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py @@ -84,9 +84,7 @@ def _build_graph_with_permission_middleware( def after_node(state: _State) -> dict[str, Any] | None: if pm is None: return None - # PermissionMiddleware._process ignores runtime; the test never relies - # on the runtime context, so passing None keeps the harness lean. - return pm._process(state, None) # type: ignore[arg-type] + return pm.after_model(state, None) # type: ignore[arg-type] g = StateGraph(_State) g.add_node("emit", node) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py new file mode 100644 index 000000000..8b469a2cb --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py @@ -0,0 +1,180 @@ +"""``always`` decisions for MCP tools are saved via the trusted-tool saver.""" + +from __future__ import annotations + +from typing import Annotated, Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import StructuredTool +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command +from pydantic import BaseModel +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset + + +class _NoArgs(BaseModel): + pass + + +async def _noop(**_kwargs) -> str: + return "" + + +def _ask_rule(tool_name: str) -> Rule: + return Rule(permission=tool_name, pattern="*", action="ask") + + +def _make_mcp_tool(*, name: str, connector_id: int): + return StructuredTool( + name=name, + description=f"Run {name} via MCP.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={ + "mcp_connector_id": connector_id, + "mcp_connector_name": "Linear", + "mcp_transport": "http", + "hitl": True, + }, + ) + + +def _make_native_tool(*, name: str): + return StructuredTool( + name=name, + description=f"Native {name}.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={"hitl": True}, + ) + + +class _State(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +def _build_graph(pm, tool_name: str): + def emit(_state: _State) -> dict[str, Any]: + return { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": {}, + "id": "call-1", + "type": "tool_call", + } + ], + ) + ] + } + + g = StateGraph(_State) + g.add_node("emit", emit) + g.add_node("permission", pm.aafter_model) # type: ignore[arg-type] + g.add_edge(START, "emit") + g.add_edge("emit", "permission") + g.add_edge("permission", END) + return g.compile(checkpointer=InMemorySaver()) + + +@pytest.mark.asyncio +async def test_always_decision_saves_mcp_tool_via_callback(): + saved: list[tuple[int, str]] = [] + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + saved.append((connector_id, tool_name)) + + tool = _make_mcp_tool(name="linear_create_issue", connector_id=7) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=trusted_tool_saver, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "always-mcp"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config) + + assert saved == [(7, "linear_create_issue")] + + +@pytest.mark.asyncio +async def test_once_decision_does_not_save(): + saved: list[tuple[int, str]] = [] + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + saved.append((connector_id, tool_name)) + + tool = _make_mcp_tool(name="linear_create_issue", connector_id=7) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=trusted_tool_saver, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "once-mcp"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config) + + assert saved == [] + + +@pytest.mark.asyncio +async def test_always_decision_for_native_tool_skips_save(): + """Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust.""" + saved: list[tuple[int, str]] = [] + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + saved.append((connector_id, tool_name)) + + tool = _make_native_tool(name="rm") + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=trusted_tool_saver, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "always-native"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config) + + assert saved == [] + + +@pytest.mark.asyncio +async def test_always_decision_with_no_saver_callback_is_a_noop(): + """Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash.""" + tool = _make_mcp_tool(name="linear_create_issue", connector_id=7) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=None, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "anon-always"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)