multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list

Until now an "Always Allow" reply only updated the in-memory runtime
ruleset, evaporating after the session ended. Persist it to the
existing connector.config['trusted_tools'] list so the next session's
fetch_user_allowlist_rulesets picks it up and the user is never asked
again for the same (connector, tool) pair.

- TrustedToolSaver + make_trusted_tool_saver(user_id) in
  user_tool_allowlist: opens its own session via async_session_maker
  per call, logs and swallows failures (in-memory promotion is the
  canonical "always" path, durable persistence is opportunistic).

- PermissionMiddleware._process is now pure: returns
  (state_update, list[_AlwaysPromotion]). aafter_model awaits the
  saver for each promotion; after_model discards them. Promotions are
  only emitted for tools whose metadata exposes mcp_connector_id, so
  native tools and KB FS ops are correctly skipped.

- main_agent factory builds the saver once per turn and stashes it in
  dependencies["trusted_tool_saver"]; pack_subagent and the KB
  middleware stack forward it through build_permission_mw.

- Renamed pm._process(state, None) call sites in two existing tests to
  pm.after_model(state, None) so they exercise the public hook
  contract instead of the now-tuple-returning private method.
This commit is contained in:
CREDO23 2026-05-15 14:07:08 +02:00
parent a97d1548a6
commit 6671c91841
9 changed files with 323 additions and 103 deletions

View file

@ -29,7 +29,10 @@ from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_to
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.services.user_tool_allowlist import fetch_user_allowlist_rulesets
from app.services.user_tool_allowlist import (
fetch_user_allowlist_rulesets,
make_trusted_tool_saver,
)
from app.utils.perf import get_perf_logger
from ..system_prompt import build_main_agent_system_prompt
@ -153,28 +156,37 @@ async def create_multi_agent_chat_deep_agent(
# ``ask`` via last-match-wins. Anonymous turns and read failures both
# degrade to "no user rules" rather than blocking the turn.
user_allowlist_by_subagent: dict[str, Any] = {}
trusted_tool_saver = None
if user_id:
_t0 = time.perf_counter()
try:
import uuid as _uuid
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
db_session,
user_id=_uuid.UUID(user_id),
search_space_id=search_space_id,
user_uuid = _uuid.UUID(user_id)
except (TypeError, ValueError):
user_uuid = None
if user_uuid is not None:
_t0 = time.perf_counter()
try:
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
db_session,
user_id=user_uuid,
search_space_id=search_space_id,
)
except Exception as e:
logging.warning(
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
e,
)
user_allowlist_by_subagent = {}
_perf_log.info(
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
time.perf_counter() - _t0,
len(user_allowlist_by_subagent),
)
except Exception as e:
logging.warning(
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
e,
)
user_allowlist_by_subagent = {}
_perf_log.info(
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
time.perf_counter() - _t0,
len(user_allowlist_by_subagent),
)
trusted_tool_saver = make_trusted_tool_saver(user_uuid)
dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent
dependencies["trusted_tool_saver"] = trusted_tool_saver
modified_disabled_tools = list(disabled_tools) if disabled_tools else []

View file

@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
Per-tool-call flow inside :meth:`_process`:
1. Skip when the last message has no tool calls.
2. For each call, evaluate the rules. ``deny`` is replaced with a
synthetic :class:`ToolMessage` carrying a typed
:class:`StreamingError`. ``ask`` raises an interrupt via
:mod:`interrupt.request`; the resulting decision is dispatched here:
- ``once`` keep the call as-is.
- ``always`` also extend the runtime ruleset.
- ``reject`` (with feedback) :class:`CorrectedError`.
- ``reject`` (no feedback) :class:`RejectedError`.
``allow`` keeps the call unchanged.
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
plus any deny ``ToolMessage`` entries appended after it. Tool-list
filtering at ``before_model`` is intentionally not done here that
would invalidate provider prompt-cache prefixes.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from langchain.agents.middleware.types import (
@ -47,6 +28,7 @@ from langgraph.runtime import Runtime
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.permissions import Ruleset
from app.services.user_tool_allowlist import TrustedToolSaver
from ..ask.edit import merge_edited_args
from ..ask.request import request_permission_decision
@ -59,6 +41,14 @@ from .runtime_promote import persist_always
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class _AlwaysPromotion:
"""A pending request to save an ``always`` decision to the user's trust list."""
connector_id: int
tool_name: str
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
tools_by_name: Map from tool name to :class:`BaseTool`, used to
decorate ``ask`` interrupts with the tool's description and
MCP metadata for the FE card.
trusted_tool_saver: Async callback invoked on ``always`` decisions
for MCP tools (those whose ``metadata`` carries an
``mcp_connector_id``). Without it the promotion only lives
in-memory for the current agent instance.
"""
tools = ()
@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
runtime_ruleset: Ruleset | None = None,
always_emit_interrupt_payload: bool = True,
tools_by_name: dict[str, BaseTool] | None = None,
trusted_tool_saver: TrustedToolSaver | None = None,
) -> None:
super().__init__()
self._static_rulesets: list[Ruleset] = list(rulesets or [])
@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
self._emit_interrupt = always_emit_interrupt_payload
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver
def _process(
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]:
"""Pure decision pass: returns ``(state_update, pending_promotions)``.
Side effects performed here are in-memory only (rule promotion
into ``runtime_ruleset``). DB writes for ``always`` decisions
are queued as ``_AlwaysPromotion`` and flushed by the async hook.
"""
del runtime
messages = state.get("messages") or []
if not messages:
return None
return None, []
last = messages[-1]
if not isinstance(last, AIMessage) or not last.tool_calls:
return None
return None, []
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
deny_messages: list[ToolMessage] = []
kept_calls: list[dict[str, Any]] = []
promotions: list[_AlwaysPromotion] = []
any_change = False
for raw in last.tool_calls:
@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
any_change = True
if kind == "always":
persist_always(self._runtime_ruleset, name, patterns)
promotion = self._build_always_promotion(name)
if promotion is not None:
promotions.append(promotion)
kept_calls.append(final_call)
elif kind == "reject":
feedback = decision.get("feedback")
@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
kept_calls.append(call)
if not any_change and len(kept_calls) == len(last.tool_calls):
return None
return None, promotions
updated = last.model_copy(update={"tool_calls": kept_calls})
result_messages: list[Any] = [updated]
if deny_messages:
result_messages.extend(deny_messages)
return {"messages": result_messages}
return {"messages": result_messages}, promotions
def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None:
"""Return a save request iff the tool exposes an ``mcp_connector_id``."""
tool = self._tools_by_name.get(tool_name)
metadata = getattr(tool, "metadata", None) or {}
connector_id = metadata.get("mcp_connector_id")
if not isinstance(connector_id, int):
return None
return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name)
def after_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
update, _ = self._process(state, runtime)
return update
async def aafter_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
update, promotions = self._process(state, runtime)
if self._trusted_tool_saver is not None:
for promotion in promotions:
await self._trusted_tool_saver(
promotion.connector_id, promotion.tool_name
)
return update
__all__ = ["PermissionMiddleware"]

View file

@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
from app.services.user_tool_allowlist import TrustedToolSaver
from .core import PermissionMiddleware
@ -43,6 +44,7 @@ def build_permission_mw(
flags: AgentFeatureFlags,
subagent_rulesets: list[Ruleset] | None = None,
tools: Sequence[BaseTool] | None = None,
trusted_tool_saver: TrustedToolSaver | None = None,
) -> PermissionMiddleware | None:
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
@ -58,6 +60,9 @@ def build_permission_mw(
an explicit ``ask`` rule always asks.
tools: Subagent tools used to decorate ``ask`` interrupts with
FE-card metadata (description, MCP connector). Optional.
trusted_tool_saver: Async callback invoked when an MCP tool's
``always`` decision lands; persists the user's preference to
``connector.config['trusted_tools']``. Optional.
Returns:
``None`` when the engine has no rules to enforce
@ -73,7 +78,11 @@ def build_permission_mw(
if subagent_rulesets:
rulesets.extend(subagent_rulesets)
tools_by_name = {t.name: t for t in (tools or [])}
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
return PermissionMiddleware(
rulesets=rulesets,
tools_by_name=tools_by_name,
trusted_tool_saver=trusted_tool_saver,
)
__all__ = ["build_permission_mw"]

View file

@ -93,7 +93,11 @@ def build_kb_middleware(
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
if user_allowlist is not None:
rulesets.append(user_allowlist)
permission_mw = build_permission_mw(flags=flags, subagent_rulesets=rulesets)
permission_mw = build_permission_mw(
flags=flags,
subagent_rulesets=rulesets,
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
)
return [
mws["todos"],
build_kb_context_projection_mw(),

View file

@ -74,7 +74,10 @@ def pack_subagent(
if user_allowlist is not None:
subagent_rulesets.append(user_allowlist)
per_subagent_perm = build_permission_mw(
flags=flags, subagent_rulesets=subagent_rulesets, tools=tools
flags=flags,
subagent_rulesets=subagent_rulesets,
tools=tools,
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
)
prepended: list[Any] = []

View file

@ -1,33 +1,16 @@
"""User-scoped tool allow-list backed by ``SearchSourceConnector.config``.
"""User-scoped trusted-tools list backed by ``SearchSourceConnector.config``.
Stores the user's "always allow" preferences as a list of tool names under
``connector.config['trusted_tools']``. Storage is per
``(user_id, search_space_id, connector_id)`` i.e. tied to a specific
connected account inside a specific workspace, exactly what the UI cares
about.
Callers split into two roles:
- **Writers** the ``/connectors/.../trust-tool`` and ``/untrust-tool``
HTTP routes, and the chat resume handler when it processes a
``{type: "always"}`` decision. Both call
:func:`add_user_trust` / :func:`remove_user_trust`. The FE button is
the upstream UI trigger but it talks to the routes, never to this
module directly.
- **Reader** the subagent compile path, which calls
:func:`fetch_user_allowlist_rulesets` and layers the result after the
subagent's coded ruleset. User ``allow`` rules then override coded
``ask`` via the rule engine's last-match-wins evaluation.
Coded ``deny`` rules are intentionally **not** overridable by this
allow-list only ``ask`` can be promoted to ``allow``. The rule engine
enforces this naturally because user rules only ever emit ``allow``.
Storage is per ``(user_id, search_space_id, connector_id)`` under
``connector.config['trusted_tools']``. The list only ever encodes
``allow`` decisions; coded ``deny`` rules cannot be overridden here.
"""
from __future__ import annotations
import logging
import uuid
from collections import defaultdict
from collections.abc import Awaitable, Callable
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@ -37,10 +20,14 @@ from app.agents.multi_agent_chat.constants import (
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
)
from app.agents.new_chat.permissions import Rule, Ruleset
from app.db import SearchSourceConnector
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
_TRUSTED_TOOLS_KEY = "trusted_tools"
TrustedToolSaver = Callable[[int, str], Awaitable[None]]
async def _load_owned_connector(
session: AsyncSession,
@ -48,11 +35,7 @@ async def _load_owned_connector(
user_id: uuid.UUID,
connector_id: int,
) -> SearchSourceConnector | None:
"""Return a connector iff it belongs to ``user_id``, else ``None``.
Ownership scoping is mandatory: the trust list mutates user-private
data, callers must never write across user boundaries.
"""
"""Return the connector iff owned by ``user_id``, else ``None``."""
result = await session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id,
@ -84,11 +67,7 @@ async def add_user_trust(
connector_id: int,
tool_name: str,
) -> list[str]:
"""Append ``tool_name`` to the connector's trusted list (idempotent).
Returns the updated trusted-tools list. Raises ``LookupError`` when
the connector does not exist or is not owned by ``user_id``.
"""
"""Append ``tool_name`` to the connector's trusted list; raise ``LookupError`` if not owned."""
connector = await _load_owned_connector(
session, user_id=user_id, connector_id=connector_id
)
@ -112,11 +91,7 @@ async def remove_user_trust(
connector_id: int,
tool_name: str,
) -> list[str]:
"""Remove ``tool_name`` from the connector's trusted list (idempotent).
Returns the updated trusted-tools list. Raises ``LookupError`` when
the connector does not exist or is not owned by ``user_id``.
"""
"""Remove ``tool_name`` from the connector's trusted list; raise ``LookupError`` if not owned."""
connector = await _load_owned_connector(
session, user_id=user_id, connector_id=connector_id
)
@ -139,20 +114,10 @@ async def fetch_user_allowlist_rulesets(
user_id: uuid.UUID,
search_space_id: int,
) -> dict[str, Ruleset]:
"""Project the user's trusted-tool lists into per-subagent rulesets.
"""Project the user's trusted tools into per-subagent ``allow`` rulesets.
Walks every connector the user owns in this workspace, maps each
``connector_type`` to its consuming subagent via
:data:`CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS`, and emits one
``Rule(permission=tool_name, pattern="*", action="allow")`` per
trusted entry. Rules from different connector accounts feeding the
same subagent (e.g. two Linear workspaces) are merged into one
ruleset; duplicates are harmless under last-match-wins.
Connectors whose type is not mapped (search APIs, Github, etc.) and
connectors with empty trust lists contribute nothing. Subagents
with no trusted tools are absent from the returned dict callers
should treat ``missing == empty``.
Subagents with no trusted tools are absent from the result
callers must treat ``missing == empty``.
"""
result = await session.execute(
select(
@ -189,8 +154,35 @@ async def fetch_user_allowlist_rulesets(
}
def make_trusted_tool_saver(user_id: uuid.UUID) -> TrustedToolSaver:
"""Bind ``user_id`` into a saver closure; failures are logged, never raised."""
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
try:
async with async_session_maker() as session:
await add_user_trust(
session,
user_id=user_id,
connector_id=connector_id,
tool_name=tool_name,
)
await session.commit()
except LookupError as exc:
logger.warning("trusted-tool save skipped: %s", exc)
except Exception:
logger.exception(
"trusted-tool save failed for connector=%s tool=%s",
connector_id,
tool_name,
)
return trusted_tool_saver
__all__ = [
"TrustedToolSaver",
"add_user_trust",
"fetch_user_allowlist_rulesets",
"make_trusted_tool_saver",
"remove_user_trust",
]

View file

@ -128,7 +128,7 @@ def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
def after(state: _State) -> dict[str, Any] | None:
return pm._process(state, None) # type: ignore[arg-type]
return pm.after_model(state, None) # type: ignore[arg-type]
g = StateGraph(_State)
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))

View file

@ -84,9 +84,7 @@ def _build_graph_with_permission_middleware(
def after_node(state: _State) -> dict[str, Any] | None:
if pm is None:
return None
# PermissionMiddleware._process ignores runtime; the test never relies
# on the runtime context, so passing None keeps the harness lean.
return pm._process(state, None) # type: ignore[arg-type]
return pm.after_model(state, None) # type: ignore[arg-type]
g = StateGraph(_State)
g.add_node("emit", node)

View file

@ -0,0 +1,180 @@
"""``always`` decisions for MCP tools are saved via the trusted-tool saver."""
from __future__ import annotations
from typing import Annotated, Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command
from pydantic import BaseModel
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions import (
build_permission_mw,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
class _NoArgs(BaseModel):
pass
async def _noop(**_kwargs) -> str:
return ""
def _ask_rule(tool_name: str) -> Rule:
return Rule(permission=tool_name, pattern="*", action="ask")
def _make_mcp_tool(*, name: str, connector_id: int):
return StructuredTool(
name=name,
description=f"Run {name} via MCP.",
coroutine=_noop,
args_schema=_NoArgs,
metadata={
"mcp_connector_id": connector_id,
"mcp_connector_name": "Linear",
"mcp_transport": "http",
"hitl": True,
},
)
def _make_native_tool(*, name: str):
return StructuredTool(
name=name,
description=f"Native {name}.",
coroutine=_noop,
args_schema=_NoArgs,
metadata={"hitl": True},
)
class _State(TypedDict, total=False):
messages: Annotated[list, add_messages]
def _build_graph(pm, tool_name: str):
def emit(_state: _State) -> dict[str, Any]:
return {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": tool_name,
"args": {},
"id": "call-1",
"type": "tool_call",
}
],
)
]
}
g = StateGraph(_State)
g.add_node("emit", emit)
g.add_node("permission", pm.aafter_model) # type: ignore[arg-type]
g.add_edge(START, "emit")
g.add_edge("emit", "permission")
g.add_edge("permission", END)
return g.compile(checkpointer=InMemorySaver())
@pytest.mark.asyncio
async def test_always_decision_saves_mcp_tool_via_callback():
saved: list[tuple[int, str]] = []
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
saved.append((connector_id, tool_name))
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=trusted_tool_saver,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "always-mcp"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
assert saved == [(7, "linear_create_issue")]
@pytest.mark.asyncio
async def test_once_decision_does_not_save():
saved: list[tuple[int, str]] = []
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
saved.append((connector_id, tool_name))
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=trusted_tool_saver,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "once-mcp"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
assert saved == []
@pytest.mark.asyncio
async def test_always_decision_for_native_tool_skips_save():
"""Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust."""
saved: list[tuple[int, str]] = []
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
saved.append((connector_id, tool_name))
tool = _make_native_tool(name="rm")
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=trusted_tool_saver,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "always-native"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
assert saved == []
@pytest.mark.asyncio
async def test_always_decision_with_no_saver_callback_is_a_noop():
"""Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash."""
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=None,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "anon-always"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)