mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list
Until now an "Always Allow" reply only updated the in-memory runtime ruleset, evaporating after the session ended. Persist it to the existing connector.config['trusted_tools'] list so the next session's fetch_user_allowlist_rulesets picks it up and the user is never asked again for the same (connector, tool) pair. - TrustedToolSaver + make_trusted_tool_saver(user_id) in user_tool_allowlist: opens its own session via async_session_maker per call, logs and swallows failures (in-memory promotion is the canonical "always" path, durable persistence is opportunistic). - PermissionMiddleware._process is now pure: returns (state_update, list[_AlwaysPromotion]). aafter_model awaits the saver for each promotion; after_model discards them. Promotions are only emitted for tools whose metadata exposes mcp_connector_id, so native tools and KB FS ops are correctly skipped. - main_agent factory builds the saver once per turn and stashes it in dependencies["trusted_tool_saver"]; pack_subagent and the KB middleware stack forward it through build_permission_mw. - Renamed pm._process(state, None) call sites in two existing tests to pm.after_model(state, None) so they exercise the public hook contract instead of the now-tuple-returning private method.
This commit is contained in:
parent
a97d1548a6
commit
6671c91841
9 changed files with 323 additions and 103 deletions
|
|
@ -29,7 +29,10 @@ from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_to
|
|||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.user_tool_allowlist import fetch_user_allowlist_rulesets
|
||||
from app.services.user_tool_allowlist import (
|
||||
fetch_user_allowlist_rulesets,
|
||||
make_trusted_tool_saver,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from ..system_prompt import build_main_agent_system_prompt
|
||||
|
|
@ -153,28 +156,37 @@ async def create_multi_agent_chat_deep_agent(
|
|||
# ``ask`` via last-match-wins. Anonymous turns and read failures both
|
||||
# degrade to "no user rules" rather than blocking the turn.
|
||||
user_allowlist_by_subagent: dict[str, Any] = {}
|
||||
trusted_tool_saver = None
|
||||
if user_id:
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
|
||||
db_session,
|
||||
user_id=_uuid.UUID(user_id),
|
||||
search_space_id=search_space_id,
|
||||
user_uuid = _uuid.UUID(user_id)
|
||||
except (TypeError, ValueError):
|
||||
user_uuid = None
|
||||
|
||||
if user_uuid is not None:
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
|
||||
db_session,
|
||||
user_id=user_uuid,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
|
||||
e,
|
||||
)
|
||||
user_allowlist_by_subagent = {}
|
||||
_perf_log.info(
|
||||
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
|
||||
time.perf_counter() - _t0,
|
||||
len(user_allowlist_by_subagent),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
|
||||
e,
|
||||
)
|
||||
user_allowlist_by_subagent = {}
|
||||
_perf_log.info(
|
||||
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
|
||||
time.perf_counter() - _t0,
|
||||
len(user_allowlist_by_subagent),
|
||||
)
|
||||
trusted_tool_saver = make_trusted_tool_saver(user_uuid)
|
||||
dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent
|
||||
dependencies["trusted_tool_saver"] = trusted_tool_saver
|
||||
|
||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||
|
||||
|
|
|
|||
|
|
@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the
|
|||
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
|
||||
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
|
||||
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
|
||||
|
||||
Per-tool-call flow inside :meth:`_process`:
|
||||
|
||||
1. Skip when the last message has no tool calls.
|
||||
2. For each call, evaluate the rules. ``deny`` is replaced with a
|
||||
synthetic :class:`ToolMessage` carrying a typed
|
||||
:class:`StreamingError`. ``ask`` raises an interrupt via
|
||||
:mod:`interrupt.request`; the resulting decision is dispatched here:
|
||||
|
||||
- ``once`` → keep the call as-is.
|
||||
- ``always`` → also extend the runtime ruleset.
|
||||
- ``reject`` (with feedback) → :class:`CorrectedError`.
|
||||
- ``reject`` (no feedback) → :class:`RejectedError`.
|
||||
|
||||
``allow`` keeps the call unchanged.
|
||||
|
||||
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
|
||||
plus any deny ``ToolMessage`` entries appended after it. Tool-list
|
||||
filtering at ``before_model`` is intentionally not done here — that
|
||||
would invalidate provider prompt-cache prefixes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
|
|
@ -47,6 +28,7 @@ from langgraph.runtime import Runtime
|
|||
|
||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||
from app.agents.new_chat.permissions import Ruleset
|
||||
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||
|
||||
from ..ask.edit import merge_edited_args
|
||||
from ..ask.request import request_permission_decision
|
||||
|
|
@ -59,6 +41,14 @@ from .runtime_promote import persist_always
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _AlwaysPromotion:
|
||||
"""A pending request to save an ``always`` decision to the user's trust list."""
|
||||
|
||||
connector_id: int
|
||||
tool_name: str
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
|
|
@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
tools_by_name: Map from tool name to :class:`BaseTool`, used to
|
||||
decorate ``ask`` interrupts with the tool's description and
|
||||
MCP metadata for the FE card.
|
||||
trusted_tool_saver: Async callback invoked on ``always`` decisions
|
||||
for MCP tools (those whose ``metadata`` carries an
|
||||
``mcp_connector_id``). Without it the promotion only lives
|
||||
in-memory for the current agent instance.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
|
@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
tools_by_name: dict[str, BaseTool] | None = None,
|
||||
trusted_tool_saver: TrustedToolSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
|
|
@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
|
||||
self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver
|
||||
|
||||
def _process(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]:
|
||||
"""Pure decision pass: returns ``(state_update, pending_promotions)``.
|
||||
|
||||
Side effects performed here are in-memory only (rule promotion
|
||||
into ``runtime_ruleset``). DB writes for ``always`` decisions
|
||||
are queued as ``_AlwaysPromotion`` and flushed by the async hook.
|
||||
"""
|
||||
del runtime
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
return None, []
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||
return None
|
||||
return None, []
|
||||
|
||||
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
|
||||
deny_messages: list[ToolMessage] = []
|
||||
kept_calls: list[dict[str, Any]] = []
|
||||
promotions: list[_AlwaysPromotion] = []
|
||||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
|
|
@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
any_change = True
|
||||
if kind == "always":
|
||||
persist_always(self._runtime_ruleset, name, patterns)
|
||||
promotion = self._build_always_promotion(name)
|
||||
if promotion is not None:
|
||||
promotions.append(promotion)
|
||||
kept_calls.append(final_call)
|
||||
elif kind == "reject":
|
||||
feedback = decision.get("feedback")
|
||||
|
|
@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
kept_calls.append(call)
|
||||
|
||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||
return None
|
||||
return None, promotions
|
||||
|
||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||
result_messages: list[Any] = [updated]
|
||||
if deny_messages:
|
||||
result_messages.extend(deny_messages)
|
||||
return {"messages": result_messages}
|
||||
return {"messages": result_messages}, promotions
|
||||
|
||||
def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None:
|
||||
"""Return a save request iff the tool exposes an ``mcp_connector_id``."""
|
||||
tool = self._tools_by_name.get(tool_name)
|
||||
metadata = getattr(tool, "metadata", None) or {}
|
||||
connector_id = metadata.get("mcp_connector_id")
|
||||
if not isinstance(connector_id, int):
|
||||
return None
|
||||
return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name)
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
update, _ = self._process(state, runtime)
|
||||
return update
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
update, promotions = self._process(state, runtime)
|
||||
if self._trusted_tool_saver is not None:
|
||||
for promotion in promotions:
|
||||
await self._trusted_tool_saver(
|
||||
promotion.connector_id, promotion.tool_name
|
||||
)
|
||||
return update
|
||||
|
||||
|
||||
__all__ = ["PermissionMiddleware"]
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool
|
|||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||
|
||||
from .core import PermissionMiddleware
|
||||
|
||||
|
|
@ -43,6 +44,7 @@ def build_permission_mw(
|
|||
flags: AgentFeatureFlags,
|
||||
subagent_rulesets: list[Ruleset] | None = None,
|
||||
tools: Sequence[BaseTool] | None = None,
|
||||
trusted_tool_saver: TrustedToolSaver | None = None,
|
||||
) -> PermissionMiddleware | None:
|
||||
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
|
||||
|
||||
|
|
@ -58,6 +60,9 @@ def build_permission_mw(
|
|||
an explicit ``ask`` rule always asks.
|
||||
tools: Subagent tools used to decorate ``ask`` interrupts with
|
||||
FE-card metadata (description, MCP connector). Optional.
|
||||
trusted_tool_saver: Async callback invoked when an MCP tool's
|
||||
``always`` decision lands; persists the user's preference to
|
||||
``connector.config['trusted_tools']``. Optional.
|
||||
|
||||
Returns:
|
||||
``None`` when the engine has no rules to enforce
|
||||
|
|
@ -73,7 +78,11 @@ def build_permission_mw(
|
|||
if subagent_rulesets:
|
||||
rulesets.extend(subagent_rulesets)
|
||||
tools_by_name = {t.name: t for t in (tools or [])}
|
||||
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
|
||||
return PermissionMiddleware(
|
||||
rulesets=rulesets,
|
||||
tools_by_name=tools_by_name,
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_permission_mw"]
|
||||
|
|
|
|||
|
|
@ -93,7 +93,11 @@ def build_kb_middleware(
|
|||
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
|
||||
if user_allowlist is not None:
|
||||
rulesets.append(user_allowlist)
|
||||
permission_mw = build_permission_mw(flags=flags, subagent_rulesets=rulesets)
|
||||
permission_mw = build_permission_mw(
|
||||
flags=flags,
|
||||
subagent_rulesets=rulesets,
|
||||
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
|
||||
)
|
||||
return [
|
||||
mws["todos"],
|
||||
build_kb_context_projection_mw(),
|
||||
|
|
|
|||
|
|
@ -74,7 +74,10 @@ def pack_subagent(
|
|||
if user_allowlist is not None:
|
||||
subagent_rulesets.append(user_allowlist)
|
||||
per_subagent_perm = build_permission_mw(
|
||||
flags=flags, subagent_rulesets=subagent_rulesets, tools=tools
|
||||
flags=flags,
|
||||
subagent_rulesets=subagent_rulesets,
|
||||
tools=tools,
|
||||
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
|
||||
)
|
||||
|
||||
prepended: list[Any] = []
|
||||
|
|
|
|||
|
|
@ -1,33 +1,16 @@
|
|||
"""User-scoped tool allow-list backed by ``SearchSourceConnector.config``.
|
||||
"""User-scoped trusted-tools list backed by ``SearchSourceConnector.config``.
|
||||
|
||||
Stores the user's "always allow" preferences as a list of tool names under
|
||||
``connector.config['trusted_tools']``. Storage is per
|
||||
``(user_id, search_space_id, connector_id)`` — i.e. tied to a specific
|
||||
connected account inside a specific workspace, exactly what the UI cares
|
||||
about.
|
||||
|
||||
Callers split into two roles:
|
||||
|
||||
- **Writers** — the ``/connectors/.../trust-tool`` and ``/untrust-tool``
|
||||
HTTP routes, and the chat resume handler when it processes a
|
||||
``{type: "always"}`` decision. Both call
|
||||
:func:`add_user_trust` / :func:`remove_user_trust`. The FE button is
|
||||
the upstream UI trigger but it talks to the routes, never to this
|
||||
module directly.
|
||||
- **Reader** — the subagent compile path, which calls
|
||||
:func:`fetch_user_allowlist_rulesets` and layers the result after the
|
||||
subagent's coded ruleset. User ``allow`` rules then override coded
|
||||
``ask`` via the rule engine's last-match-wins evaluation.
|
||||
|
||||
Coded ``deny`` rules are intentionally **not** overridable by this
|
||||
allow-list — only ``ask`` can be promoted to ``allow``. The rule engine
|
||||
enforces this naturally because user rules only ever emit ``allow``.
|
||||
Storage is per ``(user_id, search_space_id, connector_id)`` under
|
||||
``connector.config['trusted_tools']``. The list only ever encodes
|
||||
``allow`` decisions; coded ``deny`` rules cannot be overridden here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -37,10 +20,14 @@ from app.agents.multi_agent_chat.constants import (
|
|||
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.db import SearchSourceConnector
|
||||
from app.db import SearchSourceConnector, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRUSTED_TOOLS_KEY = "trusted_tools"
|
||||
|
||||
TrustedToolSaver = Callable[[int, str], Awaitable[None]]
|
||||
|
||||
|
||||
async def _load_owned_connector(
|
||||
session: AsyncSession,
|
||||
|
|
@ -48,11 +35,7 @@ async def _load_owned_connector(
|
|||
user_id: uuid.UUID,
|
||||
connector_id: int,
|
||||
) -> SearchSourceConnector | None:
|
||||
"""Return a connector iff it belongs to ``user_id``, else ``None``.
|
||||
|
||||
Ownership scoping is mandatory: the trust list mutates user-private
|
||||
data, callers must never write across user boundaries.
|
||||
"""
|
||||
"""Return the connector iff owned by ``user_id``, else ``None``."""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
|
|
@ -84,11 +67,7 @@ async def add_user_trust(
|
|||
connector_id: int,
|
||||
tool_name: str,
|
||||
) -> list[str]:
|
||||
"""Append ``tool_name`` to the connector's trusted list (idempotent).
|
||||
|
||||
Returns the updated trusted-tools list. Raises ``LookupError`` when
|
||||
the connector does not exist or is not owned by ``user_id``.
|
||||
"""
|
||||
"""Append ``tool_name`` to the connector's trusted list; raise ``LookupError`` if not owned."""
|
||||
connector = await _load_owned_connector(
|
||||
session, user_id=user_id, connector_id=connector_id
|
||||
)
|
||||
|
|
@ -112,11 +91,7 @@ async def remove_user_trust(
|
|||
connector_id: int,
|
||||
tool_name: str,
|
||||
) -> list[str]:
|
||||
"""Remove ``tool_name`` from the connector's trusted list (idempotent).
|
||||
|
||||
Returns the updated trusted-tools list. Raises ``LookupError`` when
|
||||
the connector does not exist or is not owned by ``user_id``.
|
||||
"""
|
||||
"""Remove ``tool_name`` from the connector's trusted list; raise ``LookupError`` if not owned."""
|
||||
connector = await _load_owned_connector(
|
||||
session, user_id=user_id, connector_id=connector_id
|
||||
)
|
||||
|
|
@ -139,20 +114,10 @@ async def fetch_user_allowlist_rulesets(
|
|||
user_id: uuid.UUID,
|
||||
search_space_id: int,
|
||||
) -> dict[str, Ruleset]:
|
||||
"""Project the user's trusted-tool lists into per-subagent rulesets.
|
||||
"""Project the user's trusted tools into per-subagent ``allow`` rulesets.
|
||||
|
||||
Walks every connector the user owns in this workspace, maps each
|
||||
``connector_type`` to its consuming subagent via
|
||||
:data:`CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS`, and emits one
|
||||
``Rule(permission=tool_name, pattern="*", action="allow")`` per
|
||||
trusted entry. Rules from different connector accounts feeding the
|
||||
same subagent (e.g. two Linear workspaces) are merged into one
|
||||
ruleset; duplicates are harmless under last-match-wins.
|
||||
|
||||
Connectors whose type is not mapped (search APIs, Github, etc.) and
|
||||
connectors with empty trust lists contribute nothing. Subagents
|
||||
with no trusted tools are absent from the returned dict — callers
|
||||
should treat ``missing == empty``.
|
||||
Subagents with no trusted tools are absent from the result —
|
||||
callers must treat ``missing == empty``.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(
|
||||
|
|
@ -189,8 +154,35 @@ async def fetch_user_allowlist_rulesets(
|
|||
}
|
||||
|
||||
|
||||
def make_trusted_tool_saver(user_id: uuid.UUID) -> TrustedToolSaver:
|
||||
"""Bind ``user_id`` into a saver closure; failures are logged, never raised."""
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
await add_user_trust(
|
||||
session,
|
||||
user_id=user_id,
|
||||
connector_id=connector_id,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
await session.commit()
|
||||
except LookupError as exc:
|
||||
logger.warning("trusted-tool save skipped: %s", exc)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"trusted-tool save failed for connector=%s tool=%s",
|
||||
connector_id,
|
||||
tool_name,
|
||||
)
|
||||
|
||||
return trusted_tool_saver
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TrustedToolSaver",
|
||||
"add_user_trust",
|
||||
"fetch_user_allowlist_rulesets",
|
||||
"make_trusted_tool_saver",
|
||||
"remove_user_trust",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
|
|||
|
||||
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
|
||||
def after(state: _State) -> dict[str, Any] | None:
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
|
||||
|
|
|
|||
|
|
@ -84,9 +84,7 @@ def _build_graph_with_permission_middleware(
|
|||
def after_node(state: _State) -> dict[str, Any] | None:
|
||||
if pm is None:
|
||||
return None
|
||||
# PermissionMiddleware._process ignores runtime; the test never relies
|
||||
# on the runtime context, so passing None keeps the harness lean.
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", node)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
"""``always`` decisions for MCP tools are saved via the trusted-tool saver."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
class _NoArgs(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
async def _noop(**_kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _ask_rule(tool_name: str) -> Rule:
|
||||
return Rule(permission=tool_name, pattern="*", action="ask")
|
||||
|
||||
|
||||
def _make_mcp_tool(*, name: str, connector_id: int):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Run {name} via MCP.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={
|
||||
"mcp_connector_id": connector_id,
|
||||
"mcp_connector_name": "Linear",
|
||||
"mcp_transport": "http",
|
||||
"hitl": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_native_tool(*, name: str):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Native {name}.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={"hitl": True},
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _build_graph(pm, tool_name: str):
|
||||
def emit(_state: _State) -> dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_name,
|
||||
"args": {},
|
||||
"id": "call-1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", emit)
|
||||
g.add_node("permission", pm.aafter_model) # type: ignore[arg-type]
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_decision_saves_mcp_tool_via_callback():
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "always-mcp"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||
|
||||
assert saved == [(7, "linear_create_issue")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_once_decision_does_not_save():
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "once-mcp"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
|
||||
|
||||
assert saved == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_decision_for_native_tool_skips_save():
|
||||
"""Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust."""
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_native_tool(name="rm")
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "always-native"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||
|
||||
assert saved == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_decision_with_no_saver_callback_is_a_noop():
|
||||
"""Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash."""
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=None,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "anon-always"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||
Loading…
Add table
Add a link
Reference in a new issue