mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list
Until now an "Always Allow" reply only updated the in-memory runtime ruleset, evaporating after the session ended. Persist it to the existing connector.config['trusted_tools'] list so the next session's fetch_user_allowlist_rulesets picks it up and the user is never asked again for the same (connector, tool) pair. - TrustedToolSaver + make_trusted_tool_saver(user_id) in user_tool_allowlist: opens its own session via async_session_maker per call, logs and swallows failures (in-memory promotion is the canonical "always" path, durable persistence is opportunistic). - PermissionMiddleware._process is now pure: returns (state_update, list[_AlwaysPromotion]). aafter_model awaits the saver for each promotion; after_model discards them. Promotions are only emitted for tools whose metadata exposes mcp_connector_id, so native tools and KB FS ops are correctly skipped. - main_agent factory builds the saver once per turn and stashes it in dependencies["trusted_tool_saver"]; pack_subagent and the KB middleware stack forward it through build_permission_mw. - Renamed pm._process(state, None) call sites in two existing tests to pm.after_model(state, None) so they exercise the public hook contract instead of the now-tuple-returning private method.
This commit is contained in:
parent
a97d1548a6
commit
6671c91841
9 changed files with 323 additions and 103 deletions
|
|
@ -29,7 +29,10 @@ from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_to
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.services.user_tool_allowlist import fetch_user_allowlist_rulesets
|
from app.services.user_tool_allowlist import (
|
||||||
|
fetch_user_allowlist_rulesets,
|
||||||
|
make_trusted_tool_saver,
|
||||||
|
)
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from ..system_prompt import build_main_agent_system_prompt
|
from ..system_prompt import build_main_agent_system_prompt
|
||||||
|
|
@ -153,28 +156,37 @@ async def create_multi_agent_chat_deep_agent(
|
||||||
# ``ask`` via last-match-wins. Anonymous turns and read failures both
|
# ``ask`` via last-match-wins. Anonymous turns and read failures both
|
||||||
# degrade to "no user rules" rather than blocking the turn.
|
# degrade to "no user rules" rather than blocking the turn.
|
||||||
user_allowlist_by_subagent: dict[str, Any] = {}
|
user_allowlist_by_subagent: dict[str, Any] = {}
|
||||||
|
trusted_tool_saver = None
|
||||||
if user_id:
|
if user_id:
|
||||||
_t0 = time.perf_counter()
|
|
||||||
try:
|
try:
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
||||||
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
|
user_uuid = _uuid.UUID(user_id)
|
||||||
db_session,
|
except (TypeError, ValueError):
|
||||||
user_id=_uuid.UUID(user_id),
|
user_uuid = None
|
||||||
search_space_id=search_space_id,
|
|
||||||
|
if user_uuid is not None:
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
|
||||||
|
db_session,
|
||||||
|
user_id=user_uuid,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
user_allowlist_by_subagent = {}
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
len(user_allowlist_by_subagent),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
trusted_tool_saver = make_trusted_tool_saver(user_uuid)
|
||||||
logging.warning(
|
|
||||||
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
user_allowlist_by_subagent = {}
|
|
||||||
_perf_log.info(
|
|
||||||
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
|
|
||||||
time.perf_counter() - _t0,
|
|
||||||
len(user_allowlist_by_subagent),
|
|
||||||
)
|
|
||||||
dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent
|
dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent
|
||||||
|
dependencies["trusted_tool_saver"] = trusted_tool_saver
|
||||||
|
|
||||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,31 +9,12 @@ This middleware layers OpenCode's wildcard-ruleset model on top of the
|
||||||
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
|
unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits
|
||||||
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
|
beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single
|
||||||
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
|
parallel-HITL routing layer in ``task_tool`` + ``resume_routing``.
|
||||||
|
|
||||||
Per-tool-call flow inside :meth:`_process`:
|
|
||||||
|
|
||||||
1. Skip when the last message has no tool calls.
|
|
||||||
2. For each call, evaluate the rules. ``deny`` is replaced with a
|
|
||||||
synthetic :class:`ToolMessage` carrying a typed
|
|
||||||
:class:`StreamingError`. ``ask`` raises an interrupt via
|
|
||||||
:mod:`interrupt.request`; the resulting decision is dispatched here:
|
|
||||||
|
|
||||||
- ``once`` → keep the call as-is.
|
|
||||||
- ``always`` → also extend the runtime ruleset.
|
|
||||||
- ``reject`` (with feedback) → :class:`CorrectedError`.
|
|
||||||
- ``reject`` (no feedback) → :class:`RejectedError`.
|
|
||||||
|
|
||||||
``allow`` keeps the call unchanged.
|
|
||||||
|
|
||||||
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
|
|
||||||
plus any deny ``ToolMessage`` entries appended after it. Tool-list
|
|
||||||
filtering at ``before_model`` is intentionally not done here — that
|
|
||||||
would invalidate provider prompt-cache prefixes.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware.types import (
|
from langchain.agents.middleware.types import (
|
||||||
|
|
@ -47,6 +28,7 @@ from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||||
from app.agents.new_chat.permissions import Ruleset
|
from app.agents.new_chat.permissions import Ruleset
|
||||||
|
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||||
|
|
||||||
from ..ask.edit import merge_edited_args
|
from ..ask.edit import merge_edited_args
|
||||||
from ..ask.request import request_permission_decision
|
from ..ask.request import request_permission_decision
|
||||||
|
|
@ -59,6 +41,14 @@ from .runtime_promote import persist_always
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _AlwaysPromotion:
|
||||||
|
"""A pending request to save an ``always`` decision to the user's trust list."""
|
||||||
|
|
||||||
|
connector_id: int
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
"""Allow/deny/ask layer over the agent's tool calls.
|
"""Allow/deny/ask layer over the agent's tool calls.
|
||||||
|
|
||||||
|
|
@ -76,6 +66,10 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
tools_by_name: Map from tool name to :class:`BaseTool`, used to
|
tools_by_name: Map from tool name to :class:`BaseTool`, used to
|
||||||
decorate ``ask`` interrupts with the tool's description and
|
decorate ``ask`` interrupts with the tool's description and
|
||||||
MCP metadata for the FE card.
|
MCP metadata for the FE card.
|
||||||
|
trusted_tool_saver: Async callback invoked on ``always`` decisions
|
||||||
|
for MCP tools (those whose ``metadata`` carries an
|
||||||
|
``mcp_connector_id``). Without it the promotion only lives
|
||||||
|
in-memory for the current agent instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tools = ()
|
tools = ()
|
||||||
|
|
@ -88,6 +82,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
runtime_ruleset: Ruleset | None = None,
|
runtime_ruleset: Ruleset | None = None,
|
||||||
always_emit_interrupt_payload: bool = True,
|
always_emit_interrupt_payload: bool = True,
|
||||||
tools_by_name: dict[str, BaseTool] | None = None,
|
tools_by_name: dict[str, BaseTool] | None = None,
|
||||||
|
trusted_tool_saver: TrustedToolSaver | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||||
|
|
@ -99,23 +94,31 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
self._emit_interrupt = always_emit_interrupt_payload
|
self._emit_interrupt = always_emit_interrupt_payload
|
||||||
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
|
self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {})
|
||||||
|
self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver
|
||||||
|
|
||||||
def _process(
|
def _process(
|
||||||
self,
|
self,
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]:
|
||||||
|
"""Pure decision pass: returns ``(state_update, pending_promotions)``.
|
||||||
|
|
||||||
|
Side effects performed here are in-memory only (rule promotion
|
||||||
|
into ``runtime_ruleset``). DB writes for ``always`` decisions
|
||||||
|
are queued as ``_AlwaysPromotion`` and flushed by the async hook.
|
||||||
|
"""
|
||||||
del runtime
|
del runtime
|
||||||
messages = state.get("messages") or []
|
messages = state.get("messages") or []
|
||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None, []
|
||||||
last = messages[-1]
|
last = messages[-1]
|
||||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||||
return None
|
return None, []
|
||||||
|
|
||||||
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
|
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
|
||||||
deny_messages: list[ToolMessage] = []
|
deny_messages: list[ToolMessage] = []
|
||||||
kept_calls: list[dict[str, Any]] = []
|
kept_calls: list[dict[str, Any]] = []
|
||||||
|
promotions: list[_AlwaysPromotion] = []
|
||||||
any_change = False
|
any_change = False
|
||||||
|
|
||||||
for raw in last.tool_calls:
|
for raw in last.tool_calls:
|
||||||
|
|
@ -162,6 +165,9 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
any_change = True
|
any_change = True
|
||||||
if kind == "always":
|
if kind == "always":
|
||||||
persist_always(self._runtime_ruleset, name, patterns)
|
persist_always(self._runtime_ruleset, name, patterns)
|
||||||
|
promotion = self._build_always_promotion(name)
|
||||||
|
if promotion is not None:
|
||||||
|
promotions.append(promotion)
|
||||||
kept_calls.append(final_call)
|
kept_calls.append(final_call)
|
||||||
elif kind == "reject":
|
elif kind == "reject":
|
||||||
feedback = decision.get("feedback")
|
feedback = decision.get("feedback")
|
||||||
|
|
@ -180,23 +186,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
kept_calls.append(call)
|
kept_calls.append(call)
|
||||||
|
|
||||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||||
return None
|
return None, promotions
|
||||||
|
|
||||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||||
result_messages: list[Any] = [updated]
|
result_messages: list[Any] = [updated]
|
||||||
if deny_messages:
|
if deny_messages:
|
||||||
result_messages.extend(deny_messages)
|
result_messages.extend(deny_messages)
|
||||||
return {"messages": result_messages}
|
return {"messages": result_messages}, promotions
|
||||||
|
|
||||||
|
def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None:
|
||||||
|
"""Return a save request iff the tool exposes an ``mcp_connector_id``."""
|
||||||
|
tool = self._tools_by_name.get(tool_name)
|
||||||
|
metadata = getattr(tool, "metadata", None) or {}
|
||||||
|
connector_id = metadata.get("mcp_connector_id")
|
||||||
|
if not isinstance(connector_id, int):
|
||||||
|
return None
|
||||||
|
return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name)
|
||||||
|
|
||||||
def after_model( # type: ignore[override]
|
def after_model( # type: ignore[override]
|
||||||
self, state: AgentState, runtime: Runtime[ContextT]
|
self, state: AgentState, runtime: Runtime[ContextT]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
return self._process(state, runtime)
|
update, _ = self._process(state, runtime)
|
||||||
|
return update
|
||||||
|
|
||||||
async def aafter_model( # type: ignore[override]
|
async def aafter_model( # type: ignore[override]
|
||||||
self, state: AgentState, runtime: Runtime[ContextT]
|
self, state: AgentState, runtime: Runtime[ContextT]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
return self._process(state, runtime)
|
update, promotions = self._process(state, runtime)
|
||||||
|
if self._trusted_tool_saver is not None:
|
||||||
|
for promotion in promotions:
|
||||||
|
await self._trusted_tool_saver(
|
||||||
|
promotion.connector_id, promotion.tool_name
|
||||||
|
)
|
||||||
|
return update
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PermissionMiddleware"]
|
__all__ = ["PermissionMiddleware"]
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
|
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||||
|
|
||||||
from .core import PermissionMiddleware
|
from .core import PermissionMiddleware
|
||||||
|
|
||||||
|
|
@ -43,6 +44,7 @@ def build_permission_mw(
|
||||||
flags: AgentFeatureFlags,
|
flags: AgentFeatureFlags,
|
||||||
subagent_rulesets: list[Ruleset] | None = None,
|
subagent_rulesets: list[Ruleset] | None = None,
|
||||||
tools: Sequence[BaseTool] | None = None,
|
tools: Sequence[BaseTool] | None = None,
|
||||||
|
trusted_tool_saver: TrustedToolSaver | None = None,
|
||||||
) -> PermissionMiddleware | None:
|
) -> PermissionMiddleware | None:
|
||||||
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
|
"""Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed.
|
||||||
|
|
||||||
|
|
@ -58,6 +60,9 @@ def build_permission_mw(
|
||||||
an explicit ``ask`` rule always asks.
|
an explicit ``ask`` rule always asks.
|
||||||
tools: Subagent tools used to decorate ``ask`` interrupts with
|
tools: Subagent tools used to decorate ``ask`` interrupts with
|
||||||
FE-card metadata (description, MCP connector). Optional.
|
FE-card metadata (description, MCP connector). Optional.
|
||||||
|
trusted_tool_saver: Async callback invoked when an MCP tool's
|
||||||
|
``always`` decision lands; persists the user's preference to
|
||||||
|
``connector.config['trusted_tools']``. Optional.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
``None`` when the engine has no rules to enforce
|
``None`` when the engine has no rules to enforce
|
||||||
|
|
@ -73,7 +78,11 @@ def build_permission_mw(
|
||||||
if subagent_rulesets:
|
if subagent_rulesets:
|
||||||
rulesets.extend(subagent_rulesets)
|
rulesets.extend(subagent_rulesets)
|
||||||
tools_by_name = {t.name: t for t in (tools or [])}
|
tools_by_name = {t.name: t for t in (tools or [])}
|
||||||
return PermissionMiddleware(rulesets=rulesets, tools_by_name=tools_by_name)
|
return PermissionMiddleware(
|
||||||
|
rulesets=rulesets,
|
||||||
|
tools_by_name=tools_by_name,
|
||||||
|
trusted_tool_saver=trusted_tool_saver,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["build_permission_mw"]
|
__all__ = ["build_permission_mw"]
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,11 @@ def build_kb_middleware(
|
||||||
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
|
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
|
||||||
if user_allowlist is not None:
|
if user_allowlist is not None:
|
||||||
rulesets.append(user_allowlist)
|
rulesets.append(user_allowlist)
|
||||||
permission_mw = build_permission_mw(flags=flags, subagent_rulesets=rulesets)
|
permission_mw = build_permission_mw(
|
||||||
|
flags=flags,
|
||||||
|
subagent_rulesets=rulesets,
|
||||||
|
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
|
||||||
|
)
|
||||||
return [
|
return [
|
||||||
mws["todos"],
|
mws["todos"],
|
||||||
build_kb_context_projection_mw(),
|
build_kb_context_projection_mw(),
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,10 @@ def pack_subagent(
|
||||||
if user_allowlist is not None:
|
if user_allowlist is not None:
|
||||||
subagent_rulesets.append(user_allowlist)
|
subagent_rulesets.append(user_allowlist)
|
||||||
per_subagent_perm = build_permission_mw(
|
per_subagent_perm = build_permission_mw(
|
||||||
flags=flags, subagent_rulesets=subagent_rulesets, tools=tools
|
flags=flags,
|
||||||
|
subagent_rulesets=subagent_rulesets,
|
||||||
|
tools=tools,
|
||||||
|
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
|
||||||
)
|
)
|
||||||
|
|
||||||
prepended: list[Any] = []
|
prepended: list[Any] = []
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,16 @@
|
||||||
"""User-scoped tool allow-list backed by ``SearchSourceConnector.config``.
|
"""User-scoped trusted-tools list backed by ``SearchSourceConnector.config``.
|
||||||
|
|
||||||
Stores the user's "always allow" preferences as a list of tool names under
|
Storage is per ``(user_id, search_space_id, connector_id)`` under
|
||||||
``connector.config['trusted_tools']``. Storage is per
|
``connector.config['trusted_tools']``. The list only ever encodes
|
||||||
``(user_id, search_space_id, connector_id)`` — i.e. tied to a specific
|
``allow`` decisions; coded ``deny`` rules cannot be overridden here.
|
||||||
connected account inside a specific workspace, exactly what the UI cares
|
|
||||||
about.
|
|
||||||
|
|
||||||
Callers split into two roles:
|
|
||||||
|
|
||||||
- **Writers** — the ``/connectors/.../trust-tool`` and ``/untrust-tool``
|
|
||||||
HTTP routes, and the chat resume handler when it processes a
|
|
||||||
``{type: "always"}`` decision. Both call
|
|
||||||
:func:`add_user_trust` / :func:`remove_user_trust`. The FE button is
|
|
||||||
the upstream UI trigger but it talks to the routes, never to this
|
|
||||||
module directly.
|
|
||||||
- **Reader** — the subagent compile path, which calls
|
|
||||||
:func:`fetch_user_allowlist_rulesets` and layers the result after the
|
|
||||||
subagent's coded ruleset. User ``allow`` rules then override coded
|
|
||||||
``ask`` via the rule engine's last-match-wins evaluation.
|
|
||||||
|
|
||||||
Coded ``deny`` rules are intentionally **not** overridable by this
|
|
||||||
allow-list — only ``ask`` can be promoted to ``allow``. The rule engine
|
|
||||||
enforces this naturally because user rules only ever emit ``allow``.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
@ -37,10 +20,14 @@ from app.agents.multi_agent_chat.constants import (
|
||||||
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
|
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
from app.db import SearchSourceConnector
|
from app.db import SearchSourceConnector, async_session_maker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_TRUSTED_TOOLS_KEY = "trusted_tools"
|
_TRUSTED_TOOLS_KEY = "trusted_tools"
|
||||||
|
|
||||||
|
TrustedToolSaver = Callable[[int, str], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
async def _load_owned_connector(
|
async def _load_owned_connector(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
|
|
@ -48,11 +35,7 @@ async def _load_owned_connector(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
) -> SearchSourceConnector | None:
|
) -> SearchSourceConnector | None:
|
||||||
"""Return a connector iff it belongs to ``user_id``, else ``None``.
|
"""Return the connector iff owned by ``user_id``, else ``None``."""
|
||||||
|
|
||||||
Ownership scoping is mandatory: the trust list mutates user-private
|
|
||||||
data, callers must never write across user boundaries.
|
|
||||||
"""
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).where(
|
select(SearchSourceConnector).where(
|
||||||
SearchSourceConnector.id == connector_id,
|
SearchSourceConnector.id == connector_id,
|
||||||
|
|
@ -84,11 +67,7 @@ async def add_user_trust(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Append ``tool_name`` to the connector's trusted list (idempotent).
|
"""Append ``tool_name`` to the connector's trusted list; raise ``LookupError`` if not owned."""
|
||||||
|
|
||||||
Returns the updated trusted-tools list. Raises ``LookupError`` when
|
|
||||||
the connector does not exist or is not owned by ``user_id``.
|
|
||||||
"""
|
|
||||||
connector = await _load_owned_connector(
|
connector = await _load_owned_connector(
|
||||||
session, user_id=user_id, connector_id=connector_id
|
session, user_id=user_id, connector_id=connector_id
|
||||||
)
|
)
|
||||||
|
|
@ -112,11 +91,7 @@ async def remove_user_trust(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Remove ``tool_name`` from the connector's trusted list (idempotent).
|
"""Remove ``tool_name`` from the connector's trusted list; raise ``LookupError`` if not owned."""
|
||||||
|
|
||||||
Returns the updated trusted-tools list. Raises ``LookupError`` when
|
|
||||||
the connector does not exist or is not owned by ``user_id``.
|
|
||||||
"""
|
|
||||||
connector = await _load_owned_connector(
|
connector = await _load_owned_connector(
|
||||||
session, user_id=user_id, connector_id=connector_id
|
session, user_id=user_id, connector_id=connector_id
|
||||||
)
|
)
|
||||||
|
|
@ -139,20 +114,10 @@ async def fetch_user_allowlist_rulesets(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
) -> dict[str, Ruleset]:
|
) -> dict[str, Ruleset]:
|
||||||
"""Project the user's trusted-tool lists into per-subagent rulesets.
|
"""Project the user's trusted tools into per-subagent ``allow`` rulesets.
|
||||||
|
|
||||||
Walks every connector the user owns in this workspace, maps each
|
Subagents with no trusted tools are absent from the result —
|
||||||
``connector_type`` to its consuming subagent via
|
callers must treat ``missing == empty``.
|
||||||
:data:`CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS`, and emits one
|
|
||||||
``Rule(permission=tool_name, pattern="*", action="allow")`` per
|
|
||||||
trusted entry. Rules from different connector accounts feeding the
|
|
||||||
same subagent (e.g. two Linear workspaces) are merged into one
|
|
||||||
ruleset; duplicates are harmless under last-match-wins.
|
|
||||||
|
|
||||||
Connectors whose type is not mapped (search APIs, Github, etc.) and
|
|
||||||
connectors with empty trust lists contribute nothing. Subagents
|
|
||||||
with no trusted tools are absent from the returned dict — callers
|
|
||||||
should treat ``missing == empty``.
|
|
||||||
"""
|
"""
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(
|
select(
|
||||||
|
|
@ -189,8 +154,35 @@ async def fetch_user_allowlist_rulesets(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_trusted_tool_saver(user_id: uuid.UUID) -> TrustedToolSaver:
|
||||||
|
"""Bind ``user_id`` into a saver closure; failures are logged, never raised."""
|
||||||
|
|
||||||
|
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
await add_user_trust(
|
||||||
|
session,
|
||||||
|
user_id=user_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
except LookupError as exc:
|
||||||
|
logger.warning("trusted-tool save skipped: %s", exc)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"trusted-tool save failed for connector=%s tool=%s",
|
||||||
|
connector_id,
|
||||||
|
tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trusted_tool_saver
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"TrustedToolSaver",
|
||||||
"add_user_trust",
|
"add_user_trust",
|
||||||
"fetch_user_allowlist_rulesets",
|
"fetch_user_allowlist_rulesets",
|
||||||
|
"make_trusted_tool_saver",
|
||||||
"remove_user_trust",
|
"remove_user_trust",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
|
||||||
|
|
||||||
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
|
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
|
||||||
def after(state: _State) -> dict[str, Any] | None:
|
def after(state: _State) -> dict[str, Any] | None:
|
||||||
return pm._process(state, None) # type: ignore[arg-type]
|
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||||
|
|
||||||
g = StateGraph(_State)
|
g = StateGraph(_State)
|
||||||
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
|
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
|
||||||
|
|
|
||||||
|
|
@ -84,9 +84,7 @@ def _build_graph_with_permission_middleware(
|
||||||
def after_node(state: _State) -> dict[str, Any] | None:
|
def after_node(state: _State) -> dict[str, Any] | None:
|
||||||
if pm is None:
|
if pm is None:
|
||||||
return None
|
return None
|
||||||
# PermissionMiddleware._process ignores runtime; the test never relies
|
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||||
# on the runtime context, so passing None keeps the harness lean.
|
|
||||||
return pm._process(state, None) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
g = StateGraph(_State)
|
g = StateGraph(_State)
|
||||||
g.add_node("emit", node)
|
g.add_node("emit", node)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,180 @@
|
||||||
|
"""``always`` decisions for MCP tools are saved via the trusted-tool saver."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langgraph.types import Command
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||||
|
build_permission_mw,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
|
|
||||||
|
|
||||||
|
class _NoArgs(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _noop(**_kwargs) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _ask_rule(tool_name: str) -> Rule:
|
||||||
|
return Rule(permission=tool_name, pattern="*", action="ask")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mcp_tool(*, name: str, connector_id: int):
|
||||||
|
return StructuredTool(
|
||||||
|
name=name,
|
||||||
|
description=f"Run {name} via MCP.",
|
||||||
|
coroutine=_noop,
|
||||||
|
args_schema=_NoArgs,
|
||||||
|
metadata={
|
||||||
|
"mcp_connector_id": connector_id,
|
||||||
|
"mcp_connector_name": "Linear",
|
||||||
|
"mcp_transport": "http",
|
||||||
|
"hitl": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_native_tool(*, name: str):
|
||||||
|
return StructuredTool(
|
||||||
|
name=name,
|
||||||
|
description=f"Native {name}.",
|
||||||
|
coroutine=_noop,
|
||||||
|
args_schema=_NoArgs,
|
||||||
|
metadata={"hitl": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _State(TypedDict, total=False):
|
||||||
|
messages: Annotated[list, add_messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_graph(pm, tool_name: str):
|
||||||
|
def emit(_state: _State) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": tool_name,
|
||||||
|
"args": {},
|
||||||
|
"id": "call-1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
g = StateGraph(_State)
|
||||||
|
g.add_node("emit", emit)
|
||||||
|
g.add_node("permission", pm.aafter_model) # type: ignore[arg-type]
|
||||||
|
g.add_edge(START, "emit")
|
||||||
|
g.add_edge("emit", "permission")
|
||||||
|
g.add_edge("permission", END)
|
||||||
|
return g.compile(checkpointer=InMemorySaver())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_always_decision_saves_mcp_tool_via_callback():
|
||||||
|
saved: list[tuple[int, str]] = []
|
||||||
|
|
||||||
|
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||||
|
saved.append((connector_id, tool_name))
|
||||||
|
|
||||||
|
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||||
|
pm = build_permission_mw(
|
||||||
|
flags=AgentFeatureFlags(enable_permission=False),
|
||||||
|
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||||
|
tools=[tool],
|
||||||
|
trusted_tool_saver=trusted_tool_saver,
|
||||||
|
)
|
||||||
|
assert pm is not None
|
||||||
|
|
||||||
|
graph = _build_graph(pm, tool.name)
|
||||||
|
config = {"configurable": {"thread_id": "always-mcp"}}
|
||||||
|
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||||
|
|
||||||
|
assert saved == [(7, "linear_create_issue")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_once_decision_does_not_save():
|
||||||
|
saved: list[tuple[int, str]] = []
|
||||||
|
|
||||||
|
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||||
|
saved.append((connector_id, tool_name))
|
||||||
|
|
||||||
|
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||||
|
pm = build_permission_mw(
|
||||||
|
flags=AgentFeatureFlags(enable_permission=False),
|
||||||
|
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||||
|
tools=[tool],
|
||||||
|
trusted_tool_saver=trusted_tool_saver,
|
||||||
|
)
|
||||||
|
assert pm is not None
|
||||||
|
|
||||||
|
graph = _build_graph(pm, tool.name)
|
||||||
|
config = {"configurable": {"thread_id": "once-mcp"}}
|
||||||
|
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
|
||||||
|
|
||||||
|
assert saved == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_always_decision_for_native_tool_skips_save():
|
||||||
|
"""Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust."""
|
||||||
|
saved: list[tuple[int, str]] = []
|
||||||
|
|
||||||
|
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||||
|
saved.append((connector_id, tool_name))
|
||||||
|
|
||||||
|
tool = _make_native_tool(name="rm")
|
||||||
|
pm = build_permission_mw(
|
||||||
|
flags=AgentFeatureFlags(enable_permission=False),
|
||||||
|
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])],
|
||||||
|
tools=[tool],
|
||||||
|
trusted_tool_saver=trusted_tool_saver,
|
||||||
|
)
|
||||||
|
assert pm is not None
|
||||||
|
|
||||||
|
graph = _build_graph(pm, tool.name)
|
||||||
|
config = {"configurable": {"thread_id": "always-native"}}
|
||||||
|
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||||
|
|
||||||
|
assert saved == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_always_decision_with_no_saver_callback_is_a_noop():
|
||||||
|
"""Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash."""
|
||||||
|
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||||
|
pm = build_permission_mw(
|
||||||
|
flags=AgentFeatureFlags(enable_permission=False),
|
||||||
|
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||||
|
tools=[tool],
|
||||||
|
trusted_tool_saver=None,
|
||||||
|
)
|
||||||
|
assert pm is not None
|
||||||
|
|
||||||
|
graph = _build_graph(pm, tool.name)
|
||||||
|
config = {"configurable": {"thread_id": "anon-always"}}
|
||||||
|
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue