multi_agent_chat/permissions: clone PermissionMiddleware with SRP split and edit support

This commit is contained in:
CREDO23 2026-05-12 12:58:53 +02:00
parent 3f77c74daf
commit 9b82f2db1d
15 changed files with 660 additions and 0 deletions

View file

@ -0,0 +1,16 @@
"""Pattern-based allow/deny/ask middleware with HITL fallback.
Public surface: :class:`PermissionMiddleware` plus
:func:`normalize_permission_decision` for the streaming layer and the
:data:`PatternResolver` type for callers that register per-tool resolvers.
"""
from .decision import normalize_permission_decision
from .middleware import PermissionMiddleware
from .pattern_resolver import PatternResolver
__all__ = [
"PatternResolver",
"PermissionMiddleware",
"normalize_permission_decision",
]

View file

@ -0,0 +1,91 @@
"""Coerce inbound permission decisions to a canonical dict shape.
Two wire formats are accepted:
- SurfSense legacy: ``{"decision_type": "once"|"always"|"reject", "feedback"?}``.
- LangChain HITL envelope: ``{"decisions": [{"type": "approve"|"edit"|"reject", ...}]}``.
The middleware downstream only inspects the canonical shape returned here,
so adding a new envelope means changing this module alone.
The middleware fails closed: any unrecognised payload becomes ``reject``
(with a warning) so the agent never proceeds on ambiguous input.
When the reply is an ``edit``, the result keeps ``decision_type="once"``
(the call still goes through) and adds an ``edited_args`` key holding the
user-modified ``args`` dict. The orchestrator merges those into the
``tool_call`` before keeping it; see :mod:`interrupt.edit.merge`.
"""
from __future__ import annotations
import logging
from typing import Any
from .interrupt.edit import extract_edited_args
logger = logging.getLogger(__name__)
# ``edit`` collapses to ``once``; any ``edited_args`` ride on the result.
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
"approve": "once",
"reject": "reject",
"edit": "once",
}
def normalize_permission_decision(decision: Any) -> dict[str, Any]:
"""Return ``{"decision_type": ..., "feedback"?: str, "edited_args"?: dict}``."""
if isinstance(decision, str):
return {"decision_type": decision}
if not isinstance(decision, dict):
logger.warning(
"Unrecognized permission resume value (%s); treating as reject",
type(decision).__name__,
)
return {"decision_type": "reject"}
if decision.get("decision_type"):
return decision
payload: dict[str, Any] = decision
decisions = decision.get("decisions")
if isinstance(decisions, list) and decisions:
first = decisions[0]
if isinstance(first, dict):
payload = first
raw_type = payload.get("type") or payload.get("decision_type")
if not raw_type:
logger.warning(
"Permission resume missing decision type (keys=%s); treating as reject",
list(payload.keys()),
)
return {"decision_type": "reject"}
raw_type = str(raw_type).lower()
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
if mapped is None:
# Tolerate legacy values arriving without ``decision_type`` wrapping.
if raw_type in {"once", "always", "reject"}:
mapped = raw_type
else:
logger.warning(
"Unknown permission decision type %r; treating as reject", raw_type
)
mapped = "reject"
out: dict[str, Any] = {"decision_type": mapped}
feedback = payload.get("feedback") or payload.get("message")
if isinstance(feedback, str) and feedback.strip():
out["feedback"] = feedback
if raw_type == "edit":
edited = extract_edited_args(payload)
if edited:
out["edited_args"] = edited
return out
__all__ = ["normalize_permission_decision"]

View file

@ -0,0 +1,39 @@
"""Synthesise a ``ToolMessage`` for a denied tool call.
The denied call is replaced with this message so the model sees a typed
``permission_denied`` error in ``ToolMessage.additional_kwargs["error"]``
and can adjust its plan without retrying the same forbidden call.
"""
from __future__ import annotations
from typing import Any
from langchain_core.messages import ToolMessage
from app.agents.new_chat.errors import StreamingError
from app.agents.new_chat.permissions import Rule
def build_deny_message(tool_call: dict[str, Any], rule: Rule) -> ToolMessage:
err = StreamingError(
code="permission_denied",
retryable=False,
suggestion=(
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
f"blocked this call"
),
)
return ToolMessage(
content=(
f"Permission denied: rule {rule.permission}/{rule.pattern} "
f"blocked tool {tool_call.get('name')!r}."
),
tool_call_id=tool_call.get("id") or "",
name=tool_call.get("name"),
status="error",
additional_kwargs={"error": err.model_dump()},
)
__all__ = ["build_deny_message"]

View file

@ -0,0 +1,9 @@
"""Build and raise the ``permission_ask`` interrupt (payload + request)."""
from .payload import build_permission_ask_payload
from .request import request_permission_decision
__all__ = [
"build_permission_ask_payload",
"request_permission_decision",
]

