diff --git a/surfsense_backend/app/agents/shared/middleware/__init__.py b/surfsense_backend/app/agents/shared/middleware/__init__.py index 9acc5292a..001025ff4 100644 --- a/surfsense_backend/app/agents/shared/middleware/__init__.py +++ b/surfsense_backend/app/agents/shared/middleware/__init__.py @@ -9,13 +9,11 @@ from app.agents.shared.middleware.kb_persistence import ( KnowledgeBasePersistenceMiddleware, commit_staged_filesystem_state, ) -from app.agents.shared.middleware.permission import PermissionMiddleware from app.agents.shared.middleware.retry_after import RetryAfterMiddleware __all__ = [ "BusyMutexMiddleware", "KnowledgeBasePersistenceMiddleware", - "PermissionMiddleware", "RetryAfterMiddleware", "SurfSenseCompactionMiddleware", "commit_staged_filesystem_state", diff --git a/surfsense_backend/app/agents/shared/middleware/permission.py b/surfsense_backend/app/agents/shared/middleware/permission.py deleted file mode 100644 index 4e624a81a..000000000 --- a/surfsense_backend/app/agents/shared/middleware/permission.py +++ /dev/null @@ -1,427 +0,0 @@ -""" -PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback. - -LangChain's :class:`HumanInTheLoopMiddleware` only supports a static -"this tool always asks" decision per tool. There's no rule-based -allow/deny/ask layered ruleset, no glob patterns, no per-search-space or -per-thread overrides, and no auto-deny synthesis. - -This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts`` -ruleset model on top of SurfSense's existing ``interrupt({type, action, -context})`` payload shape (see ``app/agents/shared/tools/hitl.py``) so -the frontend keeps working unchanged. - -Operation: -1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``. -2. For each call, the middleware builds a list of ``patterns`` (the - tool name plus any tool-specific patterns from the resolver). It - evaluates each pattern against the layered rulesets and aggregates - the results: ``deny`` > ``ask`` > ``allow``. -3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` - containing a :class:`StreamingError`. -4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy - SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}`` - replies are accepted via :func:`_normalize_permission_decision`. - - ``once``: proceed. - - ``approve_always``: also persist allow rules for ``request.always`` patterns. - - ``reject`` w/o feedback: raise :class:`RejectedError`. - - ``reject`` w/ feedback: raise :class:`CorrectedError`. -5. On ``allow``: proceed unchanged. - -The middleware also performs a *pre-model* tool-filter step (the -``before_model`` hook) so globally denied tools are stripped from the -exposed tool list before the model gets to see them. This mirrors -OpenCode's ``Permission.disabled`` and dramatically reduces the chance -the model emits a deny-only call. -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import Any - -from langchain.agents.middleware.types import ( - AgentMiddleware, - AgentState, - ContextT, -) -from langchain_core.messages import AIMessage, ToolMessage -from langgraph.runtime import Runtime -from langgraph.types import interrupt - -from app.agents.multi_agent_chat.shared.permissions import ( - Rule, - Ruleset, - aggregate_action, - evaluate_many, -) -from app.agents.shared.errors import ( - CorrectedError, - RejectedError, - StreamingError, -) -from app.observability import metrics as ot_metrics, otel as ot - -logger = logging.getLogger(__name__) - - -# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of -# patterns to evaluate. The first pattern is conventionally the bare -# tool name; later entries narrow down to specific resources. -PatternResolver = Callable[[dict[str, Any]], list[str]] - - -def _default_pattern_resolver(name: str) -> PatternResolver: - def _resolve(args: dict[str, Any]) -> list[str]: - # Bare name covers the default catch-all; primary-arg fallbacks - # are best added per-tool by callers. - del args - return [name] - - return _resolve - - -# Translation from the LangChain HITL envelope (what ``stream_resume_chat`` -# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the -# original tool args — tools needing argument edits should use -# ``request_approval`` from ``app/agents/shared/tools/hitl.py``. -_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = { - "approve": "once", - "reject": "reject", - "edit": "once", - "approve_always": "approve_always", -} - - -def _normalize_permission_decision(decision: Any) -> dict[str, Any]: - """Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``. - - Falls back to ``reject`` (with a warning) on unrecognized payloads so the - middleware fails closed. - """ - if isinstance(decision, str): - return {"decision_type": decision} - if not isinstance(decision, dict): - logger.warning( - "Unrecognized permission resume value (%s); treating as reject", - type(decision).__name__, - ) - return {"decision_type": "reject"} - - if decision.get("decision_type"): - return decision - - payload: dict[str, Any] = decision - decisions = decision.get("decisions") - if isinstance(decisions, list) and decisions: - first = decisions[0] - if isinstance(first, dict): - payload = first - - raw_type = payload.get("type") or payload.get("decision_type") - if not raw_type: - logger.warning( - "Permission resume missing decision type (keys=%s); treating as reject", - list(payload.keys()), - ) - return {"decision_type": "reject"} - - raw_type = str(raw_type).lower() - mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type) - if mapped is None: - # Tolerate legacy values arriving without ``decision_type`` wrapping. - if raw_type in {"once", "approve_always", "reject"}: - mapped = raw_type - else: - logger.warning( - "Unknown permission decision type %r; treating as reject", raw_type - ) - mapped = "reject" - - if raw_type == "edit": - logger.warning( - "Permission middleware received an 'edit' decision; original args " - "kept (edits not merged here)." - ) - - out: dict[str, Any] = {"decision_type": mapped} - feedback = payload.get("feedback") or payload.get("message") - if isinstance(feedback, str) and feedback.strip(): - out["feedback"] = feedback - return out - - -class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Allow/deny/ask layer over the agent's tool calls. - - Args: - rulesets: Layered rulesets to evaluate. Earlier entries are - overridden by later ones (last-match-wins). Typical layering: - ``defaults < global < space < thread < runtime_approved``. - pattern_resolvers: Optional per-tool callables that return a list - of patterns to evaluate. When a tool isn't listed, the bare - tool name is used as the only pattern. - runtime_ruleset: Mutable :class:`Ruleset` that the middleware - extends in-place when the user replies ``"approve_always"`` to - an ask interrupt. Reused across all calls in the same agent - instance so newly-allowed rules apply to subsequent calls. - always_emit_interrupt_payload: If True, every ask uses the - SurfSense interrupt wire format (default). Set False to - disable interrupts and treat ``ask`` as ``deny`` for - non-interactive deployments. - """ - - tools = () - - def __init__( - self, - *, - rulesets: list[Ruleset] | None = None, - pattern_resolvers: dict[str, PatternResolver] | None = None, - runtime_ruleset: Ruleset | None = None, - always_emit_interrupt_payload: bool = True, - ) -> None: - super().__init__() - self._static_rulesets: list[Ruleset] = list(rulesets or []) - self._pattern_resolvers: dict[str, PatternResolver] = dict( - pattern_resolvers or {} - ) - self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset( - origin="runtime_approved" - ) - self._emit_interrupt = always_emit_interrupt_payload - - # ------------------------------------------------------------------ - # Tool-filter step (mirrors OpenCode's ``Permission.disabled``) - # ------------------------------------------------------------------ - - def _globally_denied(self, tool_name: str) -> bool: - """Return True if a deny rule with no narrowing pattern matches.""" - rules = evaluate_many(tool_name, ["*"], *self._all_rulesets()) - return aggregate_action(rules) == "deny" - - def _all_rulesets(self) -> list[Ruleset]: - return [*self._static_rulesets, self._runtime_ruleset] - - # NOTE: ``before_model`` filtering of the tools list is left to the - # agent factory. This middleware only blocks at execution time — and - # only via the rule-evaluator path, not by mutating ``request.tools``. - # Mutating ``request.tools`` per-call would invalidate provider - # prompt-cache prefixes (see Operational risks: prompt-cache regression). - - # ------------------------------------------------------------------ - # Tool-call evaluation - # ------------------------------------------------------------------ - - def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]: - resolver = self._pattern_resolvers.get( - tool_name, _default_pattern_resolver(tool_name) - ) - try: - patterns = resolver(args or {}) - except Exception: - logger.exception( - "Pattern resolver for %s raised; using bare name", tool_name - ) - patterns = [tool_name] - if not patterns: - patterns = [tool_name] - return patterns - - def _evaluate( - self, tool_name: str, args: dict[str, Any] - ) -> tuple[str, list[str], list[Rule]]: - patterns = self._resolve_patterns(tool_name, args) - rules = evaluate_many(tool_name, patterns, *self._all_rulesets()) - action = aggregate_action(rules) - return action, patterns, rules - - # ------------------------------------------------------------------ - # HITL ask flow — SurfSense wire format - # ------------------------------------------------------------------ - - def _raise_interrupt( - self, - *, - tool_name: str, - args: dict[str, Any], - patterns: list[str], - rules: list[Rule], - ) -> dict[str, Any]: - """Block on user approval via SurfSense's ``interrupt`` shape.""" - if not self._emit_interrupt: - return {"decision_type": "reject"} - - # ``params`` (NOT ``args``) is what SurfSense's streaming - # normalizer forwards. Other fields move into ``context``. - payload = { - "type": "permission_ask", - "action": {"tool": tool_name, "params": args or {}}, - "context": { - "patterns": patterns, - "rules": [ - { - "permission": r.permission, - "pattern": r.pattern, - "action": r.action, - } - for r in rules - ], - # Rules of thumb for the frontend: surface the patterns - # the user can promote to "approve_always" with a single reply. - "always": patterns, - }, - } - # Open ``permission.asked`` + ``interrupt.raised`` OTel spans - # (no-op when OTel is disabled) so dashboards can correlate - # "we asked X" with "interrupt was actually delivered". - with ( - ot.permission_asked_span( - permission=tool_name, - pattern=patterns[0] if patterns else None, - extra={"permission.patterns": list(patterns)}, - ), - ot.interrupt_span(interrupt_type="permission_ask"), - ): - ot_metrics.record_permission_ask(permission=tool_name) - ot_metrics.record_interrupt(interrupt_type="permission_ask") - decision = interrupt(payload) - return _normalize_permission_decision(decision) - - def _persist_always(self, tool_name: str, patterns: list[str]) -> None: - """Promote ``approve_always`` reply into runtime allow rules. - - Persistence to ``agent_permission_rules`` is done by the - streaming layer (``stream_new_chat``) once it observes the - ``approve_always`` reply — the middleware just keeps an - in-memory copy so subsequent calls in the same stream see the rule. - """ - for pattern in patterns: - self._runtime_ruleset.rules.append( - Rule(permission=tool_name, pattern=pattern, action="allow") - ) - - # ------------------------------------------------------------------ - # Synthesizing deny -> ToolMessage - # ------------------------------------------------------------------ - - @staticmethod - def _deny_message( - tool_call: dict[str, Any], - rule: Rule, - ) -> ToolMessage: - err = StreamingError( - code="permission_denied", - retryable=False, - suggestion=( - f"rule permission={rule.permission!r} pattern={rule.pattern!r} " - f"blocked this call" - ), - ) - return ToolMessage( - content=( - f"Permission denied: rule {rule.permission}/{rule.pattern} " - f"blocked tool {tool_call.get('name')!r}." - ), - tool_call_id=tool_call.get("id") or "", - name=tool_call.get("name"), - status="error", - additional_kwargs={"error": err.model_dump()}, - ) - - # ------------------------------------------------------------------ - # The hook: aafter_model - # ------------------------------------------------------------------ - - def _process( - self, - state: AgentState, - runtime: Runtime[Any], - ) -> dict[str, Any] | None: - del runtime # unused - messages = state.get("messages") or [] - if not messages: - return None - last = messages[-1] - if not isinstance(last, AIMessage) or not last.tool_calls: - return None - - deny_messages: list[ToolMessage] = [] - kept_calls: list[dict[str, Any]] = [] - any_change = False - - for raw in last.tool_calls: - call = ( - dict(raw) - if isinstance(raw, dict) - else { - "name": getattr(raw, "name", None), - "args": getattr(raw, "args", {}), - "id": getattr(raw, "id", None), - "type": "tool_call", - } - ) - name = call.get("name") or "" - args = call.get("args") or {} - action, patterns, rules = self._evaluate(name, args) - - if action == "deny": - # Find the deny rule for the suggestion text - deny_rule = next((r for r in rules if r.action == "deny"), rules[0]) - deny_messages.append(self._deny_message(call, deny_rule)) - any_change = True - continue - - if action == "ask": - decision = self._raise_interrupt( - tool_name=name, args=args, patterns=patterns, rules=rules - ) - kind = str(decision.get("decision_type") or "reject").lower() - if kind == "once": - kept_calls.append(call) - elif kind == "approve_always": - self._persist_always(name, patterns) - kept_calls.append(call) - elif kind == "reject": - feedback = decision.get("feedback") - if isinstance(feedback, str) and feedback.strip(): - raise CorrectedError(feedback, tool=name) - raise RejectedError( - tool=name, pattern=patterns[0] if patterns else None - ) - else: - logger.warning( - "Unknown permission decision %r; treating as reject", kind - ) - raise RejectedError(tool=name) - continue - - # allow - kept_calls.append(call) - - if not any_change and len(kept_calls) == len(last.tool_calls): - return None - - updated = last.model_copy(update={"tool_calls": kept_calls}) - result_messages: list[Any] = [updated] - if deny_messages: - result_messages.extend(deny_messages) - return {"messages": result_messages} - - def after_model( # type: ignore[override] - self, state: AgentState, runtime: Runtime[ContextT] - ) -> dict[str, Any] | None: - return self._process(state, runtime) - - async def aafter_model( # type: ignore[override] - self, state: AgentState, runtime: Runtime[ContextT] - ) -> dict[str, Any] | None: - return self._process(state, runtime) - - -__all__ = [ - "PatternResolver", - "PermissionMiddleware", - "_normalize_permission_decision", -] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py index 0f5a5c6d0..d9c5410d7 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py @@ -16,7 +16,6 @@ from app.agents.multi_agent_chat.shared.permissions import ( aggregate_action, evaluate_many, ) -from app.agents.shared.middleware.permission import PermissionMiddleware pytestmark = pytest.mark.unit @@ -87,36 +86,3 @@ class TestDesktopSafetyOverridesAllowDefault: # Correct order: defaults < desktop_safety -> ask wins. action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) assert action == "ask" - - -class TestPermissionMiddlewareIntegration: - def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None: - from langchain_core.messages import AIMessage - - from app.agents.shared.errors import RejectedError - - mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET]) - # Stub the interrupt to a "reject" decision so we can assert the - # ask path was taken without spinning up the LangGraph runtime. - mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] - - state = { - "messages": [ - AIMessage( - content="", - tool_calls=[ - { - "name": "rm", - "args": {"path": "/Users/me/Documents/important.docx"}, - "id": "tc-rm", - } - ], - ) - ] - } - - class _FakeRuntime: - config: dict = {"configurable": {"thread_id": "test"}} - - with pytest.raises(RejectedError): - mw.after_model(state, _FakeRuntime()) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py deleted file mode 100644 index 66bd7a74b..000000000 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py +++ /dev/null @@ -1,263 +0,0 @@ -"""Tests for PermissionMiddleware end-to-end behavior.""" - -from __future__ import annotations - -import pytest -from langchain_core.messages import AIMessage, ToolMessage - -from app.agents.multi_agent_chat.shared.permissions import Rule, Ruleset -from app.agents.shared.errors import CorrectedError, RejectedError -from app.agents.shared.middleware.permission import ( - PermissionMiddleware, - _normalize_permission_decision, -) - -pytestmark = pytest.mark.unit - - -class _FakeRuntime: - config: dict = {"configurable": {"thread_id": "test"}} - - -def _msg(*tool_calls: dict) -> AIMessage: - return AIMessage(content="", tool_calls=list(tool_calls)) - - -class TestAllow: - def test_passthrough_when_allow(self) -> None: - rs = Ruleset(rules=[Rule("send_email", "*", "allow")]) - mw = PermissionMiddleware(rulesets=[rs]) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - assert out is None # no change - - -class TestDeny: - def test_replaces_with_deny_tool_message(self) -> None: - rs = Ruleset(rules=[Rule("send_email", "*", "deny")]) - mw = PermissionMiddleware(rulesets=[rs]) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - assert out is not None - msgs = out["messages"] - # Find the deny ToolMessage - deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)] - assert len(deny_msgs) == 1 - assert deny_msgs[0].status == "error" - assert "permission_denied" in str(deny_msgs[0].additional_kwargs) - # AIMessage's tool_calls should now be empty (denied call removed) - ai_msg = next(m for m in msgs if isinstance(m, AIMessage)) - assert ai_msg.tool_calls == [] - - def test_mixed_allow_deny(self) -> None: - rs = Ruleset( - rules=[ - Rule("send_email", "*", "deny"), - Rule("read", "*", "allow"), - ] - ) - mw = PermissionMiddleware(rulesets=[rs]) - state = { - "messages": [ - _msg( - {"name": "send_email", "args": {}, "id": "1"}, - {"name": "read", "args": {}, "id": "2"}, - ) - ] - } - out = mw.after_model(state, _FakeRuntime()) - assert out is not None - ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage)) - assert len(ai_msg.tool_calls) == 1 - assert ai_msg.tool_calls[0]["name"] == "read" - - -class TestAsk: - def test_reject_without_feedback_raises(self) -> None: - # Default: nothing matches -> ask - rs = Ruleset(rules=[]) - mw = PermissionMiddleware(rulesets=[rs]) - - # Bypass real interrupt — patch the helper - mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(RejectedError): - mw.after_model(state, _FakeRuntime()) - - def test_reject_with_feedback_raises_corrected(self) -> None: - rs = Ruleset(rules=[]) - mw = PermissionMiddleware(rulesets=[rs]) - mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] - "decision_type": "reject", - "feedback": "use a different subject line", - } - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(CorrectedError) as excinfo: - mw.after_model(state, _FakeRuntime()) - assert excinfo.value.feedback == "use a different subject line" - - def test_once_proceeds_without_persisting(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment] - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - # No state change because all calls kept - assert out is None - # No new rule persisted - assert mw._runtime_ruleset.rules == [] - - def test_approve_always_persists_runtime_rule(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: {"decision_type": "approve_always"} # type: ignore[assignment] - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - out = mw.after_model(state, _FakeRuntime()) - assert out is None # call kept - # Runtime ruleset got the always-allow rule - new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] - assert any(r.permission == "send_email" for r in new_rules) - - -class TestNormalizeDecision: - """Resume shapes ``_normalize_permission_decision`` must accept.""" - - def test_legacy_decision_type_dict_passes_through(self) -> None: - decision = {"decision_type": "once"} - assert _normalize_permission_decision(decision) == {"decision_type": "once"} - - def test_legacy_decision_type_with_feedback_passes_through(self) -> None: - decision = {"decision_type": "reject", "feedback": "no thanks"} - assert _normalize_permission_decision(decision) == decision - - def test_plain_string_wrapped(self) -> None: - assert _normalize_permission_decision("once") == {"decision_type": "once"} - assert _normalize_permission_decision("reject") == {"decision_type": "reject"} - - def test_lc_envelope_approve_maps_to_once(self) -> None: - decision = {"decisions": [{"type": "approve"}]} - assert _normalize_permission_decision(decision) == {"decision_type": "once"} - - def test_lc_envelope_reject_maps_to_reject(self) -> None: - decision = {"decisions": [{"type": "reject"}]} - assert _normalize_permission_decision(decision) == {"decision_type": "reject"} - - def test_lc_envelope_reject_with_message_carries_feedback(self) -> None: - decision = {"decisions": [{"type": "reject", "message": "wrong recipient"}]} - out = _normalize_permission_decision(decision) - assert out == {"decision_type": "reject", "feedback": "wrong recipient"} - - def test_lc_envelope_reject_with_feedback_field(self) -> None: - decision = { - "decisions": [{"type": "reject", "feedback": "tighten the subject"}] - } - out = _normalize_permission_decision(decision) - assert out == {"decision_type": "reject", "feedback": "tighten the subject"} - - def test_lc_envelope_edit_maps_to_once(self) -> None: - # Pins the contract: edited args are NOT merged by permission. - decision = { - "decisions": [ - { - "type": "edit", - "edited_action": { - "name": "send_email", - "args": {"subject": "edited"}, - }, - } - ] - } - assert _normalize_permission_decision(decision) == {"decision_type": "once"} - - def test_lc_single_decision_without_envelope(self) -> None: - assert _normalize_permission_decision({"type": "approve"}) == { - "decision_type": "once" - } - - def test_unknown_type_falls_back_to_reject(self) -> None: - decision = {"decisions": [{"type": "totally_unknown"}]} - assert _normalize_permission_decision(decision) == {"decision_type": "reject"} - - def test_missing_type_falls_back_to_reject(self) -> None: - assert _normalize_permission_decision({"decisions": [{}]}) == { - "decision_type": "reject" - } - - def test_non_dict_non_string_falls_back_to_reject(self) -> None: - assert _normalize_permission_decision(None) == {"decision_type": "reject"} - assert _normalize_permission_decision(42) == {"decision_type": "reject"} - - def test_empty_decisions_list_falls_back_to_reject(self) -> None: - # Fail-closed on a malformed reply rather than treat it as approve. - assert _normalize_permission_decision({"decisions": []}) == { - "decision_type": "reject" - } - - -class TestResumeShapesEndToEnd: - """LangChain HITL envelope reaches ``_process`` correctly via ``_raise_interrupt``.""" - - def test_lc_approve_envelope_keeps_call(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] - "decisions": [{"type": "approve"}] - } - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - original = mw._raise_interrupt - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - out = mw.after_model(state, _FakeRuntime()) - assert out is None - - def test_lc_reject_envelope_raises(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - original = lambda **kw: {"decisions": [{"type": "reject"}]} # noqa: E731 - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(RejectedError): - mw.after_model(state, _FakeRuntime()) - - def test_lc_reject_with_message_raises_corrected(self) -> None: - mw = PermissionMiddleware(rulesets=[]) - original = lambda **kw: { # noqa: E731 - "decisions": [{"type": "reject", "message": "wrong recipient"}] - } - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} - with pytest.raises(CorrectedError) as excinfo: - mw.after_model(state, _FakeRuntime()) - assert excinfo.value.feedback == "wrong recipient" - - def test_lc_edit_envelope_keeps_call_with_original_args(self) -> None: - # Pins the "edit -> once, args unchanged" contract. - mw = PermissionMiddleware(rulesets=[]) - original = lambda **kw: { # noqa: E731 - "decisions": [ - { - "type": "edit", - "edited_action": { - "name": "send_email", - "args": {"to": "edited@example.com"}, - }, - } - ] - } - mw._raise_interrupt = lambda **kw: _normalize_permission_decision( # type: ignore[assignment] - original(**kw) - ) - state = { - "messages": [ - _msg( - { - "name": "send_email", - "args": {"to": "original@example.com"}, - "id": "1", - } - ) - ] - } - out = mw.after_model(state, _FakeRuntime()) - assert out is None