multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list

Until now an "Always Allow" reply only updated the in-memory runtime
ruleset, evaporating after the session ended. Persist it to the
existing connector.config['trusted_tools'] list so the next session's
fetch_user_allowlist_rulesets picks it up and the user is never asked
again for the same (connector, tool) pair.

- TrustedToolSaver + make_trusted_tool_saver(user_id) in
  user_tool_allowlist: opens its own session via async_session_maker
  per call, logs and swallows failures (in-memory promotion is the
  canonical "always" path, durable persistence is opportunistic).

- PermissionMiddleware._process is now pure: returns
  (state_update, list[_AlwaysPromotion]). aafter_model awaits the
  saver for each promotion; after_model discards them. Promotions are
  only emitted for tools whose metadata exposes mcp_connector_id, so
  native tools and KB FS ops are correctly skipped.

- main_agent factory builds the saver once per turn and stashes it in
  dependencies["trusted_tool_saver"]; pack_subagent and the KB
  middleware stack forward it through build_permission_mw.

- Renamed pm._process(state, None) call sites in two existing tests to
  pm.after_model(state, None) so they exercise the public hook
  contract instead of the now-tuple-returning private method.
This commit is contained in:
CREDO23 2026-05-15 14:07:08 +02:00
parent a97d1548a6
commit 6671c91841
9 changed files with 323 additions and 103 deletions

View file

@ -29,7 +29,10 @@ from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_to
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.services.user_tool_allowlist import fetch_user_allowlist_rulesets
from app.services.user_tool_allowlist import (
fetch_user_allowlist_rulesets,
make_trusted_tool_saver,
)
from app.utils.perf import get_perf_logger
from ..system_prompt import build_main_agent_system_prompt
@ -153,28 +156,37 @@ async def create_multi_agent_chat_deep_agent(
# ``ask`` via last-match-wins. Anonymous turns and read failures both
# degrade to "no user rules" rather than blocking the turn.
user_allowlist_by_subagent: dict[str, Any] = {}
trusted_tool_saver = None
if user_id:
_t0 = time.perf_counter()
try:
import uuid as _uuid
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
db_session,
user_id=_uuid.UUID(user_id),
search_space_id=search_space_id,
user_uuid = _uuid.UUID(user_id)
except (TypeError, ValueError):
user_uuid = None
if user_uuid is not None:
_t0 = time.perf_counter()
try:
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
db_session,
user_id=user_uuid,
search_space_id=search_space_id,
)
except Exception as e:
logging.warning(
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
e,
)
user_allowlist_by_subagent = {}
_perf_log.info(
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
time.perf_counter() - _t0,
len(user_allowlist_by_subagent),
)
except Exception as e:
logging.warning(
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
e,
)
user_allowlist_by_subagent = {}
_perf_log.info(
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
time.perf_counter() - _t0,
len(user_allowlist_by_subagent),
)
trusted_tool_saver = make_trusted_tool_saver(user_uuid)
dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent
dependencies["trusted_tool_saver"] = trusted_tool_saver
modified_disabled_tools = list(disabled_tools) if disabled_tools else []

View file

@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
Per-tool-call flow inside :meth:`_process`:
1. Skip when the last message has no tool calls.
2. For each call, evaluate the rules. ``deny`` is replaced with a
synthetic :class:`ToolMessage` carrying a typed
:class:`StreamingError`. ``ask`` raises an interrupt via
:mod:`interrupt.request`; the resulting decision is dispatched here:
- ``once`` keep the call as-is.
- ``always`` also extend the runtime ruleset.
- ``reject`` (with feedback) :class:`CorrectedError`.
- ``reject`` (no feedback) :class:`RejectedError`.
``allow`` keeps the call unchanged.
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
plus any deny ``ToolMessage`` entries appended after it. Tool-list
filtering at ``before_model`` is intentionally not done here that
would invalidate provider prompt-cache prefixes.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from langchain.agents.middleware.types import (
@ -47,6 +28,7 @@ from langgraph.runtime import Runtime
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.permissions import Ruleset
from app.services.user_tool_allowlist import TrustedToolSaver
from ..ask.edit import merge_edited_args
from ..ask.request import request_permission_decision
@ -59,6 +41,14 @@ from .runtime_promote import persist_always
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class _AlwaysPromotion:
"""A pending request to save an ``always`` decision to the user's trust list."""
connector_id: int
tool_name: str
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
tools_by_name: Map from tool name to :class:`BaseTool`, used to
decorate ``ask`` interrupts with the tool's description and
MCP metadata for the FE card.
trusted_tool_saver: Async callback invoked on ``always`` decisions
for MCP tools (those whose ``metadata`` carries an
``mcp_connector_id``). Without it the promotion only lives
in-memory for the current agent instance.
"""
tools = ()
@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
runtime_ruleset: Ruleset | None = None,
always_emit_interrupt_payload: bool = True,
tools_by_name: dict[str, BaseTool] | None = None,
trusted_tool_saver: TrustedToolSaver | None = None,
) -> None:
super().__init__()
self._static_rulesets: list[Ruleset] = list(rulesets or [])
@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
self._emit_interrupt = always_emit_interrupt_payload
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver
def _process(
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]:
"""Pure decision pass: returns ``(state_update, pending_promotions)``.
Side effects performed here are in-memory only (rule promotion
into ``runtime_ruleset``). DB writes for ``always`` decisions
are queued as ``_AlwaysPromotion`` and flushed by the async hook.
"""
del runtime
messages = state.get("messages") or []
if not messages:
return None
return None, []
last = messages[-1]
if not isinstance(last, AIMessage) or not last.tool_calls:
return None
return None, []
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
deny_messages: list[ToolMessage] = []
kept_calls: list[dict[str, Any]] = []
promotions: list[_AlwaysPromotion] = []
any_change = False
for raw in last.tool_calls:
@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
any_change = True
if kind == "always":
persist_always(self._runtime_ruleset, name, patterns)
promotion = self._build_always_promotion(name)
if promotion is not None:
promotions.append(promotion)
kept_calls.append(final_call)
elif kind == "reject":
feedback = decision.get("feedback")
@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
kept_calls.append(call)
if not any_change and len(kept_calls) == len(last.tool_calls):
return None
return None, promotions
updated = last.model_copy(update={"tool_calls": kept_calls})
result_messages: list[Any] = [updated]
if deny_messages:
result_messages.extend(deny_messages)
return {"messages": result_messages}
return {"messages": result_messages}, promotions
def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None:
"""Return a save request iff the tool exposes an ``mcp_connector_id``."""
tool = self._tools_by_name.get(tool_name)
metadata = getattr(tool, "metadata", None) or {}
connector_id = metadata.get("mcp_connector_id")
if not isinstance(connector_id, int):
return None
return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name)
def after_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
update, _ = self._process(state, runtime)
return update
async def aafter_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
update, promotions = self._process(state, runtime)
if self._trusted_tool_saver is not None:
for promotion in promotions:
await self._trusted_tool_saver(
promotion.connector_id, promotion.tool_name
)
return update
__all__ = ["PermissionMiddleware"]

View file

@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
from app.services.user_tool_allowlist import TrustedToolSaver
from .core import PermissionMiddleware
@ -43,6 +44,7 @@ def build_permission_mw(
flags: AgentFeatureFlags,
subagent_rulesets: list[Ruleset] | None = None,
tools: Sequence[BaseTool] | None = None,
trusted_tool_saver: TrustedToolSaver | None = None,
) -> PermissionMiddleware | None:
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
@ -58,6 +60,9 @@ def build_permission_mw(
an explicit ``ask`` rule always asks.
tools: Subagent tools used to decorate ``ask`` interrupts with
FE-card metadata (description, MCP connector). Optional.
trusted_tool_saver: Async callback invoked when an MCP tool's
``always`` decision lands; persists the user's preference to
``connector.config['trusted_tools']``. Optional.
Returns:
``None`` when the engine has no rules to enforce
@ -73,7 +78,11 @@ def build_permission_mw(
if subagent_rulesets:
rulesets.extend(subagent_rulesets)
tools_by_name = {t.name: t for t in (tools or [])}
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
return PermissionMiddleware(
rulesets=rulesets,
tools_by_name=tools_by_name,
trusted_tool_saver=trusted_tool_saver,
)
__all__ = ["build_permission_mw"]

View file

@ -93,7 +93,11 @@ def build_kb_middleware(
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
if user_allowlist is not None:
rulesets.append(user_allowlist)
permission_mw = build_permission_mw(flags=flags, subagent_rulesets=rulesets)
permission_mw = build_permission_mw(
flags=flags,
subagent_rulesets=rulesets,
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
)
return [
mws["todos"],
build_kb_context_projection_mw(),

View file

@ -74,7 +74,10 @@ def pack_subagent(
if user_allowlist is not None:
subagent_rulesets.append(user_allowlist)
per_subagent_perm = build_permission_mw(
flags=flags, subagent_rulesets=subagent_rulesets, tools=tools
flags=flags,
subagent_rulesets=subagent_rulesets,
tools=tools,
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
)
prepended: list[Any] = []