multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list

Until now an "Always Allow" reply only updated the in-memory runtime
ruleset, evaporating after the session ended. Persist it to the
existing connector.config['trusted_tools'] list so the next session's
fetch_user_allowlist_rulesets picks it up and the user is never asked
again for the same (connector, tool) pair.

- TrustedToolSaver + make_trusted_tool_saver(user_id) in
  user_tool_allowlist: opens its own session via async_session_maker
  per call, logs and swallows failures (in-memory promotion is the
  canonical "always" path, durable persistence is opportunistic).

- PermissionMiddleware._process is now pure: returns
  (state_update, list[_AlwaysPromotion]). aafter_model awaits the
  saver for each promotion; after_model discards them. Promotions are
  only emitted for tools whose metadata exposes mcp_connector_id, so
  native tools and KB FS ops are correctly skipped.

- main_agent factory builds the saver once per turn and stashes it in
  dependencies["trusted_tool_saver"]; pack_subagent and the KB
  middleware stack forward it through build_permission_mw.

- Renamed pm._process(state, None) call sites in two existing tests to
  pm.after_model(state, None) so they exercise the public hook
  contract instead of the now-tuple-returning private method.
This commit is contained in:
CREDO23 2026-05-15 14:07:08 +02:00
parent a97d1548a6
commit 6671c91841
9 changed files with 323 additions and 103 deletions

View file

@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
Per-tool-call flow inside :meth:`_process`:
1. Skip when the last message has no tool calls.
2. For each call, evaluate the rules. ``deny`` is replaced with a
synthetic :class:`ToolMessage` carrying a typed
:class:`StreamingError`. ``ask`` raises an interrupt via
:mod:`interrupt.request`; the resulting decision is dispatched here:
- ``once`` keep the call as-is.
- ``always`` also extend the runtime ruleset.
- ``reject`` (with feedback) :class:`CorrectedError`.
- ``reject`` (no feedback) :class:`RejectedError`.
``allow`` keeps the call unchanged.
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
plus any deny ``ToolMessage`` entries appended after it. Tool-list
filtering at ``before_model`` is intentionally not done here that
would invalidate provider prompt-cache prefixes.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from langchain.agents.middleware.types import (
@ -47,6 +28,7 @@ from langgraph.runtime import Runtime
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.permissions import Ruleset
from app.services.user_tool_allowlist import TrustedToolSaver
from ..ask.edit import merge_edited_args
from ..ask.request import request_permission_decision
@ -59,6 +41,14 @@ from .runtime_promote import persist_always
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class _AlwaysPromotion:
"""A pending request to save an ``always`` decision to the user's trust list."""
connector_id: int
tool_name: str
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
tools_by_name: Map from tool name to :class:`BaseTool`, used to
decorate ``ask`` interrupts with the tool's description and
MCP metadata for the FE card.
trusted_tool_saver: Async callback invoked on ``always`` decisions
for MCP tools (those whose ``metadata`` carries an
``mcp_connector_id``). Without it the promotion only lives
in-memory for the current agent instance.
"""
tools = ()
@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
runtime_ruleset: Ruleset | None = None,
always_emit_interrupt_payload: bool = True,
tools_by_name: dict[str, BaseTool] | None = None,
trusted_tool_saver: TrustedToolSaver | None = None,
) -> None:
super().__init__()
self._static_rulesets: list[Ruleset] = list(rulesets or [])
@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
self._emit_interrupt = always_emit_interrupt_payload
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver
def _process(
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]:
"""Pure decision pass: returns ``(state_update, pending_promotions)``.
Side effects performed here are in-memory only (rule promotion
into ``runtime_ruleset``). DB writes for ``always`` decisions
are queued as ``_AlwaysPromotion`` and flushed by the async hook.
"""
del runtime
messages = state.get("messages") or []
if not messages:
return None
return None, []
last = messages[-1]
if not isinstance(last, AIMessage) or not last.tool_calls:
return None
return None, []
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
deny_messages: list[ToolMessage] = []
kept_calls: list[dict[str, Any]] = []
promotions: list[_AlwaysPromotion] = []
any_change = False
for raw in last.tool_calls:
@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
any_change = True
if kind == "always":
persist_always(self._runtime_ruleset, name, patterns)
promotion = self._build_always_promotion(name)
if promotion is not None:
promotions.append(promotion)
kept_calls.append(final_call)
elif kind == "reject":
feedback = decision.get("feedback")
@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
kept_calls.append(call)
if not any_change and len(kept_calls) == len(last.tool_calls):
return None
return None, promotions
updated = last.model_copy(update={"tool_calls": kept_calls})
result_messages: list[Any] = [updated]
if deny_messages:
result_messages.extend(deny_messages)
return {"messages": result_messages}
return {"messages": result_messages}, promotions
def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None:
"""Return a save request iff the tool exposes an ``mcp_connector_id``."""
tool = self._tools_by_name.get(tool_name)
metadata = getattr(tool, "metadata", None) or {}
connector_id = metadata.get("mcp_connector_id")
if not isinstance(connector_id, int):
return None
return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name)
def after_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
update, _ = self._process(state, runtime)
return update
async def aafter_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
update, promotions = self._process(state, runtime)
if self._trusted_tool_saver is not None:
for promotion in promotions:
await self._trusted_tool_saver(
promotion.connector_id, promotion.tool_name
)
return update
__all__ = ["PermissionMiddleware"]

View file

@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
from app.services.user_tool_allowlist import TrustedToolSaver
from .core import PermissionMiddleware
@ -43,6 +44,7 @@ def build_permission_mw(
flags: AgentFeatureFlags,
subagent_rulesets: list[Ruleset] | None = None,
tools: Sequence[BaseTool] | None = None,
trusted_tool_saver: TrustedToolSaver | None = None,
) -> PermissionMiddleware | None:
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
@ -58,6 +60,9 @@ def build_permission_mw(
an explicit ``ask`` rule always asks.
tools: Subagent tools used to decorate ``ask`` interrupts with
FE-card metadata (description, MCP connector). Optional.
trusted_tool_saver: Async callback invoked when an MCP tool's
``always`` decision lands; persists the user's preference to
``connector.config['trusted_tools']``. Optional.
Returns:
``None`` when the engine has no rules to enforce
@ -73,7 +78,11 @@ def build_permission_mw(
if subagent_rulesets:
rulesets.extend(subagent_rulesets)
tools_by_name = {t.name: t for t in (tools or [])}
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
return PermissionMiddleware(
rulesets=rulesets,
tools_by_name=tools_by_name,
trusted_tool_saver=trusted_tool_saver,
)
__all__ = ["build_permission_mw"]