mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 14:52:39 +02:00
Harden HITL for multi-step tasks: bypass internal MCP gate, full-args dedup, and decision-envelope normalization.
This commit is contained in:
parent
4ac3f0b304
commit
277bd50f37
6 changed files with 442 additions and 65 deletions
|
|
@ -21,6 +21,7 @@ A tool with no resolver from either path simply opts out of dedup.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
|
@ -57,6 +58,19 @@ def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
|||
return _resolver
|
||||
|
||||
|
||||
def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||
"""Resolver that collapses calls only when **every** argument is identical.
|
||||
|
||||
Safe default for tools where no single field uniquely identifies a call
|
||||
(e.g. MCP tools whose first required field is a shared workspace id).
|
||||
"""
|
||||
|
||||
try:
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
|
||||
|
||||
|
||||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ Operation:
|
|||
the results: ``deny`` > ``ask`` > ``allow``.
|
||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||
containing a :class:`StreamingError`.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
|
||||
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. Both the legacy
|
||||
SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}``
|
||||
replies are accepted via :func:`_normalize_permission_decision`.
|
||||
- ``once``: proceed.
|
||||
- ``always``: also persist allow rules for ``request.always`` patterns.
|
||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||
|
|
@ -81,6 +82,75 @@ def _default_pattern_resolver(name: str) -> PatternResolver:
|
|||
return _resolve
|
||||
|
||||
|
||||
# Translation from the LangChain HITL envelope (what ``stream_resume_chat``
|
||||
# sends) to SurfSense's legacy ``decision_type`` shape. ``edit`` keeps the
|
||||
# original tool args — tools needing argument edits should use
|
||||
# ``request_approval`` from ``app/agents/new_chat/tools/hitl.py``.
|
||||
_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = {
|
||||
"approve": "once",
|
||||
"reject": "reject",
|
||||
"edit": "once",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_permission_decision(decision: Any) -> dict[str, Any]:
|
||||
"""Coerce any accepted reply shape into ``{"decision_type": ..., "feedback"?}``.
|
||||
|
||||
Falls back to ``reject`` (with a warning) on unrecognized payloads so the
|
||||
middleware fails closed.
|
||||
"""
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
if not isinstance(decision, dict):
|
||||
logger.warning(
|
||||
"Unrecognized permission resume value (%s); treating as reject",
|
||||
type(decision).__name__,
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
if decision.get("decision_type"):
|
||||
return decision
|
||||
|
||||
payload: dict[str, Any] = decision
|
||||
decisions = decision.get("decisions")
|
||||
if isinstance(decisions, list) and decisions:
|
||||
first = decisions[0]
|
||||
if isinstance(first, dict):
|
||||
payload = first
|
||||
|
||||
raw_type = payload.get("type") or payload.get("decision_type")
|
||||
if not raw_type:
|
||||
logger.warning(
|
||||
"Permission resume missing decision type (keys=%s); treating as reject",
|
||||
list(payload.keys()),
|
||||
)
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
raw_type = str(raw_type).lower()
|
||||
mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type)
|
||||
if mapped is None:
|
||||
# Tolerate legacy values arriving without ``decision_type`` wrapping.
|
||||
if raw_type in {"once", "always", "reject"}:
|
||||
mapped = raw_type
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision type %r; treating as reject", raw_type
|
||||
)
|
||||
mapped = "reject"
|
||||
|
||||
if raw_type == "edit":
|
||||
logger.warning(
|
||||
"Permission middleware received an 'edit' decision; original args "
|
||||
"kept (edits not merged here)."
|
||||
)
|
||||
|
||||
out: dict[str, Any] = {"decision_type": mapped}
|
||||
feedback = payload.get("feedback") or payload.get("message")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
out["feedback"] = feedback
|
||||
return out
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
|
|
@ -214,12 +284,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
ot.interrupt_span(interrupt_type="permission_ask"),
|
||||
):
|
||||
decision = interrupt(payload)
|
||||
if isinstance(decision, dict):
|
||||
return decision
|
||||
# Tolerate a plain string reply ("once", "always", "reject")
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
return {"decision_type": "reject"}
|
||||
return _normalize_permission_decision(decision)
|
||||
|
||||
def _persist_always(self, tool_name: str, patterns: list[str]) -> None:
|
||||
"""Promote ``always`` reply into runtime allow rules.
|
||||
|
|
@ -355,4 +420,5 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
__all__ = [
|
||||
"PatternResolver",
|
||||
"PermissionMiddleware",
|
||||
"_normalize_permission_decision",
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue