mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list
Until now an "Always Allow" reply only updated the in-memory runtime ruleset, evaporating after the session ended. Persist it to the existing connector.config['trusted_tools'] list so the next session's fetch_user_allowlist_rulesets picks it up and the user is never asked again for the same (connector, tool) pair. - TrustedToolSaver + make_trusted_tool_saver(user_id) in user_tool_allowlist: opens its own session via async_session_maker per call, logs and swallows failures (in-memory promotion is the canonical "always" path, durable persistence is opportunistic). - PermissionMiddleware._process is now pure: returns (state_update, list[_AlwaysPromotion]). aafter_model awaits the saver for each promotion; after_model discards them. Promotions are only emitted for tools whose metadata exposes mcp_connector_id, so native tools and KB FS ops are correctly skipped. - main_agent factory builds the saver once per turn and stashes it in dependencies["trusted_tool_saver"]; pack_subagent and the KB middleware stack forward it through build_permission_mw. - Renamed pm._process(state, None) call sites in two existing tests to pm.after_model(state, None) so they exercise the public hook contract instead of the now-tuple-returning private method.
This commit is contained in:
parent
a97d1548a6
commit
6671c91841
9 changed files with 323 additions and 103 deletions
|
|
@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the
|
|||
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
|
||||
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
|
||||
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
|
||||
|
||||
Per-tool-call flow inside :meth:`_process`:
|
||||
|
||||
1. Skip when the last message has no tool calls.
|
||||
2. For each call, evaluate the rules. ``deny`` is replaced with a
|
||||
synthetic :class:`ToolMessage` carrying a typed
|
||||
:class:`StreamingError`. ``ask`` raises an interrupt via
|
||||
:mod:`interrupt.request`; the resulting decision is dispatched here:
|
||||
|
||||
- ``once`` → keep the call as-is.
|
||||
- ``always`` → also extend the runtime ruleset.
|
||||
- ``reject`` (with feedback) → :class:`CorrectedError`.
|
||||
- ``reject`` (no feedback) → :class:`RejectedError`.
|
||||
|
||||
``allow`` keeps the call unchanged.
|
||||
|
||||
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
|
||||
plus any deny ``ToolMessage`` entries appended after it. Tool-list
|
||||
filtering at ``before_model`` is intentionally not done here — that
|
||||
would invalidate provider prompt-cache prefixes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
|
|
@ -47,6 +28,7 @@ from langgraph.runtime import Runtime
|
|||
|
||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||
from app.agents.new_chat.permissions import Ruleset
|
||||
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||
|
||||
from ..ask.edit import merge_edited_args
|
||||
from ..ask.request import request_permission_decision
|
||||
|
|
@ -59,6 +41,14 @@ from .runtime_promote import persist_always
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _AlwaysPromotion:
|
||||
"""A pending request to save an ``always`` decision to the user's trust list."""
|
||||
|
||||
connector_id: int
|
||||
tool_name: str
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
|
|
@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
tools_by_name: Map from tool name to :class:`BaseTool`, used to
|
||||
decorate ``ask`` interrupts with the tool's description and
|
||||
MCP metadata for the FE card.
|
||||
trusted_tool_saver: Async callback invoked on ``always`` decisions
|
||||
for MCP tools (those whose ``metadata`` carries an
|
||||
``mcp_connector_id``). Without it the promotion only lives
|
||||
in-memory for the current agent instance.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
|
@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
tools_by_name: dict[str, BaseTool] | None = None,
|
||||
trusted_tool_saver: TrustedToolSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
|
|
@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
|
||||
self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver
|
||||
|
||||
def _process(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]:
|
||||
"""Pure decision pass: returns ``(state_update, pending_promotions)``.
|
||||
|
||||
Side effects performed here are in-memory only (rule promotion
|
||||
into ``runtime_ruleset``). DB writes for ``always`` decisions
|
||||
are queued as ``_AlwaysPromotion`` and flushed by the async hook.
|
||||
"""
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
return None, []
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||
return None
|
||||
return None, []
|
||||
|
||||
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
|
||||
deny_messages: list[ToolMessage] = []
|
||||
kept_calls: list[dict[str, Any]] = []
|
||||
promotions: list[_AlwaysPromotion] = []
|
||||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
|
|
@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
any_change = True
|
||||
if kind == "always":
|
||||
persist_always(self._runtime_ruleset, name, patterns)
|
||||
promotion = self._build_always_promotion(name)
|
||||
if promotion is not None:
|
||||
promotions.append(promotion)
|
||||
kept_calls.append(final_call)
|
||||
elif kind == "reject":
|
||||
feedback = decision.get("feedback")
|
||||
|
|
@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
kept_calls.append(call)
|
||||
|
||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||
return None
|
||||
return None, promotions
|
||||
|
||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||
result_messages: list[Any] = [updated]
|
||||
if deny_messages:
|
||||
result_messages.extend(deny_messages)
|
||||
return {"messages": result_messages}
|
||||
return {"messages": result_messages}, promotions
|
||||
|
||||
def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None:
|
||||
"""Return a save request iff the tool exposes an ``mcp_connector_id``."""
|
||||
tool = self._tools_by_name.get(tool_name)
|
||||
metadata = getattr(tool, "metadata", None) or {}
|
||||
connector_id = metadata.get("mcp_connector_id")
|
||||
if not isinstance(connector_id, int):
|
||||
return None
|
||||
return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name)
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
update, _ = self._process(state, runtime)
|
||||
return update
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
update, promotions = self._process(state, runtime)
|
||||
if self._trusted_tool_saver is not None:
|
||||
for promotion in promotions:
|
||||
await self._trusted_tool_saver(
|
||||
promotion.connector_id, promotion.tool_name
|
||||
)
|
||||
return update
|
||||
|
||||
|
||||
__all__ = ["PermissionMiddleware"]
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool
|
|||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||
|
||||
from .core import PermissionMiddleware
|
||||
|
||||
|
|
@ -43,6 +44,7 @@ def build_permission_mw(
|
|||
flags: AgentFeatureFlags,
|
||||
subagent_rulesets: list[Ruleset] | None = None,
|
||||
tools: Sequence[BaseTool] | None = None,
|
||||
trusted_tool_saver: TrustedToolSaver | None = None,
|
||||
) -> PermissionMiddleware | None:
|
||||
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
|
||||
|
||||
|
|
@ -58,6 +60,9 @@ def build_permission_mw(
|
|||
an explicit ``ask`` rule always asks.
|
||||
tools: Subagent tools used to decorate ``ask`` interrupts with
|
||||
FE-card metadata (description, MCP connector). Optional.
|
||||
trusted_tool_saver: Async callback invoked when an MCP tool's
|
||||
``always`` decision lands; persists the user's preference to
|
||||
``connector.config['trusted_tools']``. Optional.
|
||||
|
||||
Returns:
|
||||
``None`` when the engine has no rules to enforce
|
||||
|
|
@ -73,7 +78,11 @@ def build_permission_mw(
|
|||
if subagent_rulesets:
|
||||
rulesets.extend(subagent_rulesets)
|
||||
tools_by_name = {t.name: t for t in (tools or [])}
|
||||
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
|
||||
return PermissionMiddleware(
|
||||
rulesets=rulesets,
|
||||
tools_by_name=tools_by_name,
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_permission_mw"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue