mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
multi_agent_chat/permissions: persist 'always' decisions to trusted-tools list
Until now an "Always Allow" reply only updated the in-memory runtime ruleset, evaporating after the session ended. Persist it to the existing connector.config['trusted_tools'] list so the next session's fetch_user_allowlist_rulesets picks it up and the user is never asked again for the same (connector, tool) pair. - TrustedToolSaver + make_trusted_tool_saver(user_id) in user_tool_allowlist: opens its own session via async_session_maker per call, logs and swallows failures (in-memory promotion is the canonical "always" path, durable persistence is opportunistic). - PermissionMiddleware._process is now pure: returns (state_update, list[_AlwaysPromotion]). aafter_model awaits the saver for each promotion; after_model discards them. Promotions are only emitted for tools whose metadata exposes mcp_connector_id, so native tools and KB FS ops are correctly skipped. - main_agent factory builds the saver once per turn and stashes it in dependencies["trusted_tool_saver"]; pack_subagent and the KB middleware stack forward it through build_permission_mw. - Renamed pm._process(state, None) call sites in two existing tests to pm.after_model(state, None) so they exercise the public hook contract instead of the now-tuple-returning private method.
This commit is contained in:
parent
a97d1548a6
commit
6671c91841
9 changed files with 323 additions and 103 deletions
|
|
@ -128,7 +128,7 @@ def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
|
|||
|
||||
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
|
||||
def after(state: _State) -> dict[str, Any] | None:
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
|
||||
|
|
|
|||
|
|
@ -84,9 +84,7 @@ def _build_graph_with_permission_middleware(
|
|||
def after_node(state: _State) -> dict[str, Any] | None:
|
||||
if pm is None:
|
||||
return None
|
||||
# PermissionMiddleware._process ignores runtime; the test never relies
|
||||
# on the runtime context, so passing None keeps the harness lean.
|
||||
return pm._process(state, None) # type: ignore[arg-type]
|
||||
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", node)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
"""``always`` decisions for MCP tools are saved via the trusted-tool saver."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
class _NoArgs(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
async def _noop(**_kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _ask_rule(tool_name: str) -> Rule:
|
||||
return Rule(permission=tool_name, pattern="*", action="ask")
|
||||
|
||||
|
||||
def _make_mcp_tool(*, name: str, connector_id: int):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Run {name} via MCP.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={
|
||||
"mcp_connector_id": connector_id,
|
||||
"mcp_connector_name": "Linear",
|
||||
"mcp_transport": "http",
|
||||
"hitl": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_native_tool(*, name: str):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Native {name}.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={"hitl": True},
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _build_graph(pm, tool_name: str):
|
||||
def emit(_state: _State) -> dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_name,
|
||||
"args": {},
|
||||
"id": "call-1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", emit)
|
||||
g.add_node("permission", pm.aafter_model) # type: ignore[arg-type]
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_decision_saves_mcp_tool_via_callback():
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "always-mcp"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||
|
||||
assert saved == [(7, "linear_create_issue")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_once_decision_does_not_save():
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "once-mcp"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
|
||||
|
||||
assert saved == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_decision_for_native_tool_skips_save():
|
||||
"""Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust."""
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_native_tool(name="rm")
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "always-native"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||
|
||||
assert saved == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_decision_with_no_saver_callback_is_a_noop():
|
||||
"""Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash."""
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=None,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "anon-always"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "always"}]}), config)
|
||||
Loading…
Add table
Add a link
Reference in a new issue