mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/permissions: clone PermissionMiddleware with SRP split and edit support
This commit is contained in:
parent
3f77c74daf
commit
9b82f2db1d
15 changed files with 660 additions and 0 deletions
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Pattern-based allow/deny/ask middleware with HITL fallback.
|
||||||
|
|
||||||
|
Public surface: :class:`PermissionMiddleware` plus
|
||||||
|
:func:`normalize_permission_decision` for the streaming layer and the
|
||||||
|
:data:`PatternResolver` type for callers that register per-tool resolvers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .decision import normalize_permission_decision
|
||||||
|
from .middleware import PermissionMiddleware
|
||||||
|
from .pattern_resolver import PatternResolver
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PatternResolver",
|
||||||
|
"PermissionMiddleware",
|
||||||
|
"normalize_permission_decision",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Coerce inbound permission decisions to a canonical dict shape.
|
||||||
|
|
||||||
|
Two wire formats are accepted:
|
||||||
|
- SurfSense legacy: ``{"decision_type": "once"|"always"|"reject", "feedback"?}``.
|
||||||
|
- LangChain HITL envelope: ``{"decisions": [{"type": "approve"|"edit"|"reject", ...}]}``.
|
||||||
|
|
||||||
|
The middleware downstream only inspects the canonical shape returned here,
|
||||||
|
so adding a new envelope means changing this module alone.
|
||||||
|
|
||||||
|
The middleware fails closed: any unrecognised payload becomes ``reject``
|
||||||
|
(with a warning) so the agent never proceeds on ambiguous input.
|
||||||
|
|
||||||
|
When the reply is an ``edit``, the result keeps ``decision_type="once"``
|
||||||
|
(the call still goes through) and adds an ``edited_args`` key holding the
|
||||||
|
user-modified ``args`` dict. The orchestrator merges those into the
|
||||||
|
``tool_call`` before keeping it; see :mod:`interrupt.edit.merge`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .interrupt.edit import extract_edited_args
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ``edit`` collapses to ``once``; any ``edited_args`` ride on the result.
|
||||||
|
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
|
||||||
|
"approve": "once",
|
||||||
|
"reject": "reject",
|
||||||
|
"edit": "once",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_permission_decision(decision: Any) -> dict[str, Any]:
|
||||||
|
"""Return ``{"decision_type": ..., "feedback"?: str, "edited_args"?: dict}``."""
|
||||||
|
if isinstance(decision, str):
|
||||||
|
return {"decision_type": decision}
|
||||||
|
if not isinstance(decision, dict):
|
||||||
|
logger.warning(
|
||||||
|
"Unrecognized permission resume value (%s); treating as reject",
|
||||||
|
type(decision).__name__,
|
||||||
|
)
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
if decision.get("decision_type"):
|
||||||
|
return decision
|
||||||
|
|
||||||
|
payload: dict[str, Any] = decision
|
||||||
|
decisions = decision.get("decisions")
|
||||||
|
if isinstance(decisions, list) and decisions:
|
||||||
|
first = decisions[0]
|
||||||
|
if isinstance(first, dict):
|
||||||
|
payload = first
|
||||||
|
|
||||||
|
raw_type = payload.get("type") or payload.get("decision_type")
|
||||||
|
if not raw_type:
|
||||||
|
logger.warning(
|
||||||
|
"Permission resume missing decision type (keys=%s); treating as reject",
|
||||||
|
list(payload.keys()),
|
||||||
|
)
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
raw_type = str(raw_type).lower()
|
||||||
|
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
|
||||||
|
if mapped is None:
|
||||||
|
# Tolerate legacy values arriving without ``decision_type`` wrapping.
|
||||||
|
if raw_type in {"once", "always", "reject"}:
|
||||||
|
mapped = raw_type
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown permission decision type %r; treating as reject", raw_type
|
||||||
|
)
|
||||||
|
mapped = "reject"
|
||||||
|
|
||||||
|
out: dict[str, Any] = {"decision_type": mapped}
|
||||||
|
feedback = payload.get("feedback") or payload.get("message")
|
||||||
|
if isinstance(feedback, str) and feedback.strip():
|
||||||
|
out["feedback"] = feedback
|
||||||
|
|
||||||
|
if raw_type == "edit":
|
||||||
|
edited = extract_edited_args(payload)
|
||||||
|
if edited:
|
||||||
|
out["edited_args"] = edited
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["normalize_permission_decision"]
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Synthesise a ``ToolMessage`` for a denied tool call.
|
||||||
|
|
||||||
|
The denied call is replaced with this message so the model sees a typed
|
||||||
|
``permission_denied`` error in ``ToolMessage.additional_kwargs["error"]``
|
||||||
|
and can adjust its plan without retrying the same forbidden call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import StreamingError
|
||||||
|
from app.agents.new_chat.permissions import Rule
|
||||||
|
|
||||||
|
|
||||||
|
def build_deny_message(tool_call: dict[str, Any], rule: Rule) -> ToolMessage:
|
||||||
|
err = StreamingError(
|
||||||
|
code="permission_denied",
|
||||||
|
retryable=False,
|
||||||
|
suggestion=(
|
||||||
|
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
|
||||||
|
f"blocked this call"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return ToolMessage(
|
||||||
|
content=(
|
||||||
|
f"Permission denied: rule {rule.permission}/{rule.pattern} "
|
||||||
|
f"blocked tool {tool_call.get('name')!r}."
|
||||||
|
),
|
||||||
|
tool_call_id=tool_call.get("id") or "",
|
||||||
|
name=tool_call.get("name"),
|
||||||
|
status="error",
|
||||||
|
additional_kwargs={"error": err.model_dump()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_deny_message"]
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Build and raise the ``permission_ask`` interrupt (payload + request)."""
|
||||||
|
|
||||||
|
from .payload import build_permission_ask_payload
|
||||||
|
from .request import request_permission_decision
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_permission_ask_payload",
|
||||||
|
"request_permission_decision",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
"""Apply ``edit`` permission decisions to tool calls (extract + merge)."""
|
||||||
|
|
||||||
|
from .extract import extract_edited_args
|
||||||
|
from .merge import merge_edited_args
|
||||||
|
|
||||||
|
__all__ = ["extract_edited_args", "merge_edited_args"]
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
"""Extract edited args from a permission decision payload.
|
||||||
|
|
||||||
|
Two shapes are accepted (mirrors :func:`app.agents.new_chat.tools.hitl._parse_decision`):
|
||||||
|
|
||||||
|
- LangChain HITL envelope: ``{"edited_action": {"args": {...}}}``.
|
||||||
|
- Legacy flat shape: ``{"args": {...}}``.
|
||||||
|
|
||||||
|
Returns ``None`` when no edited args are present. The orchestrator decides
|
||||||
|
whether to merge them (see :mod:`interrupt.edit.merge`); this module is pure parsing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def extract_edited_args(decision_payload: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
if not isinstance(decision_payload, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
edited_action = decision_payload.get("edited_action")
|
||||||
|
if isinstance(edited_action, dict):
|
||||||
|
edited_args = edited_action.get("args")
|
||||||
|
if isinstance(edited_args, dict):
|
||||||
|
return edited_args
|
||||||
|
|
||||||
|
flat_args = decision_payload.get("args")
|
||||||
|
if isinstance(flat_args, dict):
|
||||||
|
return flat_args
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["extract_edited_args"]
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""Apply edited args to a tool call.
|
||||||
|
|
||||||
|
Semantics match :func:`app.agents.new_chat.tools.hitl.request_approval`'s
|
||||||
|
``final_params = {**params, **edited_params}`` — shallow merge, edited
|
||||||
|
values override originals. Keys absent from ``edited_args`` keep their
|
||||||
|
original values, so partial edits are safe.
|
||||||
|
|
||||||
|
Returns a NEW ``tool_call`` dict (the input is not mutated) so the caller
|
||||||
|
can swap it into the ``AIMessage.tool_calls`` list without aliasing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def merge_edited_args(
|
||||||
|
tool_call: dict[str, Any], edited_args: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
original_args = tool_call.get("args") or {}
|
||||||
|
merged_args = {**original_args, **edited_args}
|
||||||
|
return {**tool_call, "args": merged_args}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["merge_edited_args"]
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
"""Build the ``permission_ask`` interrupt payload (pure data).
|
||||||
|
|
||||||
|
The frontend's streaming layer keys off ``type`` and renders the approval
|
||||||
|
card from ``action`` (the tool call being reviewed) and ``context``
|
||||||
|
(the matched rules and patterns that prompted the ask). ``context.always``
|
||||||
|
lists the patterns the user can promote to a permanent allow rule with a
|
||||||
|
single ``"always"`` reply.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.new_chat.permissions import Rule
|
||||||
|
|
||||||
|
|
||||||
|
def build_permission_ask_payload(
|
||||||
|
*,
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
patterns: list[str],
|
||||||
|
rules: list[Rule],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "permission_ask",
|
||||||
|
# ``params`` (not ``args``) is what SurfSense's streaming normalizer forwards.
|
||||||
|
"action": {"tool": tool_name, "params": args or {}},
|
||||||
|
"context": {
|
||||||
|
"patterns": patterns,
|
||||||
|
"rules": [
|
||||||
|
{
|
||||||
|
"permission": r.permission,
|
||||||
|
"pattern": r.pattern,
|
||||||
|
"action": r.action,
|
||||||
|
}
|
||||||
|
for r in rules
|
||||||
|
],
|
||||||
|
"always": patterns,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_permission_ask_payload"]
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""Request a permission decision from the user (side-effectful entry point).
|
||||||
|
|
||||||
|
Wraps :func:`langgraph.types.interrupt` with the OTel spans that the
|
||||||
|
SurfSense dashboard expects, then normalises the resume value through
|
||||||
|
:func:`decision.normalize_permission_decision`.
|
||||||
|
|
||||||
|
When ``emit_interrupt`` is ``False`` the call short-circuits to
|
||||||
|
``reject``; this is used by non-interactive deployments where ``ask`` must
|
||||||
|
not block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from app.agents.new_chat.permissions import Rule
|
||||||
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
from ..decision import normalize_permission_decision
|
||||||
|
from .payload import build_permission_ask_payload
|
||||||
|
|
||||||
|
|
||||||
|
def request_permission_decision(
|
||||||
|
*,
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
patterns: list[str],
|
||||||
|
rules: list[Rule],
|
||||||
|
emit_interrupt: bool,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if not emit_interrupt:
|
||||||
|
return {"decision_type": "reject"}
|
||||||
|
|
||||||
|
payload = build_permission_ask_payload(
|
||||||
|
tool_name=tool_name, args=args, patterns=patterns, rules=rules
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
ot.permission_asked_span(
|
||||||
|
permission=tool_name,
|
||||||
|
pattern=patterns[0] if patterns else None,
|
||||||
|
extra={"permission.patterns": list(patterns)},
|
||||||
|
),
|
||||||
|
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||||
|
):
|
||||||
|
decision = interrupt(payload)
|
||||||
|
return normalize_permission_decision(decision)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["request_permission_decision"]
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""The orchestrator class plus its evaluation and ruleset-view helpers."""
|
||||||
|
|
||||||
|
from .core import PermissionMiddleware
|
||||||
|
from .evaluation import evaluate_tool_call, resolve_patterns
|
||||||
|
from .ruleset_view import all_rulesets, globally_denied
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PermissionMiddleware",
|
||||||
|
"all_rulesets",
|
||||||
|
"evaluate_tool_call",
|
||||||
|
"globally_denied",
|
||||||
|
"resolve_patterns",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,195 @@
|
||||||
|
"""``PermissionMiddleware`` — pattern-based allow/deny/ask with HITL fallback.
|
||||||
|
|
||||||
|
LangChain's :class:`HumanInTheLoopMiddleware` only supports a static
|
||||||
|
"this tool always asks" decision per tool. There's no rule-based
|
||||||
|
allow/deny/ask, no glob patterns, no per-space/per-thread overrides, and
|
||||||
|
no auto-deny synthesis.
|
||||||
|
|
||||||
|
This middleware layers OpenCode's wildcard-ruleset model on top of
|
||||||
|
SurfSense's ``interrupt({type, action, context})`` payload shape (see
|
||||||
|
:mod:`app.agents.new_chat.tools.hitl`) so the frontend keeps working
|
||||||
|
unchanged.
|
||||||
|
|
||||||
|
Per-tool-call flow inside :meth:`_process`:
|
||||||
|
|
||||||
|
1. Skip when the last message has no tool calls.
|
||||||
|
2. For each call, evaluate the rules. ``deny`` is replaced with a
|
||||||
|
synthetic :class:`ToolMessage` carrying a typed
|
||||||
|
:class:`StreamingError`. ``ask`` raises an interrupt via
|
||||||
|
:mod:`interrupt.request`; the resulting decision is dispatched here:
|
||||||
|
|
||||||
|
- ``once`` → keep the call as-is.
|
||||||
|
- ``always`` → also extend the runtime ruleset.
|
||||||
|
- ``reject`` (with feedback) → :class:`CorrectedError`.
|
||||||
|
- ``reject`` (no feedback) → :class:`RejectedError`.
|
||||||
|
|
||||||
|
``allow`` keeps the call unchanged.
|
||||||
|
|
||||||
|
3. Returns an updated ``AIMessage`` (tool calls minus the denied ones)
|
||||||
|
plus any deny ``ToolMessage`` entries appended after it. Tool-list
|
||||||
|
filtering at ``before_model`` is intentionally not done here — that
|
||||||
|
would invalidate provider prompt-cache prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ContextT,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import CorrectedError, RejectedError
|
||||||
|
from app.agents.new_chat.permissions import Ruleset
|
||||||
|
|
||||||
|
from ..deny import build_deny_message
|
||||||
|
from ..interrupt.edit import merge_edited_args
|
||||||
|
from ..interrupt import request_permission_decision
|
||||||
|
from ..pattern_resolver import PatternResolver
|
||||||
|
from ..runtime_promote import persist_always
|
||||||
|
from .evaluation import evaluate_tool_call
|
||||||
|
from .ruleset_view import all_rulesets
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Allow/deny/ask layer over the agent's tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rulesets: Layered rulesets to evaluate (earliest-to-latest wins).
|
||||||
|
Typical layering: ``defaults < global < space < thread < runtime_approved``.
|
||||||
|
pattern_resolvers: Optional per-tool callables that map ``args``
|
||||||
|
to wildcard patterns. Tools without an entry use the bare
|
||||||
|
tool name as the only pattern.
|
||||||
|
runtime_ruleset: Mutable :class:`Ruleset` extended in-place when
|
||||||
|
the user replies ``"always"``. Reused across calls in the
|
||||||
|
same agent instance so newly-allowed rules apply downstream.
|
||||||
|
always_emit_interrupt_payload: Set ``False`` to make ``ask``
|
||||||
|
collapse to ``deny`` (for non-interactive deployments).
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
rulesets: list[Ruleset] | None = None,
|
||||||
|
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||||
|
runtime_ruleset: Ruleset | None = None,
|
||||||
|
always_emit_interrupt_payload: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||||
|
self._pattern_resolvers: dict[str, PatternResolver] = dict(
|
||||||
|
pattern_resolvers or {}
|
||||||
|
)
|
||||||
|
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
|
||||||
|
origin="runtime_approved"
|
||||||
|
)
|
||||||
|
self._emit_interrupt = always_emit_interrupt_payload
|
||||||
|
|
||||||
|
def _process(
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
last = messages[-1]
|
||||||
|
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset)
|
||||||
|
deny_messages: list[ToolMessage] = []
|
||||||
|
kept_calls: list[dict[str, Any]] = []
|
||||||
|
any_change = False
|
||||||
|
|
||||||
|
for raw in last.tool_calls:
|
||||||
|
call = (
|
||||||
|
dict(raw)
|
||||||
|
if isinstance(raw, dict)
|
||||||
|
else {
|
||||||
|
"name": getattr(raw, "name", None),
|
||||||
|
"args": getattr(raw, "args", {}),
|
||||||
|
"id": getattr(raw, "id", None),
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
name = call.get("name") or ""
|
||||||
|
args = call.get("args") or {}
|
||||||
|
action, patterns, rules = evaluate_tool_call(
|
||||||
|
name, args, self._pattern_resolvers, rulesets
|
||||||
|
)
|
||||||
|
|
||||||
|
if action == "deny":
|
||||||
|
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
|
||||||
|
deny_messages.append(build_deny_message(call, deny_rule))
|
||||||
|
any_change = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action == "ask":
|
||||||
|
decision = request_permission_decision(
|
||||||
|
tool_name=name,
|
||||||
|
args=args,
|
||||||
|
patterns=patterns,
|
||||||
|
rules=rules,
|
||||||
|
emit_interrupt=self._emit_interrupt,
|
||||||
|
)
|
||||||
|
kind = str(decision.get("decision_type") or "reject").lower()
|
||||||
|
edited_args = decision.get("edited_args")
|
||||||
|
if kind in ("once", "always"):
|
||||||
|
final_call = (
|
||||||
|
merge_edited_args(call, edited_args)
|
||||||
|
if isinstance(edited_args, dict) and edited_args
|
||||||
|
else call
|
||||||
|
)
|
||||||
|
if final_call is not call:
|
||||||
|
any_change = True
|
||||||
|
if kind == "always":
|
||||||
|
persist_always(self._runtime_ruleset, name, patterns)
|
||||||
|
kept_calls.append(final_call)
|
||||||
|
elif kind == "reject":
|
||||||
|
feedback = decision.get("feedback")
|
||||||
|
if isinstance(feedback, str) and feedback.strip():
|
||||||
|
raise CorrectedError(feedback, tool=name)
|
||||||
|
raise RejectedError(
|
||||||
|
tool=name, pattern=patterns[0] if patterns else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown permission decision %r; treating as reject", kind
|
||||||
|
)
|
||||||
|
raise RejectedError(tool=name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
kept_calls.append(call)
|
||||||
|
|
||||||
|
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||||
|
return None
|
||||||
|
|
||||||
|
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||||
|
result_messages: list[Any] = [updated]
|
||||||
|
if deny_messages:
|
||||||
|
result_messages.extend(deny_messages)
|
||||||
|
return {"messages": result_messages}
|
||||||
|
|
||||||
|
def after_model( # type: ignore[override]
|
||||||
|
self, state: AgentState, runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self._process(state, runtime)
|
||||||
|
|
||||||
|
async def aafter_model( # type: ignore[override]
|
||||||
|
self, state: AgentState, runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return self._process(state, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PermissionMiddleware"]
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Resolve patterns for a tool call and aggregate the resulting rules.
|
||||||
|
|
||||||
|
Two stages run on every tool call:
|
||||||
|
|
||||||
|
1. :func:`resolve_patterns` asks the tool's resolver (or the default) for
|
||||||
|
the wildcard patterns the rule engine should evaluate. Resolver
|
||||||
|
failures fall back to the bare tool name so a buggy resolver can't
|
||||||
|
cascade into permission decisions.
|
||||||
|
2. :func:`evaluate_tool_call` runs the rule engine against those patterns
|
||||||
|
and collapses the per-pattern rules into a single action
|
||||||
|
(``deny`` > ``ask`` > ``allow``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.new_chat.permissions import (
|
||||||
|
Rule,
|
||||||
|
RuleAction,
|
||||||
|
Ruleset,
|
||||||
|
aggregate_action,
|
||||||
|
evaluate_many,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..pattern_resolver import PatternResolver, default_pattern_resolver
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_patterns(
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
pattern_resolvers: dict[str, PatternResolver],
|
||||||
|
) -> list[str]:
|
||||||
|
resolver = pattern_resolvers.get(tool_name, default_pattern_resolver(tool_name))
|
||||||
|
try:
|
||||||
|
patterns = resolver(args or {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
|
||||||
|
patterns = [tool_name]
|
||||||
|
if not patterns:
|
||||||
|
patterns = [tool_name]
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_tool_call(
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
pattern_resolvers: dict[str, PatternResolver],
|
||||||
|
rulesets: list[Ruleset],
|
||||||
|
) -> tuple[RuleAction, list[str], list[Rule]]:
|
||||||
|
patterns = resolve_patterns(tool_name, args, pattern_resolvers)
|
||||||
|
rules = evaluate_many(tool_name, patterns, *rulesets)
|
||||||
|
action = aggregate_action(rules)
|
||||||
|
return action, patterns, rules
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["evaluate_tool_call", "resolve_patterns"]
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Combined view over static + runtime rulesets.
|
||||||
|
|
||||||
|
Static rulesets come from the agent factory (defaults, space-scoped,
|
||||||
|
thread-scoped, etc.). The runtime ruleset is the in-memory one that
|
||||||
|
:func:`runtime_promote.persist_always` extends when the user replies
|
||||||
|
``"always"``. Evaluators always see them merged in this order so newly-
|
||||||
|
promoted rules apply to subsequent calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.permissions import Ruleset, aggregate_action, evaluate_many
|
||||||
|
|
||||||
|
|
||||||
|
def all_rulesets(
|
||||||
|
static_rulesets: list[Ruleset], runtime_ruleset: Ruleset
|
||||||
|
) -> list[Ruleset]:
|
||||||
|
return [*static_rulesets, runtime_ruleset]
|
||||||
|
|
||||||
|
|
||||||
|
def globally_denied(tool_name: str, rulesets: list[Ruleset]) -> bool:
|
||||||
|
"""True if an unconditional deny rule blocks every invocation of ``tool_name``."""
|
||||||
|
rules = evaluate_many(tool_name, ["*"], *rulesets)
|
||||||
|
return aggregate_action(rules) == "deny"
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["all_rulesets", "globally_denied"]
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Per-tool pattern resolution.
|
||||||
|
|
||||||
|
A :data:`PatternResolver` turns a tool's ``args`` dict into a list of
|
||||||
|
wildcard patterns evaluated against the layered rulesets. The first
|
||||||
|
pattern is conventionally the bare tool name (catch-all); later entries
|
||||||
|
narrow down to specific resources (file paths, ids, etc.).
|
||||||
|
|
||||||
|
Tools without a custom resolver fall back to :func:`default_pattern_resolver`,
|
||||||
|
which yields only the bare tool name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
PatternResolver = Callable[[dict[str, Any]], list[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def default_pattern_resolver(name: str) -> PatternResolver:
|
||||||
|
def _resolve(args: dict[str, Any]) -> list[str]:
|
||||||
|
del args
|
||||||
|
return [name]
|
||||||
|
|
||||||
|
return _resolve
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PatternResolver", "default_pattern_resolver"]
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""Promote an ``"always"`` reply into in-memory allow rules.
|
||||||
|
|
||||||
|
Subsequent calls within the same agent instance match these new rules and
|
||||||
|
proceed without prompting. Durable persistence (to ``agent_permission_rules``)
|
||||||
|
is the streaming layer's job — this module keeps the in-memory copy only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
|
|
||||||
|
|
||||||
|
def persist_always(
|
||||||
|
runtime_ruleset: Ruleset, tool_name: str, patterns: list[str]
|
||||||
|
) -> None:
|
||||||
|
for pattern in patterns:
|
||||||
|
runtime_ruleset.rules.append(
|
||||||
|
Rule(permission=tool_name, pattern=pattern, action="allow")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["persist_always"]
|
||||||
Loading…
Add table
Add a link
Reference in a new issue