View file

@ -0,0 +1,6 @@
"""Apply ``edit`` permission decisions to tool calls (extract + merge)."""
from .extract import extract_edited_args
from .merge import merge_edited_args
__all__ = ["extract_edited_args", "merge_edited_args"]

View file

@ -0,0 +1,34 @@
"""Extract edited args from a permission decision payload.
Two shapes are accepted (mirrors :func:`app.agents.new_chat.tools.hitl._parse_decision`):
- LangChain HITL envelope: ``{"edited_action": {"args": {...}}}``.
- Legacy flat shape: ``{"args": {...}}``.
Returns ``None`` when no edited args are present. The orchestrator decides
whether to merge them (see :mod:`interrupt.edit.merge`); this module is pure parsing.
"""
from __future__ import annotations
from typing import Any
def extract_edited_args(decision_payload: dict[str, Any] | None) -> dict[str, Any] | None:
if not isinstance(decision_payload, dict):
return None
edited_action = decision_payload.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
return edited_args
flat_args = decision_payload.get("args")
if isinstance(flat_args, dict):
return flat_args
return None
__all__ = ["extract_edited_args"]

View file

@ -0,0 +1,25 @@
"""Apply edited args to a tool call.
Semantics match :func:`app.agents.new_chat.tools.hitl.request_approval`'s
``final_params = {**params, **edited_params}`` shallow merge, edited
values override originals. Keys absent from ``edited_args`` keep their
original values, so partial edits are safe.
Returns a NEW ``tool_call`` dict (the input is not mutated) so the caller
can swap it into the ``AIMessage.tool_calls`` list without aliasing.
"""
from __future__ import annotations
from typing import Any
def merge_edited_args(
tool_call: dict[str, Any], edited_args: dict[str, Any]
) -> dict[str, Any]:
original_args = tool_call.get("args") or {}
merged_args = {**original_args, **edited_args}
return {**tool_call, "args": merged_args}
__all__ = ["merge_edited_args"]

View file

@ -0,0 +1,43 @@
"""Build the ``permission_ask`` interrupt payload (pure data).
The frontend's streaming layer keys off ``type`` and renders the approval
card from ``action`` (the tool call being reviewed) and ``context``
(the matched rules and patterns that prompted the ask). ``context.always``
lists the patterns the user can promote to a permanent allow rule with a
single ``"always"`` reply.
"""
from __future__ import annotations
from typing import Any
from app.agents.new_chat.permissions import Rule
def build_permission_ask_payload(
*,
tool_name: str,
args: dict[str, Any],
patterns: list[str],
rules: list[Rule],
) -> dict[str, Any]:
return {
"type": "permission_ask",
# ``params`` (not ``args``) is what SurfSense's streaming normalizer forwards.
"action": {"tool": tool_name, "params": args or {}},
"context": {
"patterns": patterns,
"rules": [
{
"permission": r.permission,
"pattern": r.pattern,
"action": r.action,
}
for r in rules
],
"always": patterns,
},
}
__all__ = ["build_permission_ask_payload"]

View file

@ -0,0 +1,52 @@
"""Request a permission decision from the user (side-effectful entry point).
Wraps :func:`langgraph.types.interrupt` with the OTel spans that the
SurfSense dashboard expects, then normalises the resume value through
:func:`decision.normalize_permission_decision`.
When ``emit_interrupt`` is ``False`` the call short-circuits to
``reject``; this is used by non-interactive deployments where ``ask`` must
not block.
"""
from __future__ import annotations
from typing import Any
from langgraph.types import interrupt
from app.agents.new_chat.permissions import Rule
from app.observability import otel as ot
from ..decision import normalize_permission_decision
from .payload import build_permission_ask_payload
def request_permission_decision(
*,
tool_name: str,
args: dict[str, Any],
patterns: list[str],
rules: list[Rule],
emit_interrupt: bool,
) -> dict[str, Any]:
if not emit_interrupt:
return {"decision_type": "reject"}
payload = build_permission_ask_payload(
tool_name=tool_name, args=args, patterns=patterns, rules=rules
)
with (
ot.permission_asked_span(
permission=tool_name,
pattern=patterns[0] if patterns else None,
extra={"permission.patterns": list(patterns)},
),
ot.interrupt_span(interrupt_type="permission_ask"),
):
decision = interrupt(payload)
return normalize_permission_decision(decision)
__all__ = ["request_permission_decision"]

View file

@ -0,0 +1,13 @@
"""The orchestrator class plus its evaluation and ruleset-view helpers."""
from .core import PermissionMiddleware
from .evaluation import evaluate_tool_call, resolve_patterns
from .ruleset_view import all_rulesets, globally_denied
__all__ = [
"PermissionMiddleware",
"all_rulesets",
"evaluate_tool_call",
"globally_denied",
"resolve_patterns",
]

View file

@ -0,0 +1,195 @@
"""``PermissionMiddleware`` — pattern-based allow/deny/ask with HITL fallback.
LangChain's :class:`HumanInTheLoopMiddleware` only supports a static
"this tool always asks" decision per tool. There's no rule-based
allow/deny/ask, no glob patterns, no per-space/per-thread overrides, and
no auto-deny synthesis.
This middleware layers OpenCode's wildcard-ruleset model on top of
SurfSense's ``interrupt({type, action, context})`` payload shape (see
:mod:`app.agents.new_chat.tools.hitl`) so the frontend keeps working
unchanged.
Per-tool-call flow inside :meth:`_process`:
1. Skip when the last message has no tool calls.
2. For each call, evaluate the rules. ``deny`` is replaced with a
synthetic :class:`ToolMessage` carrying a typed
:class:`StreamingError`. ``ask`` raises an interrupt via
:mod:`interrupt.request`; the resulting decision is dispatched here:
- ``once`` keep the call as-is.
- ``always`` also extend the runtime ruleset.
- ``reject`` (with feedback) :class:`CorrectedError`.
- ``reject`` (no feedback) :class:`RejectedError`.
``allow`` keeps the call unchanged.
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
plus any deny ``ToolMessage`` entries appended after it. Tool-list
filtering at ``before_model`` is intentionally not done here that
would invalidate provider prompt-cache prefixes.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
)
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from app.agents.new_chat.errors import CorrectedError, RejectedError
from app.agents.new_chat.permissions import Ruleset
from ..deny import build_deny_message
from ..interrupt.edit import merge_edited_args
from ..interrupt import request_permission_decision
from ..pattern_resolver import PatternResolver
from ..runtime_promote import persist_always
from .evaluation import evaluate_tool_call
from .ruleset_view import all_rulesets
logger = logging.getLogger(__name__)
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Allow/deny/ask layer over the agent's tool calls.
Args:
rulesets: Layered rulesets to evaluate (earliest-to-latest wins).
Typical layering: ``defaults < global < space < thread < runtime_approved``.
pattern_resolvers: Optional per-tool callables that map ``args``
to wildcard patterns. Tools without an entry use the bare
tool name as the only pattern.
runtime_ruleset: Mutable :class:`Ruleset` extended in-place when
the user replies ``"always"``. Reused across calls in the
same agent instance so newly-allowed rules apply downstream.
always_emit_interrupt_payload: Set ``False`` to make ``ask``
collapse to ``deny`` (for non-interactive deployments).
"""
tools = ()
def __init__(
self,
*,
rulesets: list[Ruleset] | None = None,
pattern_resolvers: dict[str, PatternResolver] | None = None,
runtime_ruleset: Ruleset | None = None,
always_emit_interrupt_payload: bool = True,
) -> None:
super().__init__()
self._static_rulesets: list[Ruleset] = list(rulesets or [])
self._pattern_resolvers: dict[str, PatternResolver] = dict(
pattern_resolvers or {}
)
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
origin="runtime_approved"
)
self._emit_interrupt = always_emit_interrupt_payload
def _process(
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage) or not last.tool_calls:
return None
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
deny_messages: list[ToolMessage] = []
kept_calls: list[dict[str, Any]] = []
any_change = False
for raw in last.tool_calls:
call = (
dict(raw)
if isinstance(raw, dict)
else {
"name": getattr(raw, "name", None),
"args": getattr(raw, "args", {}),
"id": getattr(raw, "id", None),
"type": "tool_call",
}
)
name = call.get("name") or ""
args = call.get("args") or {}
action, patterns, rules = evaluate_tool_call(
name, args, self._pattern_resolvers, rulesets
)
if action == "deny":
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
deny_messages.append(build_deny_message(call, deny_rule))
any_change = True
continue
if action == "ask":
decision = request_permission_decision(
tool_name=name,
args=args,
patterns=patterns,
rules=rules,
emit_interrupt=self._emit_interrupt,
)
kind = str(decision.get("decision_type") or "reject").lower()
edited_args = decision.get("edited_args")
if kind in ("once", "always"):
final_call = (
merge_edited_args(call, edited_args)
if isinstance(edited_args, dict) and edited_args
else call
)
if final_call is not call:
any_change = True
if kind == "always":
persist_always(self._runtime_ruleset, name, patterns)
kept_calls.append(final_call)
elif kind == "reject":
feedback = decision.get("feedback")
if isinstance(feedback, str) and feedback.strip():
raise CorrectedError(feedback, tool=name)
raise RejectedError(
tool=name, pattern=patterns[0] if patterns else None
)
else:
logger.warning(
"Unknown permission decision %r; treating as reject", kind
)
raise RejectedError(tool=name)
continue
kept_calls.append(call)
if not any_change and len(kept_calls) == len(last.tool_calls):
return None
updated = last.model_copy(update={"tool_calls": kept_calls})
result_messages: list[Any] = [updated]
if deny_messages:
result_messages.extend(deny_messages)
return {"messages": result_messages}
def after_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
async def aafter_model( # type: ignore[override]
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
return self._process(state, runtime)
__all__ = ["PermissionMiddleware"]

View file

@ -0,0 +1,60 @@
"""Resolve patterns for a tool call and aggregate the resulting rules.
Two stages run on every tool call:
1. :func:`resolve_patterns` asks the tool's resolver (or the default) for
the wildcard patterns the rule engine should evaluate. Resolver
failures fall back to the bare tool name so a buggy resolver can't
cascade into permission decisions.
2. :func:`evaluate_tool_call` runs the rule engine against those patterns
and collapses the per-pattern rules into a single action
(``deny`` > ``ask`` > ``allow``).
"""
from __future__ import annotations
import logging
from typing import Any
from app.agents.new_chat.permissions import (
Rule,
RuleAction,
Ruleset,
aggregate_action,
evaluate_many,
)
from ..pattern_resolver import PatternResolver, default_pattern_resolver
logger = logging.getLogger(__name__)
def resolve_patterns(
tool_name: str,
args: dict[str, Any],
pattern_resolvers: dict[str, PatternResolver],
) -> list[str]:
resolver = pattern_resolvers.get(tool_name, default_pattern_resolver(tool_name))
try:
patterns = resolver(args or {})
except Exception:
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
patterns = [tool_name]
if not patterns:
patterns = [tool_name]
return patterns
def evaluate_tool_call(
tool_name: str,
args: dict[str, Any],
pattern_resolvers: dict[str, PatternResolver],
rulesets: list[Ruleset],
) -> tuple[RuleAction, list[str], list[Rule]]:
patterns = resolve_patterns(tool_name, args, pattern_resolvers)
rules = evaluate_many(tool_name, patterns, *rulesets)
action = aggregate_action(rules)
return action, patterns, rules
__all__ = ["evaluate_tool_call", "resolve_patterns"]

View file

@ -0,0 +1,27 @@
"""Combined view over static + runtime rulesets.
Static rulesets come from the agent factory (defaults, space-scoped,
thread-scoped, etc.). The runtime ruleset is the in-memory one that
:func:`runtime_promote.persist_always` extends when the user replies
``"always"``. Evaluators always see them merged in this order so newly-
promoted rules apply to subsequent calls.
"""
from __future__ import annotations
from app.agents.new_chat.permissions import Ruleset, aggregate_action, evaluate_many
def all_rulesets(
static_rulesets: list[Ruleset], runtime_ruleset: Ruleset
) -> list[Ruleset]:
return [*static_rulesets, runtime_ruleset]
def globally_denied(tool_name: str, rulesets: list[Ruleset]) -> bool:
"""True if an unconditional deny rule blocks every invocation of ``tool_name``."""
rules = evaluate_many(tool_name, ["*"], *rulesets)
return aggregate_action(rules) == "deny"
__all__ = ["all_rulesets", "globally_denied"]

View file

@ -0,0 +1,28 @@
"""Per-tool pattern resolution.
A :data:`PatternResolver` turns a tool's ``args`` dict into a list of
wildcard patterns evaluated against the layered rulesets. The first
pattern is conventionally the bare tool name (catch-all); later entries
narrow down to specific resources (file paths, ids, etc.).
Tools without a custom resolver fall back to :func:`default_pattern_resolver`,
which yields only the bare tool name.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
PatternResolver = Callable[[dict[str, Any]], list[str]]
def default_pattern_resolver(name: str) -> PatternResolver:
def _resolve(args: dict[str, Any]) -> list[str]:
del args
return [name]
return _resolve
__all__ = ["PatternResolver", "default_pattern_resolver"]

View file

@ -0,0 +1,22 @@
"""Promote an ``"always"`` reply into in-memory allow rules.
Subsequent calls within the same agent instance match these new rules and
proceed without prompting. Durable persistence (to ``agent_permission_rules``)
is the streaming layer's job — this module keeps the in-memory copy only.
"""
from __future__ import annotations
from app.agents.new_chat.permissions import Rule, Ruleset
def persist_always(
runtime_ruleset: Ruleset, tool_name: str, patterns: list[str]
) -> None:
for pattern in patterns:
runtime_ruleset.rules.append(
Rule(permission=tool_name, pattern=pattern, action="allow")
)
__all__ = ["persist_always"]