mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): delete dead PermissionMiddleware twin in shared kernel
app/agents/shared/middleware/permission.py was an older, monolithic PermissionMiddleware superseded by the modular permissions/ package under multi_agent_chat/shared/middleware/ (core + evaluation + ask/ + factory). Production wires only the package (main_agent stack + every subagent builder); the kernel file was reachable only through the shared barrel re-export (itself unused) and two tests pinned to its dead internals (_raise_interrupt, _normalize_permission_decision, old after_model shape). - delete app/agents/shared/middleware/permission.py - drop PermissionMiddleware from the shared middleware barrel - delete test_permission_middleware.py (covered the dead impl only; live behavior is covered by tests/.../middleware/shared/permissions/*) - test_desktop_safety_rules.py: keep the ruleset-level regression tests, drop the dead import + TestPermissionMiddlewareIntegration class
This commit is contained in:
parent
8ae190a11d
commit
c0c4f57f5d
4 changed files with 0 additions and 726 deletions
|
|
@ -9,13 +9,11 @@ from app.agents.shared.middleware.kb_persistence import (
|
|||
KnowledgeBasePersistenceMiddleware,
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
from app.agents.shared.middleware.permission import PermissionMiddleware
|
||||
from app.agents.shared.middleware.retry_after import RetryAfterMiddleware
|
||||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"PermissionMiddleware",
|
||||
"RetryAfterMiddleware",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"commit_staged_filesystem_state",
|
||||
|
|
|
|||
|
|
@ -1,427 +0,0 @@
|
|||
"""
|
||||
PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback.
|
||||
|
||||
LangChain's :class:`HumanInTheLoopMiddleware` only supports a static
|
||||
"this tool always asks" decision per tool. There's no rule-based
|
||||
allow/deny/ask layered ruleset, no glob patterns, no per-search-space or
|
||||
per-thread overrides, and no auto-deny synthesis.
|
||||
|
||||
This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts``
|
||||
ruleset model on top of SurfSense's existing ``interrupt({type, action,
|
||||
context})`` payload shape (see ``app/agents/shared/tools/hitl.py``) so
|
||||
the frontend keeps working unchanged.
|
||||
|
||||
Operation:
|
||||
1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``.
|
||||
2. For each call, the middleware builds a list of ``patterns`` (the
|
||||
tool name plus any tool-specific patterns from the resolver). It
|
||||
evaluates each pattern against the layered rulesets and aggregates
|
||||
the results: ``deny`` > ``ask`` > ``allow``.
|
||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||
containing a :class:`StreamingError`.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
|
||||
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
|
||||
replies are accepted via :func:`_normalize_permission_decision`.
|
||||
- ``once``: proceed.
|
||||
- ``approve_always``: also persist allow rules for ``request.always`` patterns.
|
||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||
- ``reject`` w/ feedback: raise :class:`CorrectedError`.
|
||||
5. On ``allow``: proceed unchanged.
|
||||
|
||||
The middleware also performs a *pre-model* tool-filter step (the
|
||||
``before_model`` hook) so globally denied tools are stripped from the
|
||||
exposed tool list before the model gets to see them. This mirrors
|
||||
OpenCode's ``Permission.disabled`` and dramatically reduces the chance
|
||||
the model emits a deny-only call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.agents.multi_agent_chat.shared.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
from app.agents.shared.errors import (
|
||||
CorrectedError,
|
||||
RejectedError,
|
||||
StreamingError,
|
||||
)
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of
|
||||
# patterns to evaluate. The first pattern is conventionally the bare
|
||||
# tool name; later entries narrow down to specific resources.
|
||||
PatternResolver = Callable[[dict[str, Any]], list[str]]
|
||||
|
||||
|
||||
def _default_pattern_resolver(name: str) -> PatternResolver:
|
||||
def _resolve(args: dict[str, Any]) -> list[str]:
|
||||
# Bare name covers the default catch-all; primary-arg fallbacks
|
||||
# are best added per-tool by callers.
|
||||
del args
|
||||
return [name]
|
||||
|
||||
return _resolve
|
||||
|
||||
|
||||
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
|
||||
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
|
||||
# original tool args — tools needing argument edits should use
|
||||
# ``request_approval`` from ``app/agents/shared/tools/hitl.py``.
|
||||
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
|
||||
"approve": "once",
|
||||
"reject": "reject",
|
||||
"edit": "once",
|
||||
"approve_always": "approve_always",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
|
||||
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
|
||||
|
||||
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
|
||||
middleware fails closed.
|
||||
"""
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
if not isinstance(decision, dict):
|
||||
logger.warning(
|
||||
"Unrecognized permission resume value (%s); treating as reject",
|
||||
type(decision).__name__,
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
if decision.get("decision_type"):
|
||||
return decision
|
||||
|
||||
payload: dict[str, Any] = decision
|
||||
decisions = decision.get("decisions")
|
||||
if isinstance(decisions, list) and decisions:
|
||||
first = decisions[0]
|
||||
if isinstance(first, dict):
|
||||
payload = first
|
||||
|
||||
raw_type = payload.get("type") or payload.get("decision_type")
|
||||
if not raw_type:
|
||||
logger.warning(
|
||||
"Permission resume missing decision type (keys=%s); treating as reject",
|
||||
list(payload.keys()),
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
raw_type = str(raw_type).lower()
|
||||
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
|
||||
if mapped is None:
|
||||
# Tolerate legacy values arriving without ``decision_type`` wrapping.
|
||||
if raw_type in {"once", "approve_always", "reject"}:
|
||||
mapped = raw_type
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision type %r; treating as reject", raw_type
|
||||
)
|
||||
mapped = "reject"
|
||||
|
||||
if raw_type == "edit":
|
||||
logger.warning(
|
||||
"Permission middleware received an 'edit' decision; original args "
|
||||
"kept (edits not merged here)."
|
||||
)
|
||||
|
||||
out: dict[str, Any] = {"decision_type": mapped}
|
||||
feedback = payload.get("feedback") or payload.get("message")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
out["feedback"] = feedback
|
||||
return out
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
Args:
|
||||
rulesets: Layered rulesets to evaluate. Earlier entries are
|
||||
overridden by later ones (last-match-wins). Typical layering:
|
||||
``defaults < global < space < thread < runtime_approved``.
|
||||
pattern_resolvers: Optional per-tool callables that return a list
|
||||
of patterns to evaluate. When a tool isn't listed, the bare
|
||||
tool name is used as the only pattern.
|
||||
runtime_ruleset: Mutable :class:`Ruleset` that the middleware
|
||||
extends in-place when the user replies ``"approve_always"`` to
|
||||
an ask interrupt. Reused across all calls in the same agent
|
||||
instance so newly-allowed rules apply to subsequent calls.
|
||||
always_emit_interrupt_payload: If True, every ask uses the
|
||||
SurfSense interrupt wire format (default). Set False to
|
||||
disable interrupts and treat ``ask`` as ``deny`` for
|
||||
non-interactive deployments.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rulesets: list[Ruleset] | None = None,
|
||||
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
self._pattern_resolvers: dict[str, PatternResolver] = dict(
|
||||
pattern_resolvers or {}
|
||||
)
|
||||
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
|
||||
origin="runtime_approved"
|
||||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-filter step (mirrors OpenCode's ``Permission.disabled``)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _globally_denied(self, tool_name: str) -> bool:
|
||||
"""Return True if a deny rule with no narrowing pattern matches."""
|
||||
rules = evaluate_many(tool_name, ["*"], *self._all_rulesets())
|
||||
return aggregate_action(rules) == "deny"
|
||||
|
||||
def _all_rulesets(self) -> list[Ruleset]:
|
||||
return [*self._static_rulesets, self._runtime_ruleset]
|
||||
|
||||
# NOTE: ``before_model`` filtering of the tools list is left to the
|
||||
# agent factory. This middleware only blocks at execution time — and
|
||||
# only via the rule-evaluator path, not by mutating ``request.tools``.
|
||||
# Mutating ``request.tools`` per-call would invalidate provider
|
||||
# prompt-cache prefixes (see Operational risks: prompt-cache regression).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call evaluation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]:
|
||||
resolver = self._pattern_resolvers.get(
|
||||
tool_name, _default_pattern_resolver(tool_name)
|
||||
)
|
||||
try:
|
||||
patterns = resolver(args or {})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Pattern resolver for %s raised; using bare name", tool_name
|
||||
)
|
||||
patterns = [tool_name]
|
||||
if not patterns:
|
||||
patterns = [tool_name]
|
||||
return patterns
|
||||
|
||||
def _evaluate(
|
||||
self, tool_name: str, args: dict[str, Any]
|
||||
) -> tuple[str, list[str], list[Rule]]:
|
||||
patterns = self._resolve_patterns(tool_name, args)
|
||||
rules = evaluate_many(tool_name, patterns, *self._all_rulesets())
|
||||
action = aggregate_action(rules)
|
||||
return action, patterns, rules
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HITL ask flow — SurfSense wire format
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _raise_interrupt(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
patterns: list[str],
|
||||
rules: list[Rule],
|
||||
) -> dict[str, Any]:
|
||||
"""Block on user approval via SurfSense's ``interrupt`` shape."""
|
||||
if not self._emit_interrupt:
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
# ``params`` (NOT ``args``) is what SurfSense's streaming
|
||||
# normalizer forwards. Other fields move into ``context``.
|
||||
payload = {
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": tool_name, "params": args or {}},
|
||||
"context": {
|
||||
"patterns": patterns,
|
||||
"rules": [
|
||||
{
|
||||
"permission": r.permission,
|
||||
"pattern": r.pattern,
|
||||
"action": r.action,
|
||||
}
|
||||
for r in rules
|
||||
],
|
||||
# Rules of thumb for the frontend: surface the patterns
|
||||
# the user can promote to "approve_always" with a single reply.
|
||||
"always": patterns,
|
||||
},
|
||||
}
|
||||
# Open ``permission.asked`` + ``interrupt.raised`` OTel spans
|
||||
# (no-op when OTel is disabled) so dashboards can correlate
|
||||
# "we asked X" with "interrupt was actually delivered".
|
||||
with (
|
||||
ot.permission_asked_span(
|
||||
permission=tool_name,
|
||||
pattern=patterns[0] if patterns else None,
|
||||
extra={"permission.patterns": list(patterns)},
|
||||
),
|
||||
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||
):
|
||||
ot_metrics.record_permission_ask(permission=tool_name)
|
||||
ot_metrics.record_interrupt(interrupt_type="permission_ask")
|
||||
decision = interrupt(payload)
|
||||
return _normalize_permission_decision(decision)
|
||||
|
||||
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||
"""Promote ``approve_always`` reply into runtime allow rules.
|
||||
|
||||
Persistence to ``agent_permission_rules`` is done by the
|
||||
streaming layer (``stream_new_chat``) once it observes the
|
||||
``approve_always`` reply — the middleware just keeps an
|
||||
in-memory copy so subsequent calls in the same stream see the rule.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
self._runtime_ruleset.rules.append(
|
||||
Rule(permission=tool_name, pattern=pattern, action="allow")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synthesizing deny -> ToolMessage
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _deny_message(
|
||||
tool_call: dict[str, Any],
|
||||
rule: Rule,
|
||||
) -> ToolMessage:
|
||||
err = StreamingError(
|
||||
code="permission_denied",
|
||||
retryable=False,
|
||||
suggestion=(
|
||||
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
|
||||
f"blocked this call"
|
||||
),
|
||||
)
|
||||
return ToolMessage(
|
||||
content=(
|
||||
f"Permission denied: rule {rule.permission}/{rule.pattern} "
|
||||
f"blocked tool {tool_call.get('name')!r}."
|
||||
),
|
||||
tool_call_id=tool_call.get("id") or "",
|
||||
name=tool_call.get("name"),
|
||||
status="error",
|
||||
additional_kwargs={"error": err.model_dump()},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The hook: aafter_model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _process(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime # unused
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||
return None
|
||||
|
||||
deny_messages: list[ToolMessage] = []
|
||||
kept_calls: list[dict[str, Any]] = []
|
||||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
call = (
|
||||
dict(raw)
|
||||
if isinstance(raw, dict)
|
||||
else {
|
||||
"name": getattr(raw, "name", None),
|
||||
"args": getattr(raw, "args", {}),
|
||||
"id": getattr(raw, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
name = call.get("name") or ""
|
||||
args = call.get("args") or {}
|
||||
action, patterns, rules = self._evaluate(name, args)
|
||||
|
||||
if action == "deny":
|
||||
# Find the deny rule for the suggestion text
|
||||
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
|
||||
deny_messages.append(self._deny_message(call, deny_rule))
|
||||
any_change = True
|
||||
continue
|
||||
|
||||
if action == "ask":
|
||||
decision = self._raise_interrupt(
|
||||
tool_name=name, args=args, patterns=patterns, rules=rules
|
||||
)
|
||||
kind = str(decision.get("decision_type") or "reject").lower()
|
||||
if kind == "once":
|
||||
kept_calls.append(call)
|
||||
elif kind == "approve_always":
|
||||
self._persist_always(name, patterns)
|
||||
kept_calls.append(call)
|
||||
elif kind == "reject":
|
||||
feedback = decision.get("feedback")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
raise CorrectedError(feedback, tool=name)
|
||||
raise RejectedError(
|
||||
tool=name, pattern=patterns[0] if patterns else None
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision %r; treating as reject", kind
|
||||
)
|
||||
raise RejectedError(tool=name)
|
||||
continue
|
||||
|
||||
# allow
|
||||
kept_calls.append(call)
|
||||
|
||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||
return None
|
||||
|
||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||
result_messages: list[Any] = [updated]
|
||||
if deny_messages:
|
||||
result_messages.extend(deny_messages)
|
||||
return {"messages": result_messages}
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PatternResolver",
|
||||
"PermissionMiddleware",
|
||||
"_normalize_permission_decision",
|
||||
]
|
||||
|
|
@ -16,7 +16,6 @@ from app.agents.multi_agent_chat.shared.permissions import (
|
|||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
from app.agents.shared.middleware.permission import PermissionMiddleware
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
|
@ -87,36 +86,3 @@ class TestDesktopSafetyOverridesAllowDefault:
|
|||
# Correct order: defaults < desktop_safety -> ask wins.
|
||||
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||
assert action == "ask"
|
||||
|
||||
|
||||
class TestPermissionMiddlewareIntegration:
|
||||
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agents.shared.errors import RejectedError
|
||||
|
||||
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
|
||||
# Stub the interrupt to a "reject" decision so we can assert the
|
||||
# ask path was taken without spinning up the LangGraph runtime.
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "rm",
|
||||
"args": {"path": "/Users/me/Documents/important.docx"},
|
||||
"id": "tc-rm",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
class _FakeRuntime:
|
||||
config: dict = {"configurable": {"thread_id": "test"}}
|
||||
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
|
|
|||
|
|
@ -1,263 +0,0 @@
|
|||
"""Tests for PermissionMiddleware end-to-end behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.multi_agent_chat.shared.permissions import Rule, Ruleset
|
||||
from app.agents.shared.errors import CorrectedError, RejectedError
|
||||
from app.agents.shared.middleware.permission import (
|
||||
PermissionMiddleware,
|
||||
_normalize_permission_decision,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
config: dict = {"configurable": {"thread_id": "test"}}
|
||||
|
||||
|
||||
def _msg(*tool_calls: dict) -> AIMessage:
|
||||
return AIMessage(content="", tool_calls=list(tool_calls))
|
||||
|
||||
|
||||
class TestAllow:
|
||||
def test_passthrough_when_allow(self) -> None:
|
||||
rs = Ruleset(rules=[Rule("send_email", "*", "allow")])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None # no change
|
||||
|
||||
|
||||
class TestDeny:
|
||||
def test_replaces_with_deny_tool_message(self) -> None:
|
||||
rs = Ruleset(rules=[Rule("send_email", "*", "deny")])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is not None
|
||||
msgs = out["messages"]
|
||||
# Find the deny ToolMessage
|
||||
deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
|
||||
assert len(deny_msgs) == 1
|
||||
assert deny_msgs[0].status == "error"
|
||||
assert "permission_denied" in str(deny_msgs[0].additional_kwargs)
|
||||
# AIMessage's tool_calls should now be empty (denied call removed)
|
||||
ai_msg = next(m for m in msgs if isinstance(m, AIMessage))
|
||||
assert ai_msg.tool_calls == []
|
||||
|
||||
def test_mixed_allow_deny(self) -> None:
|
||||
rs = Ruleset(
|
||||
rules=[
|
||||
Rule("send_email", "*", "deny"),
|
||||
Rule("read", "*", "allow"),
|
||||
]
|
||||
)
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{"name": "send_email", "args": {}, "id": "1"},
|
||||
{"name": "read", "args": {}, "id": "2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is not None
|
||||
ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage))
|
||||
assert len(ai_msg.tool_calls) == 1
|
||||
assert ai_msg.tool_calls[0]["name"] == "read"
|
||||
|
||||
|
||||
class TestAsk:
|
||||
def test_reject_without_feedback_raises(self) -> None:
|
||||
# Default: nothing matches -> ask
|
||||
rs = Ruleset(rules=[])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
|
||||
# Bypass real interrupt — patch the helper
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
||||
def test_reject_with_feedback_raises_corrected(self) -> None:
|
||||
rs = Ruleset(rules=[])
|
||||
mw = PermissionMiddleware(rulesets=[rs])
|
||||
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
|
||||
"decision_type": "reject",
|
||||
"feedback": "use a different subject line",
|
||||
}
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(CorrectedError) as excinfo:
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
assert excinfo.value.feedback == "use a different subject line"
|
||||
|
||||
def test_once_proceeds_without_persisting(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
# No state change because all calls kept
|
||||
assert out is None
|
||||
# No new rule persisted
|
||||
assert mw._runtime_ruleset.rules == []
|
||||
|
||||
def test_approve_always_persists_runtime_rule(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "approve_always"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None # call kept
|
||||
# Runtime ruleset got the always-allow rule
|
||||
new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"]
|
||||
assert any(r.permission == "send_email" for r in new_rules)
|
||||
|
||||
|
||||
class TestNormalizeDecision:
|
||||
"""Resume shapes ``_normalize_permission_decision`` must accept."""
|
||||
|
||||
def test_legacy_decision_type_dict_passes_through(self) -> None:
|
||||
decision = {"decision_type": "once"}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_legacy_decision_type_with_feedback_passes_through(self) -> None:
|
||||
decision = {"decision_type": "reject", "feedback": "no thanks"}
|
||||
assert _normalize_permission_decision(decision) == decision
|
||||
|
||||
def test_plain_string_wrapped(self) -> None:
|
||||
assert _normalize_permission_decision("once") == {"decision_type": "once"}
|
||||
assert _normalize_permission_decision("reject") == {"decision_type": "reject"}
|
||||
|
||||
def test_lc_envelope_approve_maps_to_once(self) -> None:
|
||||
decision = {"decisions": [{"type": "approve"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_lc_envelope_reject_maps_to_reject(self) -> None:
|
||||
decision = {"decisions": [{"type": "reject"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||
|
||||
def test_lc_envelope_reject_with_message_carries_feedback(self) -> None:
|
||||
decision = {"decisions": [{"type": "reject", "message": "wrong recipient"}]}
|
||||
out = _normalize_permission_decision(decision)
|
||||
assert out == {"decision_type": "reject", "feedback": "wrong recipient"}
|
||||
|
||||
def test_lc_envelope_reject_with_feedback_field(self) -> None:
|
||||
decision = {
|
||||
"decisions": [{"type": "reject", "feedback": "tighten the subject"}]
|
||||
}
|
||||
out = _normalize_permission_decision(decision)
|
||||
assert out == {"decision_type": "reject", "feedback": "tighten the subject"}
|
||||
|
||||
def test_lc_envelope_edit_maps_to_once(self) -> None:
|
||||
# Pins the contract: edited args are NOT merged by permission.
|
||||
decision = {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": {
|
||||
"name": "send_email",
|
||||
"args": {"subject": "edited"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "once"}
|
||||
|
||||
def test_lc_single_decision_without_envelope(self) -> None:
|
||||
assert _normalize_permission_decision({"type": "approve"}) == {
|
||||
"decision_type": "once"
|
||||
}
|
||||
|
||||
def test_unknown_type_falls_back_to_reject(self) -> None:
|
||||
decision = {"decisions": [{"type": "totally_unknown"}]}
|
||||
assert _normalize_permission_decision(decision) == {"decision_type": "reject"}
|
||||
|
||||
def test_missing_type_falls_back_to_reject(self) -> None:
|
||||
assert _normalize_permission_decision({"decisions": [{}]}) == {
|
||||
"decision_type": "reject"
|
||||
}
|
||||
|
||||
def test_non_dict_non_string_falls_back_to_reject(self) -> None:
|
||||
assert _normalize_permission_decision(None) == {"decision_type": "reject"}
|
||||
assert _normalize_permission_decision(42) == {"decision_type": "reject"}
|
||||
|
||||
def test_empty_decisions_list_falls_back_to_reject(self) -> None:
|
||||
# Fail-closed on a malformed reply rather than treat it as approve.
|
||||
assert _normalize_permission_decision({"decisions": []}) == {
|
||||
"decision_type": "reject"
|
||||
}
|
||||
|
||||
|
||||
class TestResumeShapesEndToEnd:
|
||||
"""LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``."""
|
||||
|
||||
def test_lc_approve_envelope_keeps_call(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: { # type: ignore[assignment]
|
||||
"decisions": [{"type": "approve"}]
|
||||
}
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
original = mw._raise_interrupt
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None
|
||||
|
||||
def test_lc_reject_envelope_raises(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(RejectedError):
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
|
||||
def test_lc_reject_with_message_raises_corrected(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: { # noqa: E731
|
||||
"decisions": [{"type": "reject", "message": "wrong recipient"}]
|
||||
}
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
with pytest.raises(CorrectedError) as excinfo:
|
||||
mw.after_model(state, _FakeRuntime())
|
||||
assert excinfo.value.feedback == "wrong recipient"
|
||||
|
||||
def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None:
|
||||
# Pins the "edit -> once, args unchanged" contract.
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
original = lambda **kw: { # noqa: E731
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": {
|
||||
"name": "send_email",
|
||||
"args": {"to": "edited@example.com"},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment]
|
||||
original(**kw)
|
||||
)
|
||||
state = {
|
||||
"messages": [
|
||||
_msg(
|
||||
{
|
||||
"name": "send_email",
|
||||
"args": {"to": "original@example.com"},
|
||||
"id": "1",
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue