mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 13:52:40 +02:00
feat: updated agent harness
This commit is contained in:
parent
9ec9b64348
commit
31a372bb84
139 changed files with 12583 additions and 1111 deletions
344
surfsense_backend/app/agents/new_chat/middleware/permission.py
Normal file
344
surfsense_backend/app/agents/new_chat/middleware/permission.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""
|
||||
PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback.
|
||||
|
||||
Mirrors ``opencode/packages/opencode/src/permission/index.ts`` but uses
|
||||
SurfSense's existing ``interrupt({type, action, context})`` payload shape
|
||||
(see ``app/agents/new_chat/tools/hitl.py``) so the frontend keeps
|
||||
working unchanged. Tier 2.1 in the OpenCode-port plan.
|
||||
|
||||
Operation:
|
||||
1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``.
|
||||
2. For each call, the middleware builds a list of ``patterns`` (the
|
||||
tool name plus any tool-specific patterns from the resolver). It
|
||||
evaluates each pattern against the layered rulesets and aggregates
|
||||
the results: ``deny`` > ``ask`` > ``allow``.
|
||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||
containing a :class:`StreamingError`.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
|
||||
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
|
||||
- ``once``: proceed.
|
||||
- ``always``: also persist allow rules for ``request.always`` patterns.
|
||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||
- ``reject`` w/ feedback: raise :class:`CorrectedError`.
|
||||
5. On ``allow``: proceed unchanged.
|
||||
|
||||
The middleware also performs a *pre-model* tool-filter step (the
|
||||
``before_model`` hook) so globally denied tools are stripped from the
|
||||
exposed tool list before the model gets to see them. This is
|
||||
opencode's ``Permission.disabled`` equivalent and dramatically reduces
|
||||
the chance the model emits a deny-only call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.agents.new_chat.errors import (
|
||||
CorrectedError,
|
||||
RejectedError,
|
||||
StreamingError,
|
||||
)
|
||||
from app.agents.new_chat.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
from app.observability import otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of
|
||||
# patterns to evaluate. The first pattern is conventionally the bare
|
||||
# tool name; later entries narrow down to specific resources.
|
||||
PatternResolver = Callable[[dict[str, Any]], list[str]]
|
||||
|
||||
|
||||
def _default_pattern_resolver(name: str) -> PatternResolver:
|
||||
def _resolve(args: dict[str, Any]) -> list[str]:
|
||||
# Bare name covers the default catch-all; primary-arg fallbacks
|
||||
# are best added per-tool by callers.
|
||||
del args
|
||||
return [name]
|
||||
|
||||
return _resolve
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
Args:
|
||||
rulesets: Layered rulesets to evaluate. Earlier entries are
|
||||
overridden by later ones (last-match-wins). Typical layering:
|
||||
``defaults < global < space < thread < runtime_approved``.
|
||||
pattern_resolvers: Optional per-tool callables that return a list
|
||||
of patterns to evaluate. When a tool isn't listed, the bare
|
||||
tool name is used as the only pattern.
|
||||
runtime_ruleset: Mutable :class:`Ruleset` that the middleware
|
||||
extends in-place when the user replies ``"always"`` to an
|
||||
ask interrupt. Reused across all calls in the same agent
|
||||
instance so newly-allowed rules apply to subsequent calls.
|
||||
always_emit_interrupt_payload: If True, every ask uses the
|
||||
SurfSense interrupt wire format (default). Set False to
|
||||
disable interrupts and treat ``ask`` as ``deny`` for
|
||||
non-interactive deployments.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rulesets: list[Ruleset] | None = None,
|
||||
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
self._pattern_resolvers: dict[str, PatternResolver] = dict(
|
||||
pattern_resolvers or {}
|
||||
)
|
||||
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
|
||||
origin="runtime_approved"
|
||||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-filter step (opencode `Permission.disabled` equivalent)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _globally_denied(self, tool_name: str) -> bool:
|
||||
"""Return True if a deny rule with no narrowing pattern matches."""
|
||||
rules = evaluate_many(tool_name, ["*"], *self._all_rulesets())
|
||||
return aggregate_action(rules) == "deny"
|
||||
|
||||
def _all_rulesets(self) -> list[Ruleset]:
|
||||
return [*self._static_rulesets, self._runtime_ruleset]
|
||||
|
||||
# NOTE: ``before_model`` filtering of the tools list is left to the
|
||||
# agent factory. This middleware only blocks at execution time — and
|
||||
# only via the rule-evaluator path, not by mutating ``request.tools``.
|
||||
# Mutating ``request.tools`` per-call would invalidate provider
|
||||
# prompt-cache prefixes (see Operational risks: prompt-cache regression).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call evaluation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]:
|
||||
resolver = self._pattern_resolvers.get(
|
||||
tool_name, _default_pattern_resolver(tool_name)
|
||||
)
|
||||
try:
|
||||
patterns = resolver(args or {})
|
||||
except Exception:
|
||||
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
|
||||
patterns = [tool_name]
|
||||
if not patterns:
|
||||
patterns = [tool_name]
|
||||
return patterns
|
||||
|
||||
def _evaluate(
|
||||
self, tool_name: str, args: dict[str, Any]
|
||||
) -> tuple[str, list[str], list[Rule]]:
|
||||
patterns = self._resolve_patterns(tool_name, args)
|
||||
rules = evaluate_many(tool_name, patterns, *self._all_rulesets())
|
||||
action = aggregate_action(rules)
|
||||
return action, patterns, rules
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HITL ask flow — SurfSense wire format
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _raise_interrupt(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
patterns: list[str],
|
||||
rules: list[Rule],
|
||||
) -> dict[str, Any]:
|
||||
"""Block on user approval via SurfSense's ``interrupt`` shape."""
|
||||
if not self._emit_interrupt:
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
# ``params`` (NOT ``args``) is what SurfSense's streaming
|
||||
# normalizer forwards. Other fields move into ``context``.
|
||||
payload = {
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": tool_name, "params": args or {}},
|
||||
"context": {
|
||||
"patterns": patterns,
|
||||
"rules": [
|
||||
{
|
||||
"permission": r.permission,
|
||||
"pattern": r.pattern,
|
||||
"action": r.action,
|
||||
}
|
||||
for r in rules
|
||||
],
|
||||
# Rules of thumb for the frontend: surface the patterns
|
||||
# the user can promote to "always" with a single reply.
|
||||
"always": patterns,
|
||||
},
|
||||
}
|
||||
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
|
||||
# OTel is disabled). Both fire here so dashboards can correlate
|
||||
# "we asked X" with "interrupt was actually delivered".
|
||||
with ot.permission_asked_span(
|
||||
permission=tool_name,
|
||||
pattern=patterns[0] if patterns else None,
|
||||
extra={"permission.patterns": list(patterns)},
|
||||
), ot.interrupt_span(interrupt_type="permission_ask"):
|
||||
decision = interrupt(payload)
|
||||
if isinstance(decision, dict):
|
||||
return decision
|
||||
# Tolerate a plain string reply ("once", "always", "reject")
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
def _persist_always(
|
||||
self, tool_name: str, patterns: list[str]
|
||||
) -> None:
|
||||
"""Promote ``always`` reply into runtime allow rules.
|
||||
|
||||
Persistence to ``agent_permission_rules`` is done by the
|
||||
streaming layer (``stream_new_chat``) once it observes the
|
||||
``always`` reply — the middleware just keeps an in-memory
|
||||
copy so subsequent calls in the same stream see the rule.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
self._runtime_ruleset.rules.append(
|
||||
Rule(permission=tool_name, pattern=pattern, action="allow")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synthesizing deny -> ToolMessage
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _deny_message(
|
||||
tool_call: dict[str, Any],
|
||||
rule: Rule,
|
||||
) -> ToolMessage:
|
||||
err = StreamingError(
|
||||
code="permission_denied",
|
||||
retryable=False,
|
||||
suggestion=(
|
||||
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
|
||||
f"blocked this call"
|
||||
),
|
||||
)
|
||||
return ToolMessage(
|
||||
content=(
|
||||
f"Permission denied: rule {rule.permission}/{rule.pattern} "
|
||||
f"blocked tool {tool_call.get('name')!r}."
|
||||
),
|
||||
tool_call_id=tool_call.get("id") or "",
|
||||
name=tool_call.get("name"),
|
||||
status="error",
|
||||
additional_kwargs={"error": err.model_dump()},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The hook: aafter_model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _process(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime # unused
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||
return None
|
||||
|
||||
deny_messages: list[ToolMessage] = []
|
||||
kept_calls: list[dict[str, Any]] = []
|
||||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
call = dict(raw) if isinstance(raw, dict) else {
|
||||
"name": getattr(raw, "name", None),
|
||||
"args": getattr(raw, "args", {}),
|
||||
"id": getattr(raw, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
name = call.get("name") or ""
|
||||
args = call.get("args") or {}
|
||||
action, patterns, rules = self._evaluate(name, args)
|
||||
|
||||
if action == "deny":
|
||||
# Find the deny rule for the suggestion text
|
||||
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
|
||||
deny_messages.append(self._deny_message(call, deny_rule))
|
||||
any_change = True
|
||||
continue
|
||||
|
||||
if action == "ask":
|
||||
decision = self._raise_interrupt(
|
||||
tool_name=name, args=args, patterns=patterns, rules=rules
|
||||
)
|
||||
kind = str(decision.get("decision_type") or "reject").lower()
|
||||
if kind == "once":
|
||||
kept_calls.append(call)
|
||||
elif kind == "always":
|
||||
self._persist_always(name, patterns)
|
||||
kept_calls.append(call)
|
||||
elif kind == "reject":
|
||||
feedback = decision.get("feedback")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
raise CorrectedError(feedback, tool=name)
|
||||
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision %r; treating as reject", kind
|
||||
)
|
||||
raise RejectedError(tool=name)
|
||||
continue
|
||||
|
||||
# allow
|
||||
kept_calls.append(call)
|
||||
|
||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||
return None
|
||||
|
||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||
result_messages: list[Any] = [updated]
|
||||
if deny_messages:
|
||||
result_messages.extend(deny_messages)
|
||||
return {"messages": result_messages}
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PatternResolver",
|
||||
"PermissionMiddleware",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue