mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/permissions: layer user allow-list into subagent compile
This commit is contained in:
parent
e99c06c887
commit
ef1152b80e
25 changed files with 304 additions and 62 deletions
|
|
@ -1,9 +1,10 @@
|
|||
"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``.
|
||||
|
||||
The KB unification swap (legacy ``interrupt_on`` map → KB-owned ``Ruleset``
|
||||
threaded through ``build_permission_mw(extra_rulesets=...)``) must produce
|
||||
*exactly one* interrupt per destructive FS call, in LC HITL shape, even
|
||||
when ``enable_permission`` is False — destructive ops always ask.
|
||||
threaded through ``build_permission_mw(subagent_rulesets=...)``) must
|
||||
produce *exactly one* interrupt per destructive FS call, in LC HITL
|
||||
shape, even when ``enable_permission`` is False — destructive ops always
|
||||
ask.
|
||||
|
||||
We exercise the production factory and a real ``PermissionMiddleware`` on a
|
||||
real ``StateGraph`` so the test catches regressions in factory gating,
|
||||
|
|
@ -54,7 +55,7 @@ class _State(TypedDict, total=False):
|
|||
def _build_graph_with_permission_middleware(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
extra_rulesets: list[Ruleset] | None,
|
||||
subagent_rulesets: list[Ruleset] | None,
|
||||
checkpointer: InMemorySaver,
|
||||
):
|
||||
"""Compile a one-node graph that emits a tool call for ``rm`` and
|
||||
|
|
@ -64,7 +65,7 @@ def _build_graph_with_permission_middleware(
|
|||
``after_model`` hook intercepts and (if a rule says ``ask``) raises
|
||||
a ``GraphInterrupt`` carrying the LC HITL payload.
|
||||
"""
|
||||
pm = build_permission_mw(flags=flags, extra_rulesets=extra_rulesets)
|
||||
pm = build_permission_mw(flags=flags, subagent_rulesets=subagent_rulesets)
|
||||
|
||||
def node(_state: _State) -> dict[str, Any]:
|
||||
msg = AIMessage(
|
||||
|
|
@ -108,10 +109,10 @@ async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_fla
|
|||
checkpointer = InMemorySaver()
|
||||
graph, pm = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
extra_rulesets=[_kb_style_ruleset()],
|
||||
subagent_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
assert pm is not None, "extras must force the middleware on"
|
||||
assert pm is not None, "subagent rulesets must force the middleware on"
|
||||
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
|
@ -136,7 +137,7 @@ async def test_kb_ruleset_resume_with_approve_lets_rm_through():
|
|||
checkpointer = InMemorySaver()
|
||||
graph, _ = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
extra_rulesets=[_kb_style_ruleset()],
|
||||
subagent_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}}
|
||||
|
|
@ -158,12 +159,12 @@ async def test_kb_ruleset_resume_with_approve_lets_rm_through():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_extras_with_permission_off_skips_middleware_entirely():
|
||||
"""No extras + permission off → factory returns ``None`` (no engine).
|
||||
async def test_no_subagent_rulesets_with_permission_off_skips_middleware_entirely():
|
||||
"""No subagent rulesets + permission off → factory returns ``None`` (no engine).
|
||||
|
||||
The legacy gating is preserved when no caller asks for rules: nothing
|
||||
runs, nothing pauses.
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
pm = build_permission_mw(flags=flags, extra_rulesets=None)
|
||||
pm = build_permission_mw(flags=flags, subagent_rulesets=None)
|
||||
assert pm is None
|
||||
|
|
|
|||
|
|
@ -19,11 +19,14 @@ from langchain_core.language_models.fake_chat_models import (
|
|||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.middleware.core import (
|
||||
PermissionMiddleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Ruleset
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset, evaluate
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
|
|
@ -81,7 +84,7 @@ async def test_subagent_recovers_when_primary_llm_fails():
|
|||
system_prompt="be helpful",
|
||||
tools=[],
|
||||
ruleset=Ruleset(origin="resilience_test", rules=[]),
|
||||
flags=AgentFeatureFlags(),
|
||||
dependencies={"flags": AgentFeatureFlags()},
|
||||
model=primary,
|
||||
middleware_stack={"fallback": ModelFallbackMiddleware(fallback)},
|
||||
)
|
||||
|
|
@ -99,3 +102,142 @@ async def test_subagent_recovers_when_primary_llm_fails():
|
|||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered via fallback"
|
||||
|
||||
|
||||
def _extract_permission_mw(spec) -> PermissionMiddleware:
|
||||
"""Find the lone PermissionMiddleware in a subagent's middleware list."""
|
||||
matches = [m for m in spec["middleware"] if isinstance(m, PermissionMiddleware)]
|
||||
assert len(matches) == 1, "expected exactly one PermissionMiddleware"
|
||||
return matches[0]
|
||||
|
||||
|
||||
def test_user_allowlist_overrides_coded_ask_via_last_match_wins():
|
||||
"""User ``allow`` rules promoted via "Always Allow" must beat coded ``ask`` rules."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
user_allowlist = Ruleset(
|
||||
origin="user_allowlist:connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test connector",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"connector": user_allowlist},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "allow", (
|
||||
f"user_allowlist must override coded ask; got {decided!r}"
|
||||
)
|
||||
|
||||
|
||||
def test_coded_ask_stays_when_user_allowlist_unrelated():
|
||||
"""User ``allow`` rules for OTHER tools must not leak into asked-tools."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="delete_issue", pattern="*", action="ask")],
|
||||
)
|
||||
user_allowlist = Ruleset(
|
||||
origin="user_allowlist:connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"connector": user_allowlist},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("delete_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
||||
|
||||
def test_missing_user_allowlist_keeps_coded_behaviour():
|
||||
"""``dependencies`` without ``user_allowlist_by_subagent`` is the common case."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={"flags": AgentFeatureFlags()},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
||||
|
||||
def test_user_allowlist_for_different_subagent_does_not_leak():
|
||||
"""User trust for ``linear`` must not affect a ``jira`` subagent compile."""
|
||||
coded = Ruleset(
|
||||
origin="jira",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
linear_allowlist = Ruleset(
|
||||
origin="user_allowlist:linear",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="jira",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"linear": linear_allowlist},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
||||
|
||||
def test_empty_user_allowlist_is_tolerated():
|
||||
"""An empty ``Ruleset`` (no rules) must not flip evaluation to allow-everything."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
empty = Ruleset(origin="user_allowlist:connector", rules=[])
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"connector": empty},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue