""" PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback. LangChain's :class:`HumanInTheLoopMiddleware` only supports a static "this tool always asks" decision per tool. There's no rule-based allow/deny/ask layered ruleset, no glob patterns, no per-search-space or per-thread overrides, and no auto-deny synthesis. This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts`` ruleset model on top of SurfSense's existing ``interrupt({type, action, context})`` payload shape (see ``app/agents/new_chat/tools/hitl.py``) so the frontend keeps working unchanged. Operation: 1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``. 2. For each call, the middleware builds a list of ``patterns`` (the tool name plus any tool-specific patterns from the resolver). It evaluates each pattern against the layered rulesets and aggregates the results: ``deny`` > ``ask`` > ``allow``. 3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` containing a :class:`StreamingError`. 4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``. - ``once``: proceed. - ``always``: also persist allow rules for ``request.always`` patterns. - ``reject`` w/o feedback: raise :class:`RejectedError`. - ``reject`` w/ feedback: raise :class:`CorrectedError`. 5. On ``allow``: proceed unchanged. The middleware also performs a *pre-model* tool-filter step (the ``before_model`` hook) so globally denied tools are stripped from the exposed tool list before the model gets to see them. This mirrors OpenCode's ``Permission.disabled`` and dramatically reduces the chance the model emits a deny-only call. """ from __future__ import annotations import logging from collections.abc import Callable from typing import Any from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, ContextT, ) from langchain_core.messages import AIMessage, ToolMessage from langgraph.runtime import Runtime from langgraph.types import interrupt from app.agents.new_chat.errors import ( CorrectedError, RejectedError, StreamingError, ) from app.agents.new_chat.permissions import ( Rule, Ruleset, aggregate_action, evaluate_many, ) from app.observability import otel as ot logger = logging.getLogger(__name__) # Mapping ``tool_name -> resolver`` that converts ``args`` to a list of # patterns to evaluate. The first pattern is conventionally the bare # tool name; later entries narrow down to specific resources. PatternResolver = Callable[[dict[str, Any]], list[str]] def _default_pattern_resolver(name: str) -> PatternResolver: def _resolve(args: dict[str, Any]) -> list[str]: # Bare name covers the default catch-all; primary-arg fallbacks # are best added per-tool by callers. del args return [name] return _resolve class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] """Allow/deny/ask layer over the agent's tool calls. Args: rulesets: Layered rulesets to evaluate. Earlier entries are overridden by later ones (last-match-wins). Typical layering: ``defaults < global < space < thread < runtime_approved``. pattern_resolvers: Optional per-tool callables that return a list of patterns to evaluate. When a tool isn't listed, the bare tool name is used as the only pattern. runtime_ruleset: Mutable :class:`Ruleset` that the middleware extends in-place when the user replies ``"always"`` to an ask interrupt. Reused across all calls in the same agent instance so newly-allowed rules apply to subsequent calls. always_emit_interrupt_payload: If True, every ask uses the SurfSense interrupt wire format (default). Set False to disable interrupts and treat ``ask`` as ``deny`` for non-interactive deployments. """ tools = () def __init__( self, *, rulesets: list[Ruleset] | None = None, pattern_resolvers: dict[str, PatternResolver] | None = None, runtime_ruleset: Ruleset | None = None, always_emit_interrupt_payload: bool = True, ) -> None: super().__init__() self._static_rulesets: list[Ruleset] = list(rulesets or []) self._pattern_resolvers: dict[str, PatternResolver] = dict( pattern_resolvers or {} ) self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset( origin="runtime_approved" ) self._emit_interrupt = always_emit_interrupt_payload # ------------------------------------------------------------------ # Tool-filter step (mirrors OpenCode's ``Permission.disabled``) # ------------------------------------------------------------------ def _globally_denied(self, tool_name: str) -> bool: """Return True if a deny rule with no narrowing pattern matches.""" rules = evaluate_many(tool_name, ["*"], *self._all_rulesets()) return aggregate_action(rules) == "deny" def _all_rulesets(self) -> list[Ruleset]: return [*self._static_rulesets, self._runtime_ruleset] # NOTE: ``before_model`` filtering of the tools list is left to the # agent factory. This middleware only blocks at execution time — and # only via the rule-evaluator path, not by mutating ``request.tools``. # Mutating ``request.tools`` per-call would invalidate provider # prompt-cache prefixes (see Operational risks: prompt-cache regression). # ------------------------------------------------------------------ # Tool-call evaluation # ------------------------------------------------------------------ def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]: resolver = self._pattern_resolvers.get( tool_name, _default_pattern_resolver(tool_name) ) try: patterns = resolver(args or {}) except Exception: logger.exception( "Pattern resolver for %s raised; using bare name", tool_name ) patterns = [tool_name] if not patterns: patterns = [tool_name] return patterns def _evaluate( self, tool_name: str, args: dict[str, Any] ) -> tuple[str, list[str], list[Rule]]: patterns = self._resolve_patterns(tool_name, args) rules = evaluate_many(tool_name, patterns, *self._all_rulesets()) action = aggregate_action(rules) return action, patterns, rules # ------------------------------------------------------------------ # HITL ask flow — SurfSense wire format # ------------------------------------------------------------------ def _raise_interrupt( self, *, tool_name: str, args: dict[str, Any], patterns: list[str], rules: list[Rule], ) -> dict[str, Any]: """Block on user approval via SurfSense's ``interrupt`` shape.""" if not self._emit_interrupt: return {"decision_type": "reject"} # ``params`` (NOT ``args``) is what SurfSense's streaming # normalizer forwards. Other fields move into ``context``. payload = { "type": "permission_ask", "action": {"tool": tool_name, "params": args or {}}, "context": { "patterns": patterns, "rules": [ { "permission": r.permission, "pattern": r.pattern, "action": r.action, } for r in rules ], # Rules of thumb for the frontend: surface the patterns # the user can promote to "always" with a single reply. "always": patterns, }, } # Open ``permission.asked`` + ``interrupt.raised`` OTel spans # (no-op when OTel is disabled) so dashboards can correlate # "we asked X" with "interrupt was actually delivered". with ( ot.permission_asked_span( permission=tool_name, pattern=patterns[0] if patterns else None, extra={"permission.patterns": list(patterns)}, ), ot.interrupt_span(interrupt_type="permission_ask"), ): decision = interrupt(payload) if isinstance(decision, dict): return decision # Tolerate a plain string reply ("once", "always", "reject") if isinstance(decision, str): return {"decision_type": decision} return {"decision_type": "reject"} def _persist_always(self, tool_name: str, patterns: list[str]) -> None: """Promote ``always`` reply into runtime allow rules. Persistence to ``agent_permission_rules`` is done by the streaming layer (``stream_new_chat``) once it observes the ``always`` reply — the middleware just keeps an in-memory copy so subsequent calls in the same stream see the rule. """ for pattern in patterns: self._runtime_ruleset.rules.append( Rule(permission=tool_name, pattern=pattern, action="allow") ) # ------------------------------------------------------------------ # Synthesizing deny -> ToolMessage # ------------------------------------------------------------------ @staticmethod def _deny_message( tool_call: dict[str, Any], rule: Rule, ) -> ToolMessage: err = StreamingError( code="permission_denied", retryable=False, suggestion=( f"rule permission={rule.permission!r} pattern={rule.pattern!r} " f"blocked this call" ), ) return ToolMessage( content=( f"Permission denied: rule {rule.permission}/{rule.pattern} " f"blocked tool {tool_call.get('name')!r}." ), tool_call_id=tool_call.get("id") or "", name=tool_call.get("name"), status="error", additional_kwargs={"error": err.model_dump()}, ) # ------------------------------------------------------------------ # The hook: aafter_model # ------------------------------------------------------------------ def _process( self, state: AgentState, runtime: Runtime[Any], ) -> dict[str, Any] | None: del runtime # unused messages = state.get("messages") or [] if not messages: return None last = messages[-1] if not isinstance(last, AIMessage) or not last.tool_calls: return None deny_messages: list[ToolMessage] = [] kept_calls: list[dict[str, Any]] = [] any_change = False for raw in last.tool_calls: call = ( dict(raw) if isinstance(raw, dict) else { "name": getattr(raw, "name", None), "args": getattr(raw, "args", {}), "id": getattr(raw, "id", None), "type": "tool_call", } ) name = call.get("name") or "" args = call.get("args") or {} action, patterns, rules = self._evaluate(name, args) if action == "deny": # Find the deny rule for the suggestion text deny_rule = next((r for r in rules if r.action == "deny"), rules[0]) deny_messages.append(self._deny_message(call, deny_rule)) any_change = True continue if action == "ask": decision = self._raise_interrupt( tool_name=name, args=args, patterns=patterns, rules=rules ) kind = str(decision.get("decision_type") or "reject").lower() if kind == "once": kept_calls.append(call) elif kind == "always": self._persist_always(name, patterns) kept_calls.append(call) elif kind == "reject": feedback = decision.get("feedback") if isinstance(feedback, str) and feedback.strip(): raise CorrectedError(feedback, tool=name) raise RejectedError( tool=name, pattern=patterns[0] if patterns else None ) else: logger.warning( "Unknown permission decision %r; treating as reject", kind ) raise RejectedError(tool=name) continue # allow kept_calls.append(call) if not any_change and len(kept_calls) == len(last.tool_calls): return None updated = last.model_copy(update={"tool_calls": kept_calls}) result_messages: list[Any] = [updated] if deny_messages: result_messages.extend(deny_messages) return {"messages": result_messages} def after_model( # type: ignore[override] self, state: AgentState, runtime: Runtime[ContextT] ) -> dict[str, Any] | None: return self._process(state, runtime) async def aafter_model( # type: ignore[override] self, state: AgentState, runtime: Runtime[ContextT] ) -> dict[str, Any] | None: return self._process(state, runtime) __all__ = [ "PatternResolver", "PermissionMiddleware", ]