From acd2fdda8aaac4090fa66408269ad4e27963972c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 4 May 2026 18:42:39 +0200 Subject: [PATCH] Add SurfSenseCheckpointedSubAgentMiddleware to bridge HITL into deepagents subagents. --- .../__init__.py | 26 ++ .../config.py | 35 +++ .../constants.py | 18 ++ .../middleware.py | 103 ++++++++ .../propagation.py | 74 ++++++ .../resume.py | 71 ++++++ .../task_tool.py | 224 ++++++++++++++++++ 7 files changed, 551 insertions(+) create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/__init__.py create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/constants.py create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/middleware.py create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/propagation.py create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/resume.py create mode 100644 surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/__init__.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/__init__.py new file mode 100644 index 000000000..d03b571ca --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/__init__.py @@ -0,0 +1,26 @@ +"""SubAgent ``task`` tool wiring required for HITL inside subagents. + +Replaces upstream ``SubAgentMiddleware`` to: + +- share the parent's checkpointer with each subagent, +- forward ``runtime.config`` (thread_id, recursion_limit, …) into nested invokes, +- bridge ``Command(resume=...)`` from the parent into the subagent via the + ``config["configurable"]["surfsense_resume_value"]`` side-channel, +- target the resume at the captured interrupt id so a follow-up + ``HumanInTheLoopMiddleware.after_model`` does not consume the same payload, +- re-raise any new subagent interrupt at the parent so the SSE stream surfaces it. + +Module layout +------------- + +- ``constants`` — shared keys / limits. +- ``config`` — RunnableConfig + side-channel resume read. +- ``resume`` — pending-interrupt detection, fan-out, ``Command(resume=...)`` builder. +- ``propagation`` — re-raise pending subagent interrupts at the parent. +- ``task_tool`` — the ``task`` tool factory (sync + async). +- ``middleware`` — :class:`SurfSenseCheckpointedSubAgentMiddleware` itself. +""" + +from .middleware import SurfSenseCheckpointedSubAgentMiddleware + +__all__ = ["SurfSenseCheckpointedSubAgentMiddleware"] diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py new file mode 100644 index 000000000..0312a2da5 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/config.py @@ -0,0 +1,35 @@ +"""RunnableConfig wiring for nested subagent invocations. + +Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and +exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads. +""" + +from __future__ import annotations + +from typing import Any + +from langchain.tools import ToolRuntime + +from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT + + +def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]: + """RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget.""" + merged: dict[str, Any] = dict(runtime.config) if runtime.config else {} + current_limit = merged.get("recursion_limit") + try: + current_int = int(current_limit) if current_limit is not None else 0 + except (TypeError, ValueError): + current_int = 0 + if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT: + merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT + return merged + + +def extract_surfsense_resume(runtime: ToolRuntime) -> Any: + """Resume payload stashed by ``stream_resume_chat``; ``None`` on a first-time call.""" + cfg = runtime.config or {} + configurable = cfg.get("configurable") if isinstance(cfg, dict) else None + if not isinstance(configurable, dict): + return None + return configurable.get("surfsense_resume_value") diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/constants.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/constants.py new file mode 100644 index 000000000..6c4519f3a --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/constants.py @@ -0,0 +1,18 @@ +"""Constants shared by the checkpointed subagent middleware.""" + +from __future__ import annotations + +# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS. +EXCLUDED_STATE_KEYS = frozenset( + { + "messages", + "todos", + "structured_response", + "skills_metadata", + "memory_contents", + } +) + +# Match the parent graph's budget; the LangGraph default of 25 trips on +# multi-step subagent runs. +DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000 diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/middleware.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/middleware.py new file mode 100644 index 000000000..da8a62cdc --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/middleware.py @@ -0,0 +1,103 @@ +"""SubAgent middleware that compiles each subagent against the parent checkpointer.""" + +from __future__ import annotations + +from typing import Any, cast + +from deepagents.backends.protocol import BackendFactory, BackendProtocol +from deepagents.middleware.subagents import ( + TASK_SYSTEM_PROMPT, + CompiledSubAgent, + SubAgent, + SubAgentMiddleware, +) +from langchain.agents import create_agent +from langchain.agents.middleware import HumanInTheLoopMiddleware +from langchain.chat_models import init_chat_model +from langgraph.types import Checkpointer + +from .task_tool import build_task_tool_with_parent_config + + +class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware): + """``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer.""" + + def __init__( + self, + *, + checkpointer: Checkpointer, + backend: BackendProtocol | BackendFactory, + subagents: list[SubAgent | CompiledSubAgent], + system_prompt: str | None = TASK_SYSTEM_PROMPT, + task_description: str | None = None, + ) -> None: + self._surf_checkpointer = checkpointer + super(SubAgentMiddleware, self).__init__() + if not subagents: + raise ValueError( + "At least one subagent must be specified when using the new API" + ) + self._backend = backend + self._subagents = subagents + subagent_specs = self._surf_compile_subagent_graphs() + task_tool = build_task_tool_with_parent_config(subagent_specs, task_description) + if system_prompt and subagent_specs: + agents_desc = "\n".join( + f"- {s['name']}: {s['description']}" for s in subagent_specs + ) + self.system_prompt = ( + system_prompt + "\n\nAvailable subagent types:\n" + agents_desc + ) + else: + self.system_prompt = system_prompt + self.tools = [task_tool] + + def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]: + """Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer.""" + specs: list[dict[str, Any]] = [] + + for spec in self._subagents: + if "runnable" in spec: + compiled = cast(CompiledSubAgent, spec) + specs.append( + { + "name": compiled["name"], + "description": compiled["description"], + "runnable": compiled["runnable"], + } + ) + continue + + if "model" not in spec: + msg = f"SubAgent '{spec['name']}' must specify 'model'" + raise ValueError(msg) + if "tools" not in spec: + msg = f"SubAgent '{spec['name']}' must specify 'tools'" + raise ValueError(msg) + + model = spec["model"] + if isinstance(model, str): + model = init_chat_model(model) + + middleware: list[Any] = list(spec.get("middleware", [])) + + interrupt_on = spec.get("interrupt_on") + if interrupt_on: + middleware.append(HumanInTheLoopMiddleware(interrupt_on=interrupt_on)) + + specs.append( + { + "name": spec["name"], + "description": spec["description"], + "runnable": create_agent( + model, + system_prompt=spec["system_prompt"], + tools=spec["tools"], + middleware=middleware, + name=spec["name"], + checkpointer=self._surf_checkpointer, + ), + } + ) + + return specs diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/propagation.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/propagation.py new file mode 100644 index 000000000..55aae7201 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/propagation.py @@ -0,0 +1,74 @@ +"""Re-raise still-pending subagent interrupts at the parent graph level. + +After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may +still hold a pending interrupt (e.g. the LLM produced a follow-up tool call +that fired a fresh ``interrupt()``). The parent's pregel cannot see that +interrupt because it lives in a separate compiled graph; we re-raise it here +so the parent's SSE stream surfaces it as the next approval card. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from langchain_core.runnables import Runnable +from langgraph.types import interrupt as _lg_interrupt + +from .resume import get_first_pending_subagent_interrupt + +logger = logging.getLogger(__name__) + + +def maybe_propagate_subagent_interrupt( + subagent: Runnable, + sub_config: dict[str, Any], + subagent_type: str, +) -> None: + """Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it.""" + get_state_sync = getattr(subagent, "get_state", None) + if not callable(get_state_sync): + return + try: + snapshot = get_state_sync(sub_config) + except Exception: # pragma: no cover - defensive + logger.debug( + "Subagent get_state failed during re-interrupt check", + exc_info=True, + ) + return + _pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) + if pending_value is None: + return + logger.info( + "Re-raising subagent %r interrupt to parent (multi-step HITL)", + subagent_type, + ) + _lg_interrupt(pending_value) + + +async def amaybe_propagate_subagent_interrupt( + subagent: Runnable, + sub_config: dict[str, Any], + subagent_type: str, +) -> None: + """Async counterpart of :func:`maybe_propagate_subagent_interrupt`.""" + aget_state = getattr(subagent, "aget_state", None) + if not callable(aget_state): + return + try: + snapshot = await aget_state(sub_config) + except Exception: # pragma: no cover - defensive + logger.debug( + "Subagent aget_state failed during re-interrupt check", + exc_info=True, + ) + return + _pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) + if pending_value is None: + return + logger.info( + "Re-raising subagent %r interrupt to parent (multi-step HITL)", + subagent_type, + ) + _lg_interrupt(pending_value) diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/resume.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/resume.py new file mode 100644 index 000000000..c9b8b01e6 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/resume.py @@ -0,0 +1,71 @@ +"""Resume-payload shaping and pending-interrupt detection for subagents. + +Splits the work of "given a state snapshot and a parent-stashed resume value, +produce the right ``Command(resume=...)`` for the subagent" into pure helpers. +""" + +from __future__ import annotations + +from typing import Any + +from langgraph.types import Command + + +def hitlrequest_action_count(pending_value: Any) -> int: + """Bundle size for a LangChain ``HITLRequest`` payload; ``0`` for non-bundle interrupts.""" + if not isinstance(pending_value, dict): + return 0 + actions = pending_value.get("action_requests") + if isinstance(actions, list): + return len(actions) + return 0 + + +def fan_out_decisions_to_match(resume_value: Any, expected_count: int) -> Any: + """Pad a single-decision resume to N entries so an ``action_requests=N`` bundle accepts it.""" + if expected_count <= 1: + return resume_value + if not isinstance(resume_value, dict): + return resume_value + decisions = resume_value.get("decisions") + if not isinstance(decisions, list) or len(decisions) >= expected_count: + return resume_value + if not decisions: + return resume_value + padded = list(decisions) + [decisions[-1]] * (expected_count - len(decisions)) + return {**resume_value, "decisions": padded} + + +def get_first_pending_subagent_interrupt(state: Any) -> tuple[str | None, Any]: + """First pending ``(interrupt_id, value)`` in the snapshot, else ``(None, None)``. + + The ``id`` lets the caller target ``Command(resume={id: value})`` so the + payload is not broadcast to a later fresh interrupt in the same run. + """ + if state is None: + return None, None + for it in getattr(state, "interrupts", None) or (): + value = getattr(it, "value", None) + interrupt_id = getattr(it, "id", None) + if value is not None: + return ( + interrupt_id if isinstance(interrupt_id, str) else None, + value, + ) + for sub_task in getattr(state, "tasks", None) or (): + for it in getattr(sub_task, "interrupts", None) or (): + value = getattr(it, "value", None) + interrupt_id = getattr(it, "id", None) + if value is not None: + return ( + interrupt_id if isinstance(interrupt_id, str) else None, + value, + ) + return None, None + + +def build_resume_command(resume_value: Any, pending_id: str | None) -> Command: + """``Command(resume={id: value})`` when ``id`` is known, else fall back to scalar.""" + if pending_id is None: + return Command(resume=resume_value) + return Command(resume={pending_id: resume_value}) diff --git a/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py new file mode 100644 index 000000000..15145b1b8 --- /dev/null +++ b/surfsense_backend/app/agents/multi_agent_with_deepagents/main_agent/graph/middleware/checkpointed_subagent_middleware/task_tool.py @@ -0,0 +1,224 @@ +"""Build the ``task`` tool that invokes subagents with HITL bridging. + +The tool's body is the only place where the parent and the subagent meet at +runtime: it reads the parent's stashed resume value, decides whether to send +fresh state or a targeted ``Command(resume=...)`` to the subagent, then +re-raises any new pending interrupt back to the parent. +""" + +from __future__ import annotations + +import logging +from typing import Annotated, Any + +from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION +from langchain.tools import BaseTool, ToolRuntime +from langchain_core.messages import HumanMessage, ToolMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import StructuredTool +from langgraph.types import Command + +from .config import extract_surfsense_resume, subagent_invoke_config +from .constants import EXCLUDED_STATE_KEYS +from .propagation import ( + amaybe_propagate_subagent_interrupt, + maybe_propagate_subagent_interrupt, +) +from .resume import ( + build_resume_command, + fan_out_decisions_to_match, + hitlrequest_action_count, + get_first_pending_subagent_interrupt, +) + +logger = logging.getLogger(__name__) + + +def build_task_tool_with_parent_config( + subagents: list[dict[str, Any]], + task_description: str | None = None, +) -> BaseTool: + """Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging.""" + subagent_graphs: dict[str, Runnable] = { + spec["name"]: spec["runnable"] for spec in subagents + } + subagent_description_str = "\n".join( + f"- {s['name']}: {s['description']}" for s in subagents + ) + + if task_description is None: + description = TASK_TOOL_DESCRIPTION.format(available_agents=subagent_description_str) + elif "{available_agents}" in task_description: + description = task_description.format(available_agents=subagent_description_str) + else: + description = task_description + + def _return_command_with_state_update(result: dict, tool_call_id: str) -> Command: + if "messages" not in result: + msg = ( + "CompiledSubAgent must return a state containing a 'messages' key. " + "Custom StateGraphs used with CompiledSubAgent should include 'messages' " + "in their state schema to communicate results back to the main agent." + ) + raise ValueError(msg) + + state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS} + message_text = ( + result["messages"][-1].text.rstrip() if result["messages"][-1].text else "" + ) + return Command( + update={ + **state_update, + "messages": [ToolMessage(message_text, tool_call_id=tool_call_id)], + } + ) + + def _validate_and_prepare_state( + subagent_type: str, description: str, runtime: ToolRuntime + ) -> tuple[Runnable, dict]: + subagent = subagent_graphs[subagent_type] + subagent_state = { + k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS + } + subagent_state["messages"] = [HumanMessage(content=description)] + return subagent, subagent_state + + def task( + description: Annotated[ + str, + "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501 + ], + subagent_type: Annotated[ + str, + "The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501 + ], + runtime: ToolRuntime, + ) -> str | Command: + if subagent_type not in subagent_graphs: + allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs]) + return ( + f"We cannot invoke subagent {subagent_type} because it does not exist, " + f"the only allowed types are {allowed_types}" + ) + if not runtime.tool_call_id: + raise ValueError("Tool call ID is required for subagent invocation") + subagent, subagent_state = _validate_and_prepare_state( + subagent_type, description, runtime + ) + sub_config = subagent_invoke_config(runtime) + + # Resume bridge: forward the parent's stashed decision into the + # subagent's pending ``interrupt()``, targeted by id. + pending_id: str | None = None + pending_value: Any = None + get_state = getattr(subagent, "get_state", None) + if callable(get_state): + try: + snapshot = get_state(sub_config) + pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) + except Exception: # pragma: no cover - defensive + logger.debug( + "Subagent get_state failed; falling back to fresh invoke", + exc_info=True, + ) + + if pending_value is not None: + resume_value = extract_surfsense_resume(runtime) + if resume_value is not None: + expected = hitlrequest_action_count(pending_value) + resume_value = fan_out_decisions_to_match(resume_value, expected) + logger.info( + "Forwarding surfsense_resume_value into subagent %r " + "(action_requests=%d, targeted_id=%s)", + subagent_type, + expected, + pending_id is not None, + ) + result = subagent.invoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + else: + logger.warning( + "Subagent %r has pending interrupt but no surfsense_resume_value " + "on config — replaying with fresh state (interrupt will re-fire).", + subagent_type, + ) + result = subagent.invoke(subagent_state, config=sub_config) + else: + result = subagent.invoke(subagent_state, config=sub_config) + maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type) + return _return_command_with_state_update(result, runtime.tool_call_id) + + async def atask( + description: Annotated[ + str, + "A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501 + ], + subagent_type: Annotated[ + str, + "The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501 + ], + runtime: ToolRuntime, + ) -> str | Command: + if subagent_type not in subagent_graphs: + allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs]) + return ( + f"We cannot invoke subagent {subagent_type} because it does not exist, " + f"the only allowed types are {allowed_types}" + ) + if not runtime.tool_call_id: + raise ValueError("Tool call ID is required for subagent invocation") + subagent, subagent_state = _validate_and_prepare_state( + subagent_type, description, runtime + ) + sub_config = subagent_invoke_config(runtime) + + # Resume bridge — see ``task`` above. + pending_id: str | None = None + pending_value: Any = None + aget_state = getattr(subagent, "aget_state", None) + if callable(aget_state): + try: + snapshot = await aget_state(sub_config) + pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) + except Exception: # pragma: no cover - defensive + logger.debug( + "Subagent aget_state failed; falling back to fresh ainvoke", + exc_info=True, + ) + + if pending_value is not None: + resume_value = extract_surfsense_resume(runtime) + if resume_value is not None: + expected = hitlrequest_action_count(pending_value) + resume_value = fan_out_decisions_to_match(resume_value, expected) + logger.info( + "Forwarding surfsense_resume_value into subagent %r " + "(action_requests=%d, targeted_id=%s)", + subagent_type, + expected, + pending_id is not None, + ) + result = await subagent.ainvoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + else: + logger.warning( + "Subagent %r has pending interrupt but no surfsense_resume_value " + "on config — replaying with fresh state (interrupt will re-fire).", + subagent_type, + ) + result = await subagent.ainvoke(subagent_state, config=sub_config) + else: + result = await subagent.ainvoke(subagent_state, config=sub_config) + await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type) + return _return_command_with_state_update(result, runtime.tool_call_id) + + return StructuredTool.from_function( + name="task", + func=task, + coroutine=atask, + description=description, + )