diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py index 86c2ac9e8..b86da932a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/compile_graph_sync.py @@ -14,9 +14,6 @@ from langgraph.types import Checkpointer from app.agents.multi_agent_chat.middleware.stack import ( build_main_agent_deepagent_middleware, ) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.filesystem_selection import FilesystemMode @@ -42,7 +39,7 @@ def build_compiled_agent_graph_sync( flags: AgentFeatureFlags, checkpointer: Checkpointer, subagent_dependencies: dict[str, Any], - mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None, + mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None, disabled_tools: list[str] | None = None, ): """Sync compile: middleware + ``create_agent`` (run via ``asyncio.to_thread``).""" diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py index 42f984b79..1b542ebcd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer -from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions from app.agents.new_chat.agent_cache import ( flags_signature, get_cache, @@ -25,14 +24,12 @@ from app.db import ChatVisibility from ..graph.compile_graph_sync import build_compiled_agent_graph_sync -def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str: +def mcp_signature(mcp_tools_by_agent: dict[str, list[BaseTool]]) -> str: """Hash the per-agent MCP tool surface so a change rotates the cache key.""" rows = [] for agent_name in sorted(mcp_tools_by_agent.keys()): - perms = mcp_tools_by_agent[agent_name] - allow_names = sorted(item.get("name", "") for item in perms.get("allow", [])) - ask_names = sorted(item.get("name", "") for item in perms.get("ask", [])) - rows.append((agent_name, allow_names, ask_names)) + names = sorted(getattr(t, "name", "") or "" for t in mcp_tools_by_agent[agent_name]) + rows.append((agent_name, names)) return stable_hash(rows) @@ -55,7 +52,7 @@ async def build_agent_with_cache( flags: AgentFeatureFlags, checkpointer: Checkpointer, subagent_dependencies: dict[str, Any], - mcp_tools_by_agent: dict[str, ToolsPermissions], + mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, ) -> Any: diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py index 8988f0296..8451b3b7d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/runtime/factory.py @@ -29,6 +29,10 @@ from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_to from app.agents.new_chat.tools.registry import build_tools_async from app.db import ChatVisibility from app.services.connector_service import ConnectorService +from app.services.user_tool_allowlist import ( + fetch_user_allowlist_rulesets, + make_trusted_tool_saver, +) from app.utils.perf import get_perf_logger from ..system_prompt import build_main_agent_system_prompt @@ -141,11 +145,49 @@ async def create_multi_agent_chat_deep_agent( ) mcp_tools_by_agent = {} _perf_log.info( - "[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)", + "[create_agent] load_mcp_tools_by_connector in %.3fs (%d agents)", time.perf_counter() - _t0, len(mcp_tools_by_agent), ) + # User-scoped allow-list ("Always Allow" persisted to + # ``SearchSourceConnector.config.trusted_tools``). Layered last in each + # subagent's PermissionMiddleware so user ``allow`` overrides coded + # ``ask`` via last-match-wins. Anonymous turns and read failures both + # degrade to "no user rules" rather than blocking the turn. + user_allowlist_by_subagent: dict[str, Any] = {} + trusted_tool_saver = None + if user_id: + try: + import uuid as _uuid + + user_uuid = _uuid.UUID(user_id) + except (TypeError, ValueError): + user_uuid = None + + if user_uuid is not None: + _t0 = time.perf_counter() + try: + user_allowlist_by_subagent = await fetch_user_allowlist_rulesets( + db_session, + user_id=user_uuid, + search_space_id=search_space_id, + ) + except Exception as e: + logging.warning( + "User allow-list fetch failed; subagents will run without user trust rules this turn: %s", + e, + ) + user_allowlist_by_subagent = {} + _perf_log.info( + "[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)", + time.perf_counter() - _t0, + len(user_allowlist_by_subagent), + ) + trusted_tool_saver = make_trusted_tool_saver(user_uuid) + dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent + dependencies["trusted_tool_saver"] = trusted_tool_saver + modified_disabled_tools = list(disabled_tools) if disabled_tools else [] if "search_knowledge_base" not in modified_disabled_tools: diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/compose.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/compose.py index c21e69fcb..fc341dce3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/compose.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/builder/compose.py @@ -49,7 +49,7 @@ def build_main_agent_system_prompt( custom_system_instructions: str | None = None, use_default_system_instructions: bool = True, citations_enabled: bool = True, - model_name: str | None = None, # noqa: ARG001 — kept for caller compatibility + model_name: str | None = None, ) -> str: resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() visibility = thread_visibility or ChatVisibility.PRIVATE @@ -62,7 +62,9 @@ def build_main_agent_system_prompt( if custom_system_instructions and custom_system_instructions.strip(): parts.append( - "\n" + custom_system_instructions.format(resolved_today=resolved_today) + "\n" + "\n" + + custom_system_instructions.format(resolved_today=resolved_today) + + "\n" ) if use_default_system_instructions: diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md index 1308c112c..4e27381d3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/system_prompt/prompts/routing.md @@ -23,15 +23,21 @@ Use `task` for anything beyond the direct tools above. See `` for the live roster. Rules for `task`: -- **One specialist per `task` call.** A single `task` invocation must - describe work that one specialist can do end-to-end. Never bundle work - for two specialists into one task prompt — the specialist you route to - will silently drop the other half. -- **One `task` call per turn.** If the user's request spans multiple - specialists, handle them one at a time across consecutive turns: invoke - the first this turn, return, then invoke the next on your next turn (no - user input required between). Use `write_todos` to keep the plan alive - across those turns. +- **One specialist per `task` call.** A single `task` invocation targets + exactly one specialist; that specialist only has tools for its own + domain, so any work outside that domain in the same prompt won't run. +- **Parallelise independent specialist work.** When a turn needs multiple + `task` calls whose work doesn't depend on each other's results (e.g. + "create a ClickUp ticket AND a Linear ticket"), emit them as parallel + `task` calls. Two `task` calls are independent when: + - Neither's prompt references the other's output, and + - They target different specialists, OR the same specialist with + non-overlapping scopes (e.g. reading two unrelated paths). +- **Serialise dependent work across turns.** If one specialist's output + must inform another's input (e.g. "find the roadmap in my KB, then + email it to Maya"), invoke them on consecutive turns — first finishes, + then you call the second with the first's result baked into its prompt. + Use `write_todos` to keep the plan alive across those turns. - Within a single specialist, bundle every related step into the same task prompt (read + write + summary go together). - Put the **full instructions inside the task prompt** — the specialist @@ -66,19 +72,25 @@ user: "Find my Q2 roadmap and summarise the milestones." user: "Create a ClickUp ticket and a Linear ticket for the new feature flag." -→ This turn: +→ Independent work — call both specialists in parallel: write_todos([ {content: "Create ClickUp ticket for feature flag rollout", status: "in_progress"}, - {content: "Create Linear ticket for feature flag rollout", status: "pending"}, + {content: "Create Linear ticket for feature flag rollout", status: "in_progress"}, ]) task(clickup, "Create a ClickUp ticket titled 'Feature flag rollout' in the default list. Description: <…>. Tell me the ticket URL.") -→ Next turn: - write_todos([ - {content: "Create ClickUp ticket for feature flag rollout", status: "completed"}, - {content: "Create Linear ticket for feature flag rollout", status: "in_progress"}, - ]) task(linear, "Create a Linear ticket titled 'Feature flag rollout' in the default team. Description: <…>. Tell me the ticket URL.") + + +user: "Find my Q2 roadmap doc in the KB and email a summary to Maya." +→ The email body depends on the doc's contents — serialise across turns. + This turn: + task(knowledge_base, "Find the Q2 roadmap document under /documents + and return its full text plus a 3-bullet summary.") + Next turn (with the returned summary in hand): + task(gmail, "Send an email to Maya with subject 'Q2 roadmap summary' + and the following body: .") + diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py index d03b571ca..f370a71c7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/__init__.py @@ -4,21 +4,27 @@ Replaces upstream ``SubAgentMiddleware`` to: - share the parent's checkpointer with each subagent, - forward ``runtime.config`` (thread_id, recursion_limit, …) into nested invokes, +- isolate each parallel ``task`` call in its own checkpoint slot via + per-call ``thread_id`` namespacing, - bridge ``Command(resume=...)`` from the parent into the subagent via the - ``config["configurable"]["surfsense_resume_value"]`` side-channel, + ``config["configurable"]["surfsense_resume_value"]`` side-channel, keyed by + ``tool_call_id`` so parallel siblings never race on a shared scalar, - target the resume at the captured interrupt id so a follow-up ``HumanInTheLoopMiddleware.after_model`` does not consume the same payload, -- re-raise any new subagent interrupt at the parent so the SSE stream surfaces it. +- stamp each subagent's pending interrupt with the parent's ``tool_call_id`` + so ``stream_resume_chat`` can route a flat ``decisions`` list back to the + right paused subagent. Module layout ------------- -- ``constants`` — shared keys / limits. -- ``config`` — RunnableConfig + side-channel resume read. -- ``resume`` — pending-interrupt detection, fan-out, ``Command(resume=...)`` builder. -- ``propagation`` — re-raise pending subagent interrupts at the parent. -- ``task_tool`` — the ``task`` tool factory (sync + async). -- ``middleware`` — :class:`SurfSenseCheckpointedSubAgentMiddleware` itself. +- ``constants`` — shared keys / limits. +- ``config`` — RunnableConfig + side-channel resume read + per-call ``thread_id``. +- ``resume`` — pending-interrupt detection, fan-out, ``Command(resume=...)`` builder. +- ``propagation`` — ``wrap_with_tool_call_id`` helper for stamping interrupt values. +- ``resume_routing``— slice a flat decisions list to per-``tool_call_id`` payloads. +- ``task_tool`` — the ``task`` tool factory (sync + async), and the catch-and-stamp chokepoint. +- ``middleware`` — :class:`SurfSenseCheckpointedSubAgentMiddleware` itself. """ from .middleware import SurfSenseCheckpointedSubAgentMiddleware diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py index ac232b92a..ad5b58607 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/config.py @@ -21,7 +21,17 @@ _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad" def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: - """RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget.""" + """RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``. + + Each parallel subagent invocation lands in its own checkpoint slot keyed + by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``. + The same call across the resume cycle keeps reading from the same snapshot + (``tool_call_id`` is stable per LLM-emitted call). + + We namespace via ``thread_id`` rather than ``checkpoint_ns`` because + langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a + subgraph path and raises ``ValueError("Subgraph X not found")``. + """ merged: dict[str, Any] = dict(runtime.config) if runtime.config else {} current_limit = merged.get("recursion_limit") try: @@ -30,43 +40,68 @@ def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: current_int = 0 if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT: merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT + + configurable: dict[str, Any] = dict(merged.get("configurable") or {}) + parent_thread_id = configurable.get("thread_id") + per_call_suffix = f"task:{runtime.tool_call_id}" + configurable["thread_id"] = ( + f"{parent_thread_id}::{per_call_suffix}" + if parent_thread_id + else per_call_suffix + ) + merged["configurable"] = configurable return merged def consume_surfsense_resume(runtime: ToolRuntime) -> Any: - """Pop the resume payload; siblings share ``configurable`` by reference.""" + """Pop the resume payload for *this* call's ``tool_call_id``. + + The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]`` + so parallel sibling subagents (each with their own ``tool_call_id``) read + only their own decision and never race on a shared scalar. + """ cfg = runtime.config or {} configurable = cfg.get("configurable") if isinstance(cfg, dict) else None if not isinstance(configurable, dict): return None - return configurable.pop("surfsense_resume_value", None) + by_tcid = configurable.get("surfsense_resume_value") + if not isinstance(by_tcid, dict): + return None + payload = by_tcid.pop(runtime.tool_call_id, None) + if not by_tcid: + configurable.pop("surfsense_resume_value", None) + return payload def has_surfsense_resume(runtime: ToolRuntime) -> bool: - """True iff a resume payload is queued on this runtime (non-destructive).""" + """True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive).""" cfg = runtime.config or {} configurable = cfg.get("configurable") if isinstance(cfg, dict) else None if not isinstance(configurable, dict): return False - return "surfsense_resume_value" in configurable + by_tcid = configurable.get("surfsense_resume_value") + if not isinstance(by_tcid, dict): + return False + return runtime.tool_call_id in by_tcid def drain_parent_null_resume(runtime: ToolRuntime) -> None: """Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating. ``stream_resume_chat`` wakes the main agent with - ``Command(resume={"decisions": [...]})`` so the propagated - ``_lg_interrupt(...)`` can return. langgraph stores that payload as the - parent task's ``null_resume`` pending write, which only gets consumed - *after* ``subagent.[a]invoke`` returns (when the post-call propagation - re-fires). While the subagent is mid-execution, any *new* ``interrupt()`` - inside it (e.g. a follow-up tool call after a mixed approve/reject) walks - ``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up - the parent's still-live decisions — mismatching against a different number - of hanging tool calls and crashing ``HumanInTheLoopMiddleware``. + ``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously + propagated parent-level interrupt can return. langgraph stores that + payload as the parent task's ``null_resume`` pending write. The ``task`` + tool then forwards this turn's slice into the subagent via its own + ``Command(resume=...)``. While the subagent is mid-execution, any *new* + ``interrupt()`` inside it (e.g. a follow-up tool call after a mixed + approve/reject) walks ``subagent_scratchpad → parent_scratchpad.get_null_resume`` + and picks up the parent's still-live decisions — mismatching against a + different number of hanging tool calls and crashing + ``HumanInTheLoopMiddleware``. Draining the write here closes that cross-graph leak so subagent - interrupts pause cleanly and re-propagate as a fresh approval card. + interrupts pause cleanly and bubble back up as a fresh approval card. """ cfg = runtime.config or {} configurable = cfg.get("configurable") if isinstance(cfg, dict) else None diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py index da8a62cdc..8f51ffed7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/middleware.py @@ -12,7 +12,6 @@ from deepagents.middleware.subagents import ( SubAgentMiddleware, ) from langchain.agents import create_agent -from langchain.agents.middleware import HumanInTheLoopMiddleware from langchain.chat_models import init_chat_model from langgraph.types import Checkpointer @@ -81,10 +80,6 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): middleware: list[Any] = list(spec.get("middleware", [])) - interrupt_on = spec.get("interrupt_on") - if interrupt_on: - middleware.append(HumanInTheLoopMiddleware(interrupt_on=interrupt_on)) - specs.append( { "name": spec["name"], diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py index 55aae7201..cfebe1fd9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py @@ -1,74 +1,38 @@ -"""Re-raise still-pending subagent interrupts at the parent graph level. +"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value. -After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may -still hold a pending interrupt (e.g. the LLM produced a follow-up tool call -that fired a fresh ``interrupt()``). The parent's pregel cannot see that -interrupt because it lives in a separate compiled graph; we re-raise it here -so the parent's SSE stream surfaces it as the next approval card. +When a subagent (compiled as a langgraph subgraph and invoked from a parent +tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph +raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's +``task`` tool catches that exception, stamps ``tool_call_id`` onto each +``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a +fresh ``GraphInterrupt`` whose values carry that stamp. + +``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]`` +to route a flat ``decisions`` list back to the right paused subagent — without +the stamp, parallel HITL across siblings would collapse into an ambiguous +bucket and resume would fail. + +This module hosts only the stamping helper; the catch/re-raise lives in +``task_tool.py`` since that's the single chokepoint where the raw exception +is in our hands. """ from __future__ import annotations -import logging from typing import Any -from langchain_core.runnables import Runnable -from langgraph.types import interrupt as _lg_interrupt -from .resume import get_first_pending_subagent_interrupt +def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]: + """Return a value dict that always carries the parent's ``tool_call_id``. -logger = logging.getLogger(__name__) + Dict values are shallow-copied with ``tool_call_id`` stamped on top, so + any value the subagent may already carry under that key (from a deeper + HITL level) is overwritten — the parent's call id is the only one + ``stream_resume_chat`` correlates against. - -def maybe_propagate_subagent_interrupt( - subagent: Runnable, - sub_config: dict[str, Any], - subagent_type: str, -) -> None: - """Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it.""" - get_state_sync = getattr(subagent, "get_state", None) - if not callable(get_state_sync): - return - try: - snapshot = get_state_sync(sub_config) - except Exception: # pragma: no cover - defensive - logger.debug( - "Subagent get_state failed during re-interrupt check", - exc_info=True, - ) - return - _pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) - if pending_value is None: - return - logger.info( - "Re-raising subagent %r interrupt to parent (multi-step HITL)", - subagent_type, - ) - _lg_interrupt(pending_value) - - -async def amaybe_propagate_subagent_interrupt( - subagent: Runnable, - sub_config: dict[str, Any], - subagent_type: str, -) -> None: - """Async counterpart of :func:`maybe_propagate_subagent_interrupt`.""" - aget_state = getattr(subagent, "aget_state", None) - if not callable(aget_state): - return - try: - snapshot = await aget_state(sub_config) - except Exception: # pragma: no cover - defensive - logger.debug( - "Subagent aget_state failed during re-interrupt check", - exc_info=True, - ) - return - _pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) - if pending_value is None: - return - logger.info( - "Re-raising subagent %r interrupt to parent (multi-step HITL)", - subagent_type, - ) - _lg_interrupt(pending_value) + Non-dict values are wrapped as ``{"value": , "tool_call_id": ...}`` + so simple ``interrupt("approve?")`` patterns still propagate cleanly. + """ + if isinstance(value, dict): + return {**value, "tool_call_id": tool_call_id} + return {"value": value, "tool_call_id": tool_call_id} diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py new file mode 100644 index 000000000..37f45e42f --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/resume_routing.py @@ -0,0 +1,183 @@ +"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads. + +The frontend submits decisions in the same order the SSE stream emitted +approval cards. When multiple parallel subagents are paused, the backend uses +this module to: + +1. Read ``state.interrupts`` from the parent's paused snapshot, extracting + ``[(tool_call_id, action_count), ...]`` from each interrupt's value. + The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id`` + inside ``task_tool``'s catch-and-stamp block when a subagent's + ``GraphInterrupt`` bubbles up through ``[a]task``. +2. Slice the flat ``decisions`` list against that ordered pending list to + produce the dict shape expected by ``consume_surfsense_resume``. +3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as + the parent-level ``Command(resume={interrupt_id: payload})`` input — the + only shape langgraph accepts when multiple interrupts are pending. + +All helpers are pure: callers own the state and the input decisions; we +return new structures and never mutate. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable +from typing import Any + +logger = logging.getLogger(__name__) + + +def slice_decisions_by_tool_call( + decisions: list[dict[str, Any]], + pending: Iterable[tuple[str, int]], +) -> dict[str, dict[str, Any]]: + """Slice ``decisions`` into ``{tool_call_id: {"decisions": }}``. + + Args: + decisions: Flat list of decisions in the order the SSE stream rendered + them. + pending: Ordered ``(tool_call_id, action_count)`` pairs in the same + order. The slicer consumes ``decisions`` left-to-right. + + Returns: + Per-``tool_call_id`` payload dict ready to be written to + ``configurable["surfsense_resume_value"]``. + + Raises: + ValueError: When the total expected action count differs from the + number of decisions provided. We fail loud rather than silently + dropping or padding so a frontend/backend contract drift surfaces + immediately. + """ + pending_list = list(pending) + expected = sum(count for _, count in pending_list) + if expected != len(decisions): + raise ValueError( + f"Decision count mismatch: pending tool calls expect " + f"{expected} actions but received {len(decisions)} decisions." + ) + + routed: dict[str, dict[str, Any]] = {} + cursor = 0 + for tool_call_id, action_count in pending_list: + routed[tool_call_id] = {"decisions": decisions[cursor : cursor + action_count]} + cursor += action_count + return routed + + +def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]: + """Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state. + + Reads ``state.interrupts`` (the bundle langgraph aggregated from each + paused subagent's propagated interrupt). Each interrupt value carries the + ``tool_call_id`` that the parent's ``task`` tool was processing — see + ``propagation.wrap_with_tool_call_id`` and ``task_tool``'s + ``except GraphInterrupt`` chokepoint. + + Order is preserved from ``state.interrupts``, which is the order the SSE + stream emitted approval cards. The frontend submits decisions in that + same order, so the slicer can consume them left-to-right. + + Interrupts without a ``tool_call_id`` are skipped — they were not + produced by our task-routing layer (e.g. parent-side HITL middleware on + a different tool); ``stream_resume_chat`` is not responsible for routing + those. + + Args: + state: A langgraph ``StateSnapshot`` (or any object with an + ``interrupts`` attribute). + + Returns: + Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is + ``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for + scalar-style ``interrupt("...")`` values that were wrapped as + ``{"value": ..., "tool_call_id": ...}``. + + Raises: + ValueError: When an interrupt value carries a ``tool_call_id`` but + the action count cannot be determined (contract bug — every + propagated value should be either a HITL bundle or a wrapped + scalar). + """ + pending: list[tuple[str, int]] = [] + for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()): + value = getattr(interrupt_obj, "value", None) + if not isinstance(value, dict): + logger.warning( + "[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)", + idx, + type(value).__name__, + ) + continue + tool_call_id = value.get("tool_call_id") + if not isinstance(tool_call_id, str): + # Should not happen post-stamping; flag loudly if a regression + # ever lets an unstamped value reach the parent state. + logger.warning( + "[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)", + idx, + sorted(value.keys()), + ) + continue + + action_requests = value.get("action_requests") + if isinstance(action_requests, list): + pending.append((tool_call_id, len(action_requests))) + continue + if "value" in value: + pending.append((tool_call_id, 1)) + continue + + raise ValueError( + f"Interrupt for tool_call_id={tool_call_id!r} has no " + "``action_requests`` list and is not a wrapped scalar value; " + "cannot determine action count for resume routing." + ) + + return pending + + +def build_lg_resume_map( + state: Any, by_tool_call_id: dict[str, dict[str, Any]] +) -> dict[str, dict[str, Any]]: + """Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume. + + ``stream_resume_chat`` builds ``by_tool_call_id`` via + :func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)`` + requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the + parent state has multiple pending interrupts. This pure helper re-keys the + slice without mutating it, and skips entries that can't be paired (no + stamp, no slice) so contract drift surfaces as a count mismatch at the + call site instead of a silent mis-route. + + The two key spaces serve two different consumers: + - ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the + subagent bridge inside ``task_tool``. + - ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's + pregel to wake each pending interrupt site. + + Args: + state: A langgraph ``StateSnapshot`` (or any object with an + ``interrupts`` iterable). + by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`. + + Returns: + Dict ready to be passed as ``Command(resume=)``. + """ + out: dict[str, dict[str, Any]] = {} + for interrupt_obj in getattr(state, "interrupts", ()) or (): + value = getattr(interrupt_obj, "value", None) + if not isinstance(value, dict): + continue + tool_call_id = value.get("tool_call_id") + if not isinstance(tool_call_id, str): + continue + interrupt_id = getattr(interrupt_obj, "id", None) + if not isinstance(interrupt_id, str): + continue + payload = by_tool_call_id.get(tool_call_id) + if payload is None: + continue + out[interrupt_id] = payload + return out diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py index 7c0dd8624..f9b316e23 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py @@ -9,14 +9,15 @@ re-raises any new pending interrupt back to the parent. from __future__ import annotations import logging -from typing import Annotated, Any +from typing import Annotated, Any, NoReturn from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION from langchain.tools import BaseTool, ToolRuntime from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import StructuredTool -from langgraph.types import Command +from langgraph.errors import GraphInterrupt +from langgraph.types import Command, Interrupt from .config import ( consume_surfsense_resume, @@ -25,10 +26,7 @@ from .config import ( subagent_invoke_config, ) from .constants import EXCLUDED_STATE_KEYS -from .propagation import ( - amaybe_propagate_subagent_interrupt, - maybe_propagate_subagent_interrupt, -) +from .propagation import wrap_with_tool_call_id from .resume import ( build_resume_command, fan_out_decisions_to_match, @@ -39,6 +37,31 @@ from .resume import ( logger = logging.getLogger(__name__) +def _reraise_stamped_subagent_interrupt( + gi: GraphInterrupt, tool_call_id: str +) -> NoReturn: + """Stamp ``tool_call_id`` onto each pending interrupt value and re-raise. + + See :mod:`...propagation` for why this stamp is required for resume routing. + Chained via ``from gi`` so tracebacks point at the subagent's original + ``interrupt(...)`` site. + """ + interrupts = gi.args[0] if gi.args else () + stamped = tuple( + Interrupt( + value=wrap_with_tool_call_id(i.value, tool_call_id), + id=i.id, + ) + for i in interrupts + ) + logger.info( + "[hitl_route] stamped %d subagent interrupt(s) with tool_call_id=%s", + len(stamped), + tool_call_id, + ) + raise GraphInterrupt(stamped) from gi + + def build_task_tool_with_parent_config( subagents: list[dict[str, Any]], task_description: str | None = None, @@ -161,13 +184,18 @@ def build_task_tool_with_parent_config( # Prevent the parent's resume payload from leaking into subagent # interrupts via langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) - result = subagent.invoke( - build_resume_command(resume_value, pending_id), - config=sub_config, - ) + try: + result = subagent.invoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) else: - result = subagent.invoke(subagent_state, config=sub_config) - maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type) + try: + result = subagent.invoke(subagent_state, config=sub_config) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) return _return_command_with_state_update(result, runtime.tool_call_id) async def atask( @@ -181,6 +209,11 @@ def build_task_tool_with_parent_config( ], runtime: ToolRuntime, ) -> str | Command: + logger.info( + "[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s", + subagent_type, + runtime.tool_call_id, + ) if subagent_type not in subagent_graphs: allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs]) return ( @@ -228,13 +261,18 @@ def build_task_tool_with_parent_config( # Prevent the parent's resume payload from leaking into subagent # interrupts via langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) - result = await subagent.ainvoke( - build_resume_command(resume_value, pending_id), - config=sub_config, - ) + try: + result = await subagent.ainvoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) else: - result = await subagent.ainvoke(subagent_state, config=sub_config) - await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type) + try: + result = await subagent.ainvoke(subagent_state, config=sub_config) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) return _return_command_with_state_update(result, runtime.tool_call_id) return StructuredTool.from_function( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py index 850f0953b..539050414 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/namespace_policy.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: def check_cloud_write_namespace( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, path: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str | None: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py index a1f8e3f2c..2c8ec6b4d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/middleware/path_resolution.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: def current_cwd( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: cwd = runtime.state.get("cwd") if hasattr(runtime, "state") else None @@ -35,7 +35,7 @@ def current_cwd( def get_contract_suggested_path( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: """Read the planner's suggested write path; otherwise default to ``notes.md``.""" @@ -47,7 +47,7 @@ def get_contract_suggested_path( def resolve_relative( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, path: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: @@ -63,7 +63,7 @@ def resolve_relative( def resolve_write_target_path( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, file_path: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: @@ -77,7 +77,7 @@ def resolve_write_target_path( def resolve_move_target_path( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, file_path: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: @@ -91,7 +91,7 @@ def resolve_move_target_path( def resolve_list_target_path( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, path: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: @@ -105,7 +105,7 @@ def resolve_list_target_path( def normalize_local_mount_path( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, candidate: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py index 74261c3f1..9d3cdbae3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/system_prompt/index.py @@ -9,9 +9,7 @@ from .common import HEADER, SANDBOX_ADDENDUM from .desktop import BODY as DESKTOP_BODY -def build_system_prompt( - mode: FilesystemMode, *, sandbox_available: bool -) -> str: +def build_system_prompt(mode: FilesystemMode, *, sandbox_available: bool) -> str: """Assemble the FS prompt: common header + mode body + optional sandbox section.""" body = CLOUD_BODY if mode == FilesystemMode.CLOUD else DESKTOP_BODY base = HEADER + body diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py index ac6b95805..8df6b9edb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/cd/index.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_cd_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_cd_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_cd( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py index 6506cf876..324ef09b0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/edit_file/index.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_edit_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_edit_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_edit_file( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py index 2b7ada887..cda9f535d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/helpers.py @@ -36,7 +36,7 @@ def wrap_as_python(code: str) -> str: async def execute_in_sandbox( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, command: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], timeout: int | None, @@ -59,14 +59,12 @@ async def execute_in_sandbox( try: return await _try_sandbox_execute(mw, command, runtime, timeout) except Exception: - logger.exception( - "Sandbox retry also failed for thread %s", mw._thread_id - ) + logger.exception("Sandbox retry also failed for thread %s", mw._thread_id) return "Error: Code execution is temporarily unavailable. Please try again." async def _try_sandbox_execute( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, command: str, runtime: ToolRuntime[None, SurfSenseFilesystemState], timeout: int | None, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py index f826c4fe9..2711636e4 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/execute_code/index.py @@ -17,13 +17,11 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_execute_code_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_execute_code_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) def sync_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], + command: Annotated[str, "Python code to execute. Use print() to see output."], runtime: ToolRuntime[None, SurfSenseFilesystemState], timeout: Annotated[ int | None, @@ -35,14 +33,10 @@ def create_execute_code_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: return f"Error: timeout must be non-negative, got {timeout}." if timeout > MAX_EXECUTE_TIMEOUT: return f"Error: timeout {timeout}s exceeds maximum ({MAX_EXECUTE_TIMEOUT}s)." - return run_async_blocking( - execute_in_sandbox(mw, command, runtime, timeout) - ) + return run_async_blocking(execute_in_sandbox(mw, command, runtime, timeout)) async def async_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], + command: Annotated[str, "Python code to execute. Use print() to see output."], runtime: ToolRuntime[None, SurfSenseFilesystemState], timeout: Annotated[ int | None, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py index b17cdffe1..8bad88a74 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/list_tree/index.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_list_tree_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_list_tree_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_list_tree( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py index bfae66416..70f31dd04 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/ls/index.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_ls_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_ls_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_ls( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py index 768403e5b..788381faa 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/mkdir/index.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_mkdir_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_mkdir_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_mkdir( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py index 04c15d479..7613f62f1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/helpers.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: async def cloud_move_file( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], source: str, dest: str, @@ -39,8 +39,7 @@ async def cloud_move_file( ) if not source.startswith(DOCUMENTS_ROOT + "/"): return ( - "Error: cloud move_file source must be under /documents/ (got " - f"'{source}')." + f"Error: cloud move_file source must be under /documents/ (got '{source}')." ) if not dest.startswith(DOCUMENTS_ROOT + "/"): return ( @@ -89,9 +88,7 @@ async def cloud_move_file( ], "messages": [ ToolMessage( - content=( - f"Moved '{source}' to '{dest}' (will commit at end of turn)." - ), + content=(f"Moved '{source}' to '{dest}' (will commit at end of turn)."), tool_call_id=runtime.tool_call_id, ) ], diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py index d04812775..d90535990 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/move_file/index.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_move_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_move_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_move_file( @@ -85,9 +85,7 @@ def create_move_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: ] = False, ) -> Command | str: return run_async_blocking( - async_move_file( - source_path, destination_path, runtime, overwrite=overwrite - ) + async_move_file(source_path, destination_path, runtime, overwrite=overwrite) ) return StructuredTool.from_function( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py index f4ca75067..c15b67114 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/pwd/index.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_pwd_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_pwd_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) def sync_pwd( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py index c6e62dd21..8b0a1a1c8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/read_file/index.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_read_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_read_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_read_file( @@ -90,9 +90,7 @@ def create_read_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: "Maximum number of lines to read.", ] = 100, ) -> Command | str: - return run_async_blocking( - async_read_file(file_path, runtime, offset, limit) - ) + return run_async_blocking(async_read_file(file_path, runtime, offset, limit)) return StructuredTool.from_function( name="read_file", diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py index cc125b181..8a02544d8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/helpers.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: async def cloud_rm( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], validated: str, ) -> Command | str: @@ -31,8 +31,7 @@ async def cloud_rm( return f"Error: refusing to rm '{validated}'." if not validated.startswith(DOCUMENTS_ROOT + "/"): return ( - "Error: cloud rm must target a path under /documents/ " - f"(got '{validated}')." + f"Error: cloud rm must target a path under /documents/ (got '{validated}')." ) anon = runtime.state.get("kb_anon_doc") or {} @@ -41,14 +40,10 @@ async def cloud_rm( staged_dirs = list(runtime.state.get("staged_dirs") or []) if validated in staged_dirs: - return ( - f"Error: '{validated}' is a directory. Use rmdir for " - "empty directories." - ) + return f"Error: '{validated}' is a directory. Use rmdir for empty directories." pending_dir_deletes = list(runtime.state.get("pending_dir_deletes") or []) if any( - isinstance(d, dict) and d.get("path") == validated - for d in pending_dir_deletes + isinstance(d, dict) and d.get("path") == validated for d in pending_dir_deletes ): return f"Error: '{validated}' is already queued for rmdir." @@ -57,14 +52,11 @@ async def cloud_rm( children = await backend.als_info(validated) if children: return ( - f"Error: '{validated}' is a directory. Use rmdir for " - "empty directories." + f"Error: '{validated}' is a directory. Use rmdir for empty directories." ) pending_deletes = list(runtime.state.get("pending_deletes") or []) - if any( - isinstance(d, dict) and d.get("path") == validated for d in pending_deletes - ): + if any(isinstance(d, dict) and d.get("path") == validated for d in pending_deletes): return f"'{validated}' is already queued for deletion." files_state = runtime.state.get("files") or {} @@ -93,8 +85,7 @@ async def cloud_rm( "messages": [ ToolMessage( content=( - f"Staged delete of '{validated}' (will commit at " - "end of turn)." + f"Staged delete of '{validated}' (will commit at end of turn)." ), tool_call_id=runtime.tool_call_id, ) @@ -114,7 +105,7 @@ async def cloud_rm( async def desktop_rm( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], validated: str, ) -> Command | str: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py index 52d2e231e..0c4e2fc71 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rm/index.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_rm_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_rm_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_rm( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py index da986ac31..de5afe722 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/helpers.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: async def cloud_rmdir( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], validated: str, ) -> Command | str: @@ -49,8 +49,7 @@ async def cloud_rmdir( staged_dirs = list(runtime.state.get("staged_dirs") or []) pending_dir_deletes = list(runtime.state.get("pending_dir_deletes") or []) if any( - isinstance(d, dict) and d.get("path") == validated - for d in pending_dir_deletes + isinstance(d, dict) and d.get("path") == validated for d in pending_dir_deletes ): return f"'{validated}' is already queued for deletion." @@ -61,11 +60,7 @@ async def cloud_rmdir( if isinstance(backend, KBPostgresBackend): children = list(await backend.als_info(validated)) - if ( - isinstance(backend, KBPostgresBackend) - and not children - and not exists_in_staged - ): + if isinstance(backend, KBPostgresBackend) and not children and not exists_in_staged: loaded = await backend._load_file_data(validated) if loaded is not None: return f"Error: '{validated}' is a file. Use rm to delete files." @@ -79,9 +74,7 @@ async def cloud_rmdir( return f"Error: directory '{validated}' not found." if children: - return ( - f"Error: directory '{validated}' is not empty. Remove contents first." - ) + return f"Error: directory '{validated}' is not empty. Remove contents first." if exists_in_staged: rest = [d for d in staged_dirs if d != validated] @@ -109,8 +102,7 @@ async def cloud_rmdir( "messages": [ ToolMessage( content=( - f"Staged rmdir of '{validated}' (will commit " - "at end of turn)." + f"Staged rmdir of '{validated}' (will commit at end of turn)." ), tool_call_id=runtime.tool_call_id, ) @@ -120,7 +112,7 @@ async def cloud_rmdir( async def desktop_rmdir( - mw: "SurfSenseFilesystemMiddleware", + mw: SurfSenseFilesystemMiddleware, runtime: ToolRuntime[None, SurfSenseFilesystemState], validated: str, ) -> Command | str: diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py index 457b3312c..cdf057353 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/rmdir/index.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_rmdir_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_rmdir_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_rmdir( diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py index 9d169e2c1..a42f7ed62 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/filesystem/tools/write_file/index.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from ...middleware import SurfSenseFilesystemMiddleware -def create_write_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: +def create_write_file_tool(mw: SurfSenseFilesystemMiddleware) -> BaseTool: description = select_description(mw._filesystem_mode) async def async_write_file( @@ -73,9 +73,7 @@ def create_write_file_tool(mw: "SurfSenseFilesystemMiddleware") -> BaseTool: content: Annotated[str, "Text content to write to the file."], runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> Command | str: - return run_async_blocking( - async_write_file(file_path, content, runtime) - ) + return run_async_blocking(async_write_file(file_path, content, runtime)) return StructuredTool.from_function( name="write_file", diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py index 95f62d3f1..c25c2b281 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/__init__.py @@ -1,16 +1,11 @@ -"""Pattern-based allow/deny/ask middleware with HITL fallback. +"""Pattern-based allow/deny/ask middleware with HITL fallback (vertical slice). -Public surface: :class:`PermissionMiddleware` plus -:func:`normalize_permission_decision` for the streaming layer and the -:data:`PatternResolver` type for callers that register per-tool resolvers. +Public surface (one entry point only — every other symbol is an internal of +the rule engine and stays inside ``middleware/``, ``ask/``, or ``deny.py``): + +- :func:`build_permission_mw` — construction recipe shared by every stack. """ -from .decision import normalize_permission_decision -from .middleware import PermissionMiddleware -from .pattern_resolver import PatternResolver +from .middleware.factory import build_permission_mw -__all__ = [ - "PatternResolver", - "PermissionMiddleware", - "normalize_permission_decision", -] +__all__ = ["build_permission_mw"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/__init__.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/__init__.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/__init__.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/decision.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/decision.py new file mode 100644 index 000000000..f507e85ff --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/decision.py @@ -0,0 +1,74 @@ +"""Translate the unified langchain HITL envelope into permission-domain semantics. + +``PermissionMiddleware`` works with the canonical shape +``{decision_type: "once" | "approve_always" | "reject", feedback?: str, edited_args?: dict}``. +The wire envelope arriving from langgraph already lives in the LC HITL shape +(parsed once in :mod:`hitl_wire.decision`); this module performs the small +domain mapping (``approve|edit`` → ``once``, ``approve_always`` → +``approve_always``, anything else → ``reject``) without re-implementing the +envelope walk. + +Failing closed: any unrecognised decision becomes ``reject`` (with a warning) +so the middleware never proceeds on ambiguous input. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( + LC_DECISION_APPROVE, + LC_DECISION_EDIT, + LC_DECISION_REJECT, + SURFSENSE_DECISION_APPROVE_ALWAYS, + parse_lc_envelope, +) + +logger = logging.getLogger(__name__) + + +# ``approve`` and ``edit`` both mean "let this call go through this once". The +# legacy SurfSense bare-scalar values (``once`` / ``approve_always`` / ``reject``) +# pass through unchanged so historical resume payloads still work. +_LC_TO_PERMISSION: dict[str, str] = { + LC_DECISION_APPROVE: "once", + LC_DECISION_EDIT: "once", + SURFSENSE_DECISION_APPROVE_ALWAYS: "approve_always", + LC_DECISION_REJECT: "reject", + "once": "once", + "approve_always": "approve_always", + "reject": "reject", +} + + +def normalize_permission_decision(envelope: Any) -> dict[str, Any]: + """Project the user's reply into the canonical permission decision shape. + + Args: + envelope: The raw resume value from langgraph (LC HITL envelope, a + bare scalar string, or a pre-canonical dict). + + Returns: + ``{"decision_type": "once"|"approve_always"|"reject"}`` plus optional + ``feedback`` (``reject`` with a user message) and ``edited_args`` + (``edit`` reply with non-empty arg overrides). + """ + parsed = parse_lc_envelope(envelope) + mapped = _LC_TO_PERMISSION.get(parsed.decision_type) + if mapped is None: + logger.warning( + "Unknown permission decision %r; treating as reject", + parsed.decision_type, + ) + mapped = "reject" + + out: dict[str, Any] = {"decision_type": mapped} + if parsed.message: + out["feedback"] = parsed.message + if parsed.edited_args: + out["edited_args"] = parsed.edited_args + return out + + +__all__ = ["normalize_permission_decision"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/__init__.py new file mode 100644 index 000000000..2921cbe70 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/__init__.py @@ -0,0 +1,10 @@ +"""Apply ``edit`` permission decisions to tool calls. + +Edited-arg extraction now lives in :mod:`hitl_wire.decision` (single parser +for all approval paths); this module owns the merge step that produces a +fresh tool-call dict for the orchestrator. +""" + +from .merge import merge_edited_args + +__all__ = ["merge_edited_args"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/merge.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/merge.py new file mode 100644 index 000000000..21474ad52 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/edit/merge.py @@ -0,0 +1,22 @@ +"""Apply edited args to a tool call (shallow merge, no mutation). + +Edited values override originals; keys absent from ``edited_args`` keep +their original values, so partial edits are safe. Returns a NEW tool-call +dict so the caller can swap it into ``AIMessage.tool_calls`` without +aliasing the live message object. +""" + +from __future__ import annotations + +from typing import Any + + +def merge_edited_args( + tool_call: dict[str, Any], edited_args: dict[str, Any] +) -> dict[str, Any]: + original_args = tool_call.get("args") or {} + merged_args = {**original_args, **edited_args} + return {**tool_call, "args": merged_args} + + +__all__ = ["merge_edited_args"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py new file mode 100644 index 000000000..6c5d011df --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/payload.py @@ -0,0 +1,89 @@ +"""Build the permission-ask interrupt payload (LC HITL wire + SurfSense context).""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.tools import BaseTool + +from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( + LC_DECISION_APPROVE, + LC_DECISION_EDIT, + LC_DECISION_REJECT, + SURFSENSE_DECISION_APPROVE_ALWAYS, + build_lc_hitl_payload, +) +from app.agents.new_chat.permissions import Rule + +PERMISSION_ASK_INTERRUPT_TYPE = "permission_ask" + +_BASE_PERMISSION_ASK_DECISIONS: list[str] = [ + LC_DECISION_APPROVE, + LC_DECISION_REJECT, + LC_DECISION_EDIT, +] + + +def _is_mcp_tool(tool: BaseTool | None) -> bool: + """An MCP tool advertises a connector id in its langchain metadata.""" + if tool is None: + return False + metadata = getattr(tool, "metadata", None) or {} + return metadata.get("mcp_connector_id") is not None + + +def _card_fields_from_tool(tool: BaseTool | None) -> dict[str, Any]: + """Project the FE card's tool-scoped fields out of a BaseTool.""" + if tool is None: + return {} + metadata = getattr(tool, "metadata", None) or {} + fields: dict[str, Any] = {} + connector_id = metadata.get("mcp_connector_id") + if connector_id is not None: + fields["mcp_connector_id"] = connector_id + connector_name = metadata.get("mcp_connector_name") + if connector_name: + fields["mcp_server"] = connector_name + if tool.description: + fields["tool_description"] = tool.description + return fields + + +def build_permission_ask_payload( + *, + tool_name: str, + args: dict[str, Any], + patterns: list[str], + rules: list[Rule], + tool: BaseTool | None = None, +) -> dict[str, Any]: + """Build the permission-ask interrupt payload. + + ``approve_always`` is added to the palette only for MCP tools, since that + is the only case where the user's choice can persist beyond the current + agent instance (saved to the connector's trusted-tools list). Native + tools fall back to the once/reject/edit triad. + """ + allowed_decisions = list(_BASE_PERMISSION_ASK_DECISIONS) + if _is_mcp_tool(tool): + allowed_decisions.append(SURFSENSE_DECISION_APPROVE_ALWAYS) + + context: dict[str, Any] = { + "patterns": patterns, + "rules": [ + {"permission": r.permission, "pattern": r.pattern, "action": r.action} + for r in rules + ], + "always": patterns, + **_card_fields_from_tool(tool), + } + return build_lc_hitl_payload( + tool_name=tool_name, + args=args, + allowed_decisions=allowed_decisions, + interrupt_type=PERMISSION_ASK_INTERRUPT_TYPE, + context=context, + ) + + +__all__ = ["PERMISSION_ASK_INTERRUPT_TYPE", "build_permission_ask_payload"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py new file mode 100644 index 000000000..d61d38f34 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/ask/request.py @@ -0,0 +1,59 @@ +"""Side-effectful entry point: pause the graph and return the permission decision. + +Wraps :func:`langgraph.types.interrupt` with the OTel spans the SurfSense +dashboard expects, then projects the resume value through +:func:`normalize_permission_decision` so the middleware downstream only +sees the canonical permission-domain shape. + +When ``emit_interrupt`` is ``False`` the call short-circuits to ``reject``; +this is used by non-interactive deployments where ``ask`` must not block. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.tools import BaseTool +from langgraph.types import interrupt + +from app.agents.new_chat.permissions import Rule +from app.observability import otel as ot + +from .decision import normalize_permission_decision +from .payload import PERMISSION_ASK_INTERRUPT_TYPE, build_permission_ask_payload + + +def request_permission_decision( + *, + tool_name: str, + args: dict[str, Any], + patterns: list[str], + rules: list[Rule], + emit_interrupt: bool, + tool: BaseTool | None = None, +) -> dict[str, Any]: + """Pause for an ``ask`` decision; return the canonical permission decision dict.""" + if not emit_interrupt: + return {"decision_type": "reject"} + + payload = build_permission_ask_payload( + tool_name=tool_name, + args=args, + patterns=patterns, + rules=rules, + tool=tool, + ) + + with ( + ot.permission_asked_span( + permission=tool_name, + pattern=patterns[0] if patterns else None, + extra={"permission.patterns": list(patterns)}, + ), + ot.interrupt_span(interrupt_type=PERMISSION_ASK_INTERRUPT_TYPE), + ): + decision = interrupt(payload) + return normalize_permission_decision(decision) + + +__all__ = ["request_permission_decision"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/decision.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/decision.py deleted file mode 100644 index bb8f9ea25..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/decision.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Coerce inbound permission decisions to a canonical dict shape. - -Two wire formats are accepted: -- SurfSense legacy: ``{"decision_type": "once"|"always"|"reject", "feedback"?}``. -- LangChain HITL envelope: ``{"decisions": [{"type": "approve"|"edit"|"reject", ...}]}``. - -The middleware downstream only inspects the canonical shape returned here, -so adding a new envelope means changing this module alone. - -The middleware fails closed: any unrecognised payload becomes ``reject`` -(with a warning) so the agent never proceeds on ambiguous input. - -When the reply is an ``edit``, the result keeps ``decision_type="once"`` -(the call still goes through) and adds an ``edited_args`` key holding the -user-modified ``args`` dict. The orchestrator merges those into the -``tool_call`` before keeping it; see :mod:`interrupt.edit.merge`. -""" - -from __future__ import annotations - -import logging -from typing import Any - -from .interrupt.edit import extract_edited_args - -logger = logging.getLogger(__name__) - - -# ``edit`` collapses to ``once``; any ``edited_args`` ride on the result. -_LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = { - "approve": "once", - "reject": "reject", - "edit": "once", -} - - -def normalize_permission_decision(decision: Any) -> dict[str, Any]: - """Return ``{"decision_type": ..., "feedback"?: str, "edited_args"?: dict}``.""" - if isinstance(decision, str): - return {"decision_type": decision} - if not isinstance(decision, dict): - logger.warning( - "Unrecognized permission resume value (%s); treating as reject", - type(decision).__name__, - ) - return {"decision_type": "reject"} - - if decision.get("decision_type"): - return decision - - payload: dict[str, Any] = decision - decisions = decision.get("decisions") - if isinstance(decisions, list) and decisions: - first = decisions[0] - if isinstance(first, dict): - payload = first - - raw_type = payload.get("type") or payload.get("decision_type") - if not raw_type: - logger.warning( - "Permission resume missing decision type (keys=%s); treating as reject", - list(payload.keys()), - ) - return {"decision_type": "reject"} - - raw_type = str(raw_type).lower() - mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type) - if mapped is None: - # Tolerate legacy values arriving without ``decision_type`` wrapping. - if raw_type in {"once", "always", "reject"}: - mapped = raw_type - else: - logger.warning( - "Unknown permission decision type %r; treating as reject", raw_type - ) - mapped = "reject" - - out: dict[str, Any] = {"decision_type": mapped} - feedback = payload.get("feedback") or payload.get("message") - if isinstance(feedback, str) and feedback.strip(): - out["feedback"] = feedback - - if raw_type == "edit": - edited = extract_edited_args(payload) - if edited: - out["edited_args"] = edited - - return out - - -__all__ = ["normalize_permission_decision"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/__init__.py deleted file mode 100644 index 993bc50b9..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Apply ``edit`` permission decisions to tool calls (extract + merge).""" - -from .extract import extract_edited_args -from .merge import merge_edited_args - -__all__ = ["extract_edited_args", "merge_edited_args"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/extract.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/extract.py deleted file mode 100644 index 85d365ece..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/extract.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Extract edited args from a permission decision payload. - -Two shapes are accepted (mirrors :func:`app.agents.new_chat.tools.hitl._parse_decision`): - -- LangChain HITL envelope: ``{"edited_action": {"args": {...}}}``. -- Legacy flat shape: ``{"args": {...}}``. - -Returns ``None`` when no edited args are present. The orchestrator decides -whether to merge them (see :mod:`interrupt.edit.merge`); this module is pure parsing. -""" - -from __future__ import annotations - -from typing import Any - - -def extract_edited_args(decision_payload: dict[str, Any] | None) -> dict[str, Any] | None: - if not isinstance(decision_payload, dict): - return None - - edited_action = decision_payload.get("edited_action") - if isinstance(edited_action, dict): - edited_args = edited_action.get("args") - if isinstance(edited_args, dict): - return edited_args - - flat_args = decision_payload.get("args") - if isinstance(flat_args, dict): - return flat_args - - return None - - -__all__ = ["extract_edited_args"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/merge.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/merge.py deleted file mode 100644 index 6632c677c..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/edit/merge.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Apply edited args to a tool call. - -Semantics match :func:`app.agents.new_chat.tools.hitl.request_approval`'s -``final_params = {**params, **edited_params}`` — shallow merge, edited -values override originals. Keys absent from ``edited_args`` keep their -original values, so partial edits are safe. - -Returns a NEW ``tool_call`` dict (the input is not mutated) so the caller -can swap it into the ``AIMessage.tool_calls`` list without aliasing. -""" - -from __future__ import annotations - -from typing import Any - - -def merge_edited_args( - tool_call: dict[str, Any], edited_args: dict[str, Any] -) -> dict[str, Any]: - original_args = tool_call.get("args") or {} - merged_args = {**original_args, **edited_args} - return {**tool_call, "args": merged_args} - - -__all__ = ["merge_edited_args"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/payload.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/payload.py deleted file mode 100644 index d5de1c209..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/payload.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Build the ``permission_ask`` interrupt payload (pure data). - -The frontend's streaming layer keys off ``type`` and renders the approval -card from ``action`` (the tool call being reviewed) and ``context`` -(the matched rules and patterns that prompted the ask). ``context.always`` -lists the patterns the user can promote to a permanent allow rule with a -single ``"always"`` reply. -""" - -from __future__ import annotations - -from typing import Any - -from app.agents.new_chat.permissions import Rule - - -def build_permission_ask_payload( - *, - tool_name: str, - args: dict[str, Any], - patterns: list[str], - rules: list[Rule], -) -> dict[str, Any]: - return { - "type": "permission_ask", - # ``params`` (not ``args``) is what SurfSense's streaming normalizer forwards. - "action": {"tool": tool_name, "params": args or {}}, - "context": { - "patterns": patterns, - "rules": [ - { - "permission": r.permission, - "pattern": r.pattern, - "action": r.action, - } - for r in rules - ], - "always": patterns, - }, - } - - -__all__ = ["build_permission_ask_payload"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/request.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/request.py deleted file mode 100644 index abd2871b8..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/interrupt/request.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Request a permission decision from the user (side-effectful entry point). - -Wraps :func:`langgraph.types.interrupt` with the OTel spans that the -SurfSense dashboard expects, then normalises the resume value through -:func:`decision.normalize_permission_decision`. - -When ``emit_interrupt`` is ``False`` the call short-circuits to -``reject``; this is used by non-interactive deployments where ``ask`` must -not block. -""" - -from __future__ import annotations - -from typing import Any - -from langgraph.types import interrupt - -from app.agents.new_chat.permissions import Rule -from app.observability import otel as ot - -from ..decision import normalize_permission_decision -from .payload import build_permission_ask_payload - - -def request_permission_decision( - *, - tool_name: str, - args: dict[str, Any], - patterns: list[str], - rules: list[Rule], - emit_interrupt: bool, -) -> dict[str, Any]: - if not emit_interrupt: - return {"decision_type": "reject"} - - payload = build_permission_ask_payload( - tool_name=tool_name, args=args, patterns=patterns, rules=rules - ) - - with ( - ot.permission_asked_span( - permission=tool_name, - pattern=patterns[0] if patterns else None, - extra={"permission.patterns": list(patterns)}, - ), - ot.interrupt_span(interrupt_type="permission_ask"), - ): - decision = interrupt(payload) - return normalize_permission_decision(decision) - - -__all__ = ["request_permission_decision"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py index d2370889c..d2950c5b4 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/core.py @@ -5,35 +5,16 @@ LangChain's :class:`HumanInTheLoopMiddleware` only supports a static allow/deny/ask, no glob patterns, no per-space/per-thread overrides, and no auto-deny synthesis. -This middleware layers OpenCode's wildcard-ruleset model on top of -SurfSense's ``interrupt({type, action, context})`` payload shape (see -:mod:`app.agents.new_chat.tools.hitl`) so the frontend keeps working -unchanged. - -Per-tool-call flow inside :meth:`_process`: - -1. Skip when the last message has no tool calls. -2. For each call, evaluate the rules. ``deny`` is replaced with a - synthetic :class:`ToolMessage` carrying a typed - :class:`StreamingError`. ``ask`` raises an interrupt via - :mod:`interrupt.request`; the resulting decision is dispatched here: - - - ``once`` → keep the call as-is. - - ``always`` → also extend the runtime ruleset. - - ``reject`` (with feedback) → :class:`CorrectedError`. - - ``reject`` (no feedback) → :class:`RejectedError`. - - ``allow`` keeps the call unchanged. - -3. Returns an updated ``AIMessage`` (tool calls minus the denied ones) - plus any deny ``ToolMessage`` entries appended after it. Tool-list - filtering at ``before_model`` is intentionally not done here — that - would invalidate provider prompt-cache prefixes. +This middleware layers OpenCode's wildcard-ruleset model on top of the +unified langchain HITL wire format (see :mod:`hitl_wire`), so it sits +beside ``HumanInTheLoopMiddleware`` and self-gated approvals on a single +parallel-HITL routing layer in ``task_tool`` + ``resume_routing``. """ from __future__ import annotations import logging +from dataclasses import dataclass from typing import Any from langchain.agents.middleware.types import ( @@ -42,22 +23,32 @@ from langchain.agents.middleware.types import ( ContextT, ) from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import BaseTool from langgraph.runtime import Runtime from app.agents.new_chat.errors import CorrectedError, RejectedError from app.agents.new_chat.permissions import Ruleset +from app.services.user_tool_allowlist import TrustedToolSaver +from ..ask.edit import merge_edited_args +from ..ask.request import request_permission_decision from ..deny import build_deny_message -from ..interrupt.edit import merge_edited_args -from ..interrupt.request import request_permission_decision -from ..pattern_resolver import PatternResolver -from ..runtime_promote import persist_always from .evaluation import evaluate_tool_call +from .pattern_resolver import PatternResolver from .ruleset_view import all_rulesets +from .runtime_promote import persist_always logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class _AlwaysPromotion: + """A pending request to save an ``approve_always`` decision to the user's trust list.""" + + connector_id: int + tool_name: str + + class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] """Allow/deny/ask layer over the agent's tool calls. @@ -68,10 +59,17 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] to wildcard patterns. Tools without an entry use the bare tool name as the only pattern. runtime_ruleset: Mutable :class:`Ruleset` extended in-place when - the user replies ``"always"``. Reused across calls in the - same agent instance so newly-allowed rules apply downstream. + the user replies ``"approve_always"``. Reused across calls in + the same agent instance so newly-allowed rules apply downstream. always_emit_interrupt_payload: Set ``False`` to make ``ask`` collapse to ``deny`` (for non-interactive deployments). + tools_by_name: Map from tool name to :class:`BaseTool`, used to + decorate ``ask`` interrupts with the tool's description and + MCP metadata for the FE card. + trusted_tool_saver: Async callback invoked on ``approve_always`` + decisions for MCP tools (those whose ``metadata`` carries an + ``mcp_connector_id``). Without it the promotion only lives + in-memory for the current agent instance. """ tools = () @@ -83,6 +81,8 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] pattern_resolvers: dict[str, PatternResolver] | None = None, runtime_ruleset: Ruleset | None = None, always_emit_interrupt_payload: bool = True, + tools_by_name: dict[str, BaseTool] | None = None, + trusted_tool_saver: TrustedToolSaver | None = None, ) -> None: super().__init__() self._static_rulesets: list[Ruleset] = list(rulesets or []) @@ -93,23 +93,33 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] origin="runtime_approved" ) self._emit_interrupt = always_emit_interrupt_payload + self._tools_by_name: dict[str, BaseTool] = dict(tools_by_name or {}) + self._trusted_tool_saver: TrustedToolSaver | None = trusted_tool_saver def _process( self, state: AgentState, runtime: Runtime[Any], - ) -> dict[str, Any] | None: + ) -> tuple[dict[str, Any] | None, list[_AlwaysPromotion]]: + """Pure decision pass: returns ``(state_update, pending_promotions)``. + + Side effects performed here are in-memory only (rule promotion + into ``runtime_ruleset``). DB writes for ``approve_always`` + decisions are queued as ``_AlwaysPromotion`` and flushed by the + async hook. + """ del runtime messages = state.get("messages") or [] if not messages: - return None + return None, [] last = messages[-1] if not isinstance(last, AIMessage) or not last.tool_calls: - return None + return None, [] rulesets = all_rulesets(self._static_rulesets, self._runtime_ruleset) deny_messages: list[ToolMessage] = [] kept_calls: list[dict[str, Any]] = [] + promotions: list[_AlwaysPromotion] = [] any_change = False for raw in last.tool_calls: @@ -142,10 +152,11 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] patterns=patterns, rules=rules, emit_interrupt=self._emit_interrupt, + tool=self._tools_by_name.get(name), ) kind = str(decision.get("decision_type") or "reject").lower() edited_args = decision.get("edited_args") - if kind in ("once", "always"): + if kind in ("once", "approve_always"): final_call = ( merge_edited_args(call, edited_args) if isinstance(edited_args, dict) and edited_args @@ -153,8 +164,11 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] ) if final_call is not call: any_change = True - if kind == "always": + if kind == "approve_always": persist_always(self._runtime_ruleset, name, patterns) + promotion = self._build_always_promotion(name) + if promotion is not None: + promotions.append(promotion) kept_calls.append(final_call) elif kind == "reject": feedback = decision.get("feedback") @@ -173,23 +187,39 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] kept_calls.append(call) if not any_change and len(kept_calls) == len(last.tool_calls): - return None + return None, promotions updated = last.model_copy(update={"tool_calls": kept_calls}) result_messages: list[Any] = [updated] if deny_messages: result_messages.extend(deny_messages) - return {"messages": result_messages} + return {"messages": result_messages}, promotions + + def _build_always_promotion(self, tool_name: str) -> _AlwaysPromotion | None: + """Return a save request iff the tool exposes an ``mcp_connector_id``.""" + tool = self._tools_by_name.get(tool_name) + metadata = getattr(tool, "metadata", None) or {} + connector_id = metadata.get("mcp_connector_id") + if not isinstance(connector_id, int): + return None + return _AlwaysPromotion(connector_id=connector_id, tool_name=tool_name) def after_model( # type: ignore[override] self, state: AgentState, runtime: Runtime[ContextT] ) -> dict[str, Any] | None: - return self._process(state, runtime) + update, _ = self._process(state, runtime) + return update async def aafter_model( # type: ignore[override] self, state: AgentState, runtime: Runtime[ContextT] ) -> dict[str, Any] | None: - return self._process(state, runtime) + update, promotions = self._process(state, runtime) + if self._trusted_tool_saver is not None: + for promotion in promotions: + await self._trusted_tool_saver( + promotion.connector_id, promotion.tool_name + ) + return update __all__ = ["PermissionMiddleware"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py index 6777aa093..51531c4eb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/evaluation.py @@ -24,7 +24,7 @@ from app.agents.new_chat.permissions import ( evaluate_many, ) -from ..pattern_resolver import PatternResolver, default_pattern_resolver +from .pattern_resolver import PatternResolver, default_pattern_resolver logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py new file mode 100644 index 000000000..3c061ded6 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/factory.py @@ -0,0 +1,88 @@ +"""Construction recipe for :class:`PermissionMiddleware` shared across stacks. + +Single source of truth used by both the main-agent stack and every subagent +stack. Rule layers are evaluated earliest-to-latest (last match wins, +matching OpenCode's ``permission/index.ts`` evaluation order): + +1. ``surfsense_defaults`` — single ``allow */*`` rule. Connector tools + already self-gate via :func:`request_approval`, so the rule engine only + needs to *deny* what the user has explicitly forbidden; the default + ``ask`` fallback would otherwise double-prompt every safe read-only + call. +2. ``subagent_rulesets`` — caller-supplied rulesets contributed by the + consuming subagent. Each subagent passes its coded rules (KB: + destructive-FS ``ask`` rules; connectors: per-tool ``allow``/``ask``) + plus, when present, the user's persisted allow-list for that subagent. + +Connector deny synthesis from ``new_chat._synthesize_connector_deny_rules`` +is intentionally NOT replicated: the multi-agent orchestrator already +excludes entire subagents whose required connectors are missing +(``SUBAGENT_TO_REQUIRED_CONNECTOR_MAP``), so the per-tool deny pass is +redundant here. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from langchain_core.tools import BaseTool + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset +from app.services.user_tool_allowlist import TrustedToolSaver + +from .core import PermissionMiddleware + +_SURFSENSE_DEFAULTS = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", +) + + +def build_permission_mw( + *, + flags: AgentFeatureFlags, + subagent_rulesets: list[Ruleset] | None = None, + tools: Sequence[BaseTool] | None = None, + trusted_tool_saver: TrustedToolSaver | None = None, +) -> PermissionMiddleware | None: + """Return a configured :class:`PermissionMiddleware` or ``None`` when no work is needed. + + Args: + flags: Feature toggles. ``enable_permission`` switches the engine on; + ``disable_new_agent_stack`` overrides everything for safety. + subagent_rulesets: Caller-supplied rulesets layered after the + defaults. Subagents pass their own coded ruleset here (and, + when present, the user's persisted allow-list for that + subagent) so each subagent owns its own rule surface without + aliasing a shared engine. Presence of any subagent ruleset + forces the middleware on regardless of ``enable_permission`` — + an explicit ``ask`` rule always asks. + tools: Subagent tools used to decorate ``ask`` interrupts with + FE-card metadata (description, MCP connector). Optional. + trusted_tool_saver: Async callback invoked when an MCP tool's + ``always`` decision lands; persists the user's preference to + ``connector.config['trusted_tools']``. Optional. + + Returns: + ``None`` when the engine has no rules to enforce + (``enable_permission=False`` and no subagent rulesets); a + configured middleware otherwise. + """ + permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack + has_subagent_rulesets = bool(subagent_rulesets) + if not (permission_enabled or has_subagent_rulesets): + return None + + rulesets: list[Ruleset] = [_SURFSENSE_DEFAULTS] + if subagent_rulesets: + rulesets.extend(subagent_rulesets) + tools_by_name = {t.name: t for t in (tools or [])} + return PermissionMiddleware( + rulesets=rulesets, + tools_by_name=tools_by_name, + trusted_tool_saver=trusted_tool_saver, + ) + + +__all__ = ["build_permission_mw"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/pattern_resolver.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/pattern_resolver.py similarity index 100% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/pattern_resolver.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/pattern_resolver.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py index 23fa9cf1c..fbb66d455 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/ruleset_view.py @@ -3,8 +3,8 @@ Static rulesets come from the agent factory (defaults, space-scoped, thread-scoped, etc.). The runtime ruleset is the in-memory one that :func:`runtime_promote.persist_always` extends when the user replies -``"always"``. Evaluators always see them merged in this order so newly- -promoted rules apply to subsequent calls. +``"approve_always"``. Evaluators always see them merged in this order so +newly-promoted rules apply to subsequent calls. """ from __future__ import annotations diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/runtime_promote.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/runtime_promote.py similarity index 89% rename from surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/runtime_promote.py rename to surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/runtime_promote.py index d528010e0..afc65fdc0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/runtime_promote.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/shared/permissions/middleware/runtime_promote.py @@ -1,4 +1,4 @@ -"""Promote an ``"always"`` reply into in-memory allow rules. +"""Promote an ``"approve_always"`` reply into in-memory allow rules. Subsequent calls within the same agent instance match these new rules and proceed without prompting. Durable persistence (to ``agent_permission_rules``) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py index cc52633fa..c1ebe31ca 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/stack.py @@ -31,7 +31,6 @@ from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import ( build_ask_knowledge_base_tool, ) -from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.db import ChatVisibility @@ -61,6 +60,7 @@ from .shared.compaction import build_compaction_mw from .shared.kb_context_projection import build_kb_context_projection_mw from .shared.memory import build_memory_mw from .shared.patch_tool_calls import build_patch_tool_calls_mw +from .shared.permissions import build_permission_mw from .shared.resilience import build_resilience_middlewares from .shared.todos import build_todos_mw from .subagent.middleware_stack import build_subagent_middleware_stack @@ -84,7 +84,7 @@ def build_main_agent_deepagent_middleware( flags: AgentFeatureFlags, subagent_dependencies: dict[str, Any], checkpointer: Checkpointer, - mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None, + mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None, disabled_tools: list[str] | None = None, ) -> list[Any]: """Ordered middleware for ``create_agent`` (None entries already stripped).""" @@ -100,14 +100,19 @@ def build_main_agent_deepagent_middleware( **subagent_dependencies, "backend_resolver": backend_resolver, "filesystem_mode": filesystem_mode, + "flags": flags, } - shared_subagent_middleware = build_subagent_middleware_stack(resilience=resilience) + shared_subagent_middleware = build_subagent_middleware_stack( + resilience=resilience, + flags=flags, + ) - kb_readonly_spec = build_kb_readonly_subagent( + kb_readonly = build_kb_readonly_subagent( dependencies=subagent_dependencies, model=llm, middleware_stack=shared_subagent_middleware, ) + kb_readonly_spec = kb_readonly.spec kb_readonly_runnable = create_agent( llm, system_prompt=kb_readonly_spec["system_prompt"], @@ -182,6 +187,7 @@ def build_main_agent_deepagent_middleware( resilience.retry, resilience.fallback, build_repair_mw(flags=flags, tools=tools), + build_permission_mw(flags=flags), build_doom_loop_mw(flags), build_action_log_mw( flags=flags, diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py index 9889e629a..aa6211fcc 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/subagent/middleware_stack.py @@ -3,7 +3,8 @@ Mirrors ``middleware/stack.py`` (the orchestrator's middleware stack) but exposes its contents as a dict keyed by purpose so specialists can pick the entries they need and decide ordering. The default consumer -(``pack_subagent``) prepends every non-``None`` value in insertion order. +(:func:`pack_subagent`) prepends every non-``None`` value in insertion +order, so ``None`` slots are silently skipped. Registry subagents never touch the SurfSense filesystem — that capability belongs to ``knowledge_base`` — so no FS middleware is exposed here. @@ -13,6 +14,9 @@ from __future__ import annotations from typing import Any +from app.agents.new_chat.feature_flags import AgentFeatureFlags + +from ..shared.permissions import build_permission_mw from ..shared.resilience import ResilienceMiddlewares from ..shared.todos import build_todos_mw @@ -20,9 +24,24 @@ from ..shared.todos import build_todos_mw def build_subagent_middleware_stack( *, resilience: ResilienceMiddlewares, + flags: AgentFeatureFlags | None = None, ) -> dict[str, Any]: + """Assemble the dict of middlewares prepended to every subagent's stack. + + Args: + resilience: Pre-built retry / fallback / call-limit middlewares + (shared with the orchestrator stack to keep behaviour symmetric). + flags: Feature flags driving optional layers. ``None`` disables the + permission layer (used in tests that only need todos+resilience). + + Returns: + Insertion-ordered dict; ``None`` values are tolerated and dropped by + the consumer so callers can flip slots on/off without reshaping. + """ + permission = build_permission_mw(flags=flags) if flags is not None else None return { "todos": build_todos_mw(), + "permission": permission, "retry": resilience.retry, "fallback": resilience.fallback, "model_call_limit": resilience.model_call_limit, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py index 0baa6714f..396e0ec79 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/agent.py @@ -1,27 +1,22 @@ -"""`deliverables` route: ``SubAgent`` spec for deepagents.""" +"""``deliverables`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "deliverables" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles deliverables tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles deliverables tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index 938e73bd4..5f76f1d52 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -1,10 +1,15 @@ +"""``deliverables`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .generate_image import create_generate_image_tool from .podcast import create_generate_podcast_tool @@ -12,43 +17,39 @@ from .report import create_generate_report_tool from .resume import create_generate_resume_tool from .video_presentation import create_generate_video_presentation_tool +NAME = "deliverables" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - resolved_dependencies = {**(dependencies or {}), **kwargs} - podcast = create_generate_podcast_tool( - search_space_id=resolved_dependencies["search_space_id"], - db_session=resolved_dependencies["db_session"], - thread_id=resolved_dependencies["thread_id"], - ) - video = create_generate_video_presentation_tool( - search_space_id=resolved_dependencies["search_space_id"], - db_session=resolved_dependencies["db_session"], - thread_id=resolved_dependencies["thread_id"], - ) - report = create_generate_report_tool( - search_space_id=resolved_dependencies["search_space_id"], - thread_id=resolved_dependencies["thread_id"], - connector_service=resolved_dependencies.get("connector_service"), - available_connectors=resolved_dependencies.get("available_connectors"), - available_document_types=resolved_dependencies.get("available_document_types"), - ) - resume = create_generate_resume_tool( - search_space_id=resolved_dependencies["search_space_id"], - thread_id=resolved_dependencies["thread_id"], - ) - image = create_generate_image_tool( - search_space_id=resolved_dependencies["search_space_id"], - db_session=resolved_dependencies["db_session"], - ) - return { - "allow": [ - {"name": getattr(podcast, "name", "") or "", "tool": podcast}, - {"name": getattr(video, "name", "") or "", "tool": video}, - {"name": getattr(report, "name", "") or "", "tool": report}, - {"name": getattr(resume, "name", "") or "", "tool": resume}, - {"name": getattr(image, "name", "") or "", "tool": image}, - ], - "ask": [], - } +) -> list[BaseTool]: + d = {**(dependencies or {}), **kwargs} + return [ + create_generate_podcast_tool( + search_space_id=d["search_space_id"], + db_session=d["db_session"], + thread_id=d["thread_id"], + ), + create_generate_video_presentation_tool( + search_space_id=d["search_space_id"], + db_session=d["db_session"], + thread_id=d["thread_id"], + ), + create_generate_report_tool( + search_space_id=d["search_space_id"], + thread_id=d["thread_id"], + connector_service=d.get("connector_service"), + available_connectors=d.get("available_connectors"), + available_document_types=d.get("available_document_types"), + ), + create_generate_resume_tool( + search_space_id=d["search_space_id"], + thread_id=d["thread_id"], + ), + create_generate_image_tool( + search_space_id=d["search_space_id"], + db_session=d["db_session"], + ), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py index 555911910..c6a0220ec 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/agent.py @@ -1,4 +1,9 @@ -"""`knowledge_base` route: full and read-only ``SubAgent`` specs.""" +"""``knowledge_base`` route: full and read-only ``SurfSenseSubagentSpec`` builders. + +KB owns its destructive-FS approval ruleset (:data:`KB_RULESET`); rules +are layered into KB's :class:`PermissionMiddleware` (built inside +``build_kb_middleware``). One emitter, one wire format, one source of truth. +""" from __future__ import annotations @@ -6,42 +11,56 @@ from typing import Any, cast from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.permissions import Rule, Ruleset from .middleware_stack import build_kb_middleware from .prompts import load_description, load_readonly_system_prompt, load_system_prompt -from .tools.index import destructive_fs_interrupt_on +from .tools.index import DESTRUCTIVE_FS_OPS NAME = "knowledge_base" READONLY_NAME = "knowledge_base_readonly" +KB_RULESET = Ruleset( + origin=NAME, + rules=[Rule(permission=op, pattern="*", action="ask") for op in DESTRUCTIVE_FS_OPS], +) + +_KB_READONLY_RULESET = Ruleset(origin=READONLY_NAME, rules=[]) + def build_subagent( *, dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, # noqa: ARG001 — KB ships fixed tools -) -> SubAgent: + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + del mcp_tools llm = model if model is not None else dependencies["llm"] filesystem_mode: FilesystemMode = dependencies["filesystem_mode"] - spec: dict[str, Any] = { - "name": NAME, - "description": load_description(), - "system_prompt": load_system_prompt(filesystem_mode), - "model": llm, - "tools": [], - "middleware": build_kb_middleware( - llm=llm, - dependencies=dependencies, - middleware_stack=middleware_stack, - read_only=False, - ), - "interrupt_on": destructive_fs_interrupt_on(), - } - return cast(SubAgent, spec) + spec = cast( + SubAgent, + { + "name": NAME, + "description": load_description(), + "system_prompt": load_system_prompt(filesystem_mode), + "model": llm, + "tools": [], + "middleware": build_kb_middleware( + llm=llm, + dependencies=dependencies, + middleware_stack=middleware_stack, + read_only=False, + subagent_name=NAME, + ruleset=KB_RULESET, + ), + }, + ) + return SurfSenseSubagentSpec(spec=spec, ruleset=KB_RULESET) def build_readonly_subagent( @@ -49,21 +68,25 @@ def build_readonly_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, -) -> SubAgent: +) -> SurfSenseSubagentSpec: llm = model if model is not None else dependencies["llm"] filesystem_mode: FilesystemMode = dependencies["filesystem_mode"] - spec: dict[str, Any] = { - "name": READONLY_NAME, - "description": "Read-only knowledge_base specialist (invoked via ask_knowledge_base).", - "system_prompt": load_readonly_system_prompt(filesystem_mode), - "model": llm, - "tools": [], - "middleware": build_kb_middleware( - llm=llm, - dependencies=dependencies, - middleware_stack=middleware_stack, - read_only=True, - ), - "interrupt_on": {}, - } - return cast(SubAgent, spec) + spec = cast( + SubAgent, + { + "name": READONLY_NAME, + "description": "Read-only knowledge_base specialist (invoked via ask_knowledge_base).", + "system_prompt": load_readonly_system_prompt(filesystem_mode), + "model": llm, + "tools": [], + "middleware": build_kb_middleware( + llm=llm, + dependencies=dependencies, + middleware_stack=middleware_stack, + read_only=True, + subagent_name=READONLY_NAME, + ruleset=None, + ), + }, + ) + return SurfSenseSubagentSpec(spec=spec, ruleset=_KB_READONLY_RULESET) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py index 7b2d54c59..778bb250c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/middleware_stack.py @@ -1,4 +1,8 @@ -"""Middleware list shared by the full and read-only knowledge_base compiles.""" +"""Middleware list shared by the full and read-only knowledge_base compiles. + +The KB-owned :class:`PermissionMiddleware` slot is what enforces +"ask before destructive FS op" for KB tools. +""" from __future__ import annotations @@ -21,7 +25,29 @@ from app.agents.multi_agent_chat.middleware.shared.kb_context_projection import from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import ( build_patch_tool_calls_mw, ) +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.permissions import Ruleset + + +def _kb_user_allowlist( + dependencies: dict[str, Any], subagent_name: str +) -> Ruleset | None: + """Return the user's persisted allow-rules for ``subagent_name`` if any. + + KB does not currently expose an "Always Allow" UI surface (the FE + button is MCP-only today), but the wiring is symmetrical with the + connector subagents so that adding KB trust later is a one-line + backend change. + """ + by_subagent = dependencies.get("user_allowlist_by_subagent") or {} + user_allowlist = by_subagent.get(subagent_name) + if isinstance(user_allowlist, Ruleset) and user_allowlist.rules: + return user_allowlist + return None def build_kb_middleware( @@ -30,9 +56,27 @@ def build_kb_middleware( dependencies: dict[str, Any], middleware_stack: dict[str, Any] | None, read_only: bool, + subagent_name: str, + ruleset: Ruleset | None = None, ) -> list[Any]: + """Compose the KB subagent's middleware list. + + Args: + subagent_name: Identity of the subagent being built (e.g. + ``"knowledge_base"``, ``"knowledge_base_readonly"``). Used to + look up the user's persistent allow-list bucket in + ``dependencies["user_allowlist_by_subagent"]``. + ruleset: The KB-owned permission ruleset (typically the + destructive-FS ``ask`` rules). When provided, a dedicated + :class:`PermissionMiddleware` is appended so KB enforces + approval at the rule layer. The user's persistent allow-list + for ``subagent_name`` is layered after ``ruleset`` so user + ``allow`` rules override coded ``ask`` rules via + last-match-wins. + """ mws = middleware_stack or {} filesystem_mode: FilesystemMode = dependencies["filesystem_mode"] + flags: AgentFeatureFlags | None = dependencies.get("flags") resilience_mws = [ m for m in ( @@ -43,6 +87,17 @@ def build_kb_middleware( ) if m is not None ] + permission_mw = None + if ruleset is not None and flags is not None: + rulesets: list[Ruleset] = [ruleset] + user_allowlist = _kb_user_allowlist(dependencies, subagent_name) + if user_allowlist is not None: + rulesets.append(user_allowlist) + permission_mw = build_permission_mw( + flags=flags, + subagent_rulesets=rulesets, + trusted_tool_saver=dependencies.get("trusted_tool_saver"), + ) return [ mws["todos"], build_kb_context_projection_mw(), @@ -56,6 +111,7 @@ def build_kb_middleware( ), build_compaction_mw(llm), build_patch_tool_calls_mw(), + *([permission_mw] if permission_mw is not None else []), *resilience_mws, build_anthropic_cache_mw(), ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py index 616dfc814..5a83c68a3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/__init__.py @@ -1 +1 @@ -"""Route-local tool policy for the ``knowledge_base`` subagent.""" +"""Route-local tool permissions for the ``knowledge_base`` subagent.""" diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py index 555160a64..55a9a4edf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/knowledge_base/tools/index.py @@ -1,14 +1,9 @@ -"""Route-local FS tool policy. +"""Route-local FS tool permissions. The KB subagent's actual ``BaseTool`` instances are provided at runtime by -``SurfSenseFilesystemMiddleware`` (mounted in ``agent.py``). This module only -carries policy that the subagent spec needs to declare up front — which -destructive ops require explicit user confirmation via ``interrupt_on``. - -Mirrors the ``desktop_safety`` ruleset in -``multi_agent_chat.middleware.shared.permissions.context``: in desktop mode -those rules guard the main-agent FS toolset; in cloud mode the same toolset -lives on the KB subagent and the same policy is enforced here instead. +``SurfSenseFilesystemMiddleware`` (mounted in ``agent.py``). This module +only carries the *names* of destructive ops so the agent can convert them +into permission rules — see :data:`KB_RULESET` in ``agent.py``. """ from __future__ import annotations @@ -22,9 +17,4 @@ DESTRUCTIVE_FS_OPS: tuple[str, ...] = ( ) -def destructive_fs_interrupt_on() -> dict[str, bool]: - """Fresh ``interrupt_on`` dict for the KB subagent spec.""" - return {op: True for op in DESTRUCTIVE_FS_OPS} - - -__all__ = ["DESTRUCTIVE_FS_OPS", "destructive_fs_interrupt_on"] +__all__ = ["DESTRUCTIVE_FS_OPS"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py index 2cd9e70a1..84ab0c2fb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/agent.py @@ -1,27 +1,17 @@ -"""`memory` route: ``SubAgent`` spec for deepagents.""" +"""``memory`` route: ``SurfSenseSubagentSpec`` builder for deepagents.""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "memory" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +19,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles memory tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles memory tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py index 6c65b2cee..b6e06dcdd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/memory/tools/index.py @@ -1,32 +1,37 @@ +"""``memory`` native tools and (empty) permission ruleset.""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from app.db import ChatVisibility from .update_memory import create_update_memory_tool, create_update_team_memory_tool +NAME = "memory" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - resolved_dependencies = {**(dependencies or {}), **kwargs} - if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE: - mem = create_update_team_memory_tool( - search_space_id=resolved_dependencies["search_space_id"], - db_session=resolved_dependencies["db_session"], - llm=resolved_dependencies.get("llm"), +) -> list[BaseTool]: + d = {**(dependencies or {}), **kwargs} + if d.get("thread_visibility") == ChatVisibility.SEARCH_SPACE: + return [ + create_update_team_memory_tool( + search_space_id=d["search_space_id"], + db_session=d["db_session"], + llm=d.get("llm"), + ) + ] + return [ + create_update_memory_tool( + user_id=d["user_id"], + db_session=d["db_session"], + llm=d.get("llm"), ) - return { - "allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], - "ask": [], - } - mem = create_update_memory_tool( - user_id=resolved_dependencies["user_id"], - db_session=resolved_dependencies["db_session"], - llm=resolved_dependencies.get("llm"), - ) - return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []} + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py index d38ab2af3..37026bebd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/agent.py @@ -1,27 +1,17 @@ -"""`research` route: ``SubAgent`` spec for deepagents.""" +"""``research`` route: ``SurfSenseSubagentSpec`` builder for deepagents.""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "research" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +19,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles research tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles research tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py index 3546d4d01..ea544a8da 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/research/tools/index.py @@ -1,35 +1,31 @@ +"""``research`` native tools and (empty) permission ruleset.""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool from .web_search import create_web_search_tool +NAME = "research" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - resolved_dependencies = {**(dependencies or {}), **kwargs} - web = create_web_search_tool( - search_space_id=resolved_dependencies.get("search_space_id"), - available_connectors=resolved_dependencies.get("available_connectors"), - ) - scrape = create_scrape_webpage_tool( - firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key") - ) - docs = create_search_surfsense_docs_tool( - db_session=resolved_dependencies["db_session"] - ) - return { - "allow": [ - {"name": getattr(web, "name", "") or "", "tool": web}, - {"name": getattr(scrape, "name", "") or "", "tool": scrape}, - {"name": getattr(docs, "name", "") or "", "tool": docs}, - ], - "ask": [], - } +) -> list[BaseTool]: + d = {**(dependencies or {}), **kwargs} + return [ + create_web_search_tool( + search_space_id=d.get("search_space_id"), + available_connectors=d.get("available_connectors"), + ), + create_scrape_webpage_tool(firecrawl_api_key=d.get("firecrawl_api_key")), + create_search_surfsense_docs_tool(db_session=d["db_session"]), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py index c186684ab..d7648d407 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/agent.py @@ -1,27 +1,22 @@ -"""`airtable` route: ``SubAgent`` spec for deepagents.""" +"""``airtable`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools come exclusively from MCP. The connector's own approval ruleset is +declared in :data:`tools.index.RULESET`; the orchestrator layers it into +a per-subagent :class:`PermissionMiddleware`. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "airtable" +from .tools.index import NAME, RULESET def build_subagent( @@ -29,26 +24,20 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles airtable tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + description = ( + read_md_file(__package__, "description").strip() + or "Handles airtable tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, - tools=tools, - interrupt_on=interrupt_on, + tools=list(mcp_tools or []), + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py index 08b0e005e..9eebd2395 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/airtable/tools/index.py @@ -1,14 +1,21 @@ +"""``airtable`` permission ruleset (rules over MCP tool names).""" + from __future__ import annotations -from typing import Any +from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, +NAME = "airtable" + +RULESET = Ruleset( + origin=NAME, + rules=[ + Rule(permission="list_bases", pattern="*", action="allow"), + Rule(permission="search_bases", pattern="*", action="allow"), + Rule(permission="list_tables_for_base", pattern="*", action="allow"), + Rule(permission="get_table_schema", pattern="*", action="allow"), + Rule(permission="list_records_for_table", pattern="*", action="allow"), + Rule(permission="search_records", pattern="*", action="allow"), + Rule(permission="create_records_for_table", pattern="*", action="ask"), + Rule(permission="update_records_for_table", pattern="*", action="ask"), + ], ) - - -def load_tools( - *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - _ = {**(dependencies or {}), **kwargs} - return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py index 0f00c68e8..7ef706c3d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/agent.py @@ -1,27 +1,22 @@ -"""`calendar` route: ``SubAgent`` spec for deepagents.""" +"""``calendar`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity with MCP-backed connectors. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "calendar" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles calendar tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles calendar tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py index a8183314a..e5262bd43 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/create_event.py @@ -8,7 +8,9 @@ from googleapiclient.discovery import build from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py index 3d160e669..2f907e746 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/delete_event.py @@ -8,7 +8,9 @@ from googleapiclient.discovery import build from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py index 2538a494b..2570a51b2 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/index.py @@ -1,35 +1,39 @@ +"""``calendar`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies, so the +ruleset just falls through to the SurfSense allow-by-default rules. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_event import create_create_calendar_event_tool from .delete_event import create_delete_calendar_event_tool from .search_events import create_search_calendar_events_tool from .update_event import create_update_calendar_event_tool +NAME = "calendar" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - resolved_dependencies = {**(dependencies or {}), **kwargs} - session_dependencies = { - "db_session": resolved_dependencies["db_session"], - "search_space_id": resolved_dependencies["search_space_id"], - "user_id": resolved_dependencies["user_id"], - } - search = create_search_calendar_events_tool(**session_dependencies) - create = create_create_calendar_event_tool(**session_dependencies) - update = create_update_calendar_event_tool(**session_dependencies) - delete = create_delete_calendar_event_tool(**session_dependencies) - return { - "allow": [{"name": getattr(search, "name", "") or "", "tool": search}], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], +) -> list[BaseTool]: + d = {**(dependencies or {}), **kwargs} + common = { + "db_session": d["db_session"], + "search_space_id": d["search_space_id"], + "user_id": d["user_id"], } + return [ + create_search_calendar_events_tool(**common), + create_create_calendar_event_tool(**common), + create_update_calendar_event_tool(**common), + create_delete_calendar_event_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py index a74979484..e6f9f098e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/calendar/tools/update_event.py @@ -8,7 +8,9 @@ from googleapiclient.discovery import build from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py index fb34aa938..e1308a100 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/agent.py @@ -1,27 +1,22 @@ -"""`clickup` route: ``SubAgent`` spec for deepagents.""" +"""``clickup`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools come exclusively from MCP. The connector's own approval ruleset is +declared in :data:`tools.index.RULESET`; the orchestrator layers it into +a per-subagent :class:`PermissionMiddleware`. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "clickup" +from .tools.index import NAME, RULESET def build_subagent( @@ -29,26 +24,20 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles clickup tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + description = ( + read_md_file(__package__, "description").strip() + or "Handles clickup tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, - tools=tools, - interrupt_on=interrupt_on, + tools=list(mcp_tools or []), + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py index 08b0e005e..b2c523080 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/clickup/tools/index.py @@ -1,14 +1,20 @@ +"""``clickup`` permission ruleset (rules over MCP tool names).""" + from __future__ import annotations -from typing import Any +from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, +NAME = "clickup" + +RULESET = Ruleset( + origin=NAME, + rules=[ + Rule(permission="clickup_search", pattern="*", action="allow"), + Rule(permission="clickup_get_task", pattern="*", action="allow"), + Rule(permission="clickup_get_workspace_hierarchy", pattern="*", action="allow"), + Rule(permission="clickup_get_list", pattern="*", action="allow"), + Rule(permission="clickup_find_member_by_name", pattern="*", action="allow"), + Rule(permission="clickup_create_task", pattern="*", action="ask"), + Rule(permission="clickup_update_task", pattern="*", action="ask"), + ], ) - - -def load_tools( - *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - _ = {**(dependencies or {}), **kwargs} - return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py index 044fd7dc1..5e95c876d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/agent.py @@ -1,27 +1,22 @@ -"""`confluence` route: ``SubAgent`` spec for deepagents.""" +"""``confluence`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "confluence" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles confluence tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles confluence tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py index 095413bdb..f33dc8e23 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/create_page.py @@ -5,7 +5,9 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py index 7c03c2760..7a3a4f2c7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/delete_page.py @@ -5,7 +5,9 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py index 28c4ee6ee..b38503c5c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/index.py @@ -1,34 +1,37 @@ +"""``confluence`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_page import create_create_confluence_page_tool from .delete_page import create_delete_confluence_page_tool from .update_page import create_update_confluence_page_tool +NAME = "confluence" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - resolved_dependencies = {**(dependencies or {}), **kwargs} - session_dependencies = { - "db_session": resolved_dependencies["db_session"], - "search_space_id": resolved_dependencies["search_space_id"], - "user_id": resolved_dependencies["user_id"], - "connector_id": resolved_dependencies.get("connector_id"), - } - create = create_create_confluence_page_tool(**session_dependencies) - update = create_update_confluence_page_tool(**session_dependencies) - delete = create_delete_confluence_page_tool(**session_dependencies) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], +) -> list[BaseTool]: + d = {**(dependencies or {}), **kwargs} + common = { + "db_session": d["db_session"], + "search_space_id": d["search_space_id"], + "user_id": d["user_id"], + "connector_id": d.get("connector_id"), } + return [ + create_create_confluence_page_tool(**common), + create_update_confluence_page_tool(**common), + create_delete_confluence_page_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py index 791d0d8c5..7a8207a00 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/confluence/tools/update_page.py @@ -5,7 +5,9 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.attributes import flag_modified -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.confluence_history import ConfluenceHistoryConnector from app.services.confluence import ConfluenceToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py index d2cb3a9b1..567e72973 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/agent.py @@ -1,27 +1,22 @@ -"""`discord` route: ``SubAgent`` spec for deepagents.""" +"""``discord`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "discord" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles discord tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles discord tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py index c0a3bf3c9..c69ef3e5c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/index.py @@ -1,32 +1,36 @@ +"""``discord`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .list_channels import create_list_discord_channels_tool from .read_messages import create_read_discord_messages_tool from .send_message import create_send_discord_message_tool +NAME = "discord" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - list_ch = create_list_discord_channels_tool(**common) - read_msg = create_read_discord_messages_tool(**common) - send = create_send_discord_message_tool(**common) - return { - "allow": [ - {"name": getattr(list_ch, "name", "") or "", "tool": list_ch}, - {"name": getattr(read_msg, "name", "") or "", "tool": read_msg}, - ], - "ask": [{"name": getattr(send, "name", "") or "", "tool": send}], - } + return [ + create_list_discord_channels_tool(**common), + create_read_discord_messages_tool(**common), + create_send_discord_message_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py index 236cd017a..95890ed10 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/discord/tools/send_message.py @@ -5,7 +5,9 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from ._auth import DISCORD_API, get_bot_token, get_discord_connector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py index b9743c9d6..d3ae6dc83 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/agent.py @@ -1,27 +1,22 @@ -"""`dropbox` route: ``SubAgent`` spec for deepagents.""" +"""``dropbox`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "dropbox" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles dropbox tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles dropbox tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py index 22d8a8a27..2de7c301f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/create_file.py @@ -8,7 +8,9 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.dropbox.client import DropboxClient from app.db import SearchSourceConnector, SearchSourceConnectorType diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py index 5864ae972..68e02866a 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/index.py @@ -1,30 +1,34 @@ +"""``dropbox`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_file import create_create_dropbox_file_tool from .trash_file import create_delete_dropbox_file_tool +NAME = "dropbox" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - create = create_create_dropbox_file_tool(**common) - delete = create_delete_dropbox_file_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + return [ + create_create_dropbox_file_tool(**common), + create_delete_dropbox_file_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py index 12559b57a..7cb652d5d 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/dropbox/tools/trash_file.py @@ -6,7 +6,9 @@ from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.dropbox.client import DropboxClient from app.db import ( Document, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py index bd4bbc929..082400eb9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/agent.py @@ -1,27 +1,22 @@ -"""`gmail` route: ``SubAgent`` spec for deepagents.""" +"""``gmail`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "gmail" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles gmail tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles gmail tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py index 59e471097..fb1461d7c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/create_draft.py @@ -8,7 +8,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py index 09082d091..020089ebb 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/index.py @@ -1,10 +1,15 @@ +"""``gmail`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_draft import create_create_gmail_draft_tool from .read_email import create_read_gmail_email_tool @@ -13,31 +18,25 @@ from .send_email import create_send_gmail_email_tool from .trash_email import create_trash_gmail_email_tool from .update_draft import create_update_gmail_draft_tool +NAME = "gmail" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - search = create_search_gmail_tool(**common) - read = create_read_gmail_email_tool(**common) - draft = create_create_gmail_draft_tool(**common) - send = create_send_gmail_email_tool(**common) - trash = create_trash_gmail_email_tool(**common) - updraft = create_update_gmail_draft_tool(**common) - return { - "allow": [ - {"name": getattr(search, "name", "") or "", "tool": search}, - {"name": getattr(read, "name", "") or "", "tool": read}, - ], - "ask": [ - {"name": getattr(draft, "name", "") or "", "tool": draft}, - {"name": getattr(send, "name", "") or "", "tool": send}, - {"name": getattr(trash, "name", "") or "", "tool": trash}, - {"name": getattr(updraft, "name", "") or "", "tool": updraft}, - ], - } + return [ + create_search_gmail_tool(**common), + create_read_gmail_email_tool(**common), + create_create_gmail_draft_tool(**common), + create_send_gmail_email_tool(**common), + create_trash_gmail_email_tool(**common), + create_update_gmail_draft_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py index d5de24b62..578233b57 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/send_email.py @@ -8,7 +8,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py index b78f88934..b24e9ebe4 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/trash_email.py @@ -6,7 +6,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py index b6688ac53..1ab9d30cf 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/gmail/tools/update_draft.py @@ -8,7 +8,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.services.gmail import GmailToolMetadataService logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py index 31d270b22..fb4a24ddd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/agent.py @@ -1,27 +1,22 @@ -"""`google_drive` route: ``SubAgent`` spec for deepagents.""" +"""``google_drive`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "google_drive" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles google drive tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles google drive tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py index 9e9a30429..70f5eea74 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/create_file.py @@ -5,7 +5,9 @@ from googleapiclient.errors import HttpError from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.google_drive.client import GoogleDriveClient from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET from app.services.google_drive import GoogleDriveToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py index 7dbee87a0..dd05374a1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/index.py @@ -1,30 +1,34 @@ +"""``google_drive`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_file import create_create_google_drive_file_tool from .trash_file import create_delete_google_drive_file_tool +NAME = "google_drive" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - create = create_create_google_drive_file_tool(**common) - delete = create_delete_google_drive_file_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + return [ + create_create_google_drive_file_tool(**common), + create_delete_google_drive_file_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py index f7531cf3d..7fbcd74a3 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/google_drive/tools/trash_file.py @@ -5,7 +5,9 @@ from googleapiclient.errors import HttpError from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.google_drive.client import GoogleDriveClient from app.services.google_drive import GoogleDriveToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py index ae6573e4b..ff71d4cf7 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/agent.py @@ -1,27 +1,22 @@ -"""`jira` route: ``SubAgent`` spec for deepagents.""" +"""``jira`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools come exclusively from MCP. The connector's own approval ruleset is +declared in :data:`tools.index.RULESET`; the orchestrator layers it into +a per-subagent :class:`PermissionMiddleware`. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "jira" +from .tools.index import NAME, RULESET def build_subagent( @@ -29,26 +24,20 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles jira tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + description = ( + read_md_file(__package__, "description").strip() + or "Handles jira tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, - tools=tools, - interrupt_on=interrupt_on, + tools=list(mcp_tools or []), + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py index 08b0e005e..13b2a073c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/jira/tools/index.py @@ -1,14 +1,24 @@ +"""``jira`` permission ruleset (rules over MCP tool names).""" + from __future__ import annotations -from typing import Any +from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, +NAME = "jira" + +RULESET = Ruleset( + origin=NAME, + rules=[ + Rule(permission="getAccessibleAtlassianResources", pattern="*", action="allow"), + Rule(permission="getVisibleJiraProjects", pattern="*", action="allow"), + Rule(permission="searchJiraIssuesUsingJql", pattern="*", action="allow"), + Rule(permission="getJiraIssue", pattern="*", action="allow"), + Rule(permission="getJiraProjectIssueTypesMetadata", pattern="*", action="allow"), + Rule(permission="getJiraIssueTypeMetaWithFields", pattern="*", action="allow"), + Rule(permission="getTransitionsForJiraIssue", pattern="*", action="allow"), + Rule(permission="lookupJiraAccountId", pattern="*", action="allow"), + Rule(permission="createJiraIssue", pattern="*", action="ask"), + Rule(permission="editJiraIssue", pattern="*", action="ask"), + Rule(permission="transitionJiraIssue", pattern="*", action="ask"), + ], ) - - -def load_tools( - *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - _ = {**(dependencies or {}), **kwargs} - return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py index f93d15b3c..d9b282f2b 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/agent.py @@ -1,27 +1,22 @@ -"""`linear` route: ``SubAgent`` spec for deepagents.""" +"""``linear`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools come exclusively from MCP. The connector's own approval ruleset is +declared in :data:`tools.index.RULESET`; the orchestrator layers it into +a per-subagent :class:`PermissionMiddleware`. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "linear" +from .tools.index import NAME, RULESET def build_subagent( @@ -29,26 +24,20 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles linear tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + description = ( + read_md_file(__package__, "description").strip() + or "Handles linear tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, - tools=tools, - interrupt_on=interrupt_on, + tools=list(mcp_tools or []), + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py index 08b0e005e..4a71a31b8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/linear/tools/index.py @@ -1,14 +1,31 @@ +"""``linear`` permission ruleset (rules over MCP tool names).""" + from __future__ import annotations -from typing import Any +from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, +NAME = "linear" + +RULESET = Ruleset( + origin=NAME, + rules=[ + Rule(permission="list_issues", pattern="*", action="allow"), + Rule(permission="get_issue", pattern="*", action="allow"), + Rule(permission="list_my_issues", pattern="*", action="allow"), + Rule(permission="list_issue_statuses", pattern="*", action="allow"), + Rule(permission="list_issue_labels", pattern="*", action="allow"), + Rule(permission="list_comments", pattern="*", action="allow"), + Rule(permission="list_users", pattern="*", action="allow"), + Rule(permission="get_user", pattern="*", action="allow"), + Rule(permission="list_teams", pattern="*", action="allow"), + Rule(permission="get_team", pattern="*", action="allow"), + Rule(permission="list_projects", pattern="*", action="allow"), + Rule(permission="get_project", pattern="*", action="allow"), + Rule(permission="list_project_labels", pattern="*", action="allow"), + Rule(permission="list_cycles", pattern="*", action="allow"), + Rule(permission="list_documents", pattern="*", action="allow"), + Rule(permission="get_document", pattern="*", action="allow"), + Rule(permission="search_documentation", pattern="*", action="allow"), + Rule(permission="save_issue", pattern="*", action="ask"), + ], ) - - -def load_tools( - *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - _ = {**(dependencies or {}), **kwargs} - return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py index afd5787ef..d84efaed8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/agent.py @@ -1,27 +1,22 @@ -"""`luma` route: ``SubAgent`` spec for deepagents.""" +"""``luma`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "luma" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles luma tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles luma tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py index 0a24a988f..e3e1126fd 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/create_event.py @@ -5,7 +5,9 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py index 47b303295..dbde01061 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/luma/tools/index.py @@ -1,32 +1,36 @@ +"""``luma`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_event import create_create_luma_event_tool from .list_events import create_list_luma_events_tool from .read_event import create_read_luma_event_tool +NAME = "luma" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - list_ev = create_list_luma_events_tool(**common) - read_ev = create_read_luma_event_tool(**common) - create = create_create_luma_event_tool(**common) - return { - "allow": [ - {"name": getattr(list_ev, "name", "") or "", "tool": list_ev}, - {"name": getattr(read_ev, "name", "") or "", "tool": read_ev}, - ], - "ask": [{"name": getattr(create, "name", "") or "", "tool": create}], - } + return [ + create_list_luma_events_tool(**common), + create_read_luma_event_tool(**common), + create_create_luma_event_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py index 7910eb450..8de86b2d8 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/agent.py @@ -1,27 +1,22 @@ -"""`notion` route: ``SubAgent`` spec for deepagents.""" +"""``notion`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "notion" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles notion tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles notion tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py index 6efffe960..20862eb56 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/create_page.py @@ -4,7 +4,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py index 07f7583d2..85d0ef22e 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/delete_page.py @@ -4,7 +4,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion.tool_metadata_service import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py index c78f630a1..0475e9dd0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/index.py @@ -1,33 +1,36 @@ +"""``notion`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_page import create_create_notion_page_tool from .delete_page import create_delete_notion_page_tool from .update_page import create_update_notion_page_tool +NAME = "notion" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - create = create_create_notion_page_tool(**common) - update = create_update_notion_page_tool(**common) - delete = create_delete_notion_page_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(update, "name", "") or "", "tool": update}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + return [ + create_create_notion_page_tool(**common), + create_update_notion_page_tool(**common), + create_delete_notion_page_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py index 85c08177c..2b9ce3a6c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/notion/tools/update_page.py @@ -4,7 +4,9 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector from app.services.notion import NotionToolMetadataService diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py index 521c45958..f7634d8ef 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/agent.py @@ -1,27 +1,22 @@ -"""`onedrive` route: ``SubAgent`` spec for deepagents.""" +"""``onedrive`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "onedrive" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles onedrive tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles onedrive tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py index 21272e01d..41fa65787 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/create_file.py @@ -8,7 +8,9 @@ from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.onedrive.client import OneDriveClient from app.db import SearchSourceConnector, SearchSourceConnectorType diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py index 9a2dadd36..e09b43200 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/index.py @@ -1,30 +1,34 @@ +"""``onedrive`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .create_file import create_create_onedrive_file_tool from .trash_file import create_delete_onedrive_file_tool +NAME = "onedrive" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - create = create_create_onedrive_file_tool(**common) - delete = create_delete_onedrive_file_tool(**common) - return { - "allow": [], - "ask": [ - {"name": getattr(create, "name", "") or "", "tool": create}, - {"name": getattr(delete, "name", "") or "", "tool": delete}, - ], - } + return [ + create_create_onedrive_file_tool(**common), + create_delete_onedrive_file_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py index a7f13b5df..1f7c51ac5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/onedrive/tools/trash_file.py @@ -6,7 +6,9 @@ from sqlalchemy import String, and_, cast, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from app.connectors.onedrive.client import OneDriveClient from app.db import ( Document, diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py index 552070961..e16956b25 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/agent.py @@ -1,27 +1,22 @@ -"""`slack` route: ``SubAgent`` spec for deepagents.""" +"""``slack`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools come exclusively from MCP. The connector's own approval ruleset is +declared in :data:`tools.index.RULESET`; the orchestrator layers it into +a per-subagent :class:`PermissionMiddleware`. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "slack" +from .tools.index import NAME, RULESET def build_subagent( @@ -29,26 +24,20 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles slack tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + description = ( + read_md_file(__package__, "description").strip() + or "Handles slack tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, - tools=tools, - interrupt_on=interrupt_on, + tools=list(mcp_tools or []), + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py index 08b0e005e..44b96661c 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/slack/tools/index.py @@ -1,14 +1,19 @@ +"""``slack`` permission ruleset (rules over MCP tool names).""" + from __future__ import annotations -from typing import Any +from app.agents.new_chat.permissions import Rule, Ruleset -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, +NAME = "slack" + +RULESET = Ruleset( + origin=NAME, + rules=[ + Rule(permission="slack_search_channels", pattern="*", action="allow"), + Rule(permission="slack_search_messages", pattern="*", action="allow"), + Rule(permission="slack_search_users", pattern="*", action="allow"), + Rule(permission="slack_read_channel", pattern="*", action="allow"), + Rule(permission="slack_read_thread", pattern="*", action="allow"), + Rule(permission="slack_send_message", pattern="*", action="ask"), + ], ) - - -def load_tools( - *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: - _ = {**(dependencies or {}), **kwargs} - return {"allow": [], "ask": []} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py index 0f7f7e2bc..ab808b745 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/agent.py @@ -1,27 +1,22 @@ -"""`teams` route: ``SubAgent`` spec for deepagents.""" +"""``teams`` route: ``SurfSenseSubagentSpec`` builder for deepagents. + +Tools self-gate inside their bodies via :func:`request_approval`; the +empty :data:`tools.index.RULESET` is layered into a per-subagent +:class:`PermissionMiddleware` for uniformity. +""" from __future__ import annotations from typing import Any -from deepagents import SubAgent from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool -from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( - read_md_file, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, - merge_tools_permissions, - middleware_gated_interrupt_on, -) -from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( - pack_subagent, -) +from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import pack_subagent -from .tools.index import load_tools - -NAME = "teams" +from .tools.index import NAME, RULESET, load_tools def build_subagent( @@ -29,26 +24,21 @@ def build_subagent( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, -) -> SubAgent: - buckets = load_tools(dependencies=dependencies) - merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket) - tools = [ - row["tool"] - for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"]) - if row.get("tool") is not None - ] - interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket) - description = read_md_file(__package__, "description").strip() - if not description: - description = "Handles teams tasks for this workspace." + mcp_tools: list[BaseTool] | None = None, +) -> SurfSenseSubagentSpec: + tools = [*load_tools(dependencies=dependencies), *(mcp_tools or [])] + description = ( + read_md_file(__package__, "description").strip() + or "Handles teams tasks for this workspace." + ) system_prompt = read_md_file(__package__, "system_prompt").strip() return pack_subagent( name=NAME, description=description, system_prompt=system_prompt, tools=tools, - interrupt_on=interrupt_on, + ruleset=RULESET, + dependencies=dependencies, model=model, middleware_stack=middleware_stack, ) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py index cbe76b040..41661651f 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/index.py @@ -1,32 +1,36 @@ +"""``teams`` native tools and (empty) permission ruleset. + +Tools self-gate via :func:`request_approval` in their bodies. +""" + from __future__ import annotations from typing import Any -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from langchain_core.tools import BaseTool + +from app.agents.new_chat.permissions import Ruleset from .list_channels import create_list_teams_channels_tool from .read_messages import create_read_teams_messages_tool from .send_message import create_send_teams_message_tool +NAME = "teams" + +RULESET = Ruleset(origin=NAME, rules=[]) + def load_tools( *, dependencies: dict[str, Any] | None = None, **kwargs: Any -) -> ToolsPermissions: +) -> list[BaseTool]: d = {**(dependencies or {}), **kwargs} common = { "db_session": d["db_session"], "search_space_id": d["search_space_id"], "user_id": d["user_id"], } - list_ch = create_list_teams_channels_tool(**common) - read_msg = create_read_teams_messages_tool(**common) - send = create_send_teams_message_tool(**common) - return { - "allow": [ - {"name": getattr(list_ch, "name", "") or "", "tool": list_ch}, - {"name": getattr(read_msg, "name", "") or "", "tool": read_msg}, - ], - "ask": [{"name": getattr(send, "name", "") or "", "tool": send}], - } + return [ + create_list_teams_channels_tool(**common), + create_read_teams_messages_tool(**common), + create_send_teams_message_tool(**common), + ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py index fd8d00870..f1469e3e1 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/connectors/teams/tools/send_message.py @@ -5,7 +5,9 @@ import httpx from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.tools.hitl import request_approval +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) from ._auth import GRAPH_API, get_access_token, get_teams_connector diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/__init__.py index c8714cd04..5d0707610 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/__init__.py @@ -1,11 +1,7 @@ -"""Load MCP tools, partition by connector agent, apply allow/ask name rules.""" +"""Load MCP tools and partition them by connector agent.""" from __future__ import annotations -from app.agents.multi_agent_chat.subagents.mcp_tools.permissions import ( - TOOLS_PERMISSIONS_BY_AGENT, -) - from .index import ( fetch_mcp_connector_metadata_maps, load_mcp_tools_by_connector, @@ -13,7 +9,6 @@ from .index import ( ) __all__ = [ - "TOOLS_PERMISSIONS_BY_AGENT", "fetch_mcp_connector_metadata_maps", "load_mcp_tools_by_connector", "partition_mcp_tools_by_connector", diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py index 79ab3db10..16dc09ac5 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/index.py @@ -1,4 +1,10 @@ -"""Discover MCP tools, bucket by connector agent, apply allow/ask from policy.""" +"""Discover MCP tools and bucket them by connector agent. + +Tool gating is no longer the loader's concern: each subagent declares its +own :class:`Ruleset` and the per-subagent :class:`PermissionMiddleware` +enforces it at runtime. This module just routes flat ``BaseTool`` lists +to the right subagents. +""" from __future__ import annotations @@ -15,23 +21,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.multi_agent_chat.constants import ( CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS, ) -from app.agents.multi_agent_chat.subagents.mcp_tools.permissions import ( - TOOLS_PERMISSIONS_BY_AGENT, -) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolPermissionItem, - ToolsPermissions, - mcp_tool_permission_row, -) from app.agents.new_chat.tools.mcp_tool import load_mcp_tools from app.db import SearchSourceConnector logger = logging.getLogger(__name__) -## Helper functions for fetching connector metadata maps - - async def fetch_mcp_connector_metadata_maps( session: AsyncSession, search_space_id: int, @@ -57,9 +52,6 @@ async def fetch_mcp_connector_metadata_maps( return id_to_type, name_to_type -## Helper functions for partitioning tools by connector agent - - def partition_mcp_tools_by_connector( tools: Sequence[BaseTool], connector_id_to_type: dict[int, str], @@ -107,59 +99,15 @@ def partition_mcp_tools_by_connector( return dict(buckets) -## Helper functions for splitting tools by permissions - - -def _get_mcp_tool_name(tool: BaseTool) -> str: - meta: dict[str, Any] = getattr(tool, "metadata", None) or {} - orig = meta.get("mcp_original_tool_name") - if isinstance(orig, str) and orig: - return orig - return getattr(tool, "name", "") or "" - - -def _split_tools_by_permissions( - tools: Sequence[BaseTool], - perms: ToolsPermissions, -) -> ToolsPermissions: - allow_names = frozenset(r["name"] for r in perms["allow"]) - ask_names = frozenset(r["name"] for r in perms["ask"]) - allow: list[ToolPermissionItem] = [] - ask: list[ToolPermissionItem] = [] - for t in tools: - meta: dict[str, Any] = getattr(t, "metadata", None) or {} - if meta.get("hitl") is False: - allow.append(mcp_tool_permission_row(t)) - continue - key = _get_mcp_tool_name(t) - if key in allow_names: - allow.append(mcp_tool_permission_row(t)) - elif key in ask_names: - ask.append(mcp_tool_permission_row(t)) - else: - ask.append(mcp_tool_permission_row(t)) - return {"allow": allow, "ask": ask} - - -## Main function to load MCP tools and split them by permissions for each connector agent - - async def load_mcp_tools_by_connector( session: AsyncSession, search_space_id: int, -) -> dict[str, ToolsPermissions]: - """Load MCP tools and split rows using ``TOOLS_PERMISSIONS_BY_AGENT`` name sets. +) -> dict[str, list[BaseTool]]: + """Load MCP tools and route them to each subagent as a flat list. - Pass ``bypass_internal_hitl=True`` so the subagent's - ``HumanInTheLoopMiddleware`` is the single HITL gate. + ``bypass_internal_hitl=True`` is set so tool gating is uniformly the + consuming subagent's :class:`PermissionMiddleware` responsibility. """ flat = await load_mcp_tools(session, search_space_id, bypass_internal_hitl=True) id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id) - buckets = partition_mcp_tools_by_connector(flat, id_map, name_map) - return { - agent: _split_tools_by_permissions( - tools, - TOOLS_PERMISSIONS_BY_AGENT.get(agent, {"allow": [], "ask": []}), - ) - for agent, tools in buckets.items() - } + return partition_mcp_tools_by_connector(flat, id_map, name_map) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/__init__.py deleted file mode 100644 index f24dedcf2..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Bundled MCP allow/ask name rows per connector agent (MCP-backed routes only).""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) - -from .airtable import TOOLS_PERMISSIONS as _AIRTABLE -from .clickup import TOOLS_PERMISSIONS as _CLICKUP -from .jira import TOOLS_PERMISSIONS as _JIRA -from .linear import TOOLS_PERMISSIONS as _LINEAR -from .slack import TOOLS_PERMISSIONS as _SLACK - -TOOLS_PERMISSIONS_BY_AGENT: dict[str, ToolsPermissions] = { - "airtable": _AIRTABLE, - "clickup": _CLICKUP, - "jira": _JIRA, - "linear": _LINEAR, - "slack": _SLACK, -} - -__all__ = ["TOOLS_PERMISSIONS_BY_AGENT"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/airtable.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/airtable.py deleted file mode 100644 index 35028f1bc..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/airtable.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Airtable MCP: which server tool names are allow vs ask.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) - -TOOLS_PERMISSIONS: ToolsPermissions = { - "allow": [ - {"name": "list_bases"}, - {"name": "search_bases"}, - {"name": "list_tables_for_base"}, - {"name": "get_table_schema"}, - {"name": "list_records_for_table"}, - {"name": "search_records"}, - ], - "ask": [ - {"name": "create_records_for_table"}, - {"name": "update_records_for_table"}, - ], -} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/clickup.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/clickup.py deleted file mode 100644 index fb9e26661..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/clickup.py +++ /dev/null @@ -1,21 +0,0 @@ -"""ClickUp MCP: which server tool names are allow vs ask.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) - -TOOLS_PERMISSIONS: ToolsPermissions = { - "allow": [ - {"name": "clickup_search"}, - {"name": "clickup_get_task"}, - {"name": "clickup_get_workspace_hierarchy"}, - {"name": "clickup_get_list"}, - {"name": "clickup_find_member_by_name"}, - ], - "ask": [ - {"name": "clickup_create_task"}, - {"name": "clickup_update_task"}, - ], -} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/index.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/index.py deleted file mode 100644 index 10781c9d9..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/index.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Re-exports permission row types for MCP policy modules.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolPermissionItem, - ToolsPermissions, -) - -__all__ = ["ToolPermissionItem", "ToolsPermissions"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/jira.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/jira.py deleted file mode 100644 index 5cbd72888..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/jira.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Jira MCP: which server tool names are allow vs ask.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) - -TOOLS_PERMISSIONS: ToolsPermissions = { - "allow": [ - {"name": "getAccessibleAtlassianResources"}, - {"name": "getVisibleJiraProjects"}, - {"name": "searchJiraIssuesUsingJql"}, - {"name": "getJiraIssue"}, - {"name": "getJiraProjectIssueTypesMetadata"}, - {"name": "getJiraIssueTypeMetaWithFields"}, - {"name": "getTransitionsForJiraIssue"}, - {"name": "lookupJiraAccountId"}, - ], - "ask": [ - {"name": "createJiraIssue"}, - {"name": "editJiraIssue"}, - {"name": "transitionJiraIssue"}, - ], -} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/linear.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/linear.py deleted file mode 100644 index 18fd827dc..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/linear.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Linear MCP: which server tool names are allow vs ask.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) - -_TOOLS_ALLOW = ( - "list_issues", - "get_issue", - "list_my_issues", - "list_issue_statuses", - "list_issue_labels", - "list_comments", - "list_users", - "get_user", - "list_teams", - "get_team", - "list_projects", - "get_project", - "list_project_labels", - "list_cycles", - "list_documents", - "get_document", - "search_documentation", -) - -TOOLS_PERMISSIONS: ToolsPermissions = { - "allow": [{"name": n} for n in _TOOLS_ALLOW], - "ask": [{"name": "save_issue"}], -} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/slack.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/slack.py deleted file mode 100644 index 3b7847567..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/mcp_tools/permissions/slack.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Slack MCP: which server tool names are allow vs ask.""" - -from __future__ import annotations - -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) - -TOOLS_PERMISSIONS: ToolsPermissions = { - "allow": [ - {"name": "slack_search_channels"}, - {"name": "slack_search_messages"}, - {"name": "slack_search_users"}, - {"name": "slack_read_channel"}, - {"name": "slack_read_thread"}, - ], - "ask": [ - {"name": "slack_send_message"}, - ], -} diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py index e3f4ca83b..27c147672 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/registry.py @@ -71,9 +71,7 @@ from app.agents.multi_agent_chat.subagents.connectors.teams.agent import ( from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( read_md_file, ) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolsPermissions, -) +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec class SubagentBuilder(Protocol): @@ -83,8 +81,8 @@ class SubagentBuilder(Protocol): dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - extra_tools_bucket: ToolsPermissions | None = None, - ) -> SubAgent: ... + mcp_tools: list[BaseTool] | None = None, + ) -> SurfSenseSubagentSpec: ... SUBAGENT_BUILDERS_BY_NAME: dict[str, SubagentBuilder] = { @@ -154,7 +152,7 @@ def _filter_disabled_tools_in_place( spec: SubAgent, disabled_names: frozenset[str], ) -> None: - """Drop UI-disabled tools from ``spec["tools"]`` and ``spec["interrupt_on"]``.""" + """Drop UI-disabled tools from ``spec["tools"]``.""" if not disabled_names: return tools = spec.get("tools") # type: ignore[typeddict-item] @@ -162,11 +160,6 @@ def _filter_disabled_tools_in_place( spec["tools"] = [ # type: ignore[typeddict-unknown-key] t for t in tools if getattr(t, "name", None) not in disabled_names ] - interrupt_on = spec.get("interrupt_on") # type: ignore[typeddict-item] - if isinstance(interrupt_on, dict): - spec["interrupt_on"] = { # type: ignore[typeddict-unknown-key] - k: v for k, v in interrupt_on.items() if k not in disabled_names - } def _inject_ask_kb_tool_in_place(spec: SubAgent, ask_kb_tool: BaseTool) -> None: @@ -187,7 +180,7 @@ def build_subagents( dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None, + mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None, exclude: list[str] | None = None, disabled_tools: list[str] | None = None, ask_kb_tool: BaseTool | None = None, @@ -203,12 +196,13 @@ def build_subagents( if name in excluded: continue builder = SUBAGENT_BUILDERS_BY_NAME[name] - spec = builder( + result = builder( dependencies=dependencies, model=model, middleware_stack=middleware_stack, - extra_tools_bucket=mcp.get(name), + mcp_tools=mcp.get(name), ) + spec = result.spec _filter_disabled_tools_in_place(spec, disabled_names) if ask_kb_tool is not None: _inject_ask_kb_tool_in_place(spec, ask_kb_tool) diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py index 12443da88..70d3dfe39 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/__init__.py @@ -5,21 +5,13 @@ from __future__ import annotations from app.agents.multi_agent_chat.subagents.shared.md_file_reader import ( read_md_file, ) -from app.agents.multi_agent_chat.subagents.shared.permissions import ( - ToolPermissionItem, - ToolsPermissions, - merge_tools_permissions, - tool_permission_row, -) +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( pack_subagent, ) __all__ = [ - "ToolPermissionItem", - "ToolsPermissions", - "merge_tools_permissions", + "SurfSenseSubagentSpec", "pack_subagent", "read_md_file", - "tool_permission_row", ] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py new file mode 100644 index 000000000..038ec5652 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py @@ -0,0 +1,17 @@ +"""Self-gated approval primitive — tools that pause from inside their own body. + +Public surface: +- :func:`request_approval` — entry point for sensitive tool bodies. +- :class:`HITLResult` — outcome contract. +- ``DEFAULT_AUTO_APPROVED_TOOLS`` — safe-by-construction allowlist. +""" + +from .auto_approved import DEFAULT_AUTO_APPROVED_TOOLS +from .request import request_approval +from .result import HITLResult + +__all__ = [ + "DEFAULT_AUTO_APPROVED_TOOLS", + "HITLResult", + "request_approval", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py new file mode 100644 index 000000000..b99b26f3a --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/auto_approved.py @@ -0,0 +1,35 @@ +"""Default safe-by-construction allowlist for self-gated approvals. + +Tools listed here mirror the safety profile of ``write_file`` against the +SurfSense KB: each call creates exactly one artifact in the user's own +workspace with no external visibility (drafts aren't sent; new files aren't +shared unless the user shares them later). Auto-approving them lets the agent +seed scratch artifacts without firing a popup on every call. + +Members still flow through :func:`request_approval` — the function returns +immediately with ``decision_type="auto_approved"`` and the original params +untouched. This keeps tool bodies (logging, metadata fetches, account +fallbacks) symmetrical with the prompted path; the only behavior change is +"no interrupt fires". + +Per-search-space ``agent_permission_rules`` (when wired) take precedence and +can re-enable prompting for any of these. +""" + +from __future__ import annotations + +DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( + { + "create_gmail_draft", + "update_gmail_draft", + "create_calendar_event", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } +) + + +__all__ = ["DEFAULT_AUTO_APPROVED_TOOLS"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py new file mode 100644 index 000000000..8729ea85b --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/request.py @@ -0,0 +1,119 @@ +"""Self-gated approval entry point — pause from inside a tool body. + +Sensitive connector tools (Gmail send, Notion delete, Linear issue create…) +call :func:`request_approval` to ask the user before performing the side +effect. The function emits the unified langchain HITL wire payload (so the +parallel-HITL routing layer in ``task_tool`` and ``resume_routing`` sees the +same shape it sees for middleware-gated approvals) and returns a typed +:class:`HITLResult`. + +Synchronous on purpose: ``langgraph.types.interrupt`` raises ``GraphInterrupt`` +inline; the langgraph runtime catches it. Making this ``async`` would only +move the throw site without changing semantics. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from langgraph.types import interrupt + +from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( + LC_DECISION_APPROVE, + LC_DECISION_EDIT, + LC_DECISION_REJECT, + build_lc_hitl_payload, + parse_lc_envelope, +) + +from .auto_approved import DEFAULT_AUTO_APPROVED_TOOLS +from .result import HITLResult + +logger = logging.getLogger(__name__) + +# Decisions a self-gated card may carry back. ``"always"`` is reserved for +# permission-rule promotion (middleware-gated path) and intentionally absent +# here. +_SELF_GATED_DECISIONS: list[str] = [ + LC_DECISION_APPROVE, + LC_DECISION_REJECT, + LC_DECISION_EDIT, +] + + +def request_approval( + *, + action_type: str, + tool_name: str, + params: dict[str, Any], + context: dict[str, Any] | None = None, + trusted_tools: list[str] | None = None, +) -> HITLResult: + """Pause the graph for user approval and return the user's decision. + + Args: + action_type: FE card discriminator (``"gmail_email_send"``, + ``"mcp_tool_call"``…). Forwarded as ``interrupt_type`` on the + wire so the FE can mount the right card variant. + tool_name: Registered langchain tool name (``"send_gmail_email"``…) + shown in the card header and used for trust-list lookups. + params: Original tool arguments. Rendered to the user and used as + defaults when no edits are made. + context: Rich metadata (account info, folder lists, MCP server name…) + forwarded verbatim to the FE for richer card chrome. + trusted_tools: Per-session allowlist; when ``tool_name`` is in it the + interrupt is skipped and the tool runs immediately. + + Returns: + :class:`HITLResult` with ``rejected=True`` if the user declined or + the resume envelope was unparseable; otherwise ``params`` carries + the original args (or args shallow-merged with the user's edits on + ``"edit"``). + """ + if trusted_tools and tool_name in trusted_tools: + logger.info("Tool %r is user-trusted — skipping HITL", tool_name) + return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) + + if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: + logger.info( + "Tool %r is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", tool_name + ) + return HITLResult( + rejected=False, decision_type="auto_approved", params=dict(params) + ) + + payload = build_lc_hitl_payload( + tool_name=tool_name, + args=params, + allowed_decisions=_SELF_GATED_DECISIONS, + interrupt_type=action_type, + context=context, + ) + approval = interrupt(payload) + + parsed = parse_lc_envelope(approval) + logger.info("User decision for %r: %s", tool_name, parsed.decision_type) + + if parsed.decision_type == LC_DECISION_REJECT: + return HITLResult(rejected=True, decision_type="reject", params=dict(params)) + + # Anything outside approve/edit at this point is unexpected — fail closed + # so a malformed FE envelope can't smuggle a side effect through. + if parsed.decision_type not in (LC_DECISION_APPROVE, LC_DECISION_EDIT): + logger.warning( + "Unrecognized decision %r for %r — rejecting for safety", + parsed.decision_type, + tool_name, + ) + return HITLResult(rejected=True, decision_type="error", params=dict(params)) + + final_params = ( + {**params, **parsed.edited_args} if parsed.edited_args else dict(params) + ) + return HITLResult( + rejected=False, decision_type=parsed.decision_type, params=final_params + ) + + +__all__ = ["request_approval"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py new file mode 100644 index 000000000..645e6d47e --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/result.py @@ -0,0 +1,34 @@ +"""Outcome contract returned by :func:`request_approval`. + +Lives in its own file so callers that only need the type for annotations don't +drag in ``langgraph`` imports through the entry-point module. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True, slots=True) +class HITLResult: + """Outcome of a self-gated human-in-the-loop approval request. + + Attributes: + rejected: ``True`` when the tool MUST NOT execute (user said no, or + the wire envelope was unparseable). Always check this first. + decision_type: Reason tag for logging / metrics — + ``"approve" | "edit" | "reject" | "trusted" | "auto_approved" + | "error"``. Callers shouldn't branch on this for control flow; + use ``rejected`` for that. + params: Final parameters to pass to the underlying tool. On + ``"edit"`` this is the original ``params`` shallow-merged with + the user's edits; otherwise it's a copy of the originals. + """ + + rejected: bool + decision_type: str + params: dict[str, Any] = field(default_factory=dict) + + +__all__ = ["HITLResult"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py new file mode 100644 index 000000000..2d35ac056 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py @@ -0,0 +1,26 @@ +"""Single source of truth for the langchain HITL wire format used by every approval path. + +Public surface: +- :func:`build_lc_hitl_payload` — outbound (interrupt argument). +- :func:`parse_lc_envelope` + :class:`ParsedLcDecision` — inbound (resume value). +- Decision-type constants for callers that care about identity rather than literals. +""" + +from .decision import ParsedLcDecision, parse_lc_envelope +from .payload import ( + LC_DECISION_APPROVE, + LC_DECISION_EDIT, + LC_DECISION_REJECT, + SURFSENSE_DECISION_APPROVE_ALWAYS, + build_lc_hitl_payload, +) + +__all__ = [ + "LC_DECISION_APPROVE", + "LC_DECISION_EDIT", + "LC_DECISION_REJECT", + "SURFSENSE_DECISION_APPROVE_ALWAYS", + "ParsedLcDecision", + "build_lc_hitl_payload", + "parse_lc_envelope", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/decision.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/decision.py new file mode 100644 index 000000000..43fd0382c --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/decision.py @@ -0,0 +1,121 @@ +"""Parse the langchain HITL resume envelope into a typed decision. + +Both self-gated approvals (``request_approval``) and middleware-gated paths +(``PermissionMiddleware``) receive the user's reply through langgraph's +``Command(resume=...)`` channel as ``{"decisions": [{"type": ..., ...}]}``. +This module owns the decoding so the wire-shape knowledge lives in exactly +one place; callers project the parsed values into their own domain decisions +(``HITLResult`` for self-gated, ``decision_type`` for permissions) without +re-implementing the envelope walk. + +Failing closed: any unrecognized envelope shape collapses to +``decision_type="reject"`` (with a warning) so callers never proceed on +ambiguous input. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class ParsedLcDecision: + """Decoded resume reply with the fields callers actually need. + + Attributes: + decision_type: Lower-cased decision identifier — ``"approve"``, + ``"reject"``, ``"edit"``, ``"approve_always"``, or any custom value + the FE may emit. Callers map this to their own domain semantics. + edited_args: Populated only on ``"edit"`` replies that actually carry + args; ``None`` otherwise so callers can use truthiness directly. + message: Free-form user feedback (typically attached to ``"reject"``). + ``None`` when absent or when the value isn't a non-empty string. + """ + + decision_type: str + edited_args: dict[str, Any] | None = None + message: str | None = None + + +def parse_lc_envelope(envelope: Any) -> ParsedLcDecision: + """Extract a typed decision from a langgraph resume envelope. + + Accepts: + + - ``{"decisions": [{"type": "approve" | "reject" | "edit", ...}]}`` — the + langchain HITL standard envelope. + - A bare scalar string (``"once"``, ``"approve_always"``, ``"reject"``) — + used by the legacy SurfSense permission wire. We tolerate it so the + parser can sit behind both call sites without a second adapter. + + Edit args are read from the standard ``edited_action.args`` first, then + fall back to a flat ``args`` field for legacy compatibility — both shapes + are produced by the FE depending on which card variant was rendered. + + Args: + envelope: The raw resume value as it arrived from langgraph. + + Returns: + A :class:`ParsedLcDecision` describing the user's intent. + """ + if isinstance(envelope, str): + return ParsedLcDecision(decision_type=envelope.lower()) + + if not isinstance(envelope, dict): + logger.warning( + "Resume envelope is not a dict (got %s); treating as reject", + type(envelope).__name__, + ) + return ParsedLcDecision(decision_type="reject") + + payload: dict[str, Any] = envelope + decisions = envelope.get("decisions") + if isinstance(decisions, list) and decisions: + first = decisions[0] + if isinstance(first, dict): + payload = first + + raw_type = payload.get("type") or payload.get("decision_type") + if not raw_type: + logger.warning( + "Resume payload missing decision type (keys=%s); treating as reject", + list(payload.keys()), + ) + return ParsedLcDecision(decision_type="reject") + + decision_type = str(raw_type).lower() + edited_args = _extract_edited_args(payload) if decision_type == "edit" else None + message = _extract_message(payload) + return ParsedLcDecision( + decision_type=decision_type, + edited_args=edited_args, + message=message, + ) + + +def _extract_edited_args(payload: dict[str, Any]) -> dict[str, Any] | None: + """Pull non-empty edited args from either the LC nested or flat shape.""" + edited_action = payload.get("edited_action") + if isinstance(edited_action, dict): + nested = edited_action.get("args") + if isinstance(nested, dict) and nested: + return nested + flat = payload.get("args") + if isinstance(flat, dict) and flat: + return flat + return None + + +def _extract_message(payload: dict[str, Any]) -> str | None: + """Pull a non-empty user-feedback string, accepting either field name.""" + raw = payload.get("feedback") or payload.get("message") + if isinstance(raw, str) and raw.strip(): + return raw + return None + + +__all__ = ["ParsedLcDecision", "parse_lc_envelope"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/payload.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/payload.py new file mode 100644 index 000000000..bac4a6677 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/hitl/wire/payload.py @@ -0,0 +1,86 @@ +"""Build the langchain HITL ``interrupt(...)`` payload — single source of truth. + +Every approval path in the multi-agent stack — self-gated tool bodies that call +``request_approval``, and middleware-gated paths (``HumanInTheLoopMiddleware``, +``PermissionMiddleware``) — emits the SAME wire shape from this module so the +parallel-HITL routing layer (``task_tool``, ``resume_routing``) only ever sees +one format. SurfSense-specific extras (FE card discriminator, structured +context) ride alongside the langchain standard fields without colliding with +them. +""" + +from __future__ import annotations + +from typing import Any + +LC_DECISION_APPROVE = "approve" +LC_DECISION_REJECT = "reject" +LC_DECISION_EDIT = "edit" + +# ``approve_always`` is a SurfSense extension surfaced by ``PermissionMiddleware`` +# so a single click can promote the matched pattern to a runtime allow rule and +# (for MCP tools) save it to the user's trusted-tools list. The FE renders an +# extra button when it appears in ``allowed_decisions``. +SURFSENSE_DECISION_APPROVE_ALWAYS = "approve_always" + + +def build_lc_hitl_payload( + *, + tool_name: str, + args: dict[str, Any], + allowed_decisions: list[str], + interrupt_type: str, + description: str | None = None, + context: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build the unified langchain HITL interrupt payload. + + Args: + tool_name: The langchain tool's registered name (drives both the action + request and the review config so the FE can pair them). + args: Tool call arguments shown to the user. ``None`` is normalized to + an empty dict so the FE always has a stable shape to render. + allowed_decisions: Subset of + ``[LC_DECISION_APPROVE, LC_DECISION_REJECT, LC_DECISION_EDIT, + SURFSENSE_DECISION_APPROVE_ALWAYS]``. Other values are passed through + but the FE may not render a control for them. + interrupt_type: SurfSense card discriminator (``"gmail_email_send"``, + ``"permission_ask"``, etc.); the FE keys off this to mount the + right card. + description: Optional human-readable line shown above the args block. + context: Optional structured metadata (account info, matched permission + rules, etc.) forwarded verbatim for richer card chrome. + + Returns: + A dict suitable for ``langgraph.types.interrupt(...)``. Top-level + ``action_requests`` and ``review_configs`` are what + ``collect_pending_tool_calls`` reads at the routing layer; the + SurfSense extensions (``interrupt_type``, ``context``) sit alongside + them — langchain ignores unknown keys, so the contract stays clean. + """ + request: dict[str, Any] = {"name": tool_name, "args": args or {}} + if description: + request["description"] = description + + payload: dict[str, Any] = { + "action_requests": [request], + "review_configs": [ + { + "action_name": tool_name, + "allowed_decisions": list(allowed_decisions), + } + ], + "interrupt_type": interrupt_type, + } + if context: + payload["context"] = context + return payload + + +__all__ = [ + "LC_DECISION_APPROVE", + "LC_DECISION_EDIT", + "LC_DECISION_REJECT", + "SURFSENSE_DECISION_APPROVE_ALWAYS", + "build_lc_hitl_payload", +] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/permissions.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/permissions.py deleted file mode 100644 index 649478485..000000000 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/permissions.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Typed tool-permission rows: allow vs ask (``name`` + optional ``tool``).""" - -from __future__ import annotations - -from typing import Literal, NotRequired, TypedDict - -from langchain_core.tools import BaseTool - -# ``native`` rows self-gate via ``request_approval`` in the tool body; -# ``mcp`` rows are gated by ``HumanInTheLoopMiddleware`` via ``interrupt_on``. -ToolKind = Literal["native", "mcp"] - - -class ToolPermissionItem(TypedDict): - """``name`` is always set; ``tool`` is present when a bound tool exists; ``kind`` defaults to ``native`` when absent.""" - - name: str - tool: NotRequired[BaseTool] - kind: NotRequired[ToolKind] - - -class ToolsPermissions(TypedDict): - """Same shape for native factories and MCP name-only policy rows.""" - - allow: list[ToolPermissionItem] - ask: list[ToolPermissionItem] - - -def tool_permission_row(tool: BaseTool) -> ToolPermissionItem: - """Build one allow/ask row for a loaded tool.""" - return {"name": getattr(tool, "name", "") or "", "tool": tool} - - -def mcp_tool_permission_row(tool: BaseTool) -> ToolPermissionItem: - """Build one allow/ask row tagged ``kind="mcp"`` so it routes through ``HumanInTheLoopMiddleware``.""" - return {"name": getattr(tool, "name", "") or "", "tool": tool, "kind": "mcp"} - - -def merge_tools_permissions( - base: ToolsPermissions, - extra: ToolsPermissions | None, -) -> ToolsPermissions: - """Concatenate allow/ask lists (e.g. native factory + MCP bucket) before building HITL maps.""" - if not extra: - return base - return { - "allow": [*base["allow"], *extra["allow"]], - "ask": [*base["ask"], *extra["ask"]], - } - - -def middleware_gated_interrupt_on( - bucket: ToolsPermissions, -) -> dict[str, bool]: - """``interrupt_on`` for ``ask`` rows whose bodies don't self-gate via ``request_approval``.""" - return { - r["name"]: True - for r in bucket["ask"] - if r.get("name") and r.get("kind") == "mcp" - } diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py new file mode 100644 index 000000000..797ab535b --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/spec.py @@ -0,0 +1,29 @@ +"""SurfSense's subagent contribution: deepagents spec + permission ruleset.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from deepagents import SubAgent + +from app.agents.new_chat.permissions import Ruleset + + +@dataclass(frozen=True, slots=True) +class SurfSenseSubagentSpec: + """A subagent contribution from a SurfSense route. + + Attributes: + spec: The deepagents-shaped dict handed to ``create_agent``. Holds + only fields ``deepagents.SubAgent`` recognises. + ruleset: Permission rules this subagent contributes. The orchestrator + layers them into the subagent's :class:`PermissionMiddleware`, + so each subagent owns its own ruleset without aliasing the + shared rule engine. + """ + + spec: SubAgent + ruleset: Ruleset + + +__all__ = ["SurfSenseSubagentSpec"] diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py index a4a1f84d4..7173901f9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/shared/subagent_builder.py @@ -9,7 +9,28 @@ from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec +from app.agents.new_chat.permissions import Ruleset + + +def _user_allowlist_for( + dependencies: dict[str, Any], subagent_name: str +) -> Ruleset | None: + """Return the user's persisted allow-rules for ``subagent_name`` if any. + + Populated by the agent factory from + :func:`app.services.user_tool_allowlist.fetch_user_allowlist_rulesets`. + Returning ``None`` is the common case (fresh accounts, non-MCP + subagents, or no "Always Allow" interactions yet). + """ + by_subagent = dependencies.get("user_allowlist_by_subagent") or {} + user_allowlist = by_subagent.get(subagent_name) + if isinstance(user_allowlist, Ruleset) and user_allowlist.rules: + return user_allowlist + return None def pack_subagent( @@ -18,27 +39,58 @@ def pack_subagent( description: str, system_prompt: str, tools: list[BaseTool], + ruleset: Ruleset, + dependencies: dict[str, Any], model: BaseChatModel | None = None, middleware_stack: dict[str, Any] | None = None, - interrupt_on: dict[str, bool] | None = None, -) -> SubAgent: - """Pack the route-local pieces passed in into one sub-agent spec. +) -> SurfSenseSubagentSpec: + """Pack the route-local pieces into one sub-agent spec + its Ruleset. - ``middleware_stack`` is the shared subagent middleware stack (see - ``build_subagent_middleware_stack``). Every non-``None`` value is - prepended to this subagent's middleware list in insertion order. + Tool gating is uniformly performed by a per-subagent + :class:`PermissionMiddleware`. Three rule layers are evaluated + earliest-to-latest (last match wins): + + 1. SurfSense defaults — single ``allow */*`` rule (added by + :func:`build_permission_mw`). + 2. ``ruleset`` — the subagent's coded approval rules (e.g. KB's + destructive-FS ``ask`` rules, connector ``ask`` writes). + 3. The user's persisted allow-list for this subagent — pulled from + ``dependencies['user_allowlist_by_subagent'][name]``. User + ``allow`` rules layered last override coded ``ask`` rules, + implementing the "Always Allow" UX without re-asking on the + next turn. + + The shared ``permission`` slot from ``middleware_stack`` is dropped + so each subagent owns its own rule surface and cannot accidentally + share state with the main agent's permission middleware. """ if not system_prompt.strip(): msg = f"Subagent {name!r}: system_prompt is empty" raise ValueError(msg) - prepended = [m for m in (middleware_stack or {}).values() if m is not None] - middleware: list[Any] = [ - *prepended, - PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=tools), - ] - spec: dict[str, Any] = { + flags = dependencies["flags"] + user_allowlist = _user_allowlist_for(dependencies, name) + subagent_rulesets: list[Ruleset] = [ruleset] + if user_allowlist is not None: + subagent_rulesets.append(user_allowlist) + per_subagent_perm = build_permission_mw( + flags=flags, + subagent_rulesets=subagent_rulesets, + tools=tools, + trusted_tool_saver=dependencies.get("trusted_tool_saver"), + ) + + prepended: list[Any] = [] + for slot, mw in (middleware_stack or {}).items(): + if mw is None: + continue + if slot == "permission": + continue + prepended.append(mw) + if per_subagent_perm is not None: + prepended.append(per_subagent_perm) + middleware: list[Any] = [*prepended, PatchToolCallsMiddleware()] + spec_dict: dict[str, Any] = { "name": name, "description": description, "system_prompt": system_prompt, @@ -46,7 +98,5 @@ def pack_subagent( "middleware": middleware, } if model is not None: - spec["model"] = model - if interrupt_on: - spec["interrupt_on"] = interrupt_on - return cast(SubAgent, spec) + spec_dict["model"] = model + return SurfSenseSubagentSpec(spec=cast(SubAgent, spec_dict), ruleset=ruleset) diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py index 5ea7f1740..f77b7e387 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/permission.py +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -23,7 +23,7 @@ Operation: SurfSense shape and LangChain HITL ``{"decisions": [{"type": ...}]}`` replies are accepted via :func:`_normalize_permission_decision`. - ``once``: proceed. - - ``always``: also persist allow rules for ``request.always`` patterns. + - ``approve_always``: also persist allow rules for ``request.always`` patterns. - ``reject`` w/o feedback: raise :class:`RejectedError`. - ``reject`` w/ feedback: raise :class:`CorrectedError`. 5. On ``allow``: proceed unchanged. @@ -90,6 +90,7 @@ _LC_TYPE_TO_PERMISSION_DECISION: dict[str, str] = { "approve": "once", "reject": "reject", "edit": "once", + "approve_always": "approve_always", } @@ -130,7 +131,7 @@ def _normalize_permission_decision(decision: Any) -> dict[str, Any]: mapped = _LC_TYPE_TO_PERMISSION_DECISION.get(raw_type) if mapped is None: # Tolerate legacy values arriving without ``decision_type`` wrapping. - if raw_type in {"once", "always", "reject"}: + if raw_type in {"once", "approve_always", "reject"}: mapped = raw_type else: logger.warning( @@ -162,8 +163,8 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] of patterns to evaluate. When a tool isn't listed, the bare tool name is used as the only pattern. runtime_ruleset: Mutable :class:`Ruleset` that the middleware - extends in-place when the user replies ``"always"`` to an - ask interrupt. Reused across all calls in the same agent + extends in-place when the user replies ``"approve_always"`` to + an ask interrupt. Reused across all calls in the same agent instance so newly-allowed rules apply to subsequent calls. always_emit_interrupt_payload: If True, every ask uses the SurfSense interrupt wire format (default). Set False to @@ -268,7 +269,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] for r in rules ], # Rules of thumb for the frontend: surface the patterns - # the user can promote to "always" with a single reply. + # the user can promote to "approve_always" with a single reply. "always": patterns, }, } @@ -287,12 +288,12 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] return _normalize_permission_decision(decision) def _persist_always(self, tool_name: str, patterns: list[str]) -> None: - """Promote ``always`` reply into runtime allow rules. + """Promote ``approve_always`` reply into runtime allow rules. Persistence to ``agent_permission_rules`` is done by the streaming layer (``stream_new_chat``) once it observes the - ``always`` reply — the middleware just keeps an in-memory - copy so subsequent calls in the same stream see the rule. + ``approve_always`` reply — the middleware just keeps an + in-memory copy so subsequent calls in the same stream see the rule. """ for pattern in patterns: self._runtime_ruleset.rules.append( @@ -377,7 +378,7 @@ class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] kind = str(decision.get("decision_type") or "reject").lower() if kind == "once": kept_calls.append(call) - elif kind == "always": + elif kind == "approve_always": self._persist_always(name, patterns) kept_calls.append(call) elif kind == "reject": diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 92a808a5e..64368a878 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -229,6 +229,7 @@ async def _create_mcp_tool_from_definition_stdio( "mcp_input_schema": input_schema, "mcp_transport": "stdio", "mcp_connector_name": connector_name or None, + "mcp_connector_id": connector_id, "mcp_is_generic": True, "hitl": True, # Full-args hash: shared identifiers (cloudId, workspaceId, …) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 9037d275a..e9ffb7050 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -3071,6 +3071,37 @@ class MCPTrustToolRequest(BaseModel): tool_name: str +async def _ensure_mcp_connector_for_user( + session: AsyncSession, *, user_id, connector_id: int +) -> int: + """Verify ``connector_id`` is an MCP-backed connector owned by ``user_id``. + + The trust-list feature is intentionally MCP-only; native connectors + (Gmail, Calendar, Notion, ...) do not have a "trust this tool" UI. + The JSONB ``has_key("server_config")`` filter is the same MCP marker + used elsewhere in this module. + + Returns the connector's ``search_space_id`` (needed downstream for + MCP tool cache invalidation). Raises ``HTTPException(404)`` when the + connector does not exist, is not owned by the user, or is not + MCP-backed. + """ + from sqlalchemy import cast + from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB + + result = await session.execute( + select(SearchSourceConnector.search_space_id).where( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user_id, + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), + ) + ) + search_space_id = result.scalar_one_or_none() + if search_space_id is None: + raise HTTPException(status_code=404, detail="MCP connector not found") + return search_space_id + + @router.post("/connectors/mcp/{connector_id}/trust-tool") async def trust_mcp_tool( connector_id: int, @@ -3080,45 +3111,32 @@ async def trust_mcp_tool( ): """Add a tool to the MCP connector's trusted (always-allow) list. - Once trusted, the tool executes without HITL approval on subsequent calls. - Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors - (LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``. + Once trusted, the tool executes without HITL approval on subsequent + calls. Works for both generic ``MCP_CONNECTOR`` and OAuth-backed MCP + connectors (``LINEAR_CONNECTOR``, ``JIRA_CONNECTOR``, ...) — the + storage primitive is the same JSON list under ``config.trusted_tools``. """ + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.services.user_tool_allowlist import add_user_trust + try: - from sqlalchemy import cast - from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB - - result = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == connector_id, - SearchSourceConnector.user_id == user.id, - cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), - ) + search_space_id = await _ensure_mcp_connector_for_user( + session, user_id=user.id, connector_id=connector_id + ) + trusted = await add_user_trust( + session, + user_id=user.id, + connector_id=connector_id, + tool_name=body.tool_name, ) - connector = result.scalars().first() - if not connector: - raise HTTPException(status_code=404, detail="MCP connector not found") - - config = dict(connector.config or {}) - trusted: list[str] = list(config.get("trusted_tools", [])) - if body.tool_name not in trusted: - trusted.append(body.tool_name) - config["trusted_tools"] = trusted - connector.config = config - - from sqlalchemy.orm.attributes import flag_modified - - flag_modified(connector, "config") await session.commit() - - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache - - invalidate_mcp_tools_cache(connector.search_space_id) - + invalidate_mcp_tools_cache(search_space_id) return {"status": "ok", "trusted_tools": trusted} except HTTPException: raise + except LookupError as e: + raise HTTPException(status_code=404, detail="MCP connector not found") from e except Exception as e: logger.error(f"Failed to trust MCP tool: {e!s}", exc_info=True) await session.rollback() @@ -3137,43 +3155,28 @@ async def untrust_mcp_tool( """Remove a tool from the MCP connector's trusted list. The tool will require HITL approval again on subsequent calls. - Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors. """ + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + from app.services.user_tool_allowlist import remove_user_trust + try: - from sqlalchemy import cast - from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB - - result = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.id == connector_id, - SearchSourceConnector.user_id == user.id, - cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), - ) + search_space_id = await _ensure_mcp_connector_for_user( + session, user_id=user.id, connector_id=connector_id + ) + trusted = await remove_user_trust( + session, + user_id=user.id, + connector_id=connector_id, + tool_name=body.tool_name, ) - connector = result.scalars().first() - if not connector: - raise HTTPException(status_code=404, detail="MCP connector not found") - - config = dict(connector.config or {}) - trusted: list[str] = list(config.get("trusted_tools", [])) - if body.tool_name in trusted: - trusted.remove(body.tool_name) - config["trusted_tools"] = trusted - connector.config = config - - from sqlalchemy.orm.attributes import flag_modified - - flag_modified(connector, "config") await session.commit() - - from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache - - invalidate_mcp_tools_cache(connector.search_space_id) - + invalidate_mcp_tools_cache(search_space_id) return {"status": "ok", "trusted_tools": trusted} except HTTPException: raise + except LookupError as e: + raise HTTPException(status_code=404, detail="MCP connector not found") from e except Exception as e: logger.error(f"Failed to untrust MCP tool: {e!s}", exc_info=True) await session.rollback() diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index c809d6235..c5315cce5 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -395,7 +395,7 @@ class AgentToolInfo(BaseModel): class ResumeDecision(BaseModel): - type: Literal["approve", "edit", "reject"] + type: Literal["approve", "edit", "reject", "approve_always"] edited_action: dict[str, Any] | None = None diff --git a/surfsense_backend/app/services/user_tool_allowlist.py b/surfsense_backend/app/services/user_tool_allowlist.py new file mode 100644 index 000000000..fb21a7df2 --- /dev/null +++ b/surfsense_backend/app/services/user_tool_allowlist.py @@ -0,0 +1,188 @@ +"""User-scoped trusted-tools list backed by ``SearchSourceConnector.config``. + +Storage is per ``(user_id, search_space_id, connector_id)`` under +``connector.config['trusted_tools']``. The list only ever encodes +``allow`` decisions; coded ``deny`` rules cannot be overridden here. +""" + +from __future__ import annotations + +import logging +import uuid +from collections import defaultdict +from collections.abc import Awaitable, Callable + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified + +from app.agents.multi_agent_chat.constants import ( + CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS, +) +from app.agents.new_chat.permissions import Rule, Ruleset +from app.db import SearchSourceConnector, async_session_maker + +logger = logging.getLogger(__name__) + +_TRUSTED_TOOLS_KEY = "trusted_tools" + +TrustedToolSaver = Callable[[int, str], Awaitable[None]] + + +async def _load_owned_connector( + session: AsyncSession, + *, + user_id: uuid.UUID, + connector_id: int, +) -> SearchSourceConnector | None: + """Return the connector iff owned by ``user_id``, else ``None``.""" + result = await session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user_id, + ) + ) + return result.scalars().first() + + +def _read_trusted(connector: SearchSourceConnector) -> list[str]: + config = connector.config or {} + raw = config.get(_TRUSTED_TOOLS_KEY, []) + if not isinstance(raw, list): + return [] + return [str(item) for item in raw if isinstance(item, str)] + + +def _write_trusted(connector: SearchSourceConnector, trusted: list[str]) -> None: + config = dict(connector.config or {}) + config[_TRUSTED_TOOLS_KEY] = trusted + connector.config = config + flag_modified(connector, "config") + + +async def add_user_trust( + session: AsyncSession, + *, + user_id: uuid.UUID, + connector_id: int, + tool_name: str, +) -> list[str]: + """Append ``tool_name`` to the connector's trusted list; raise ``LookupError`` if not owned.""" + connector = await _load_owned_connector( + session, user_id=user_id, connector_id=connector_id + ) + if connector is None: + raise LookupError( + f"connector {connector_id} not found for user {user_id}" + ) + + trusted = _read_trusted(connector) + if tool_name not in trusted: + trusted.append(tool_name) + _write_trusted(connector, trusted) + await session.flush() + return trusted + + +async def remove_user_trust( + session: AsyncSession, + *, + user_id: uuid.UUID, + connector_id: int, + tool_name: str, +) -> list[str]: + """Remove ``tool_name`` from the connector's trusted list; raise ``LookupError`` if not owned.""" + connector = await _load_owned_connector( + session, user_id=user_id, connector_id=connector_id + ) + if connector is None: + raise LookupError( + f"connector {connector_id} not found for user {user_id}" + ) + + trusted = _read_trusted(connector) + if tool_name in trusted: + trusted = [t for t in trusted if t != tool_name] + _write_trusted(connector, trusted) + await session.flush() + return trusted + + +async def fetch_user_allowlist_rulesets( + session: AsyncSession, + *, + user_id: uuid.UUID, + search_space_id: int, +) -> dict[str, Ruleset]: + """Project the user's trusted tools into per-subagent ``allow`` rulesets. + + Subagents with no trusted tools are absent from the result — + callers must treat ``missing == empty``. + """ + result = await session.execute( + select( + SearchSourceConnector.id, + SearchSourceConnector.connector_type, + SearchSourceConnector.config, + ).where( + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.search_space_id == search_space_id, + ) + ) + + rules_by_subagent: dict[str, list[Rule]] = defaultdict(list) + for _connector_id, connector_type, config in result.all(): + subagent = CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS.get(str(connector_type)) + if subagent is None: + continue + + cfg = config or {} + raw = cfg.get(_TRUSTED_TOOLS_KEY, []) + if not isinstance(raw, list): + continue + + for tool in raw: + if not isinstance(tool, str) or not tool: + continue + rules_by_subagent[subagent].append( + Rule(permission=tool, pattern="*", action="allow") + ) + + return { + subagent: Ruleset(rules=rules, origin=f"user_allowlist:{subagent}") + for subagent, rules in rules_by_subagent.items() + } + + +def make_trusted_tool_saver(user_id: uuid.UUID) -> TrustedToolSaver: + """Bind ``user_id`` into a saver closure; failures are logged, never raised.""" + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + try: + async with async_session_maker() as session: + await add_user_trust( + session, + user_id=user_id, + connector_id=connector_id, + tool_name=tool_name, + ) + await session.commit() + except LookupError as exc: + logger.warning("trusted-tool save skipped: %s", exc) + except Exception: + logger.exception( + "trusted-tool save failed for connector=%s tool=%s", + connector_id, + tool_name, + ) + + return trusted_tool_saver + + +__all__ = [ + "TrustedToolSaver", + "add_user_trust", + "fetch_user_allowlist_rulesets", + "make_trusted_tool_saver", + "remove_user_trust", +] diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 818282996..2219ad022 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -76,6 +76,9 @@ from app.services.chat_session_state_service import ( from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.tasks.chat.streaming.graph_stream.event_stream import stream_output +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + all_interrupt_values, +) from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap from app.utils.user_message_multimodal import build_human_message_content @@ -89,6 +92,21 @@ TURN_CANCELLING_BACKOFF_FACTOR = 2 TURN_CANCELLING_MAX_DELAY_MS = 1500 +def _resume_step_prefix(turn_id: str) -> str: + """Build the per-turn ``step_prefix`` for a resume invocation. + + Each ``_stream_agent_events`` call constructs a fresh + :class:`AgentEventRelayState` with ``thinking_step_counter=0``, so two + consecutive resume turns would otherwise both emit ``thinking-resume-1``, + ``-2`` etc. The frontend rehydrates ``currentThinkingSteps`` from the + immediate prior assistant message at the start of every resume — if the + new stream's IDs collide with the seeded ones, React renders sibling + Timeline rows with the same key. Salting with ``turn_id`` guarantees + disjoint IDs across resumes within one thread. + """ + return f"thinking-resume-{turn_id}" + + def _compute_turn_cancelling_retry_delay(attempt: int) -> int: if attempt < 1: attempt = 1 @@ -98,47 +116,6 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int: return min(delay, TURN_CANCELLING_MAX_DELAY_MS) -def _first_interrupt_value(state: Any) -> dict[str, Any] | None: - """Return the first LangGraph interrupt payload across all snapshot tasks.""" - - def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None: - if isinstance(candidate, dict): - value = candidate.get("value", candidate) - return value if isinstance(value, dict) else None - value = getattr(candidate, "value", None) - if isinstance(value, dict): - return value - if isinstance(candidate, (list, tuple)): - for item in candidate: - extracted = _extract_interrupt_value(item) - if extracted is not None: - return extracted - return None - - for task in getattr(state, "tasks", ()) or (): - try: - interrupts = getattr(task, "interrupts", ()) or () - except (AttributeError, IndexError, TypeError): - interrupts = () - if not interrupts: - extracted = _extract_interrupt_value(task) - if extracted is not None: - return extracted - continue - for interrupt_item in interrupts: - extracted = _extract_interrupt_value(interrupt_item) - if extracted is not None: - return extracted - try: - state_interrupts = getattr(state, "interrupts", ()) or () - except (AttributeError, IndexError, TypeError): - state_interrupts = () - extracted = _extract_interrupt_value(state_interrupts) - if extracted is not None: - return extracted - return None - - def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. @@ -301,7 +278,6 @@ def extract_todos_from_deepagents(command_output) -> dict: class StreamResult: accumulated_text: str = "" is_interrupted: bool = False - interrupt_value: dict[str, Any] | None = None sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False request_id: str | None = None @@ -915,11 +891,15 @@ async def _stream_agent_events( result.accumulated_text = accumulated_text _log_file_contract("turn_outcome", result) - interrupt_value = _first_interrupt_value(state) - if interrupt_value is not None: + pending_values = all_interrupt_values(state) + if pending_values: result.is_interrupted = True - result.interrupt_value = interrupt_value - yield streaming_service.format_interrupt_request(result.interrupt_value) + # One frame per paused subagent so each parallel HITL renders its own + # approval card on the wire. Order matches ``state.interrupts``, which + # the resume slicer in ``checkpointed_subagent_middleware.resume_routing`` + # consumes in the same order — keeping emit and resume in lock-step. + for interrupt_value in pending_values: + yield streaming_service.format_interrupt_request(interrupt_value) async def stream_new_chat( @@ -2863,14 +2843,40 @@ async def stream_resume_chat( from langgraph.types import Command + from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, + ) + + # Each pending interrupt is stamped with its originating ``tool_call_id`` + # (see ``checkpointed_subagent_middleware.propagation``) so we can route + # a flat ``decisions`` list back to the right paused subagent. + parent_state = await agent.aget_state( + {"configurable": {"thread_id": str(chat_id)}} + ) + pending = collect_pending_tool_calls(parent_state) + _perf_log.info( + "[hitl_route] resume_entry chat_id=%s decisions=%d pending_subagents=%d", + chat_id, + len(decisions), + len(pending), + ) + routed_resume_value = slice_decisions_by_tool_call(decisions, pending) + # Langgraph rejects scalar ``Command(resume=...)`` when multiple + # interrupts are pending (parallel HITL); the mapped form works + # for the single-pause case too, so we always use it. + lg_resume_map = build_lg_resume_map(parent_state, routed_resume_value) + config = { "configurable": { "thread_id": str(chat_id), "request_id": request_id or "unknown", "turn_id": stream_result.turn_id, - # Side-channel consumed by ``SurfSenseCheckpointedSubAgentMiddleware`` - # to forward the resume into a subagent's pending ``interrupt()``. - "surfsense_resume_value": {"decisions": decisions}, + # Per-``tool_call_id`` resume slices read by + # ``SurfSenseCheckpointedSubAgentMiddleware``. Parallel + # siblings each pop their own entry, so they never race. + "surfsense_resume_value": routed_resume_value, }, # See ``stream_new_chat`` above for rationale: effectively # uncapped to mirror the agent default and OpenCode's @@ -2952,10 +2958,10 @@ async def stream_resume_chat( async for sse in _stream_agent_events( agent=agent, config=config, - input_data=Command(resume={"decisions": decisions}), + input_data=Command(resume=lg_resume_map), streaming_service=streaming_service, result=stream_result, - step_prefix="thinking-resume", + step_prefix=_resume_step_prefix(stream_result.turn_id), fallback_commit_search_space_id=search_space_id, fallback_commit_created_by_id=user_id, fallback_commit_filesystem_mode=( diff --git a/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py index 40404e9d0..391f14f24 100644 --- a/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py +++ b/surfsense_backend/app/tasks/chat/streaming/graph_stream/result.py @@ -10,7 +10,6 @@ from typing import Any class StreamingResult: accumulated_text: str = "" is_interrupted: bool = False - interrupt_value: dict[str, Any] | None = None sandbox_files: list[str] = field(default_factory=list) agent_called_update_memory: bool = False request_id: str | None = None diff --git a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py index dca099b3f..f4b00431c 100644 --- a/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py +++ b/surfsense_backend/app/tasks/chat/streaming/helpers/interrupt_inspector.py @@ -1,12 +1,30 @@ -"""Read the first interrupt payload from a LangGraph state snapshot.""" +"""Read every pending interrupt payload from a LangGraph state snapshot. + +The chat-stream emit loop yields one ``data-interrupt-request`` SSE frame per +pending interrupt so parallel HITL across siblings stays addressable on the +wire (the resume slicer in ``checkpointed_subagent_middleware.resume_routing`` +correlates each frame back to the right paused subagent via the stamped +``tool_call_id``). This helper produces that flat, ordered list. +""" from __future__ import annotations from typing import Any -def first_interrupt_value(state: Any) -> dict[str, Any] | None: - """Return the first interrupt payload across all snapshot tasks.""" +def all_interrupt_values(state: Any) -> list[dict[str, Any]]: + """Return every interrupt payload across the snapshot, in traversal order. + + Walks ``state.tasks[*].interrupts`` first (langgraph's per-task buckets, + which carry one interrupt per paused subagent) and falls back to + ``state.interrupts`` when the per-task lists are empty. Order matches the + snapshot's iteration order so the emit-time order on the SSE stream agrees + with ``collect_pending_tool_calls`` consumption order on resume. + + Defensive against malformed snapshots: tasks/interrupts that raise on + attribute access are skipped silently. Non-dict values are skipped — the + chat-stream contract requires structured interrupt payloads. + """ def _extract(candidate: Any) -> dict[str, Any] | None: if isinstance(candidate, dict): @@ -15,33 +33,32 @@ def first_interrupt_value(state: Any) -> dict[str, Any] | None: value = getattr(candidate, "value", None) if isinstance(value, dict): return value - if isinstance(candidate, list | tuple): - for item in candidate: - extracted = _extract(item) - if extracted is not None: - return extracted return None + values: list[dict[str, Any]] = [] + saw_task_interrupt = False + for task in getattr(state, "tasks", ()) or (): try: interrupts = getattr(task, "interrupts", ()) or () except (AttributeError, IndexError, TypeError): interrupts = () - if not interrupts: - extracted = _extract(task) - if extracted is not None: - return extracted - continue - for interrupt_item in interrupts: - extracted = _extract(interrupt_item) - if extracted is not None: - return extracted + if interrupts: + saw_task_interrupt = True + for interrupt_item in interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + values.append(extracted) + + if saw_task_interrupt: + return values try: state_interrupts = getattr(state, "interrupts", ()) or () except (AttributeError, IndexError, TypeError): state_interrupts = () - extracted = _extract(state_interrupts) - if extracted is not None: - return extracted - return None + for interrupt_item in state_interrupts: + extracted = _extract(interrupt_item) + if extracted is not None: + values.append(extracted) + return values diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py index dbc2c9c00..48eabbd7c 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py @@ -3,15 +3,24 @@ from __future__ import annotations import ast +import asyncio +from types import SimpleNamespace import pytest from langchain.tools import ToolRuntime -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import END, START, StateGraph from langgraph.types import Command, interrupt from typing_extensions import TypedDict +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( + subagent_invoke_config, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) @@ -24,8 +33,6 @@ class _SubagentState(TypedDict, total=False): def _build_single_interrupt_subagent(): def approve_node(state): - from langchain_core.messages import AIMessage - decision = interrupt( { "action_requests": [ @@ -50,17 +57,27 @@ def _build_single_interrupt_subagent(): return graph.compile(checkpointer=InMemorySaver()) -def _make_runtime(config: dict) -> ToolRuntime: +def _make_runtime(config: dict, *, tool_call_id: str = "parent-tcid-1") -> ToolRuntime: return ToolRuntime( state={"messages": [HumanMessage(content="seed")]}, context=None, config=config, stream_writer=None, - tool_call_id="parent-tcid-1", + tool_call_id=tool_call_id, store=None, ) +def _prime_subagent_at_runtime_thread(subagent, runtime: ToolRuntime) -> dict: + """Build the per-call ``RunnableConfig`` the production ``task`` tool will use. + + Mirrors what the ``task`` tool does on first invocation so test fixtures + can prime the subagent's pending interrupt at the same checkpoint slot + (per-call ``thread_id``) the bridge looks at on resume. + """ + return subagent_invoke_config(runtime) + + @pytest.mark.asyncio async def test_resume_bridge_dispatches_decision_into_pending_subagent(): """Side-channel decision must reach the subagent's pending interrupt verbatim.""" @@ -79,16 +96,17 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent(): "configurable": {"thread_id": "shared-thread"}, "recursion_limit": 100, } - await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) - snap = await subagent.aget_state(parent_config) + runtime = _make_runtime(parent_config) + sub_config = _prime_subagent_at_runtime_thread(subagent, runtime) + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config) + snap = await subagent.aget_state(sub_config) assert snap.tasks and snap.tasks[0].interrupts, ( "fixture broken: subagent should be paused on its interrupt" ) parent_config["configurable"]["surfsense_resume_value"] = { - "decisions": ["APPROVED"] + runtime.tool_call_id: {"decisions": ["APPROVED"]} } - runtime = _make_runtime(parent_config) result = await task_tool.coroutine( description="please approve", @@ -101,7 +119,7 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent(): assert update["decision_text"] == repr({"decisions": ["APPROVED"]}) assert "surfsense_resume_value" not in parent_config["configurable"] - final = await subagent.aget_state(parent_config) + final = await subagent.aget_state(sub_config) assert not final.tasks or all(not t.interrupts for t in final.tasks) @@ -123,11 +141,11 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error(): "configurable": {"thread_id": "guard-thread"}, "recursion_limit": 100, } - await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) - snap = await subagent.aget_state(parent_config) - assert snap.tasks and snap.tasks[0].interrupts, "fixture broken" - runtime = _make_runtime(parent_config) + sub_config = _prime_subagent_at_runtime_thread(subagent, runtime) + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config) + snap = await subagent.aget_state(sub_config) + assert snap.tasks and snap.tasks[0].interrupts, "fixture broken" with pytest.raises(RuntimeError, match="resume bridge is broken"): await task_tool.coroutine( @@ -139,8 +157,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error(): def _build_bundle_subagent(): def bundle_node(state): - from langchain_core.messages import AIMessage - decision = interrupt( { "action_requests": [ @@ -181,7 +197,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order(): "configurable": {"thread_id": "bundle-thread"}, "recursion_limit": 100, } - await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + runtime = _make_runtime(parent_config) + sub_config = _prime_subagent_at_runtime_thread(subagent, runtime) + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config) decisions_payload = { "decisions": [ @@ -190,8 +208,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order(): {"type": "reject", "args": {"message": "no thanks"}}, ] } - parent_config["configurable"]["surfsense_resume_value"] = decisions_payload - runtime = _make_runtime(parent_config) + parent_config["configurable"]["surfsense_resume_value"] = { + runtime.tool_call_id: decisions_payload + } result = await task_tool.coroutine( description="run bundle", @@ -206,3 +225,186 @@ async def test_bundle_three_mixed_decisions_arrive_in_order(): assert received["decisions"][1]["type"] == "edit" assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}} assert received["decisions"][2]["type"] == "reject" + + +@pytest.mark.asyncio +async def test_parallel_atask_routes_each_decision_to_its_own_subagent(): + """Two ``atask`` calls with distinct ``tool_call_id``s must each get their own decision. + + With per-call ``thread_id`` isolation and per-call resume keying, A's + decision must reach A's pending interrupt and B's must reach B's. They + must NOT cross-contaminate even though they share ``configurable``. + """ + subagent_a = _build_single_interrupt_subagent() + subagent_b = _build_single_interrupt_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "approver_a", + "description": "approves A", + "runnable": subagent_a, + }, + { + "name": "approver_b", + "description": "approves B", + "runnable": subagent_b, + }, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "parallel-thread"}, + "recursion_limit": 100, + } + + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B") + + sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a) + sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b) + + await subagent_a.ainvoke( + {"messages": [HumanMessage(content="seed-A")]}, sub_config_a + ) + await subagent_b.ainvoke( + {"messages": [HumanMessage(content="seed-B")]}, sub_config_b + ) + + parent_config["configurable"]["surfsense_resume_value"] = { + "tcid-A": {"decisions": ["DECISION-A"]}, + "tcid-B": {"decisions": ["DECISION-B"]}, + } + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="please approve A", + subagent_type="approver_a", + runtime=runtime_a, + ), + task_tool.coroutine( + description="please approve B", + subagent_type="approver_b", + runtime=runtime_b, + ), + ) + + assert isinstance(result_a, Command) + assert isinstance(result_b, Command) + assert result_a.update["decision_text"] == repr({"decisions": ["DECISION-A"]}) + assert result_b.update["decision_text"] == repr({"decisions": ["DECISION-B"]}) + + assert "surfsense_resume_value" not in parent_config["configurable"] + + +@pytest.mark.asyncio +async def test_full_resume_routing_glue_for_two_paused_subagents(): + """End-to-end: extractor + slicer + bridge correctly route a flat decisions list. + + This simulates exactly what ``stream_resume_chat`` will do on resume: + given a paused parent state with two pending interrupts (one per + subagent) and a flat ``decisions`` list, build the per-tool-call dict + via ``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call``, + then resume the bridge concurrently and verify each subagent received + only its own slice. + """ + subagent_a = _build_bundle_subagent() + subagent_b = _build_single_interrupt_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "bundler", + "description": "three-action bundle", + "runnable": subagent_a, + }, + { + "name": "approver", + "description": "single approval", + "runnable": subagent_b, + }, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "glue-thread"}, + "recursion_limit": 100, + } + + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-bundler") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-approver") + + sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a) + sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b) + + await subagent_a.ainvoke( + {"messages": [HumanMessage(content="seed-A")]}, sub_config_a + ) + await subagent_b.ainvoke( + {"messages": [HumanMessage(content="seed-B")]}, sub_config_b + ) + + # Synthetic parent state mirroring what the parent's pregel would have + # bundled: one Interrupt per subagent, value carrying tool_call_id + + # action_requests (exactly the shape ``propagation.wrap_with_tool_call_id`` + # produces). + parent_interrupts = ( + SimpleNamespace( + id="i-bundler", + value={ + "action_requests": [ + {"name": "create_a", "args": {}, "description": ""}, + {"name": "create_b", "args": {}, "description": ""}, + {"name": "create_c", "args": {}, "description": ""}, + ], + "review_configs": [{}, {}, {}], + "tool_call_id": "tcid-bundler", + }, + ), + SimpleNamespace( + id="i-approver", + value={ + "action_requests": [ + {"name": "approve", "args": {}, "description": ""} + ], + "review_configs": [{}], + "tool_call_id": "tcid-approver", + }, + ), + ) + parent_state = SimpleNamespace(interrupts=parent_interrupts) + + flat_decisions = [ + {"type": "approve"}, + {"type": "edit", "args": {"args": {"name": "edited-b"}}}, + {"type": "reject", "args": {"message": "no thanks"}}, + {"type": "approve"}, + ] + + pending = collect_pending_tool_calls(parent_state) + assert pending == [("tcid-bundler", 3), ("tcid-approver", 1)] + + routed = slice_decisions_by_tool_call(flat_decisions, pending) + parent_config["configurable"]["surfsense_resume_value"] = routed + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="run bundle", + subagent_type="bundler", + runtime=runtime_a, + ), + task_tool.coroutine( + description="please approve", + subagent_type="approver", + runtime=runtime_b, + ), + ) + + assert isinstance(result_a, Command) + assert isinstance(result_b, Command) + + received_a = ast.literal_eval(result_a.update["decision_text"]) + assert received_a == {"decisions": flat_decisions[0:3]} + assert result_b.update["decision_text"] == repr( + {"decisions": flat_decisions[3:4]} + ) + + assert "surfsense_resume_value" not in parent_config["configurable"] diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py new file mode 100644 index 000000000..cd5000acd --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_heterogeneous_decisions.py @@ -0,0 +1,260 @@ +"""Real-graph contract: heterogeneous decisions route correctly across parallel subagents. + +The simple "approve everything" parallel test (see +``test_parallel_resume_command_keying``) proves the routing wires up at all, +but it doesn't exercise the actual production user flow: rejecting one card +while approving another, or editing one action's args before approving the +rest. Those are the decisions ``HumanInTheLoopMiddleware`` differentiates on, +and they're exactly where a slicer/router bug silently mis-applies a reject +to the wrong subagent. + +This module pins: + +1. **Order preservation** across the slice boundary — flat decisions enter + in the order the SSE stream rendered cards; each subagent must receive + only its slice in the original order. +2. **Per-decision metadata pass-through** — ``message`` and ``edited_action`` + payloads must reach the subagent intact (not just the ``type`` discriminator). +3. **Off-by-one-sensitive bundle sizes** — both paused subagents have action + counts ``> 1`` (``2`` and ``3``). With those sizes a buggy + ``cursor += 1`` slicer (instead of ``cursor += action_count``) produces a + different B-slice from the correct one, so this test catches the most + common refactor mistake. A ``(1, 2)`` configuration would silently pass + such a bug because ``+= 1`` and ``+= count`` are arithmetically identical + when ``count == 1``. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + messages: Annotated[list, add_messages] + tcid: str + desc: str + subtype: str + + +def _build_capturing_subagent(checkpointer: InMemorySaver, *, action_count: int): + """Subagent that pauses with an N-action bundle and on resume records what it received. + + The recorded ``AIMessage`` content is the JSON-serialized resume payload, so + the assertions can inspect exactly which decisions reached this subagent + (vs. its sibling) — including the ``message`` and ``edited_action`` + metadata, not just the ``type``. + """ + + def hitl_node(_state): + decision_payload = interrupt( + { + "action_requests": [ + { + "name": f"act_{i}", + "args": {"i": i}, + "description": f"action {i}", + } + for i in range(action_count) + ], + "review_configs": [ + { + "action_name": f"act_{i}", + "allowed_decisions": ["approve", "reject", "edit"], + } + for i in range(action_count) + ], + } + ) + return { + "messages": [ + AIMessage(content=json.dumps(decision_payload, sort_keys=True)) + ] + } + + g = StateGraph(_SubState) + g.add_node("hitl", hitl_node) + g.add_edge(START, "hitl") + g.add_edge("hitl", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_dispatching_two_subagents( + task_tool, *, dispatches: list[dict[str, str]], checkpointer +): + """Parent that fans out to ``len(dispatches)`` parallel ``task`` tool calls. + + Each entry in ``dispatches`` is ``{"tcid": ..., "subtype": ..., "desc": ...}`` + so different parallel branches can target different subagent types — the + actual production scenario (Linear + Jira, etc.). + """ + + def fanout(_state) -> list[Send]: + return [Send("call_task", d) for d in dispatches] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type=state["subtype"], runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_intact(): + """Mixed approve/reject/edit decisions across two parallel subagents. + + Setup chosen so the slicer's cursor arithmetic is sensitive to off-by-one + refactors: + - Sub-A pauses with a 2-action bundle (``act_0``, ``act_1``). + - Sub-B pauses with a 3-action bundle (``act_0``, ``act_1``, ``act_2``). + - Parent ends up with 2 pending interrupts (one per subagent). + + With both counts ``> 1``, a buggy ``cursor += 1`` (instead of + ``cursor += action_count``) produces a different B-slice from the correct + one, so the assertions catch it. A ``(1, 2)`` configuration would not + because ``+= 1`` and ``+= count`` are arithmetically identical when + ``count == 1``. + + The frontend submits a flat + ``[A_approve, A_reject, B_edit, B_approve, B_reject]`` list with distinct + ``message`` and ``edited_action`` payloads; our slicer must split into + ``{tcid_A: [A_approve, A_reject], tcid_B: [B_edit, B_approve, B_reject]}`` + and the bridge must forward each subagent's slice intact — including all + metadata, in original order. + """ + checkpointer = InMemorySaver() + + sub_a = _build_capturing_subagent(checkpointer, action_count=2) + sub_b = _build_capturing_subagent(checkpointer, action_count=3) + + task_tool = build_task_tool_with_parent_config( + [ + {"name": "agent-a", "description": "first", "runnable": sub_a}, + {"name": "agent-b", "description": "second", "runnable": sub_b}, + ] + ) + + parent = _parent_dispatching_two_subagents( + task_tool, + dispatches=[ + {"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"}, + {"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"}, + ], + checkpointer=checkpointer, + ) + + config: dict = { + "configurable": {"thread_id": "het-decisions-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused_state = await parent.aget_state(config) + assert len(paused_state.interrupts) == 2, ( + f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}" + ) + + pending = collect_pending_tool_calls(paused_state) + pending_by_tcid = dict(pending) + assert pending_by_tcid == {"tcid-A": 2, "tcid-B": 3}, ( + f"REGRESSION: action-count accounting wrong; got {pending_by_tcid!r}" + ) + + a_approve = {"type": "approve"} + a_reject = {"type": "reject", "message": "A[1] looks redundant"} + b_edit = { + "type": "edit", + "edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}}, + } + b_approve = {"type": "approve"} + b_reject = {"type": "reject", "message": "B[2] needs more context"} + flat_decisions = [a_approve, a_reject, b_edit, b_approve, b_reject] + + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + + assert by_tool_call_id == { + "tcid-A": {"decisions": [a_approve, a_reject]}, + "tcid-B": {"decisions": [b_edit, b_approve, b_reject]}, + }, f"REGRESSION: slicer mis-routed decisions: {by_tool_call_id!r}" + + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id) + + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final_state = await parent.aget_state(config) + assert not final_state.interrupts, ( + f"REGRESSION: leftover pending interrupts after resume: {final_state.interrupts!r}" + ) + + payloads: list[dict] = [] + for msg in final_state.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + expected_a = {"decisions": [a_approve, a_reject]} + expected_b = {"decisions": [b_edit, b_approve, b_reject]} + + assert expected_a in payloads, ( + f"REGRESSION: sub-A did not receive its 2-decision slice in original order; " + f"payloads seen: {payloads!r}" + ) + assert expected_b in payloads, ( + f"REGRESSION: sub-B did not receive its 3-decision slice in original order; " + f"payloads seen: {payloads!r}" + ) + + +@pytest.mark.asyncio +async def test_decision_count_mismatch_fails_loud_before_dispatch(): + """The slicer must refuse a flat list whose total != sum(action_counts). + + Otherwise a frontend/backend contract drift would silently send a + truncated/padded slice to one of the subagents — the worst possible + failure mode (mis-applied reject on a long-lived ticket). + """ + pending = [("tcid-A", 1), ("tcid-B", 2)] + decisions = [{"type": "approve"}, {"type": "approve"}] + + with pytest.raises(ValueError, match="Decision count mismatch"): + slice_decisions_by_tool_call(decisions, pending) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py new file mode 100644 index 000000000..79210032b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py @@ -0,0 +1,254 @@ +"""Real-graph contract: one parallel branch completes while a sibling pauses with HITL. + +The two existing parallel-routing tests +(``test_parallel_resume_command_keying`` and +``test_parallel_heterogeneous_decisions``) both pause **all** branches +simultaneously. That's the easy case — every dispatched ``task`` call has a +matching pending interrupt, and the routing helpers see a uniform shape. + +Production rarely matches that uniform shape. The orchestrator typically +delegates "create a Linear ticket and summarize the user's recent activity": +one branch needs HITL, the other returns its result and exits. At the pause +moment:: + + state.values["messages"] += [ToolMessage(from-A)] # A merged in + state.interrupts = [Interrupt(value-from-B)] # B alone is pending + +So ``len(state.interrupts) < num_dispatched_tasks``. The slicer and +``build_lg_resume_map`` must: + +1. **Key off ``state.interrupts``, never off the originally dispatched tcids.** + A flat decisions list of length 1 must route only to B; if anything tries + to look up A in the resume map, langgraph rejects an unknown + ``Interrupt.id``. +2. **Leave A's contributions intact across resume.** A's ToolMessage was + committed at the pause; resuming the paused branch must not re-run A nor + drop its message. +3. **Drain the single pending interrupt.** Final ``state.interrupts`` is + empty regardless of whether sibling branches were paused. + +The langgraph semantics this test relies on were verified empirically in the +exploratory probe before this test was authored. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +class _DispatchState(TypedDict, total=False): + messages: Annotated[list, add_messages] + tcid: str + desc: str + subtype: str + + +_QUICK_MARKER = "quick-subagent-finished-without-pausing" + + +def _build_quick_subagent(checkpointer: InMemorySaver): + """Subagent that completes synchronously without firing any interrupt.""" + + def quick_node(_state): + return {"messages": [AIMessage(content=_QUICK_MARKER)]} + + g = StateGraph(_SubState) + g.add_node("quick", quick_node) + g.add_edge(START, "quick") + g.add_edge("quick", END) + return g.compile(checkpointer=checkpointer) + + +def _build_pausing_subagent(checkpointer: InMemorySaver): + """Subagent that pauses with a single-action HITL bundle and records its resume payload.""" + + def hitl_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "act_0", "args": {"i": 0}, "description": ""} + ], + "review_configs": [ + { + "action_name": "act_0", + "allowed_decisions": ["approve", "reject", "edit"], + } + ], + } + ) + return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]} + + g = StateGraph(_SubState) + g.add_node("hitl", hitl_node) + g.add_edge(START, "hitl") + g.add_edge("hitl", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_with_two_branches(task_tool, *, dispatches, checkpointer): + def fanout(_state) -> list[Send]: + return [Send("call_task", d) for d in dispatches] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type=state["subtype"], runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +def _quick_marker_count(state) -> int: + """How many messages anywhere in parent state contain the quick subagent's marker.""" + n = 0 + for msg in state.values.get("messages", []) or []: + content = getattr(msg, "content", "") + if isinstance(content, str) and _QUICK_MARKER in content: + n += 1 + return n + + +@pytest.mark.asyncio +async def test_partial_pause_routes_only_to_paused_branch_without_rerunning_completed_one(): + """One branch completes synchronously; the other pauses with HITL — resume routes only to B. + + Setup: + - Sub-A (``quick``): no interrupt, finishes immediately, writes a marker + message to parent state. + - Sub-B (``pausing``): interrupts with a 1-action HITL bundle. + + At pause, parent state has A's marker already merged in and exactly one + pending interrupt (B's). Resume sends a 1-element flat decisions list; + the routing helpers must not look up A in the resume map (would explode + with an unknown ``Interrupt.id``) and must not re-invoke A on resume + (would duplicate the marker). + """ + checkpointer = InMemorySaver() + + quick_sub = _build_quick_subagent(checkpointer) + pausing_sub = _build_pausing_subagent(checkpointer) + + task_tool = build_task_tool_with_parent_config( + [ + {"name": "quick-agent", "description": "instant", "runnable": quick_sub}, + { + "name": "pausing-agent", + "description": "needs review", + "runnable": pausing_sub, + }, + ] + ) + + parent = _parent_with_two_branches( + task_tool, + dispatches=[ + {"tcid": "tcid-A", "subtype": "quick-agent", "desc": "do A fast"}, + { + "tcid": "tcid-B", + "subtype": "pausing-agent", + "desc": "needs approval", + }, + ], + checkpointer=checkpointer, + ) + + config: dict = { + "configurable": {"thread_id": "partial-pause-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused = await parent.aget_state(config) + + assert len(paused.interrupts) == 1, ( + f"REGRESSION: expected exactly 1 pending interrupt (sub-B alone), " + f"got {len(paused.interrupts)}" + ) + + pending = collect_pending_tool_calls(paused) + assert pending == [("tcid-B", 1)], ( + f"REGRESSION: pending list contains stale tcids; got {pending!r}" + ) + + pre_resume_marker_count = _quick_marker_count(paused) + assert pre_resume_marker_count == 1, ( + f"REGRESSION: sub-A's contribution missing or duplicated at pause " + f"(found {pre_resume_marker_count}, expected 1)" + ) + + flat_decisions = [{"type": "approve"}] + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + assert by_tool_call_id == {"tcid-B": {"decisions": [{"type": "approve"}]}}, ( + f"REGRESSION: slicer routed to a non-pending tcid: {by_tool_call_id!r}" + ) + + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + lg_resume_map = build_lg_resume_map(paused, by_tool_call_id) + + assert set(lg_resume_map.keys()) == {paused.interrupts[0].id}, ( + f"REGRESSION: resume map keyed by an unknown Interrupt.id " + f"(would crash langgraph): {lg_resume_map!r}" + ) + + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final = await parent.aget_state(config) + assert not final.interrupts, ( + f"REGRESSION: pending interrupts after resume: {final.interrupts!r}" + ) + + post_resume_marker_count = _quick_marker_count(final) + assert post_resume_marker_count == 1, ( + f"REGRESSION: sub-A re-ran on resume (marker count went " + f"{pre_resume_marker_count} → {post_resume_marker_count}); " + f"resume must touch only the paused branch." + ) + + payloads: list[dict] = [] + for msg in final.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + assert {"decisions": [{"type": "approve"}]} in payloads, ( + f"REGRESSION: sub-B did not receive its single approve on resume; " + f"payloads seen: {payloads!r}" + ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py new file mode 100644 index 000000000..f4ee947c6 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_reject_only_routing.py @@ -0,0 +1,216 @@ +"""Real-graph contract: all-reject decisions route correctly across parallel subagents. + +Heterogeneous routing is covered by ``test_parallel_heterogeneous_decisions``. +This module pins the narrower edge case where **every** card on **every** +paused subagent is rejected. + +Why a separate pin: + +1. **No approval-bias in the slicer.** A future "if no approvals, short-circuit + resume" optimization would be tempting (skips a langgraph round-trip) and + would also silently break this scenario. Pin it. +2. **``message`` metadata pass-through across a run of rejects.** The reject + ``message`` is the user-visible reason ("looks suspicious", "duplicate", + etc.). Losing it would silently swallow user intent — the worst UX + failure mode for HITL. Heterogeneous covers one reject; here we verify a + sequence of rejects survives the slicer + bridge with distinct messages + intact and in order. +3. **All branches complete with no leftover pending.** Even when nothing was + approved, the parent must drain every paused subagent so the SSE stream + can close cleanly. A bug that left one ``Interrupt.id`` un-keyed would + strand the conversation in "pending" forever. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + messages: Annotated[list, add_messages] + tcid: str + desc: str + subtype: str + + +def _build_recording_subagent(checkpointer: InMemorySaver, *, action_count: int): + """Subagent that pauses with ``action_count`` actions and records its resume payload. + + The recorded ``AIMessage`` content is the JSON-serialized payload, so the + test can match each subagent's slice by content. + """ + + def hitl_node(_state): + decision_payload = interrupt( + { + "action_requests": [ + {"name": f"act_{i}", "args": {"i": i}, "description": ""} + for i in range(action_count) + ], + "review_configs": [ + { + "action_name": f"act_{i}", + "allowed_decisions": ["approve", "reject", "edit"], + } + for i in range(action_count) + ], + } + ) + return { + "messages": [ + AIMessage(content=json.dumps(decision_payload, sort_keys=True)) + ] + } + + g = StateGraph(_SubState) + g.add_node("hitl", hitl_node) + g.add_edge(START, "hitl") + g.add_edge("hitl", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_two_branches(task_tool, *, dispatches, checkpointer): + def fanout(_state) -> list[Send]: + return [Send("call_task", d) for d in dispatches] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type=state["subtype"], runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_all_reject_decisions_route_to_each_subagent_with_messages_intact(): + """All cards rejected across two parallel subagents — order and messages preserved. + + Setup mirrors a real "user reviews two parallel ticket creations and + rejects everything with distinct reasons": + + - Sub-A pauses with 2 actions. + - Sub-B pauses with 1 action. + - Flat decisions: 3 rejects, each with a unique ``message``. + + Asserts each subagent receives only its slice, in original order, + with every ``message`` intact and no ``edited_action`` fields fabricated. + """ + checkpointer = InMemorySaver() + + sub_a = _build_recording_subagent(checkpointer, action_count=2) + sub_b = _build_recording_subagent(checkpointer, action_count=1) + + task_tool = build_task_tool_with_parent_config( + [ + {"name": "agent-a", "description": "first", "runnable": sub_a}, + {"name": "agent-b", "description": "second", "runnable": sub_b}, + ] + ) + + parent = _parent_two_branches( + task_tool, + dispatches=[ + {"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"}, + {"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"}, + ], + checkpointer=checkpointer, + ) + + config: dict = { + "configurable": {"thread_id": "all-reject-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused_state = await parent.aget_state(config) + assert len(paused_state.interrupts) == 2, ( + f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}" + ) + + a_reject_0 = {"type": "reject", "message": "A[0] looks suspicious"} + a_reject_1 = {"type": "reject", "message": "A[1] duplicates A[0]"} + b_reject_0 = {"type": "reject", "message": "B[0] needs more context"} + flat_decisions = [a_reject_0, a_reject_1, b_reject_0] + + pending = collect_pending_tool_calls(paused_state) + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + + assert by_tool_call_id == { + "tcid-A": {"decisions": [a_reject_0, a_reject_1]}, + "tcid-B": {"decisions": [b_reject_0]}, + }, f"REGRESSION: slicer mis-routed all-reject decisions: {by_tool_call_id!r}" + + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id) + + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final_state = await parent.aget_state(config) + assert not final_state.interrupts, ( + f"REGRESSION: leftover pending interrupts after all-reject resume: " + f"{final_state.interrupts!r}" + ) + + payloads: list[dict] = [] + for msg in final_state.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + expected_a = {"decisions": [a_reject_0, a_reject_1]} + expected_b = {"decisions": [b_reject_0]} + + assert expected_a in payloads, ( + f"REGRESSION: sub-A did not receive its 2-reject slice in order; " + f"payloads seen: {payloads!r}" + ) + assert expected_b in payloads, ( + f"REGRESSION: sub-B did not receive its single reject; " + f"payloads seen: {payloads!r}" + ) + + for p in payloads: + for d in p.get("decisions", []): + assert "edited_action" not in d, ( + f"REGRESSION: spurious ``edited_action`` on a reject — " + f"slicer/bridge mutated payload: {d!r}" + ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py new file mode 100644 index 000000000..458a2539b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_resume_command_keying.py @@ -0,0 +1,301 @@ +"""Real-graph contract: parallel resume must key ``Command(resume=...)`` by ``Interrupt.id``. + +When the parent state has multiple pending interrupts, langgraph rejects a +scalar ``Command(resume=v)`` with:: + + RuntimeError: When there are multiple pending interrupts, you must specify + the interrupt id when resuming. + +The fix is to map each ``Interrupt.id`` from ``state.interrupts`` to the +per-subagent slice — orthogonal to our ``tool_call_id``-keyed +``surfsense_resume_value`` side-channel (different consumer: langgraph's +pregel vs. our subagent bridge). + +This test reproduces the production failure with a real two-task parallel +``Send`` parent graph, exercises the full resume cycle, and asserts both +subagents complete cleanly with their per-subagent slice intact. + +Bundle sizes are chosen heterogeneous (``2`` and ``3``) so the assertions +also catch slicer arithmetic regressions (e.g., ``cursor += 1`` instead of +``cursor += action_count``). A symmetric ``(1, 1)`` configuration would +silently pass such a bug because the slices would coincide. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + # ``add_messages`` reducer matches production agent state shape and is + # required when two parallel ``Send`` branches both write to ``messages`` + # in the same superstep (post-resume both subagents return their own + # ``{"messages": [...]}``). Without a reducer langgraph raises + # ``InvalidUpdateError: At key 'messages': Can receive only one value``. + messages: Annotated[list, add_messages] + tcid: str + desc: str + subtype: str + + +def _build_pausing_subagent(checkpointer: InMemorySaver, *, action_count: int): + """Subagent that pauses with an ``action_count``-action HITL bundle. + + On resume it captures the decision payload as a JSON-serialized + ``AIMessage`` content so the test can inspect exactly which slice + reached this subagent — the strongest assertion against slicer + routing regressions. + """ + + def approve_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": f"act_{i}", "args": {"i": i}, "description": ""} + for i in range(action_count) + ], + "review_configs": [{} for _ in range(action_count)], + } + ) + return { + "messages": [AIMessage(content=json.dumps(decision, sort_keys=True))] + } + + g = StateGraph(_SubState) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_graph_dispatching_two_tasks_via_send( + task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer +): + def fanout_edge(_state) -> list[Send]: + return [ + Send( + "call_task", + {"tcid": tool_call_id_a, "desc": "approve A", "subtype": "agent-a"}, + ), + Send( + "call_task", + {"tcid": tool_call_id_b, "desc": "approve B", "subtype": "agent-b"}, + ), + ] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type=state["subtype"], runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout_edge, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +def _build_two_subagents_task_tool(checkpointer: InMemorySaver): + """Register two subagents under distinct names with heterogeneous bundle sizes. + + Sub-A: 2-action bundle. Sub-B: 3-action bundle. Both ``> 1`` so the slice + arithmetic is sensitive to off-by-one mistakes. + """ + sub_a = _build_pausing_subagent(checkpointer, action_count=2) + sub_b = _build_pausing_subagent(checkpointer, action_count=3) + return build_task_tool_with_parent_config( + [ + {"name": "agent-a", "description": "first", "runnable": sub_a}, + {"name": "agent-b", "description": "second", "runnable": sub_b}, + ] + ) + + +@pytest.mark.asyncio +async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error(): + """Confirm the production failure mode: scalar resume on multi-pending state explodes. + + This is a contract pin: if langgraph relaxes the requirement in a future + release, this test starts passing and we know we can simplify + ``stream_resume_chat``. Until then, the keyed form is mandatory. + """ + checkpointer = InMemorySaver() + task_tool = _build_two_subagents_task_tool(checkpointer) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + config: dict = { + "configurable": {"thread_id": "parallel-resume-scalar"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + with pytest.raises(RuntimeError, match="multiple pending interrupts"): + await parent.ainvoke(Command(resume={"decisions": ["A"]}), config) + + +@pytest.mark.asyncio +async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subagents(): + """Production-shape resume: builds the langgraph-keyed map and resumes both subagents. + + Mirrors what ``stream_resume_chat`` does: collects pending interrupts, + slices the flat decisions list by ``tool_call_id``, builds the + ``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes. + + Post-conditions checked: + 1. The langgraph-keyed map has exactly one entry per pending interrupt + id (``str`` keys, count matches). + 2. Both subagents complete with no leftover pending interrupts. + 3. **Each subagent receives its exact slice in the original order** — + this catches slicer arithmetic regressions (e.g., ``cursor += 1``) + that wouldn't surface by checking only "no leftover pending". + """ + checkpointer = InMemorySaver() + task_tool = _build_two_subagents_task_tool(checkpointer) + tcid_a = "parent-tcid-A" + tcid_b = "parent-tcid-B" + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a=tcid_a, + tool_call_id_b=tcid_b, + checkpointer=checkpointer, + ) + config: dict = { + "configurable": {"thread_id": "parallel-resume-keyed"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused_state = await parent.aget_state(config) + assert len(paused_state.interrupts) == 2, "fixture broken: expected 2 paused subagents" + + pending = collect_pending_tool_calls(paused_state) + assert dict(pending) == {tcid_a: 2, tcid_b: 3}, ( + f"fixture broken: heterogeneous bundle sizes not detected; got {pending!r}" + ) + + a_d0 = {"type": "approve"} + a_d1 = {"type": "reject", "message": "A[1] is redundant"} + b_d0 = { + "type": "edit", + "edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}}, + } + b_d1 = {"type": "approve"} + b_d2 = {"type": "reject", "message": "B[2] needs more context"} + flat_decisions = [a_d0, a_d1, b_d0, b_d1, b_d2] + + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id) + + assert len(lg_resume_map) == 2, ( + f"expected one entry per pending interrupt id, got {lg_resume_map!r}" + ) + assert all(isinstance(k, str) for k in lg_resume_map), ( + f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}" + ) + + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final_state = await parent.aget_state(config) + assert not final_state.interrupts, ( + f"expected no leftover pending interrupts after resume, got " + f"{final_state.interrupts!r}" + ) + + payloads: list[dict] = [] + for msg in final_state.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + expected_a = {"decisions": [a_d0, a_d1]} + expected_b = {"decisions": [b_d0, b_d1, b_d2]} + assert expected_a in payloads, ( + f"REGRESSION: sub-A did not receive its 2-decision slice in order; " + f"payloads seen: {payloads!r}" + ) + assert expected_b in payloads, ( + f"REGRESSION: sub-B did not receive its 3-decision slice in order; " + f"payloads seen: {payloads!r}" + ) + + +def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps(): + """Unstamped interrupts can't be routed; we don't fabricate keys for them. + + If a regression lets an unstamped interrupt reach the parent state, the + empty map propagates to the call site and surfaces as a clear count + mismatch instead of a silent mis-route. + """ + from types import SimpleNamespace + + fake_interrupt = SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}) + state = SimpleNamespace(interrupts=(fake_interrupt,)) + + assert build_lg_resume_map(state, {"some-tcid": {"decisions": ["x"]}}) == {} + + +def test_build_lg_resume_map_skips_interrupts_without_corresponding_slice(): + """Skip rather than silently mis-route when the slice and interrupts disagree. + + Only emit a resume entry when both an interrupt id and a tool_call_id + slice are present; a mismatch indicates upstream contract drift and + should not be papered over. + """ + from types import SimpleNamespace + + state = SimpleNamespace( + interrupts=( + SimpleNamespace( + id="i-A", + value={"action_requests": [{}], "tool_call_id": "tcid-A"}, + ), + SimpleNamespace( + id="i-B", + value={"action_requests": [{}], "tool_call_id": "tcid-B"}, + ), + ) + ) + + out = build_lg_resume_map(state, {"tcid-A": {"decisions": ["only-A"]}}) + assert out == {"i-A": {"decisions": ["only-A"]}} diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py new file mode 100644 index 000000000..57969e8fa --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_self_and_middleware_gated.py @@ -0,0 +1,272 @@ +"""Real-graph parallel HITL across both approval kinds — the keystone regression. + +Pre-fix bug: the parallel-HITL routing layer (``collect_pending_tool_calls`` ++ ``slice_decisions_by_tool_call`` + ``build_lg_resume_map``) only +recognized middleware-gated approvals (LC HITL shape from +``HumanInTheLoopMiddleware``). Self-gated approvals from +``request_approval`` and middleware-gated permission asks from +``PermissionMiddleware`` both used the SurfSense-specific +``{type, action, context}`` shape, so when the orchestrator dispatched +two parallel ``task`` calls — one self-gated, one middleware-gated — only +one interrupt was visible to the routing layer and resume crashed with +``Decision count mismatch``. + +This test fans out two real subagents via ``Send``: one calls +``request_approval`` (self-gated), the other calls +``request_permission_decision`` (middleware-gated). Both pause; the routing +layer must see TWO LC HITL interrupts, slice the decisions by +``tool_call_id``, key by ``Interrupt.id``, and resume both branches with +their per-slice payload. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) +from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import ( + request_permission_decision, +) +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) +from app.agents.new_chat.permissions import Rule + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + # ``add_messages`` is mandatory: parallel ``Send`` branches both append + # to ``messages`` in the same superstep; without a reducer langgraph + # raises ``InvalidUpdateError``. + messages: Annotated[list, add_messages] + tcid: str + desc: str + subtype: str + + +def _build_self_gated_subagent(checkpointer: InMemorySaver): + """Subagent that pauses via ``request_approval`` (self-gated path).""" + + def gate_node(_state): + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={"to": "alice@example.com"}, + ) + return { + "messages": [ + AIMessage( + content=json.dumps( + { + "kind": "self_gated", + "decision_type": result.decision_type, + "params": result.params, + "rejected": result.rejected, + }, + sort_keys=True, + ) + ) + ] + } + + g = StateGraph(_SubState) + g.add_node("gate", gate_node) + g.add_edge(START, "gate") + g.add_edge("gate", END) + return g.compile(checkpointer=checkpointer) + + +def _build_middleware_gated_subagent(checkpointer: InMemorySaver): + """Subagent that pauses via ``request_permission_decision`` (middleware-gated path).""" + + def perm_node(_state): + decision = request_permission_decision( + tool_name="rm", + args={"path": "/tmp/file"}, + patterns=["rm/*"], + rules=[Rule(permission="rm", pattern="*", action="ask")], + emit_interrupt=True, + ) + return { + "messages": [ + AIMessage( + content=json.dumps( + {"kind": "middleware_gated", "decision": decision}, + sort_keys=True, + ) + ) + ] + } + + g = StateGraph(_SubState) + g.add_node("perm", perm_node) + g.add_edge(START, "perm") + g.add_edge("perm", END) + return g.compile(checkpointer=checkpointer) + + +def _build_mixed_task_tool(checkpointer: InMemorySaver): + """Two subagents, one per approval kind, registered under distinct names.""" + return build_task_tool_with_parent_config( + [ + { + "name": "self-gated-agent", + "description": "uses request_approval", + "runnable": _build_self_gated_subagent(checkpointer), + }, + { + "name": "middleware-gated-agent", + "description": "uses request_permission_decision", + "runnable": _build_middleware_gated_subagent(checkpointer), + }, + ] + ) + + +def _parent_dispatching_one_of_each( + task_tool, *, tcid_self: str, tcid_mw: str, checkpointer +): + def fanout_edge(_state) -> list[Send]: + return [ + Send( + "call_task", + {"tcid": tcid_self, "desc": "approve email", "subtype": "self-gated-agent"}, + ), + Send( + "call_task", + { + "tcid": tcid_mw, + "desc": "approve rm", + "subtype": "middleware-gated-agent", + }, + ), + ] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type=state["subtype"], runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout_edge, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_parallel_self_gated_and_middleware_gated_route_and_resume_cleanly(): + """Both interrupt kinds must reach the routing layer in LC HITL shape and resume independently.""" + checkpointer = InMemorySaver() + task_tool = _build_mixed_task_tool(checkpointer) + tcid_self = "tcid-self-gated" + tcid_mw = "tcid-middleware-gated" + parent = _parent_dispatching_one_of_each( + task_tool, + tcid_self=tcid_self, + tcid_mw=tcid_mw, + checkpointer=checkpointer, + ) + config: dict = { + "configurable": {"thread_id": "mixed-parallel"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused = await parent.aget_state(config) + assert len(paused.interrupts) == 2, ( + "fixture broken: expected one paused interrupt per approval kind" + ) + + # Both interrupts must speak the same wire shape — the whole point of + # the unification. If either one regresses to the legacy SurfSense shape + # ``collect_pending_tool_calls`` would silently skip it and the count + # below would be 1. + pending = collect_pending_tool_calls(paused) + assert dict(pending) == {tcid_self: 1, tcid_mw: 1}, ( + f"REGRESSION: not all interrupt kinds reached the routing layer; " + f"got {pending!r}" + ) + + # Verify the actual wire payloads carry the LC HITL standard fields + # (extra defensive assertion against partial regressions where one + # path stamps tool_call_id but reverts the body shape). + interrupt_types = {i.value.get("interrupt_type") for i in paused.interrupts} + assert interrupt_types == {"gmail_email_send", "permission_ask"} + + # Resume order: same order the SSE stream would emit (interrupts list). + decision_self = {"type": "approve"} + decision_mw = {"type": "approve_always"} + flat_decisions = [ + # Match `pending` order. + decision_self if pending[0][0] == tcid_self else decision_mw, + decision_mw if pending[0][0] == tcid_self else decision_self, + ] + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + lg_resume_map = build_lg_resume_map(paused, by_tool_call_id) + assert len(lg_resume_map) == 2 + + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final = await parent.aget_state(config) + assert not final.interrupts, ( + f"expected both branches resumed, but state still has interrupts: " + f"{final.interrupts!r}" + ) + + # Each subagent must have received its own slice — verify by inspecting + # the JSON-serialized result messages. + payloads: list[dict] = [] + for msg in final.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + self_payloads = [p for p in payloads if p.get("kind") == "self_gated"] + mw_payloads = [p for p in payloads if p.get("kind") == "middleware_gated"] + assert len(self_payloads) == 1, ( + f"self-gated subagent did not complete; payloads: {payloads!r}" + ) + assert len(mw_payloads) == 1, ( + f"middleware-gated subagent did not complete; payloads: {payloads!r}" + ) + + # Self-gated approve → HITLResult(decision_type="approve", rejected=False). + assert self_payloads[0]["decision_type"] == "approve" + assert self_payloads[0]["rejected"] is False + + # Middleware-gated approve_always → canonical permission shape unchanged. + assert mw_payloads[0]["decision"] == {"decision_type": "approve_always"} diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py new file mode 100644 index 000000000..9c067fc57 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py @@ -0,0 +1,222 @@ +"""Behavioural guarantees for parallel ``task`` tool calls (non-HITL cases). + +The HITL bridge tests in ``test_hitl_bridge.py`` cover the parallel-interrupt +flow. This file covers the *normal* parallel paths (no interrupts) and the +failure-isolation guarantee — together they pin the behaviour we promise the +user about ``asyncio.gather`` over two ``atask`` coroutines. +""" + +from __future__ import annotations + +import asyncio + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +def _build_success_subagent(reply: str): + """A subagent that completes immediately with ``reply``, never interrupts.""" + + def node(_state): + return {"messages": [AIMessage(content=reply)]} + + g = StateGraph(_SubState) + g.add_node("only", node) + g.add_edge(START, "only") + g.add_edge("only", END) + return g.compile(checkpointer=InMemorySaver()) + + +def _build_failing_subagent(exc: Exception): + """A subagent whose only node raises ``exc`` — simulates a tool-level failure.""" + + def node(_state): + raise exc + + g = StateGraph(_SubState) + g.add_node("only", node) + g.add_edge(START, "only") + g.add_edge("only", END) + return g.compile(checkpointer=InMemorySaver()) + + +def _make_runtime(parent_config: dict, *, tool_call_id: str) -> ToolRuntime: + return ToolRuntime( + state={"messages": [HumanMessage(content="seed")]}, + context=None, + config=parent_config, + stream_writer=None, + tool_call_id=tool_call_id, + store=None, + ) + + +def _tool_message_text(cmd: Command, *, expected_tcid: str) -> str: + """Return the ToolMessage content the task tool produced for ``expected_tcid``.""" + assert isinstance(cmd, Command), f"expected Command, got {type(cmd).__name__}" + messages = cmd.update["messages"] + assert len(messages) == 1, f"expected 1 ToolMessage, got {len(messages)}" + msg = messages[0] + assert isinstance(msg, ToolMessage) + assert msg.tool_call_id == expected_tcid + return msg.content + + +@pytest.mark.asyncio +async def test_two_parallel_atasks_to_different_subagents_both_succeed(): + """Normal happy-path: two distinct subagents complete in parallel without interrupting.""" + subagent_a = _build_success_subagent("A is done") + subagent_b = _build_success_subagent("B is done") + task_tool = build_task_tool_with_parent_config( + [ + {"name": "alpha", "description": "alpha agent", "runnable": subagent_a}, + {"name": "beta", "description": "beta agent", "runnable": subagent_b}, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "ok-thread"}, + "recursion_limit": 100, + } + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B") + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="do A", + subagent_type="alpha", + runtime=runtime_a, + ), + task_tool.coroutine( + description="do B", + subagent_type="beta", + runtime=runtime_b, + ), + ) + + assert _tool_message_text(result_a, expected_tcid="tcid-A") == "A is done" + assert _tool_message_text(result_b, expected_tcid="tcid-B") == "B is done" + + +@pytest.mark.asyncio +async def test_two_parallel_atasks_same_subagent_type_different_tool_call_ids(): + """Per-call ``thread_id`` isolation: same compiled subagent invoked twice in parallel. + + Both calls share the same ``InMemorySaver`` instance but are namespaced by + distinct ``tool_call_id``s, so checkpoints land in disjoint thread slots. + """ + shared_subagent = _build_success_subagent("ok") + task_tool = build_task_tool_with_parent_config( + [ + {"name": "approver", "description": "shared approver", "runnable": shared_subagent}, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "shared-subagent-thread"}, + "recursion_limit": 100, + } + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B") + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="first request", + subagent_type="approver", + runtime=runtime_a, + ), + task_tool.coroutine( + description="second request", + subagent_type="approver", + runtime=runtime_b, + ), + ) + + # Both calls succeed and produce ToolMessages keyed by their own tool_call_id. + assert _tool_message_text(result_a, expected_tcid="tcid-A") == "ok" + assert _tool_message_text(result_b, expected_tcid="tcid-B") == "ok" + + # Verify checkpoint isolation: each call's state lives at its own thread_id. + state_a = await shared_subagent.aget_state( + {"configurable": {"thread_id": "shared-subagent-thread::task:tcid-A"}} + ) + state_b = await shared_subagent.aget_state( + {"configurable": {"thread_id": "shared-subagent-thread::task:tcid-B"}} + ) + assert state_a.values["messages"][-1].content == "ok" + assert state_b.values["messages"][-1].content == "ok" + + # The parent's own thread_id slot is untouched by either subagent. + state_parent = await shared_subagent.aget_state( + {"configurable": {"thread_id": "shared-subagent-thread"}} + ) + assert state_parent.values == {} or state_parent.values.get("messages") in (None, []) + + +@pytest.mark.asyncio +async def test_one_atask_failure_does_not_corrupt_sibling_atask(): + """Failure isolation: a sibling's exception must not poison the surviving atask's state. + + Note: in production, langgraph's pregel runner cancels siblings when any + parallel task raises a non-``GraphBubbleUp`` exception (see + ``_should_stop_others`` in ``langgraph/pregel/_runner.py``). At our layer + that policy is invisible — what we *can* guarantee is that the two atask + coroutines have disjoint state, so the surviving one returns a valid + Command even when its sibling explodes. + """ + failing_subagent = _build_failing_subagent(ValueError("boom")) + surviving_subagent = _build_success_subagent("still here") + task_tool = build_task_tool_with_parent_config( + [ + {"name": "broken", "description": "always fails", "runnable": failing_subagent}, + {"name": "healthy", "description": "always succeeds", "runnable": surviving_subagent}, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "iso-thread"}, + "recursion_limit": 100, + } + runtime_fail = _make_runtime(parent_config, tool_call_id="tcid-fail") + runtime_ok = _make_runtime(parent_config, tool_call_id="tcid-ok") + + results = await asyncio.gather( + task_tool.coroutine( + description="will explode", + subagent_type="broken", + runtime=runtime_fail, + ), + task_tool.coroutine( + description="will work", + subagent_type="healthy", + runtime=runtime_ok, + ), + return_exceptions=True, + ) + + fail_result, ok_result = results + + assert isinstance(fail_result, Exception), ( + f"expected the broken subagent to raise, got {fail_result!r}" + ) + # ValueError gets wrapped in langgraph's internal exception types — the + # important guarantee is "this path errored", not the specific class. + assert "boom" in str(fail_result) or isinstance(fail_result, ValueError) + + assert _tool_message_text(ok_result, expected_tcid="tcid-ok") == "still here" + + # Configurable side-channel must not have been corrupted by the failure. + assert "surfsense_resume_value" not in parent_config["configurable"] diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py new file mode 100644 index 000000000..ceb0df830 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_decision_routing.py @@ -0,0 +1,154 @@ +"""Slicing helper that routes a flat decisions list to per-tool-call payloads. + +The frontend submits ``decisions: list[ResumeDecision]`` in the same order the +SSE stream emitted approval cards. When multiple parallel subagents are paused, +the backend slices that flat list into per-``tool_call_id`` payloads so each +``atask`` reads only its own decisions through ``consume_surfsense_resume``. + +The extractor reads ``state.interrupts[i].value["tool_call_id"]`` — which is +populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s +``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up +through ``[a]task`` — to build the ordered ``pending`` list the slicer needs. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) + + +class TestSliceDecisionsByToolCall: + def test_splits_flat_decisions_across_two_pending_tool_calls(self): + decisions = [ + {"type": "approve"}, + {"type": "edit", "edited_action": {"name": "edited-b1"}}, + {"type": "reject"}, + {"type": "approve"}, + {"type": "approve"}, + ] + pending = [ + ("tcid-A", 3), + ("tcid-B", 2), + ] + + routed = slice_decisions_by_tool_call(decisions, pending) + + assert routed == { + "tcid-A": {"decisions": decisions[0:3]}, + "tcid-B": {"decisions": decisions[3:5]}, + } + + def test_raises_when_decision_count_less_than_total_actions(self): + decisions = [{"type": "approve"}, {"type": "approve"}] + pending = [("tcid-A", 3), ("tcid-B", 2)] + + with pytest.raises(ValueError, match=r"5 actions.*2 decisions"): + slice_decisions_by_tool_call(decisions, pending) + + def test_raises_when_decision_count_greater_than_total_actions(self): + decisions = [{"type": "approve"}] * 6 + pending = [("tcid-A", 3), ("tcid-B", 2)] + + with pytest.raises(ValueError, match=r"5 actions.*6 decisions"): + slice_decisions_by_tool_call(decisions, pending) + + def test_handles_single_pending_tool_call(self): + decisions = [{"type": "approve"}, {"type": "reject"}] + pending = [("tcid-only", 2)] + + routed = slice_decisions_by_tool_call(decisions, pending) + + assert routed == {"tcid-only": {"decisions": decisions}} + + def test_returns_empty_dict_for_no_pending(self): + routed = slice_decisions_by_tool_call([], []) + + assert routed == {} + + +def _interrupt_with(tool_call_id: str, action_count: int): + return SimpleNamespace( + id=f"i-{tool_call_id}", + value={ + "action_requests": [{"name": "n", "args": {}}] * action_count, + "review_configs": [{}] * action_count, + "tool_call_id": tool_call_id, + }, + ) + + +class TestCollectPendingToolCalls: + def test_single_pending_returns_one_pair(self): + state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),)) + + assert collect_pending_tool_calls(state) == [("tcid-only", 3)] + + def test_multiple_pending_preserves_state_order(self): + """Order must match what the SSE stream emitted (= state.interrupts order).""" + state = SimpleNamespace( + interrupts=( + _interrupt_with("tcid-A", 2), + _interrupt_with("tcid-B", 3), + ) + ) + + assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)] + + def test_empty_when_no_interrupts(self): + state = SimpleNamespace(interrupts=()) + + assert collect_pending_tool_calls(state) == [] + + def test_skips_interrupts_without_tool_call_id(self): + """Defensive: interrupts not produced by our propagation layer are ignored. + + ``stream_resume_chat`` only owns the ``task``-routing slice; non-task + interrupts (e.g. parent-side HITL middleware on a different tool) are + not the slicer's responsibility. + """ + state = SimpleNamespace( + interrupts=( + _interrupt_with("tcid-A", 2), + SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}), + _interrupt_with("tcid-B", 1), + ) + ) + + assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)] + + def test_handles_scalar_value_interrupt(self): + """Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``. + + These have no ``action_requests`` — count them as a single action so + the frontend submits exactly one decision per such interrupt. + """ + state = SimpleNamespace( + interrupts=( + SimpleNamespace( + id="i-A", + value={"value": "approve?", "tool_call_id": "tcid-A"}, + ), + ) + ) + + assert collect_pending_tool_calls(state) == [("tcid-A", 1)] + + def test_raises_when_interrupt_value_missing_action_count_keys(self): + """An interrupt with ``tool_call_id`` but no usable count signals a contract bug.""" + state = SimpleNamespace( + interrupts=( + SimpleNamespace( + id="i-A", + value={"tool_call_id": "tcid-A", "weird_shape": True}, + ), + ) + ) + + with pytest.raises(ValueError, match="action_requests"): + collect_pending_tool_calls(state) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py index 347b32dbd..e8aacfc5d 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_resume_helpers.py @@ -1,4 +1,4 @@ -"""Resume side-channel must be read exactly once per turn.""" +"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently.""" from __future__ import annotations @@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid ) -def _runtime_with_config(config: dict) -> ToolRuntime: +def _runtime_with_config( + config: dict, *, tool_call_id: str = "tcid-test" +) -> ToolRuntime: return ToolRuntime( state=None, context=None, config=config, stream_writer=None, - tool_call_id="tcid-test", + tool_call_id=tool_call_id, store=None, ) class TestConsumeSurfsenseResume: - def test_pops_value_on_first_call(self): + def test_pops_only_entry_matching_runtime_tool_call_id(self): + configurable = { + "surfsense_resume_value": { + "tcid-A": {"decisions": ["approve"]}, + "tcid-B": {"decisions": ["reject"]}, + } + } runtime = _runtime_with_config( - {"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}} + {"configurable": configurable}, tool_call_id="tcid-A" ) assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]} - def test_second_call_returns_none(self): - configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}} - runtime = _runtime_with_config({"configurable": configurable}) + def test_popping_one_entry_leaves_siblings_untouched(self): + configurable = { + "surfsense_resume_value": { + "tcid-A": {"decisions": ["approve"]}, + "tcid-B": {"decisions": ["reject"]}, + } + } + runtime_a = _runtime_with_config( + {"configurable": configurable}, tool_call_id="tcid-A" + ) - consume_surfsense_resume(runtime) + consume_surfsense_resume(runtime_a) + + assert configurable["surfsense_resume_value"] == { + "tcid-B": {"decisions": ["reject"]} + } + + def test_returns_none_when_no_entry_for_this_tool_call(self): + runtime = _runtime_with_config( + { + "configurable": { + "surfsense_resume_value": {"tcid-other": {"decisions": []}} + } + }, + tool_call_id="tcid-A", + ) assert consume_surfsense_resume(runtime) is None - assert "surfsense_resume_value" not in configurable def test_returns_none_when_no_payload_queued(self): runtime = _runtime_with_config({"configurable": {}}) @@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume: assert consume_surfsense_resume(runtime) is None + def test_drops_empty_dict_after_last_entry_consumed(self): + configurable = { + "surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}} + } + runtime = _runtime_with_config( + {"configurable": configurable}, tool_call_id="tcid-A" + ) + + consume_surfsense_resume(runtime) + + assert "surfsense_resume_value" not in configurable + class TestHasSurfsenseResume: - def test_true_when_payload_queued(self): + def test_true_when_entry_for_this_tool_call_present(self): runtime = _runtime_with_config( - {"configurable": {"surfsense_resume_value": "approve"}} + { + "configurable": { + "surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}} + } + }, + tool_call_id="tcid-A", ) assert has_surfsense_resume(runtime) is True + def test_false_when_entry_for_other_tool_call_only(self): + runtime = _runtime_with_config( + { + "configurable": { + "surfsense_resume_value": {"tcid-other": {"decisions": []}} + } + }, + tool_call_id="tcid-A", + ) + + assert has_surfsense_resume(runtime) is False + def test_does_not_consume_payload(self): - configurable = {"surfsense_resume_value": "approve"} - runtime = _runtime_with_config({"configurable": configurable}) + configurable = { + "surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}} + } + runtime = _runtime_with_config( + {"configurable": configurable}, tool_call_id="tcid-A" + ) has_surfsense_resume(runtime) - assert configurable == {"surfsense_resume_value": "approve"} + assert configurable["surfsense_resume_value"] == { + "tcid-A": {"decisions": ["approve"]} + } def test_false_when_payload_absent(self): runtime = _runtime_with_config({"configurable": {}}) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py new file mode 100644 index 000000000..7df9dedc6 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_interrupt_stamping.py @@ -0,0 +1,284 @@ +"""Production-shape regression tests for ``tool_call_id`` stamping on subagent interrupts. + +The production bug we're pinning here: when the orchestrator dispatches one or +more ``task`` tool calls and the targeted subagents hit a HITL ``interrupt(...)``, +the parent's persisted ``state.interrupts`` must carry the parent's +``tool_call_id`` on each interrupt value. Without that stamp, +``stream_resume_chat`` cannot route a flat ``decisions`` list back to the right +paused subagent and resume fails with ``Decision count mismatch``. + +The tests in this module: + +- Build a **real** ``StateGraph`` subagent that calls real ``interrupt(...)`` + (no MagicMock, no patch of langgraph internals — those are exactly the kind + of fakes that hid this bug). +- Invoke the ``task`` tool from **inside a parent pregel** (via a tiny parent + ``StateGraph`` node) so the subagent invocation happens in the + production-shape "subgraph called from a parent tool node" context. +- Assert on ``parent.state.interrupts[*].value["tool_call_id"]`` — the + observable that ``stream_resume_chat`` reads. +""" + +from __future__ import annotations + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _S(TypedDict, total=False): + messages: list + + +def _build_single_interrupt_subagent(checkpointer: InMemorySaver): + """Subagent that fires one HITL-bundle-shaped interrupt and waits for a decision.""" + + def approve_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "do_thing", "args": {"x": 1}, "description": ""} + ], + "review_configs": [{}], + } + ) + return {"messages": [AIMessage(content=f"got:{decision}")]} + + g = StateGraph(_S) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + return g.compile(checkpointer=checkpointer) + + +def _build_bundle_subagent(checkpointer: InMemorySaver): + """Subagent that fires one interrupt carrying a 3-action bundle.""" + + def bundle_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "a", "args": {}, "description": ""}, + {"name": "b", "args": {}, "description": ""}, + {"name": "c", "args": {}, "description": ""}, + ], + "review_configs": [{}, {}, {}], + } + ) + return {"messages": [AIMessage(content=f"bundle:{decision}")]} + + g = StateGraph(_S) + g.add_node("bundle", bundle_node) + g.add_edge(START, "bundle") + g.add_edge("bundle", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_graph_calling_task(task_tool, *, tool_call_id: str, checkpointer): + """A tiny parent graph whose only node invokes ``task_tool`` from inside the pregel runtime. + + This is the minimal reproduction of production's "subagent invoked from + inside a parent tool node" context — the *only* context where langgraph + treats the subagent as a subgraph and routes its interrupts back to the + parent's checkpoint. + """ + + async def call_task(state, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=tool_call_id, + store=None, + ) + return await task_tool.coroutine( + description="please approve", + subagent_type="approver", + runtime=rt, + ) + + g = StateGraph(_S) + g.add_node("call_task", call_task) + g.add_edge(START, "call_task") + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +class _DispatchState(TypedDict, total=False): + messages: list + tcid: str + desc: str + + +def _parent_graph_dispatching_two_tasks_via_send( + task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer +): + """A parent graph that dispatches two ``task`` calls as parallel pregel + tasks via :class:`~langgraph.types.Send`. + + This mirrors the production dispatch mechanism: when the orchestrator's + LLM emits two ``task`` tool calls in one turn, langchain's tool node + fans them out as parallel pregel tasks (the same primitive as ``Send``) + so each tool call gets its own pregel task that can raise + ``GraphInterrupt`` independently — and pregel collects *all* of them + into the parent's snapshot at the end of the superstep. + """ + + def fanout_edge(_state) -> list[Send]: + return [ + Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}), + Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}), + ] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type="approver", runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout_edge, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_interrupt_values(snapshot) -> list[dict]: + """Extract ``state.interrupts[*].value`` for assertions.""" + return [i.value for i in (snapshot.interrupts or ())] + + +@pytest.mark.asyncio +async def test_single_subagent_interrupt_stamps_parent_tool_call_id(): + """A single paused subagent must surface to the parent with ``tool_call_id`` stamped. + + Production bug regression: was producing + ``value={"action_requests": [...], "review_configs": [...]}`` (no + ``tool_call_id``), causing ``stream_resume_chat`` to skip the interrupt + and raise ``Decision count mismatch``. + """ + checkpointer = InMemorySaver() + subagent = _build_single_interrupt_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_calling_task( + task_tool, tool_call_id="parent-tcid-A", checkpointer=checkpointer + ) + + parent_config = { + "configurable": {"thread_id": "parent-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + + snap = await parent.aget_state(parent_config) + values = _parent_interrupt_values(snap) + assert len(values) == 1, ( + f"expected exactly 1 parent interrupt, got {len(values)}: {values!r}" + ) + value = values[0] + assert isinstance(value, dict) + assert value.get("tool_call_id") == "parent-tcid-A", ( + f"REGRESSION: parent interrupt missing/wrong tool_call_id stamp. " + f"Expected 'parent-tcid-A', got {value.get('tool_call_id')!r}. " + f"Keys present: {sorted(value.keys())}" + ) + # The original HITL payload must still be intact alongside the stamp. + assert value.get("action_requests") == [ + {"name": "do_thing", "args": {"x": 1}, "description": ""} + ] + + +@pytest.mark.asyncio +async def test_two_parallel_subagents_each_stamp_their_own_tool_call_id(): + """Two ``task`` calls dispatched in parallel must each carry their own ``tool_call_id``. + + This is the actual production scenario (Linear + Jira ticket creation): + two parallel ``task`` tool calls, both subagents hit HITL, parent must + end up with two interrupts whose ``tool_call_id``s match the two + distinct parent-level ``tool_call_id``s the LLM emitted. + """ + checkpointer = InMemorySaver() + subagent = _build_single_interrupt_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + + parent_config = { + "configurable": {"thread_id": "parent-thread-parallel"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + + snap = await parent.aget_state(parent_config) + values = _parent_interrupt_values(snap) + assert len(values) == 2, ( + f"expected 2 parent interrupts (one per parallel task call), " + f"got {len(values)}: {values!r}" + ) + stamps = {v.get("tool_call_id") for v in values} + assert stamps == {"parent-tcid-A", "parent-tcid-B"}, ( + f"REGRESSION: parallel parent interrupts missing/wrong tool_call_id stamps. " + f"Expected {{'parent-tcid-A', 'parent-tcid-B'}}, got {stamps!r}. " + f"Values: {values!r}" + ) + + +@pytest.mark.asyncio +async def test_bundle_subagent_interrupt_stamps_tool_call_id_preserving_actions(): + """A subagent emitting a multi-action bundle must surface stamped, with all actions intact. + + The bundle shape (``action_requests=[3 items]``) drives the + ``slice_decisions_by_tool_call`` accounting in ``stream_resume_chat`` — + if either the stamp or the action count is lost, resume routing + miscounts and crashes. + """ + checkpointer = InMemorySaver() + subagent = _build_bundle_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_calling_task( + task_tool, tool_call_id="parent-tcid-bundle", checkpointer=checkpointer + ) + + parent_config = { + "configurable": {"thread_id": "parent-thread-bundle"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + + snap = await parent.aget_state(parent_config) + values = _parent_interrupt_values(snap) + assert len(values) == 1 + value = values[0] + assert value.get("tool_call_id") == "parent-tcid-bundle" + assert isinstance(value.get("action_requests"), list) + assert len(value["action_requests"]) == 3, ( + f"REGRESSION: bundle action_requests count changed during stamping; " + f"got {len(value['action_requests'])} actions: {value['action_requests']!r}" + ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py new file mode 100644 index 000000000..3465dd1d8 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_subagent_invoke_config.py @@ -0,0 +1,94 @@ +"""Per-call ``thread_id`` derivation for nested subagent invocations. + +Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint +checkpoint slots so their nested pregel runs do not stomp on each other or on +the parent's checkpoint state. The slot key is derived from the runtime's +``tool_call_id`` so the same call across the resume cycle keeps reading from +the same snapshot. + +Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because +langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a +subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id`` +is the primary checkpoint key and is free-form, so it's the right primitive. +""" + +from __future__ import annotations + +from langchain.tools import ToolRuntime + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( + subagent_invoke_config, +) + + +def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime: + return ToolRuntime( + state=None, + context=None, + config=config or {}, + stream_writer=None, + tool_call_id=tool_call_id, + store=None, + ) + + +class TestSubagentInvokeThreadId: + def test_sets_per_call_thread_id_under_parent(self): + runtime = _runtime( + tool_call_id="tcid-A", + config={"configurable": {"thread_id": "t1"}}, + ) + + sub_config = subagent_invoke_config(runtime) + + assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A" + + def test_per_call_thread_id_nests_under_already_namespaced_parent(self): + """A subagent that itself spawns a subagent must keep nesting cleanly.""" + runtime = _runtime( + tool_call_id="tcid-inner", + config={ + "configurable": { + "thread_id": "t1::task:tcid-outer", + } + }, + ) + + sub_config = subagent_invoke_config(runtime) + + assert ( + sub_config["configurable"]["thread_id"] + == "t1::task:tcid-outer::task:tcid-inner" + ) + + def test_different_tool_call_ids_produce_different_thread_ids(self): + config = {"configurable": {"thread_id": "t1"}} + rt_a = _runtime(tool_call_id="tcid-A", config=config) + rt_b = _runtime(tool_call_id="tcid-B", config=config) + + tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"] + tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"] + + assert tid_a != tid_b + + def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self): + """Resume bridge needs to find the snapshot it primed earlier.""" + config = {"configurable": {"thread_id": "t1"}} + rt_first = _runtime(tool_call_id="tcid-A", config=config) + rt_second = _runtime(tool_call_id="tcid-A", config=config) + + tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"] + tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"] + + assert tid_first == tid_second + + def test_does_not_mutate_caller_config(self): + """Repeated calls must not accumulate suffixes onto the parent's config.""" + original_thread_id = "t1" + config = {"configurable": {"thread_id": original_thread_id}} + runtime = _runtime(tool_call_id="tcid-A", config=config) + + subagent_invoke_config(runtime) + subagent_invoke_config(runtime) + + assert config["configurable"]["thread_id"] == original_thread_id diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py new file mode 100644 index 000000000..f8399c031 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_lc_hitl_wire.py @@ -0,0 +1,129 @@ +"""Regression: ``request_permission_decision`` must emit the unified LC HITL wire shape. + +Same bug class as :mod:`test_lc_hitl_wire` for self-gated approvals: the +permission middleware previously fired the SurfSense-specific +``{type, action, context}`` shape, which the parallel-HITL routing layer +does not recognize. Standardizing on LC HITL keeps every approval kind on +one routing path. +""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import HumanMessage +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import ( + request_permission_decision, +) +from app.agents.new_chat.permissions import Rule + + +class _State(TypedDict, total=False): + messages: list + final_decision: dict + + +def _build_graph_calling_request_permission_decision(checkpointer: InMemorySaver): + """Real graph whose only node delegates to the permission ask primitive.""" + + def perm_node(_state): + decision = request_permission_decision( + tool_name="rm", + args={"path": "/tmp/file"}, + patterns=["rm/*"], + rules=[Rule(permission="rm", pattern="*", action="ask")], + emit_interrupt=True, + ) + return {"final_decision": decision} + + g = StateGraph(_State) + g.add_node("perm", perm_node) + g.add_edge(START, "perm") + g.add_edge("perm", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_permission_ask_payload_uses_lc_hitl_shape(): + """The permission middleware now speaks the langchain HITL standard shape.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_permission_decision(checkpointer) + config = {"configurable": {"thread_id": "perm-wire"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1 + value = snap.interrupts[0].value + + assert value.get("action_requests") == [ + {"name": "rm", "args": {"path": "/tmp/file"}} + ], f"REGRESSION: permission ask reverted to legacy shape; got {value!r}" + review = value.get("review_configs") + assert isinstance(review, list) and len(review) == 1 + palette = review[0]["allowed_decisions"] + # Native tool (no ``tool=`` argument): the palette must include the + # once/reject/edit triad. ``approve_always`` is gated on MCP-ness and + # therefore *omitted* here — palette content per tool kind is + # exercised in ``test_permission_ask_mcp_context``. + assert "approve" in palette and "reject" in palette and "edit" in palette + assert value.get("interrupt_type") == "permission_ask" + # SurfSense context rides through verbatim for FE explainability. + assert value["context"]["patterns"] == ["rm/*"] + assert value["context"]["always"] == ["rm/*"] + + +@pytest.mark.asyncio +async def test_resume_with_approve_envelope_returns_once_decision(): + """``approve`` from the LC envelope projects to permission-domain ``once``.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_permission_decision(checkpointer) + config = {"configurable": {"thread_id": "perm-once"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + await graph.ainvoke( + Command(resume={"decisions": [{"type": "approve"}]}), config + ) + final = await graph.aget_state(config) + assert final.values.get("final_decision") == {"decision_type": "once"} + + +@pytest.mark.asyncio +async def test_resume_with_approve_always_envelope_projects_unchanged(): + """``approve_always`` reply must project unchanged so the middleware can promote the rule.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_permission_decision(checkpointer) + config = {"configurable": {"thread_id": "perm-approve-always"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + await graph.ainvoke( + Command(resume={"decisions": [{"type": "approve_always"}]}), config + ) + final = await graph.aget_state(config) + assert final.values.get("final_decision") == {"decision_type": "approve_always"} + + +@pytest.mark.asyncio +async def test_resume_with_reject_and_feedback_carries_feedback_through(): + """Reject feedback must survive normalization for ``CorrectedError`` to fire downstream.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_permission_decision(checkpointer) + config = {"configurable": {"thread_id": "perm-reject"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + await graph.ainvoke( + Command( + resume={ + "decisions": [{"type": "reject", "feedback": "use the trash bin"}] + } + ), + config, + ) + final = await graph.aget_state(config) + assert final.values.get("final_decision") == { + "decision_type": "reject", + "feedback": "use the trash bin", + } diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py new file mode 100644 index 000000000..c9bd4e142 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_permission_ask_mcp_context.py @@ -0,0 +1,232 @@ +"""Permission-ask payload surfaces tool metadata for the FE card.""" + +from __future__ import annotations + +from typing import Annotated, Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import StructuredTool +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from pydantic import BaseModel +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.multi_agent_chat.middleware.shared.permissions.ask.payload import ( + build_permission_ask_payload, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset + + +class _NoArgs(BaseModel): + pass + + +async def _noop(**_kwargs) -> str: + return "" + + +def _ask_rule(tool_name: str) -> Rule: + return Rule(permission=tool_name, pattern="*", action="ask") + + +def _make_mcp_tool(*, name: str, connector_id: int, connector_name: str): + return StructuredTool( + name=name, + description=f"Run {name} via MCP.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={ + "mcp_connector_id": connector_id, + "mcp_connector_name": connector_name, + "mcp_transport": "http", + "hitl": True, + }, + ) + + +def test_payload_surfaces_mcp_fields_from_tool(): + tool = _make_mcp_tool( + name="linear_create_issue", connector_id=42, connector_name="Linear (acme)" + ) + payload = build_permission_ask_payload( + tool_name=tool.name, + args={"title": "bug"}, + patterns=[tool.name], + rules=[_ask_rule(tool.name)], + tool=tool, + ) + ctx = payload["context"] + assert ctx["mcp_connector_id"] == 42 + assert ctx["mcp_server"] == "Linear (acme)" + assert ctx["tool_description"] == "Run linear_create_issue via MCP." + + +def test_payload_omits_tool_fields_when_tool_is_none(): + payload = build_permission_ask_payload( + tool_name="rm", + args={"path": "/tmp/x"}, + patterns=["rm"], + rules=[_ask_rule("rm")], + tool=None, + ) + ctx = payload["context"] + assert "mcp_connector_id" not in ctx + assert "mcp_server" not in ctx + assert "tool_description" not in ctx + + +def test_palette_includes_approve_always_for_mcp_tool(): + """Saving to the connector's trusted-tools list is only possible for MCP tools.""" + tool = _make_mcp_tool( + name="linear_create_issue", connector_id=42, connector_name="Linear" + ) + palette = build_permission_ask_payload( + tool_name=tool.name, + args={}, + patterns=[tool.name], + rules=[_ask_rule(tool.name)], + tool=tool, + )["review_configs"][0]["allowed_decisions"] + assert "approve_always" in palette + + +def test_palette_excludes_approve_always_for_native_tool(): + """Native tools have no place to persist trust, so don't offer the button.""" + native = StructuredTool( + name="rm", + description="Remove a file.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={"hitl": True}, + ) + palette = build_permission_ask_payload( + tool_name=native.name, + args={"path": "/tmp/x"}, + patterns=[native.name], + rules=[_ask_rule(native.name)], + tool=native, + )["review_configs"][0]["allowed_decisions"] + assert "approve_always" not in palette + assert palette == ["approve", "reject", "edit"] + + +def test_palette_excludes_approve_always_when_tool_is_none(): + """Without a tool object the middleware can't tell — fall back to the safe triad.""" + palette = build_permission_ask_payload( + tool_name="rm", + args={"path": "/tmp/x"}, + patterns=["rm"], + rules=[_ask_rule("rm")], + tool=None, + )["review_configs"][0]["allowed_decisions"] + assert palette == ["approve", "reject", "edit"] + + +def test_payload_omits_falsy_mcp_metadata_fields(): + tool = StructuredTool( + name="anon_tool", + description="", + coroutine=_noop, + args_schema=_NoArgs, + metadata={"mcp_connector_id": None, "mcp_connector_name": ""}, + ) + ctx = build_permission_ask_payload( + tool_name=tool.name, + args={}, + patterns=[tool.name], + rules=[_ask_rule(tool.name)], + tool=tool, + )["context"] + assert "mcp_connector_id" not in ctx + assert "mcp_server" not in ctx + assert "tool_description" not in ctx + + +class _State(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str): + def _node(_state: _State) -> dict[str, Any]: + return { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": args, + "id": call_id, + "type": "tool_call", + } + ], + ) + ] + } + + return _node + + +def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str): + def after(state: _State) -> dict[str, Any] | None: + return pm.after_model(state, None) # type: ignore[arg-type] + + g = StateGraph(_State) + g.add_node("emit", _emit_tool_call(tool_name, args, call_id)) + g.add_node("permission", after) + g.add_edge(START, "emit") + g.add_edge("emit", "permission") + g.add_edge("permission", END) + return g.compile(checkpointer=InMemorySaver()) + + +@pytest.mark.asyncio +async def test_middleware_decorates_interrupt_with_mcp_tool_metadata(): + tool = _make_mcp_tool( + name="linear_create_issue", connector_id=7, connector_name="Linear" + ) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[ + Ruleset(origin="linear", rules=[_ask_rule(tool.name)]), + ], + tools=[tool], + ) + assert pm is not None + + graph = _compile_graph_with(pm, tool.name, {"title": "bug"}, "call-1") + config = {"configurable": {"thread_id": "linear-ask"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1 + ctx = snap.interrupts[0].value["context"] + assert ctx["mcp_connector_id"] == 7 + assert ctx["mcp_server"] == "Linear" + assert ctx["tool_description"] == "Run linear_create_issue via MCP." + + +@pytest.mark.asyncio +async def test_middleware_without_tool_index_still_asks_without_tool_fields(): + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule("rm")])], + ) + assert pm is not None + + graph = _compile_graph_with(pm, "rm", {"path": "/tmp/foo"}, "call-rm") + config = {"configurable": {"thread_id": "kb-rm"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1 + ctx = snap.interrupts[0].value["context"] + assert "mcp_connector_id" not in ctx + assert "mcp_server" not in ctx + assert "tool_description" not in ctx diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py new file mode 100644 index 000000000..6406fb09a --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_subagent_owned_ruleset.py @@ -0,0 +1,168 @@ +"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``. + +The KB unification swap (legacy ``interrupt_on`` map → KB-owned ``Ruleset`` +threaded through ``build_permission_mw(subagent_rulesets=...)``) must +produce *exactly one* interrupt per destructive FS call, in LC HITL +shape, even when ``enable_permission`` is False — destructive ops always +ask. + +We exercise the production factory and a real ``PermissionMiddleware`` on a +real ``StateGraph`` so the test catches regressions in factory gating, +ruleset layering, and interrupt emission together. +""" + +from __future__ import annotations + +from typing import Annotated, Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset + + +def _kb_style_ruleset() -> Ruleset: + """Mirror :data:`knowledge_base.agent.KB_RULESET` without importing it. + + Importing the agent module pulls in deepagents and prompts; this test + is about the factory + middleware contract, not KB wiring. + """ + return Ruleset( + origin="knowledge_base", + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + ) + + +class _State(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +def _build_graph_with_permission_middleware( + *, + flags: AgentFeatureFlags, + subagent_rulesets: list[Ruleset] | None, + checkpointer: InMemorySaver, +): + """Compile a one-node graph that emits a tool call for ``rm`` and + routes through the production ``PermissionMiddleware``. + + The node returns an ``AIMessage`` with a tool call. The middleware's + ``after_model`` hook intercepts and (if a rule says ``ask``) raises + a ``GraphInterrupt`` carrying the LC HITL payload. + """ + pm = build_permission_mw(flags=flags, subagent_rulesets=subagent_rulesets) + + def node(_state: _State) -> dict[str, Any]: + msg = AIMessage( + content="", + tool_calls=[ + { + "name": "rm", + "args": {"path": "/tmp/foo"}, + "id": "call-rm-1", + "type": "tool_call", + } + ], + ) + return {"messages": [msg]} + + def after_node(state: _State) -> dict[str, Any] | None: + if pm is None: + return None + return pm.after_model(state, None) # type: ignore[arg-type] + + g = StateGraph(_State) + g.add_node("emit", node) + g.add_node("permission", after_node) + g.add_edge(START, "emit") + g.add_edge("emit", "permission") + g.add_edge("permission", END) + return g.compile(checkpointer=checkpointer), pm + + +@pytest.mark.asyncio +async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_flag_off(): + """KB ruleset: ``rm`` must ask once even with ``enable_permission=False``. + + This is the keystone of the unification: the legacy ``interrupt_on`` + map fired regardless of ``enable_permission``, so the migrated rules + must too. Otherwise users could opt out of "ask before rm". + """ + flags = AgentFeatureFlags(enable_permission=False) + checkpointer = InMemorySaver() + graph, pm = _build_graph_with_permission_middleware( + flags=flags, + subagent_rulesets=[_kb_style_ruleset()], + checkpointer=checkpointer, + ) + assert pm is not None, "subagent rulesets must force the middleware on" + + config = {"configurable": {"thread_id": "kb-cloud-rm"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1, ( + f"REGRESSION: KB ruleset should raise exactly one interrupt; got " + f"{[i.value for i in snap.interrupts]!r}" + ) + payload = snap.interrupts[0].value + requests = payload.get("action_requests") + assert requests == [{"name": "rm", "args": {"path": "/tmp/foo"}}], ( + f"interrupt must carry the rm call in LC HITL shape; got {payload!r}" + ) + assert payload.get("interrupt_type") == "permission_ask" + + +@pytest.mark.asyncio +async def test_kb_ruleset_resume_with_approve_lets_rm_through(): + """Resume with ``approve`` → call kept; the model continues normally.""" + flags = AgentFeatureFlags(enable_permission=False) + checkpointer = InMemorySaver() + graph, _ = _build_graph_with_permission_middleware( + flags=flags, + subagent_rulesets=[_kb_style_ruleset()], + checkpointer=checkpointer, + ) + config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + await graph.ainvoke( + Command(resume={"decisions": [{"type": "approve"}]}), config + ) + final = await graph.aget_state(config) + assert final.next == (), "graph must complete after approve" + last_ai = next( + (m for m in reversed(final.values["messages"]) if isinstance(m, AIMessage)), + None, + ) + assert last_ai is not None + assert [tc["name"] for tc in last_ai.tool_calls] == ["rm"], ( + "approved rm call must remain on the AIMessage so the tool can run" + ) + + +@pytest.mark.asyncio +async def test_no_subagent_rulesets_with_permission_off_skips_middleware_entirely(): + """No subagent rulesets + permission off → factory returns ``None`` (no engine). + + The legacy gating is preserved when no caller asks for rules: nothing + runs, nothing pauses. + """ + flags = AgentFeatureFlags(enable_permission=False) + pm = build_permission_mw(flags=flags, subagent_rulesets=None) + assert pm is None diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py new file mode 100644 index 000000000..47d3704ac --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/shared/permissions/test_trusted_tool_save_on_always.py @@ -0,0 +1,186 @@ +"""``approve_always`` decisions for MCP tools are saved via the trusted-tool saver.""" + +from __future__ import annotations + +from typing import Annotated, Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import StructuredTool +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command +from pydantic import BaseModel +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.shared.permissions import ( + build_permission_mw, +) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset + + +class _NoArgs(BaseModel): + pass + + +async def _noop(**_kwargs) -> str: + return "" + + +def _ask_rule(tool_name: str) -> Rule: + return Rule(permission=tool_name, pattern="*", action="ask") + + +def _make_mcp_tool(*, name: str, connector_id: int): + return StructuredTool( + name=name, + description=f"Run {name} via MCP.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={ + "mcp_connector_id": connector_id, + "mcp_connector_name": "Linear", + "mcp_transport": "http", + "hitl": True, + }, + ) + + +def _make_native_tool(*, name: str): + return StructuredTool( + name=name, + description=f"Native {name}.", + coroutine=_noop, + args_schema=_NoArgs, + metadata={"hitl": True}, + ) + + +class _State(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +def _build_graph(pm, tool_name: str): + def emit(_state: _State) -> dict[str, Any]: + return { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": {}, + "id": "call-1", + "type": "tool_call", + } + ], + ) + ] + } + + g = StateGraph(_State) + g.add_node("emit", emit) + g.add_node("permission", pm.aafter_model) # type: ignore[arg-type] + g.add_edge(START, "emit") + g.add_edge("emit", "permission") + g.add_edge("permission", END) + return g.compile(checkpointer=InMemorySaver()) + + +@pytest.mark.asyncio +async def test_approve_always_decision_saves_mcp_tool_via_callback(): + saved: list[tuple[int, str]] = [] + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + saved.append((connector_id, tool_name)) + + tool = _make_mcp_tool(name="linear_create_issue", connector_id=7) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=trusted_tool_saver, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "approve-always-mcp"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke( + Command(resume={"decisions": [{"type": "approve_always"}]}), config + ) + + assert saved == [(7, "linear_create_issue")] + + +@pytest.mark.asyncio +async def test_once_decision_does_not_save(): + saved: list[tuple[int, str]] = [] + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + saved.append((connector_id, tool_name)) + + tool = _make_mcp_tool(name="linear_create_issue", connector_id=7) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=trusted_tool_saver, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "once-mcp"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config) + + assert saved == [] + + +@pytest.mark.asyncio +async def test_approve_always_decision_for_native_tool_skips_save(): + """Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust.""" + saved: list[tuple[int, str]] = [] + + async def trusted_tool_saver(connector_id: int, tool_name: str) -> None: + saved.append((connector_id, tool_name)) + + tool = _make_native_tool(name="rm") + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=trusted_tool_saver, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "approve-always-native"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke( + Command(resume={"decisions": [{"type": "approve_always"}]}), config + ) + + assert saved == [] + + +@pytest.mark.asyncio +async def test_approve_always_decision_with_no_saver_callback_is_a_noop(): + """Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash.""" + tool = _make_mcp_tool(name="linear_create_issue", connector_id=7) + pm = build_permission_mw( + flags=AgentFeatureFlags(enable_permission=False), + subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])], + tools=[tool], + trusted_tool_saver=None, + ) + assert pm is not None + + graph = _build_graph(pm, tool.name) + config = {"configurable": {"thread_id": "anon-approve-always"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + await graph.ainvoke( + Command(resume={"decisions": [{"type": "approve_always"}]}), config + ) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py new file mode 100644 index 000000000..195b1bc01 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/approvals/self_gated/test_lc_hitl_wire.py @@ -0,0 +1,132 @@ +"""Regression: ``request_approval`` must emit the unified LC HITL wire shape. + +Before this fix, self-gated approvals fired the SurfSense-specific +``{type, action, context}`` shape which the parallel-HITL routing layer +(``collect_pending_tool_calls``) does not recognize. In a parallel HITL +scenario where one subagent used self-gated approvals (e.g. Gmail send) +and another used middleware-gated approvals (e.g. Linear via +``HumanInTheLoopMiddleware``), the routing layer would silently skip the +self-gated interrupt and crash on resume with ``Decision count mismatch``. + +This test pins the wire contract by running ``request_approval`` inside a +real ``StateGraph`` and asserting the paused parent observes the LC HITL +shape (``action_requests``, ``review_configs``, ``interrupt_type``). +""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import HumanMessage +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import ( + request_approval, +) + + +class _State(TypedDict, total=False): + messages: list + final_decision_type: str + final_params: dict + + +def _build_graph_calling_request_approval(checkpointer: InMemorySaver): + """A real graph whose only node delegates to ``request_approval``.""" + + def gate_node(_state): + result = request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={"to": "alice@example.com", "subject": "hi"}, + context={"account": "alice@gmail.com"}, + ) + return { + "final_decision_type": result.decision_type, + "final_params": result.params, + } + + g = StateGraph(_State) + g.add_node("gate", gate_node) + g.add_edge(START, "gate") + g.add_edge("gate", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_paused_interrupt_uses_lc_hitl_action_requests_shape(): + """The paused interrupt must speak the langchain HITL standard shape.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_approval(checkpointer) + config = {"configurable": {"thread_id": "self-gated-wire"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + snap = await graph.aget_state(config) + assert len(snap.interrupts) == 1, ( + f"expected one paused interrupt, got {len(snap.interrupts)}" + ) + value = snap.interrupts[0].value + assert isinstance(value, dict) + + # Standard LC HITL fields the routing layer reads. + assert value.get("action_requests") == [ + { + "name": "send_gmail_email", + "args": {"to": "alice@example.com", "subject": "hi"}, + } + ], ( + "REGRESSION: self-gated approval reverted to legacy SurfSense shape; " + f"got {value!r}" + ) + assert value.get("review_configs") == [ + { + "action_name": "send_gmail_email", + "allowed_decisions": ["approve", "reject", "edit"], + } + ] + assert value.get("interrupt_type") == "gmail_email_send", ( + "FE card discriminator must travel as ``interrupt_type``." + ) + assert value.get("context") == {"account": "alice@gmail.com"} + + +@pytest.mark.asyncio +async def test_resume_with_lc_envelope_returns_hitl_result_with_edited_args(): + """Edit reply via the LC envelope must round-trip into ``HITLResult.params``.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_approval(checkpointer) + config = {"configurable": {"thread_id": "self-gated-resume"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + edited = {"to": "alice@example.com", "subject": "EDITED"} + await graph.ainvoke( + Command( + resume={ + "decisions": [ + {"type": "edit", "edited_action": {"args": {"subject": "EDITED"}}} + ] + } + ), + config, + ) + final = await graph.aget_state(config) + assert final.values.get("final_decision_type") == "edit" + assert final.values.get("final_params") == edited + + +@pytest.mark.asyncio +async def test_reject_envelope_returns_rejected_hitl_result(): + """Reject reply must surface as ``HITLResult.rejected=True`` without invoking the tool.""" + checkpointer = InMemorySaver() + graph = _build_graph_calling_request_approval(checkpointer) + config = {"configurable": {"thread_id": "self-gated-reject"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + await graph.ainvoke( + Command(resume={"decisions": [{"type": "reject", "feedback": "no"}]}), + config, + ) + final = await graph.aget_state(config) + assert final.values.get("final_decision_type") == "reject" diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py new file mode 100644 index 000000000..e54dbbb5a --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/hitl/wire/test_hitl_wire.py @@ -0,0 +1,168 @@ +"""Unit contract for the unified LC HITL wire format. + +Both the self-gated approval primitive (``request_approval``) and the +middleware-gated permission ask (``PermissionMiddleware``) must serialize +to the same wire shape so the parallel-HITL routing layer +(``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call`` + +``build_lg_resume_map``) sees one format. + +These tests pin the shape: + +- Builder always emits ``action_requests`` (1 entry) + ``review_configs`` + + ``interrupt_type``; ``context`` rides through verbatim when present. +- Parser tolerates the standard LC envelope, bare scalar strings, and + unrecognized shapes (failing closed to ``reject``). +- Edited args round-trip through both nested (``edited_action.args``) and + flat (``args``) shapes without inventing values for the empty case. +""" + +from __future__ import annotations + +from app.agents.multi_agent_chat.subagents.shared.hitl.wire import ( + LC_DECISION_APPROVE, + LC_DECISION_EDIT, + LC_DECISION_REJECT, + SURFSENSE_DECISION_APPROVE_ALWAYS, + build_lc_hitl_payload, + parse_lc_envelope, +) + + +class TestBuildLcHitlPayload: + def test_minimal_payload_has_one_action_request_and_one_review_config(self): + payload = build_lc_hitl_payload( + tool_name="send_email", + args={"to": "x@y.z"}, + allowed_decisions=[LC_DECISION_APPROVE, LC_DECISION_REJECT], + interrupt_type="gmail_email_send", + ) + assert payload["action_requests"] == [ + {"name": "send_email", "args": {"to": "x@y.z"}} + ] + assert payload["review_configs"] == [ + { + "action_name": "send_email", + "allowed_decisions": [LC_DECISION_APPROVE, LC_DECISION_REJECT], + } + ] + assert payload["interrupt_type"] == "gmail_email_send" + assert "context" not in payload, "context must be omitted when not provided" + + def test_none_args_normalized_to_empty_dict(self): + """FE expects a stable shape; ``None`` would crash card rendering.""" + payload = build_lc_hitl_payload( + tool_name="ping", + args=None, # type: ignore[arg-type] + allowed_decisions=[LC_DECISION_APPROVE], + interrupt_type="self_gated", + ) + assert payload["action_requests"][0]["args"] == {} + + def test_description_attached_only_when_provided(self): + with_desc = build_lc_hitl_payload( + tool_name="t", + args={}, + allowed_decisions=[LC_DECISION_APPROVE], + interrupt_type="x", + description="please review", + ) + without = build_lc_hitl_payload( + tool_name="t", + args={}, + allowed_decisions=[LC_DECISION_APPROVE], + interrupt_type="x", + ) + assert with_desc["action_requests"][0]["description"] == "please review" + assert "description" not in without["action_requests"][0] + + def test_context_passed_through_verbatim(self): + ctx = {"patterns": ["rm/*"], "rules": [], "always": ["rm/*"]} + payload = build_lc_hitl_payload( + tool_name="rm", + args={"path": "/tmp"}, + allowed_decisions=[ + LC_DECISION_APPROVE, + LC_DECISION_REJECT, + SURFSENSE_DECISION_APPROVE_ALWAYS, + ], + interrupt_type="permission_ask", + context=ctx, + ) + assert payload["context"] == ctx + + def test_allowed_decisions_list_is_copied_not_aliased(self): + """A caller mutating their original list must not corrupt the payload.""" + decisions = [LC_DECISION_APPROVE] + payload = build_lc_hitl_payload( + tool_name="t", + args={}, + allowed_decisions=decisions, + interrupt_type="x", + ) + decisions.append(LC_DECISION_REJECT) + assert payload["review_configs"][0]["allowed_decisions"] == [LC_DECISION_APPROVE] + + +class TestParseLcEnvelope: + def test_standard_lc_envelope_returns_typed_decision(self): + parsed = parse_lc_envelope({"decisions": [{"type": "approve"}]}) + assert parsed.decision_type == "approve" + assert parsed.edited_args is None + assert parsed.message is None + + def test_bare_scalar_string_passes_through_lowercased(self): + assert parse_lc_envelope("APPROVE_ALWAYS").decision_type == "approve_always" + assert parse_lc_envelope("once").decision_type == "once" + + def test_non_dict_non_string_collapses_to_reject(self): + """Failing closed: ambiguous input must never proceed.""" + assert parse_lc_envelope(42).decision_type == "reject" + assert parse_lc_envelope(None).decision_type == "reject" + assert parse_lc_envelope(["bogus"]).decision_type == "reject" + + def test_missing_decision_type_collapses_to_reject(self): + assert parse_lc_envelope({"decisions": [{}]}).decision_type == "reject" + assert parse_lc_envelope({"foo": "bar"}).decision_type == "reject" + + def test_edit_extracts_nested_args(self): + parsed = parse_lc_envelope( + { + "decisions": [ + { + "type": LC_DECISION_EDIT, + "edited_action": {"args": {"to": "edited@y.z"}}, + } + ] + } + ) + assert parsed.decision_type == "edit" + assert parsed.edited_args == {"to": "edited@y.z"} + + def test_edit_falls_back_to_flat_args(self): + parsed = parse_lc_envelope( + {"decisions": [{"type": "edit", "args": {"k": "v"}}]} + ) + assert parsed.edited_args == {"k": "v"} + + def test_edit_with_empty_args_yields_none_edited(self): + """Empty edited_args means "no edits" — caller treats as plain approve.""" + parsed = parse_lc_envelope( + {"decisions": [{"type": "edit", "edited_action": {"args": {}}}]} + ) + assert parsed.edited_args is None + + def test_message_picked_from_either_feedback_or_message_field(self): + with_feedback = parse_lc_envelope( + {"decisions": [{"type": "reject", "feedback": "no thanks"}]} + ) + with_message = parse_lc_envelope( + {"decisions": [{"type": "reject", "message": "no thanks"}]} + ) + assert with_feedback.message == "no thanks" + assert with_message.message == "no thanks" + + def test_blank_message_treated_as_absent(self): + parsed = parse_lc_envelope( + {"decisions": [{"type": "reject", "message": " "}]} + ) + assert parsed.message is None diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py index 123bdc09f..062ea92ec 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py @@ -19,9 +19,14 @@ from langchain_core.language_models.fake_chat_models import ( from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult +from app.agents.multi_agent_chat.middleware.shared.permissions.middleware.core import ( + PermissionMiddleware, +) from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( pack_subagent, ) +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.permissions import Rule, Ruleset, evaluate class RateLimitError(Exception): @@ -73,14 +78,17 @@ async def test_subagent_recovers_when_primary_llm_fails(): responses=[AIMessage(content="recovered via fallback")] ) - spec = pack_subagent( + result = pack_subagent( name="resilience_test", description="test subagent", system_prompt="be helpful", tools=[], + ruleset=Ruleset(origin="resilience_test", rules=[]), + dependencies={"flags": AgentFeatureFlags()}, model=primary, middleware_stack={"fallback": ModelFallbackMiddleware(fallback)}, ) + spec = result.spec agent = create_agent( model=spec["model"], @@ -94,3 +102,142 @@ async def test_subagent_recovers_when_primary_llm_fails(): final = result["messages"][-1] assert isinstance(final, AIMessage) assert final.content == "recovered via fallback" + + +def _extract_permission_mw(spec) -> PermissionMiddleware: + """Find the lone PermissionMiddleware in a subagent's middleware list.""" + matches = [m for m in spec["middleware"] if isinstance(m, PermissionMiddleware)] + assert len(matches) == 1, "expected exactly one PermissionMiddleware" + return matches[0] + + +def test_user_allowlist_overrides_coded_ask_via_last_match_wins(): + """User ``allow`` rules promoted via "Always Allow" must beat coded ``ask`` rules.""" + coded = Ruleset( + origin="connector", + rules=[Rule(permission="save_issue", pattern="*", action="ask")], + ) + user_allowlist = Ruleset( + origin="user_allowlist:connector", + rules=[Rule(permission="save_issue", pattern="*", action="allow")], + ) + + result = pack_subagent( + name="connector", + description="test connector", + system_prompt="x", + tools=[], + ruleset=coded, + dependencies={ + "flags": AgentFeatureFlags(), + "user_allowlist_by_subagent": {"connector": user_allowlist}, + }, + ) + + mw = _extract_permission_mw(result.spec) + decided = evaluate("save_issue", "*", *mw._static_rulesets) + assert decided.action == "allow", ( + f"user_allowlist must override coded ask; got {decided!r}" + ) + + +def test_coded_ask_stays_when_user_allowlist_unrelated(): + """User ``allow`` rules for OTHER tools must not leak into asked-tools.""" + coded = Ruleset( + origin="connector", + rules=[Rule(permission="delete_issue", pattern="*", action="ask")], + ) + user_allowlist = Ruleset( + origin="user_allowlist:connector", + rules=[Rule(permission="save_issue", pattern="*", action="allow")], + ) + + result = pack_subagent( + name="connector", + description="test", + system_prompt="x", + tools=[], + ruleset=coded, + dependencies={ + "flags": AgentFeatureFlags(), + "user_allowlist_by_subagent": {"connector": user_allowlist}, + }, + ) + + mw = _extract_permission_mw(result.spec) + decided = evaluate("delete_issue", "*", *mw._static_rulesets) + assert decided.action == "ask" + + +def test_missing_user_allowlist_keeps_coded_behaviour(): + """``dependencies`` without ``user_allowlist_by_subagent`` is the common case.""" + coded = Ruleset( + origin="connector", + rules=[Rule(permission="save_issue", pattern="*", action="ask")], + ) + + result = pack_subagent( + name="connector", + description="test", + system_prompt="x", + tools=[], + ruleset=coded, + dependencies={"flags": AgentFeatureFlags()}, + ) + + mw = _extract_permission_mw(result.spec) + decided = evaluate("save_issue", "*", *mw._static_rulesets) + assert decided.action == "ask" + + +def test_user_allowlist_for_different_subagent_does_not_leak(): + """User trust for ``linear`` must not affect a ``jira`` subagent compile.""" + coded = Ruleset( + origin="jira", + rules=[Rule(permission="save_issue", pattern="*", action="ask")], + ) + linear_allowlist = Ruleset( + origin="user_allowlist:linear", + rules=[Rule(permission="save_issue", pattern="*", action="allow")], + ) + + result = pack_subagent( + name="jira", + description="test", + system_prompt="x", + tools=[], + ruleset=coded, + dependencies={ + "flags": AgentFeatureFlags(), + "user_allowlist_by_subagent": {"linear": linear_allowlist}, + }, + ) + + mw = _extract_permission_mw(result.spec) + decided = evaluate("save_issue", "*", *mw._static_rulesets) + assert decided.action == "ask" + + +def test_empty_user_allowlist_is_tolerated(): + """An empty ``Ruleset`` (no rules) must not flip evaluation to allow-everything.""" + coded = Ruleset( + origin="connector", + rules=[Rule(permission="save_issue", pattern="*", action="ask")], + ) + empty = Ruleset(origin="user_allowlist:connector", rules=[]) + + result = pack_subagent( + name="connector", + description="test", + system_prompt="x", + tools=[], + ruleset=coded, + dependencies={ + "flags": AgentFeatureFlags(), + "user_allowlist_by_subagent": {"connector": empty}, + }, + ) + + mw = _extract_permission_mw(result.spec) + decided = evaluate("save_issue", "*", *mw._static_rulesets) + assert decided.action == "ask" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py index 47059ade6..68db11ba6 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py @@ -106,9 +106,9 @@ class TestAsk: # No new rule persisted assert mw._runtime_ruleset.rules == [] - def test_always_persists_runtime_rule(self) -> None: + def test_approve_always_persists_runtime_rule(self) -> None: mw = PermissionMiddleware(rulesets=[]) - mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment] + mw._raise_interrupt = lambda **kw: {"decision_type": "approve_always"} # type: ignore[assignment] state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} out = mw.after_model(state, _FakeRuntime()) assert out is None # call kept diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py new file mode 100644 index 000000000..348e49a4a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_interrupt_inspector_all.py @@ -0,0 +1,210 @@ +"""Real-graph contract: ``all_interrupt_values`` surfaces every pending interrupt. + +The chat-stream emit loop must yield one ``data-interrupt-request`` SSE frame +per paused subagent, in the same order ``state.interrupts`` reports them — +that's also the order the resume slicer consumes decisions. These tests pin +that contract against a **real** paused parent graph built via +:class:`~langgraph.types.Send` fan-out (no synthetic state mocks). +""" + +from __future__ import annotations + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + all_interrupt_values, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +class _DispatchState(TypedDict, total=False): + messages: list + tcid: str + desc: str + + +def _build_pausing_subagent(checkpointer: InMemorySaver): + def approve_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "do_thing", "args": {"x": 1}, "description": ""} + ], + "review_configs": [{}], + } + ) + return {"messages": [AIMessage(content=f"got:{decision}")]} + + g = StateGraph(_SubState) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_graph_dispatching_two_tasks_via_send( + task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer +): + def fanout_edge(_state) -> list[Send]: + return [ + Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}), + Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}), + ] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type="approver", runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout_edge, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +@pytest.mark.asyncio +async def test_returns_every_pending_interrupt_for_two_paused_subagents(): + """Two parallel subagents -> ``all_interrupt_values`` returns 2 dicts.""" + checkpointer = InMemorySaver() + subagent = _build_pausing_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + + parent_config = { + "configurable": {"thread_id": "all-iv-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + state = await parent.aget_state(parent_config) + + values = all_interrupt_values(state) + + assert isinstance(values, list) + assert len(values) == 2, ( + f"REGRESSION: expected one value per pending subagent, got " + f"{len(values)}: {values!r}" + ) + stamps = [v.get("tool_call_id") for v in values] + assert sorted(stamps) == ["parent-tcid-A", "parent-tcid-B"] + for v in values: + assert isinstance(v.get("action_requests"), list) + assert len(v["action_requests"]) == 1 + + +@pytest.mark.asyncio +async def test_preserves_state_interrupts_traversal_order(): + """Order returned by inspector must match ``state.interrupts`` order. + + The resume slicer consumes decisions left-to-right against + ``collect_pending_tool_calls(state)`` which walks ``state.interrupts`` + in iteration order — so the inspector (which drives the *emit* order) + must agree with that traversal or the slice and the wire fall out of sync. + """ + checkpointer = InMemorySaver() + subagent = _build_pausing_subagent(checkpointer) + task_tool = build_task_tool_with_parent_config( + [{"name": "approver", "description": "approves", "runnable": subagent}] + ) + parent = _parent_graph_dispatching_two_tasks_via_send( + task_tool, + tool_call_id_a="parent-tcid-A", + tool_call_id_b="parent-tcid-B", + checkpointer=checkpointer, + ) + parent_config = { + "configurable": {"thread_id": "order-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + state = await parent.aget_state(parent_config) + + inspector_order = [v["tool_call_id"] for v in all_interrupt_values(state)] + state_order = [ + i.value["tool_call_id"] + for i in state.interrupts + if isinstance(getattr(i, "value", None), dict) + and "tool_call_id" in i.value + ] + + assert inspector_order == state_order, ( + f"inspector order {inspector_order!r} diverged from state.interrupts " + f"order {state_order!r}; the resume slicer would mis-route decisions." + ) + + +@pytest.mark.asyncio +async def test_returns_empty_list_when_nothing_paused(): + """A graph that completes normally produces no interrupts to surface.""" + + def done_node(_state): + return {"messages": [AIMessage(content="done")]} + + g = StateGraph(_SubState) + g.add_node("done", done_node) + g.add_edge(START, "done") + g.add_edge("done", END) + graph = g.compile(checkpointer=InMemorySaver()) + config = {"configurable": {"thread_id": "no-pause-thread"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + state = await graph.aget_state(config) + + assert all_interrupt_values(state) == [] + + +@pytest.mark.asyncio +async def test_single_paused_subagent_returns_a_list_of_one(): + """Single-pause case must still return a list (not unwrap to a dict).""" + + def approve_node(_state): + decision = interrupt( + { + "action_requests": [{"name": "x", "args": {}, "description": ""}], + "review_configs": [{}], + "tool_call_id": "lonely-tcid", + } + ) + return {"messages": [AIMessage(content=f"got:{decision}")]} + + g = StateGraph(_SubState) + g.add_node("approve", approve_node) + g.add_edge(START, "approve") + g.add_edge("approve", END) + graph = g.compile(checkpointer=InMemorySaver()) + config = {"configurable": {"thread_id": "single-thread"}} + await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + state = await graph.aget_state(config) + + values = all_interrupt_values(state) + + assert isinstance(values, list) + assert len(values) == 1 + assert values[0].get("tool_call_id") == "lonely-tcid" diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py index d598de492..8fde773e3 100644 --- a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py @@ -23,7 +23,6 @@ from app.tasks.chat.stream_new_chat import ( _emit_stream_terminal_error as old_emit_terminal_error, _extract_chunk_parts as old_extract_chunk_parts, _extract_resolved_file_path as old_extract_resolved_file_path, - _first_interrupt_value as old_first_interrupt_value, _tool_output_has_error as old_tool_output_has_error, _tool_output_to_text as old_tool_output_to_text, ) @@ -36,9 +35,6 @@ from app.tasks.chat.streaming.errors.emitter import ( from app.tasks.chat.streaming.helpers.chunk_parts import ( extract_chunk_parts as new_extract_chunk_parts, ) -from app.tasks.chat.streaming.helpers.interrupt_inspector import ( - first_interrupt_value as new_first_interrupt_value, -) from app.tasks.chat.streaming.helpers.tool_output import ( extract_resolved_file_path as new_extract_resolved_file_path, tool_output_has_error as new_tool_output_has_error, @@ -105,52 +101,6 @@ def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) -# ---------------------------------------------------------- interrupt inspector - - -@dataclass -class _Interrupt: - value: dict[str, Any] - - -@dataclass -class _Task: - interrupts: tuple[Any, ...] = () - - -@dataclass -class _State: - tasks: tuple[Any, ...] = () - interrupts: tuple[Any, ...] = () - - -_INTERRUPT_CASES: list[Any] = [ - _State(), - _State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)), - # Multiple tasks: must return the FIRST one in iteration order. - _State( - tasks=( - _Task(interrupts=(_Interrupt(value={"name": "first"}),)), - _Task(interrupts=(_Interrupt(value={"name": "second"}),)), - ) - ), - # Empty task interrupts -> falls back to root state.interrupts. - _State( - tasks=(_Task(interrupts=()),), - interrupts=(_Interrupt(value={"name": "root"}),), - ), - # Interrupts as plain dicts (not wrapper objects). - _State(interrupts=({"value": {"name": "dict_root"}},)), - # A defective task whose `.interrupts` raises - must be tolerated. - _State(tasks=(object(),)), -] - - -@pytest.mark.parametrize("state", _INTERRUPT_CASES) -def test_first_interrupt_value_matches_old_implementation(state: Any) -> None: - assert new_first_interrupt_value(state) == old_first_interrupt_value(state) - - # ----------------------------------------------------------- error classifier diff --git a/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py new file mode 100644 index 000000000..4ce73bc2e --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_thinking_step_id_uniqueness.py @@ -0,0 +1,169 @@ +"""Pin: thinking-step IDs must be globally unique within a thread. + +The frontend rehydrates ``currentThinkingSteps`` from the prior assistant +message when starting a resume. If two consecutive resume turns emit step IDs +that overlap (e.g. both produce ``thinking-resume-1`` because each invocation +constructs a fresh :class:`AgentEventRelayState` with +``thinking_step_counter=0``), React renders sibling timeline rows with the +same key — the warning the user reported in production. + +The contract this module pins: each ``_stream_agent_events`` invocation must +receive a ``step_prefix`` that is unique within the thread (we salt with the +per-turn ``turn_id``), so the resulting step IDs across consecutive turns +are always disjoint. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.stream_new_chat import ( + StreamResult, + _resume_step_prefix, + _stream_agent_events, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeChunk: + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +class _FakeAgentState: + def __init__(self) -> None: + self.values: dict[str, Any] = {} + self.tasks: list[Any] = [] + + +class _FakeAgent: + def __init__(self, events: list[dict[str, Any]]) -> None: + self._events = events + self._state = _FakeAgentState() + + async def astream_events( # type: ignore[no-untyped-def] + self, _input_data: Any, *, config: dict[str, Any], version: str + ) -> AsyncGenerator[dict[str, Any], None]: + del config, version + for ev in self._events: + yield ev + + async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState: + return self._state + + +def _tool_start(*, name: str, run_id: str) -> dict[str, Any]: + return { + "event": "on_tool_start", + "name": name, + "run_id": run_id, + "data": {"input": {}}, + } + + +async def _drain_step_ids(events: list[dict[str, Any]], *, step_prefix: str) -> set[str]: + """Run ``_stream_agent_events`` once and return every emitted thinking-step ID.""" + agent = _FakeAgent(events) + service = VercelStreamingService() + result = StreamResult() + config = {"configurable": {"thread_id": "regression-thread"}} + + sse_lines: list[str] = [] + async for sse in _stream_agent_events( + agent, config, {}, service, result, step_prefix=step_prefix + ): + sse_lines.append(sse) + + ids: set[str] = set() + for line in sse_lines: + if not line.startswith("data: "): + continue + body = line[len("data: ") :].rstrip("\n") + if not body or body == "[DONE]": + continue + try: + payload = json.loads(body) + except json.JSONDecodeError: + continue + if payload.get("type") != "data-thinking-step": + continue + step_id = (payload.get("data") or {}).get("id") + if isinstance(step_id, str): + ids.add(step_id) + return ids + + +@pytest.mark.asyncio +async def test_consecutive_invocations_with_same_prefix_produce_overlapping_ids(): + """Pin the bug: identical ``step_prefix`` across two turns reuses ``-1``, ``-2``… + + This is what production was doing for resume — every resume invocation + passed ``step_prefix='thinking-resume'`` and the relay state's counter + restarted at 0. Two scrollback timelines built from such turns then + presented React with siblings keyed by the same ``thinking-resume-1``. + """ + events = [ + _tool_start(name="t1", run_id="run-A-1"), + _tool_start(name="t2", run_id="run-A-2"), + ] + + ids_turn_one = await _drain_step_ids(events, step_prefix="thinking-resume") + ids_turn_two = await _drain_step_ids(events, step_prefix="thinking-resume") + + assert ids_turn_one == ids_turn_two != set(), ( + "fixture broken: expected non-empty overlapping ids when prefix is reused" + ) + + +@pytest.mark.asyncio +async def test_per_turn_salted_prefix_yields_disjoint_step_ids_across_turns(): + """The fix: salting the prefix with the per-turn ``turn_id`` makes IDs disjoint. + + Two consecutive resume calls in the same thread feed two different + ``turn_id``s into the prefix, so the resulting step IDs cannot collide + no matter how many times the FE rehydrates from earlier assistant + messages — which is the precondition for the React duplicate-key warning. + """ + events = [ + _tool_start(name="t1", run_id="run-A-1"), + _tool_start(name="t2", run_id="run-A-2"), + ] + + ids_turn_one = await _drain_step_ids( + events, step_prefix="thinking-resume-104:1778698228472" + ) + ids_turn_two = await _drain_step_ids( + events, step_prefix="thinking-resume-104:1778698244022" + ) + + assert ids_turn_one and ids_turn_two, "fixture broken: expected non-empty id sets" + assert ids_turn_one.isdisjoint(ids_turn_two), ( + f"REGRESSION: per-turn-salted prefixes produced overlapping step IDs: " + f"{ids_turn_one & ids_turn_two!r}" + ) + + +def test_resume_step_prefix_helper_includes_turn_id_verbatim(): + """Production call-site pin: ``stream_resume_chat`` builds the prefix via + this helper. Reverting it back to a hardcoded ``'thinking-resume'`` would + silently re-introduce the duplicate-key React warning across consecutive + resumes — this test fails first instead. + """ + a = _resume_step_prefix("104:1778698228472") + b = _resume_step_prefix("104:1778698244022") + + assert a.startswith("thinking-resume-"), ( + f"prefix shape changed; the FE log filters and the timeline contract " + f"expect the ``thinking-resume-`` head to remain stable: got {a!r}" + ) + assert "104:1778698228472" in a and "104:1778698244022" in b + assert a != b diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 0ebd8dc9a..190ad745b 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -49,7 +49,11 @@ import { type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; -import { type HitlDecision, PendingInterruptProvider } from "@/features/chat-messages/hitl"; +import { + type HitlDecision, + PendingInterruptProvider, + type PendingInterruptState, +} from "@/features/chat-messages/hitl"; import { TimelineDataUI } from "@/features/chat-messages/timeline"; import { applyActionLogSse, @@ -272,12 +276,16 @@ export default function NewChatPage() { const [tokenUsageStore] = useState(() => createTokenUsageStore()); const abortControllerRef = useRef(null); const recentCancelRequestedAtRef = useRef(0); - const [pendingInterrupt, setPendingInterrupt] = useState<{ - threadId: number; - assistantMsgId: string; - interruptData: Record; - bundleToolCallIds: string[]; - } | null>(null); + // One entry per paused subagent, in receipt order (which matches the + // backend's ``state.interrupts`` traversal — and therefore the order + // ``slice_decisions_by_tool_call`` consumes on resume). Cleared on submit + // or on a fresh user turn. + const [pendingInterrupts, setPendingInterrupts] = useState([]); + // Per-card staged decisions held until every pending card has submitted, + // at which point we batch them into one ``hitl-decision`` event in the + // same order as ``pendingInterrupts``. Using a ref because partial + // progress should not re-render the page. + const stagedDecisionsByInterruptIdRef = useRef>(new Map()); const toolsWithUI = TOOLS_WITH_UI_ALL; const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); @@ -1194,12 +1202,24 @@ export default function NewChatPage() { ) ); if (currentThreadId) { - setPendingInterrupt({ - threadId: currentThreadId, - assistantMsgId, - interruptData, - bundleToolCallIds, - }); + // ``tool_call_id`` is stamped on the backend by + // ``checkpointed_subagent_middleware``. Without it we + // can't address the paused subagent on resume — skip + // rather than fabricate a synthetic key. + const interruptId = String(interruptData.tool_call_id ?? ""); + if (interruptId) { + const incoming: PendingInterruptState = { + interruptId, + threadId: currentThreadId, + assistantMsgId, + interruptData, + bundleToolCallIds, + }; + setPendingInterrupts((prev) => { + const without = prev.filter((p) => p.interruptId !== interruptId); + return [...without, incoming]; + }); + } } break; } @@ -1275,7 +1295,7 @@ export default function NewChatPage() { // by ``persist_assistant_shell``. Rename the optimistic // id, migrate ``tokenUsageStore`` so any pending // ``data-token-usage`` payload binds to the new id, - // remap any in-flight ``pendingInterrupt`` reference, + // remap any in-flight ``pendingInterrupts`` entries, // and reassign the closure variable so the in-stream // flush callback (line ~1074) keeps writing to the // renamed message. @@ -1291,10 +1311,12 @@ export default function NewChatPage() { : m ) ); - setPendingInterrupt((prev) => - prev && prev.assistantMsgId === oldAssistantMsgId - ? { ...prev, assistantMsgId: newAssistantMsgId } - : prev + setPendingInterrupts((prev) => + prev.map((p) => + p.assistantMsgId === oldAssistantMsgId + ? { ...p, assistantMsgId: newAssistantMsgId } + : p + ) ); assistantMsgId = newAssistantMsgId; break; @@ -1381,14 +1403,23 @@ export default function NewChatPage() { edited_action?: { name: string; args: Record }; }> ) => { - if (!pendingInterrupt) return; - const { threadId: resumeThreadId } = pendingInterrupt; + if (pendingInterrupts.length === 0) return; + // All cards in this turn share the same threadId/assistantMsgId + // (they're siblings of one parent agent step), so reading from + // the first entry is safe. + const resumeThreadId = pendingInterrupts[0].threadId; // Destructured separately as ``let`` so the SSE // ``data-assistant-message-id`` handler (resume always // allocates a fresh server-side row) can rename it to // the canonical ``msg-{db_id}`` mid-stream. - let assistantMsgId = pendingInterrupt.assistantMsgId; - setPendingInterrupt(null); + let assistantMsgId = pendingInterrupts[0].assistantMsgId; + // Concatenate every card's tool-call ids in pendingInterrupts order; + // this matches the ``decisions`` ordering produced by + // ``handleApprovalSubmit`` and the backend slicer's traversal of + // ``state.interrupts``. + const allBundleToolCallIds = pendingInterrupts.flatMap((p) => p.bundleToolCallIds); + setPendingInterrupts([]); + stagedDecisionsByInterruptIdRef.current.clear(); setIsRunning(true); const token = getBearerToken(); @@ -1465,7 +1496,7 @@ export default function NewChatPage() { // collapse onto ``decisions[0]``. Cards outside the bundle are // untouched. Mirrors the host ``hitl-decision`` handler. const decisionByTcId = new Map(); - const tcIds = pendingInterrupt.bundleToolCallIds; + const tcIds = allBundleToolCallIds; if (decisions.length === tcIds.length) { for (let i = 0; i < tcIds.length; i++) decisionByTcId.set(tcIds[i], decisions[i]); } @@ -1477,7 +1508,7 @@ export default function NewChatPage() { if (!d) continue; if (typeof part.result !== "object" || part.result === null) continue; if (!("__interrupt__" in (part.result as Record))) continue; - const decided = d.type as "approve" | "reject" | "edit"; + const decided = d.type; if (decided === "edit" && d.edited_action) { const mergedArgs = { ...part.args, ...d.edited_action.args }; part.args = mergedArgs; @@ -1597,12 +1628,22 @@ export default function NewChatPage() { : m ) ); - setPendingInterrupt({ - threadId: resumeThreadId, - assistantMsgId, - interruptData, - bundleToolCallIds, - }); + { + const interruptId = String(interruptData.tool_call_id ?? ""); + if (interruptId) { + const incoming: PendingInterruptState = { + interruptId, + threadId: resumeThreadId, + assistantMsgId, + interruptData, + bundleToolCallIds, + }; + setPendingInterrupts((prev) => { + const without = prev.filter((p) => p.interruptId !== interruptId); + return [...without, incoming]; + }); + } + } break; } @@ -1680,7 +1721,7 @@ export default function NewChatPage() { } }, [ - pendingInterrupt, + pendingInterrupts, messages, searchSpaceId, localFilesystemEnabled, @@ -1701,17 +1742,19 @@ export default function NewChatPage() { edited_action?: { name: string; args: Record }; }>; }; - if (!detail?.decisions || !pendingInterrupt) return; + if (!detail?.decisions || pendingInterrupts.length === 0) return; const incoming = detail.decisions; if (incoming.length === 0) return; - const tcIds = pendingInterrupt.bundleToolCallIds; + // Concatenated tool-call ids across every pending card, in the + // order ``handleApprovalSubmit`` produced ``incoming``. + const tcIds = pendingInterrupts.flatMap((p) => p.bundleToolCallIds); const N = tcIds.length; - // Bundles must submit exactly one decision per action_request. - // Refuse rather than silently broadcast a single decision across - // the bundle (would mis-apply rejects/edits and diverge from - // what handleResume sends to /resume). - if (N > 1 && incoming.length !== N) { + // Refuse rather than silently broadcast or drop. The orchestrator + // only fires ``hitl-decision`` once every pending card has + // submitted, so a count mismatch indicates a contract drift + // (and would later make the backend slicer raise). + if (incoming.length !== N) { toast.error( `Cannot resume: ${incoming.length} decision(s) submitted for ${N} pending actions.` ); @@ -1722,9 +1765,12 @@ export default function NewChatPage() { for (let i = 0; i < tcIds.length; i++) byTcId.set(tcIds[i], incoming[i]); const submittedDecisions = tcIds.map((id) => byTcId.get(id)!); + // All pending cards belong to the same assistant message, so a + // single content-update pass suffices. + const targetAssistantMsgId = pendingInterrupts[0].assistantMsgId; setMessages((prev) => prev.map((m) => { - if (m.id !== pendingInterrupt.assistantMsgId) return m; + if (m.id !== targetAssistantMsgId) return m; const parts = m.content as unknown as Array>; const newContent = parts.map((part) => { const tcId = part.toolCallId as string | undefined; @@ -1732,7 +1778,7 @@ export default function NewChatPage() { if (!d || part.type !== "tool-call") return part; if (typeof part.result !== "object" || part.result === null) return part; if (!("__interrupt__" in (part.result as Record))) return part; - const decided = d.type as "approve" | "reject" | "edit"; + const decided = d.type; if (decided === "edit" && d.edited_action) { return { ...part, @@ -1761,7 +1807,7 @@ export default function NewChatPage() { }; window.addEventListener("hitl-decision", handler); return () => window.removeEventListener("hitl-decision", handler); - }, [handleResume, pendingInterrupt]); + }, [handleResume, pendingInterrupts]); // Convert message (pass through since already in correct format) const convertMessage = useCallback( @@ -2283,11 +2329,32 @@ export default function NewChatPage() { [handleRegenerate, messages, agentActionItems] ); - const handleApprovalSubmit = useCallback((orderedDecisions: HitlDecision[]) => { - window.dispatchEvent( - new CustomEvent("hitl-decision", { detail: { decisions: orderedDecisions } }) - ); - }, []); + const handleApprovalSubmit = useCallback( + (interruptId: string, decisions: HitlDecision[]) => { + // Stage this card's decisions; only fire the resume once every + // pending card in the current turn has submitted, so the + // backend slicer sees a single concatenated decisions list + // whose total matches the parent state's pending action count. + stagedDecisionsByInterruptIdRef.current.set(interruptId, decisions); + if (stagedDecisionsByInterruptIdRef.current.size < pendingInterrupts.length) { + return; + } + const ordered: HitlDecision[] = []; + for (const pi of pendingInterrupts) { + const staged = stagedDecisionsByInterruptIdRef.current.get(pi.interruptId); + if (!staged) { + // Defensive: a missing entry means the staging map and + // the pending list disagreed for one cycle. Bail rather + // than dispatch a count-mismatched batch. + return; + } + ordered.push(...staged); + } + stagedDecisionsByInterruptIdRef.current.clear(); + window.dispatchEvent(new CustomEvent("hitl-decision", { detail: { decisions: ordered } })); + }, + [pendingInterrupts] + ); const handleEditDialogChoice = useCallback( async (choice: EditMessageDialogChoice) => { @@ -2360,7 +2427,7 @@ export default function NewChatPage() {
diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 32943142a..84361e25b 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -124,7 +124,6 @@ export const ConnectorIndicator = forwardRef - handleDisconnectFromList(connector, () => refreshConnectors()) - } onAddAccount={handleAddNewMCPFromList} addButtonText="Add New MCP Server" /> @@ -247,9 +243,6 @@ export const ConnectorIndicator = forwardRef - handleDisconnectFromList(connector, () => refreshConnectors()) - } onAddAccount={() => { // Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS const oauthConnector = diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx index 71d0e31a8..cfa2cde38 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx @@ -3,6 +3,7 @@ import { CheckCircle2 } from "lucide-react"; import type { FC } from "react"; import type { ConnectorConfigProps } from "../index"; +import { MCPTrustedTools } from "./mcp-trusted-tools"; export const MCPServiceConfig: FC = ({ connector }) => { const serviceName = connector.config?.mcp_service as string | undefined; @@ -11,7 +12,7 @@ export const MCPServiceConfig: FC = ({ connector }) => { : "this service"; return ( -
+
@@ -23,6 +24,8 @@ export const MCPServiceConfig: FC = ({ connector }) => {

+ + {connector.id > 0 && }
); }; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-trusted-tools.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-trusted-tools.tsx new file mode 100644 index 000000000..ed01511ca --- /dev/null +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-trusted-tools.tsx @@ -0,0 +1,89 @@ +"use client"; + +import { ShieldCheck, Trash2 } from "lucide-react"; +import type { FC } from "react"; +import { useState } from "react"; +import { toast } from "sonner"; +import { Button } from "@/components/ui/button"; +import type { SearchSourceConnector } from "@/contracts/types/connector.types"; +import { connectorsApiService } from "@/lib/apis/connectors-api.service"; + +interface MCPTrustedToolsProps { + connector: SearchSourceConnector; +} + +/** Audit + revoke surface for tools promoted via in-chat "Always Allow". */ +export const MCPTrustedTools: FC = ({ connector }) => { + const trustedTools = readTrustedTools(connector.config); + const [pending, setPending] = useState>(new Set()); + + const handleRevoke = async (toolName: string) => { + setPending((prev) => new Set(prev).add(toolName)); + try { + await connectorsApiService.untrustMCPTool(connector.id, toolName); + toast.success(`Removed ${toolName} from trusted tools`); + } catch { + toast.error(`Failed to remove ${toolName} from trusted tools`); + } finally { + setPending((prev) => { + const next = new Set(prev); + next.delete(toolName); + return next; + }); + } + }; + + return ( +
+

+ + Trusted Tools +

+ +
+

+ Tools listed here skip the approval prompt during chat. Trust is granted by clicking + "Always Allow" on an approval card; revoke it here to require approval again. +

+ + {trustedTools.length === 0 ? ( +

+ No trusted tools yet for this connector. +

+ ) : ( +
    + {trustedTools.map((toolName) => { + const isPending = pending.has(toolName); + return ( +
  • + {toolName} + +
  • + ); + })} +
+ )} +
+
+ ); +}; + +function readTrustedTools(config: Record | undefined | null): string[] { + const raw = config?.trusted_tools; + if (!Array.isArray(raw)) return []; + return raw.filter((item): item is string => typeof item === "string"); +} diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index ed9bf70a8..b49bfda96 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -1288,25 +1288,6 @@ export const useConnectorDialog = () => { [editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen] ); - const handleDisconnectFromList = useCallback( - async (connector: SearchSourceConnector, refreshConnectors: () => void) => { - if (!searchSpaceId) return; - try { - await deleteConnector({ id: connector.id }); - trackConnectorDeleted(Number(searchSpaceId), connector.connector_type, connector.id); - toast.success(`${connector.name} disconnected successfully`); - refreshConnectors(); - queryClient.invalidateQueries({ - queryKey: cacheKeys.logs.summary(Number(searchSpaceId)), - }); - } catch (error) { - console.error("Error disconnecting connector:", error); - toast.error("Failed to disconnect connector"); - } - }, - [searchSpaceId, deleteConnector] - ); - // Handle quick index (index with selected date range, or backend defaults if none selected) const handleQuickIndexConnector = useCallback( async ( @@ -1480,7 +1461,6 @@ export const useConnectorDialog = () => { handleStartEdit, handleSaveConnector, handleDisconnectConnector, - handleDisconnectFromList, handleBackFromEdit, handleBackFromConnect, handleBackFromYouTube, diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index 8aee7e005..f6291b64d 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { ArrowLeft, Plus, RefreshCw, Server, Trash2 } from "lucide-react"; +import { ArrowLeft, Plus, RefreshCw, Server } from "lucide-react"; import { type FC, useCallback, useState } from "react"; import { toast } from "sonner"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; @@ -24,7 +24,6 @@ interface ConnectorAccountsListViewProps { indexingConnectorIds: Set; onBack: () => void; onManage: (connector: SearchSourceConnector) => void; - onDisconnect?: (connector: SearchSourceConnector) => Promise | void; onAddAccount: () => void; isConnecting?: boolean; addButtonText?: string; @@ -37,15 +36,12 @@ export const ConnectorAccountsListView: FC = ({ indexingConnectorIds, onBack, onManage, - onDisconnect, onAddAccount, isConnecting = false, addButtonText, }) => { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const [reauthingId, setReauthingId] = useState(null); - const [confirmDisconnectId, setConfirmDisconnectId] = useState(null); - const [disconnectingId, setDisconnectingId] = useState(null); // Get connector status const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus(); @@ -240,51 +236,6 @@ export const ConnectorAccountsListView: FC = ({ /> Re-authenticate - ) : isLive && onDisconnect ? ( - confirmDisconnectId === connector.id ? ( -
- - -
- ) : ( - - ) ) : ( )} - {isMCPTool && ( - )} diff --git a/surfsense_web/features/chat-messages/hitl/approval/pending-interrupt-context.tsx b/surfsense_web/features/chat-messages/hitl/approval/pending-interrupt-context.tsx index 2c193d952..9e4153ff5 100644 --- a/surfsense_web/features/chat-messages/hitl/approval/pending-interrupt-context.tsx +++ b/surfsense_web/features/chat-messages/hitl/approval/pending-interrupt-context.tsx @@ -3,8 +3,10 @@ import { createContext, type ReactNode, useContext } from "react"; import type { HitlDecision } from "../types"; -/** Snapshot of one in-flight HITL interrupt; ``null`` when nothing is pending. */ +/** One in-flight HITL interrupt (one paused subagent). */ export interface PendingInterruptState { + /** Stable id keyed by the parent ``tool_call_id`` stamped on the interrupt. */ + interruptId: string; threadId: number; assistantMsgId: string; interruptData: Record; @@ -12,8 +14,19 @@ export interface PendingInterruptState { } export interface PendingInterruptValue { - pendingInterrupt: PendingInterruptState | null; - onSubmit: (decisions: HitlDecision[]) => void; + /** + * Every paused subagent for the current turn, in the order the SSE stream + * delivered them — which matches ``state.interrupts`` traversal on the + * backend, which is the order ``slice_decisions_by_tool_call`` consumes. + */ + pendingInterrupts: PendingInterruptState[]; + /** + * Stage one card's decisions. The orchestrator (page-level) batches across + * cards and dispatches the resume only once every pending interrupt has + * submitted, so the backend slicer sees a single concatenated decisions + * list whose total matches the parent state's pending action count. + */ + onSubmit: (interruptId: string, decisions: HitlDecision[]) => void; } const PendingInterruptContext = createContext(null); @@ -24,16 +37,16 @@ const PendingInterruptContext = createContext(null * page root. */ export function PendingInterruptProvider({ - pendingInterrupt, + pendingInterrupts, onSubmit, children, }: { - pendingInterrupt: PendingInterruptState | null; - onSubmit: (decisions: HitlDecision[]) => void; + pendingInterrupts: PendingInterruptState[]; + onSubmit: (interruptId: string, decisions: HitlDecision[]) => void; children: ReactNode; }) { return ( - + {children} ); diff --git a/surfsense_web/features/chat-messages/hitl/types.ts b/surfsense_web/features/chat-messages/hitl/types.ts index 03f00ba9d..76af439f1 100644 --- a/surfsense_web/features/chat-messages/hitl/types.ts +++ b/surfsense_web/features/chat-messages/hitl/types.ts @@ -7,12 +7,12 @@ export interface InterruptActionRequest { export interface InterruptReviewConfig { action_name: string; - allowed_decisions: Array<"approve" | "edit" | "reject">; + allowed_decisions: Array<"approve" | "edit" | "reject" | "approve_always">; } export interface InterruptResult = Record> { __interrupt__: true; - __decided__?: "approve" | "reject" | "edit"; + __decided__?: "approve" | "reject" | "edit" | "approve_always"; __completed__?: boolean; action_requests: InterruptActionRequest[]; review_configs: InterruptReviewConfig[]; @@ -31,7 +31,7 @@ export function isInterruptResult(result: unknown): result is InterruptResult { } export interface HitlDecision { - type: "approve" | "reject" | "edit"; + type: "approve" | "reject" | "edit" | "approve_always"; message?: string; edited_action?: { name: string; diff --git a/surfsense_web/features/chat-messages/timeline/data-renderer.tsx b/surfsense_web/features/chat-messages/timeline/data-renderer.tsx index 861e35fd2..fb3dda047 100644 --- a/surfsense_web/features/chat-messages/timeline/data-renderer.tsx +++ b/surfsense_web/features/chat-messages/timeline/data-renderer.tsx @@ -11,10 +11,9 @@ const noopSubmit = () => {}; /** * assistant-ui data UI for the ``thinking-steps`` data-part. * - * Re-scopes the global ``PendingInterruptProvider`` per message: the - * approval card only mounts under the assistant message that owns - * the interrupt (otherwise every message in scrollback would render - * its own card). + * Re-scopes the global ``PendingInterruptProvider`` per message: approval + * cards only mount under the assistant message that owns the interrupt + * (otherwise every message in scrollback would render its own cards). */ function TimelineDataRenderer({ data }: { name: string; data: unknown }) { const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); @@ -23,10 +22,10 @@ function TimelineDataRenderer({ data }: { name: string; data: unknown }) { const content = useAuiState(({ message }) => message?.content); const messageId = useAuiState(({ message }) => message?.id); const pendingValue = usePendingInterrupt(); - const pendingForThisMessage = - pendingValue?.pendingInterrupt && pendingValue.pendingInterrupt.assistantMsgId === messageId - ? pendingValue.pendingInterrupt - : null; + const pendingForThisMessage = useMemo( + () => (pendingValue?.pendingInterrupts ?? []).filter((p) => p.assistantMsgId === messageId), + [pendingValue?.pendingInterrupts, messageId] + ); const onSubmit = pendingValue?.onSubmit ?? noopSubmit; const steps = useMemo( @@ -39,11 +38,11 @@ function TimelineDataRenderer({ data }: { name: string; data: unknown }) { [steps, content] ); - if (items.length === 0 && !pendingForThisMessage) return null; + if (items.length === 0 && pendingForThisMessage.length === 0) return null; return (
- +
diff --git a/surfsense_web/features/chat-messages/timeline/timeline.tsx b/surfsense_web/features/chat-messages/timeline/timeline.tsx index f51034733..31c86fb9c 100644 --- a/surfsense_web/features/chat-messages/timeline/timeline.tsx +++ b/surfsense_web/features/chat-messages/timeline/timeline.tsx @@ -32,9 +32,9 @@ export const Timeline: FC<{ isThreadRunning?: boolean; }> = ({ items, isThreadRunning = true }) => { const pendingValue = usePendingInterrupt(); - const pendingInterrupt = pendingValue?.pendingInterrupt ?? null; + const pendingInterrupts = pendingValue?.pendingInterrupts ?? []; const onSubmit = pendingValue?.onSubmit; - const hasPending = pendingInterrupt !== null; + const hasPending = pendingInterrupts.length > 0; // Apply the override here so downstream (grouping, headers, dots) // sees the corrected status without threading a callback. Keeps @@ -135,9 +135,15 @@ export const Timeline: FC<{ /> ); })} - {pendingInterrupt && onSubmit && ( -
- + {hasPending && onSubmit && ( +
+ {pendingInterrupts.map((pi) => ( + onSubmit(pi.interruptId, decisions)} + /> + ))}
)}
diff --git a/surfsense_web/lib/apis/connectors-api.service.ts b/surfsense_web/lib/apis/connectors-api.service.ts index a35e731a4..4b6d69883 100644 --- a/surfsense_web/lib/apis/connectors-api.service.ts +++ b/surfsense_web/lib/apis/connectors-api.service.ts @@ -405,35 +405,19 @@ class ConnectorsApiService { ); }; - // ============================================================================= - // MCP Tool Trust (Allow-List) Methods - // ============================================================================= - - /** - * Add a tool to the MCP connector's "Always Allow" list. - * Subsequent calls to this tool will skip HITL approval. - */ - trustMCPTool = async (connectorId: number, toolName: string): Promise => { - await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/trust-tool`, undefined, { - body: { tool_name: toolName }, - }); - }; - - /** - * Remove a tool from the MCP connector's "Always Allow" list. - */ - untrustMCPTool = async (connectorId: number, toolName: string): Promise => { - await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/untrust-tool`, undefined, { - body: { tool_name: toolName }, - }); - }; - /** Live stats for the Obsidian connector tile. */ getObsidianStats = async (vaultId: string): Promise => { return baseApiService.get( `/api/v1/obsidian/stats?vault_id=${encodeURIComponent(vaultId)}` ); }; + + /** Revoke a previously-trusted MCP tool so the next call asks again. */ + untrustMCPTool = async (connectorId: number, toolName: string): Promise => { + await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/untrust-tool`, undefined, { + body: { tool_name: toolName }, + }); + }; } export interface ObsidianStats { diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 1d057ef94..d9fb2ac99 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -565,7 +565,7 @@ export type SSEEvent = * the assistant-side row of the current turn. The frontend * renames its optimistic ``msg-assistant-XXX`` placeholder * id, migrates the local ``tokenUsageStore`` and - * ``pendingInterrupt`` references, and binds the running + * ``pendingInterrupts`` entries, and binds the running * mutable ``assistantMsgId`` closure variable to the * canonical id for the rest of the stream. */