refactor(multi-agent): switch compile graph to new orchestrator and drop deepagent_stack

This commit is contained in:
CREDO23 2026-05-05 20:55:38 +02:00
parent 5abae09435
commit 73272ce348
12 changed files with 5 additions and 550 deletions

View file

@ -0,0 +1,26 @@
"""SubAgent ``task`` tool wiring required for HITL inside subagents.
Replaces upstream ``SubAgentMiddleware`` to:
- share the parent's checkpointer with each subagent,
- forward ``runtime.config`` (thread_id, recursion_limit, ) into nested invokes,
- bridge ``Command(resume=...)`` from the parent into the subagent via the
``config["configurable"]["surfsense_resume_value"]`` side-channel,
- target the resume at the captured interrupt id so a follow-up
``HumanInTheLoopMiddleware.after_model`` does not consume the same payload,
- re-raise any new subagent interrupt at the parent so the SSE stream surfaces it.
Module layout
-------------
- ``constants`` shared keys / limits.
- ``config`` RunnableConfig + side-channel resume read.
- ``resume`` pending-interrupt detection, fan-out, ``Command(resume=...)`` builder.
- ``propagation`` re-raise pending subagent interrupts at the parent.
- ``task_tool`` the ``task`` tool factory (sync + async).
- ``middleware`` :class:`SurfSenseCheckpointedSubAgentMiddleware` itself.
"""
from .middleware import SurfSenseCheckpointedSubAgentMiddleware
__all__ = ["SurfSenseCheckpointedSubAgentMiddleware"]

View file

@ -0,0 +1,44 @@
"""RunnableConfig wiring for nested subagent invocations.
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
"""
from __future__ import annotations
from typing import Any
from langchain.tools import ToolRuntime
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
current_limit = merged.get("recursion_limit")
try:
current_int = int(current_limit) if current_limit is not None else 0
except (TypeError, ValueError):
current_int = 0
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
return merged
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
"""Pop the resume payload; siblings share ``configurable`` by reference."""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return None
return configurable.pop("surfsense_resume_value", None)
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
"""True iff a resume payload is queued on this runtime (non-destructive)."""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return False
return "surfsense_resume_value" in configurable

View file

@ -0,0 +1,18 @@
"""Constants shared by the checkpointed subagent middleware."""
from __future__ import annotations
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
EXCLUDED_STATE_KEYS = frozenset(
{
"messages",
"todos",
"structured_response",
"skills_metadata",
"memory_contents",
}
)
# Match the parent graph's budget; the LangGraph default of 25 trips on
# multi-step subagent runs.
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000

View file

@ -0,0 +1,103 @@
"""SubAgent middleware that compiles each subagent against the parent checkpointer."""
from __future__ import annotations
from typing import Any, cast
from deepagents.backends.protocol import BackendFactory, BackendProtocol
from deepagents.middleware.subagents import (
TASK_SYSTEM_PROMPT,
CompiledSubAgent,
SubAgent,
SubAgentMiddleware,
)
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware
from langchain.chat_models import init_chat_model
from langgraph.types import Checkpointer
from .task_tool import build_task_tool_with_parent_config
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
def __init__(
self,
*,
checkpointer: Checkpointer,
backend: BackendProtocol | BackendFactory,
subagents: list[SubAgent | CompiledSubAgent],
system_prompt: str | None = TASK_SYSTEM_PROMPT,
task_description: str | None = None,
) -> None:
self._surf_checkpointer = checkpointer
super(SubAgentMiddleware, self).__init__()
if not subagents:
raise ValueError(
"At least one subagent must be specified when using the new API"
)
self._backend = backend
self._subagents = subagents
subagent_specs = self._surf_compile_subagent_graphs()
task_tool = build_task_tool_with_parent_config(subagent_specs, task_description)
if system_prompt and subagent_specs:
agents_desc = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagent_specs
)
self.system_prompt = (
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
)
else:
self.system_prompt = system_prompt
self.tools = [task_tool]
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
specs: list[dict[str, Any]] = []
for spec in self._subagents:
if "runnable" in spec:
compiled = cast(CompiledSubAgent, spec)
specs.append(
{
"name": compiled["name"],
"description": compiled["description"],
"runnable": compiled["runnable"],
}
)
continue
if "model" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'model'"
raise ValueError(msg)
if "tools" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
raise ValueError(msg)
model = spec["model"]
if isinstance(model, str):
model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", []))
interrupt_on = spec.get("interrupt_on")
if interrupt_on:
middleware.append(HumanInTheLoopMiddleware(interrupt_on=interrupt_on))
specs.append(
{
"name": spec["name"],
"description": spec["description"],
"runnable": create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
),
}
)
return specs

View file

@ -0,0 +1,74 @@
"""Re-raise still-pending subagent interrupts at the parent graph level.
After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may
still hold a pending interrupt (e.g. the LLM produced a follow-up tool call
that fired a fresh ``interrupt()``). The parent's pregel cannot see that
interrupt because it lives in a separate compiled graph; we re-raise it here
so the parent's SSE stream surfaces it as the next approval card.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.runnables import Runnable
from langgraph.types import interrupt as _lg_interrupt
from .resume import get_first_pending_subagent_interrupt
logger = logging.getLogger(__name__)
def maybe_propagate_subagent_interrupt(
subagent: Runnable,
sub_config: dict[str, Any],
subagent_type: str,
) -> None:
"""Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it."""
get_state_sync = getattr(subagent, "get_state", None)
if not callable(get_state_sync):
return
try:
snapshot = get_state_sync(sub_config)
except Exception: # pragma: no cover - defensive
logger.debug(
"Subagent get_state failed during re-interrupt check",
exc_info=True,
)
return
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
if pending_value is None:
return
logger.info(
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
subagent_type,
)
_lg_interrupt(pending_value)
async def amaybe_propagate_subagent_interrupt(
subagent: Runnable,
sub_config: dict[str, Any],
subagent_type: str,
) -> None:
"""Async counterpart of :func:`maybe_propagate_subagent_interrupt`."""
aget_state = getattr(subagent, "aget_state", None)
if not callable(aget_state):
return
try:
snapshot = await aget_state(sub_config)
except Exception: # pragma: no cover - defensive
logger.debug(
"Subagent aget_state failed during re-interrupt check",
exc_info=True,
)
return
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
if pending_value is None:
return
logger.info(
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
subagent_type,
)
_lg_interrupt(pending_value)

View file

@ -0,0 +1,76 @@
"""Resume-payload shaping and pending-interrupt detection for subagents.
Splits the work of "given a state snapshot and a parent-stashed resume value,
produce the right ``Command(resume=...)`` for the subagent" into pure helpers.
"""
from __future__ import annotations
from typing import Any
from langgraph.types import Command
def hitlrequest_action_count(pending_value: Any) -> int:
"""Bundle size for a LangChain ``HITLRequest`` payload; ``0`` for non-bundle interrupts."""
if not isinstance(pending_value, dict):
return 0
actions = pending_value.get("action_requests")
if isinstance(actions, list):
return len(actions)
return 0
def fan_out_decisions_to_match(resume_value: Any, expected_count: int) -> Any:
"""Legacy fallback: pad a 1-decision resume to N for an ``action_requests=N`` bundle.
Modern frontend submits N decisions per bundle (one per action_request) so
this is a no-op; kept for backwards compatibility with old in-flight
threads or non-bundle clients that send a single decision.
"""
if expected_count <= 1:
return resume_value
if not isinstance(resume_value, dict):
return resume_value
decisions = resume_value.get("decisions")
if not isinstance(decisions, list) or len(decisions) >= expected_count:
return resume_value
if not decisions:
return resume_value
padded = list(decisions) + [decisions[-1]] * (expected_count - len(decisions))
return {**resume_value, "decisions": padded}
def get_first_pending_subagent_interrupt(state: Any) -> tuple[str | None, Any]:
"""First pending ``(interrupt_id, value)``; ``(None, None)`` if no interrupt.
Assumes at most one pending interrupt per snapshot (sequential tool nodes).
Parallel tool nodes would need an id-aware lookup instead of first-wins.
"""
if state is None:
return None, None
for it in getattr(state, "interrupts", None) or ():
value = getattr(it, "value", None)
interrupt_id = getattr(it, "id", None)
if value is not None:
return (
interrupt_id if isinstance(interrupt_id, str) else None,
value,
)
for sub_task in getattr(state, "tasks", None) or ():
for it in getattr(sub_task, "interrupts", None) or ():
value = getattr(it, "value", None)
interrupt_id = getattr(it, "id", None)
if value is not None:
return (
interrupt_id if isinstance(interrupt_id, str) else None,
value,
)
return None, None
def build_resume_command(resume_value: Any, pending_id: str | None) -> Command:
"""``Command(resume={id: value})`` when ``id`` is known, else fall back to scalar."""
if pending_id is None:
return Command(resume=resume_value)
return Command(resume={pending_id: resume_value})

View file

@ -0,0 +1,231 @@
"""Build the ``task`` tool that invokes subagents with HITL bridging.
The tool's body is the only place where the parent and the subagent meet at
runtime: it reads the parent's stashed resume value, decides whether to send
fresh state or a targeted ``Command(resume=...)`` to the subagent, then
re-raises any new pending interrupt back to the parent.
"""
from __future__ import annotations
import logging
from typing import Annotated, Any
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
from langchain.tools import BaseTool, ToolRuntime
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import StructuredTool
from langgraph.types import Command
from .config import (
consume_surfsense_resume,
has_surfsense_resume,
subagent_invoke_config,
)
from .constants import EXCLUDED_STATE_KEYS
from .propagation import (
amaybe_propagate_subagent_interrupt,
maybe_propagate_subagent_interrupt,
)
from .resume import (
build_resume_command,
fan_out_decisions_to_match,
get_first_pending_subagent_interrupt,
hitlrequest_action_count,
)
logger = logging.getLogger(__name__)
def build_task_tool_with_parent_config(
subagents: list[dict[str, Any]],
task_description: str | None = None,
) -> BaseTool:
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging."""
subagent_graphs: dict[str, Runnable] = {
spec["name"]: spec["runnable"] for spec in subagents
}
subagent_description_str = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagents
)
if task_description is None:
description = TASK_TOOL_DESCRIPTION.format(
available_agents=subagent_description_str
)
elif "{available_agents}" in task_description:
description = task_description.format(available_agents=subagent_description_str)
else:
description = task_description
def _return_command_with_state_update(result: dict, tool_call_id: str) -> Command:
if "messages" not in result:
msg = (
"CompiledSubAgent must return a state containing a 'messages' key. "
"Custom StateGraphs used with CompiledSubAgent should include 'messages' "
"in their state schema to communicate results back to the main agent."
)
raise ValueError(msg)
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
message_text = (
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
)
return Command(
update={
**state_update,
"messages": [ToolMessage(message_text, tool_call_id=tool_call_id)],
}
)
def _validate_and_prepare_state(
subagent_type: str, description: str, runtime: ToolRuntime
) -> tuple[Runnable, dict]:
subagent = subagent_graphs[subagent_type]
subagent_state = {
k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS
}
subagent_state["messages"] = [HumanMessage(content=description)]
return subagent, subagent_state
def task(
description: Annotated[
str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
],
subagent_type: Annotated[
str,
"The type of subagent to use. Must be one of the available agent types listed in the tool description.",
],
runtime: ToolRuntime,
) -> str | Command:
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
return (
f"We cannot invoke subagent {subagent_type} because it does not exist, "
f"the only allowed types are {allowed_types}"
)
if not runtime.tool_call_id:
raise ValueError("Tool call ID is required for subagent invocation")
subagent, subagent_state = _validate_and_prepare_state(
subagent_type, description, runtime
)
sub_config = subagent_invoke_config(runtime)
# Resume bridge: forward the parent's stashed decision into the
# subagent's pending ``interrupt()``, targeted by id.
pending_id: str | None = None
pending_value: Any = None
get_state = getattr(subagent, "get_state", None)
if callable(get_state):
try:
snapshot = get_state(sub_config)
pending_id, pending_value = get_first_pending_subagent_interrupt(
snapshot
)
except Exception:
# Fail loud if a resume is queued: silent fallback would
# replay the original interrupt to the user.
if has_surfsense_resume(runtime):
logger.exception(
"Subagent %r get_state raised with resume queued; re-raising.",
subagent_type,
)
raise
logger.debug(
"Subagent get_state failed; falling back to fresh invoke",
exc_info=True,
)
if pending_value is not None:
resume_value = consume_surfsense_resume(runtime)
if resume_value is None:
# Bridge invariant: a queued resume must accompany any pending
# subagent interrupt. Fall-through replay would silently re-prompt
# the user; raise so the streaming layer surfaces a clear error.
raise RuntimeError(
f"Subagent {subagent_type!r} has a pending interrupt but no "
"surfsense_resume_value on config; resume bridge is broken."
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
result = subagent.invoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
else:
result = subagent.invoke(subagent_state, config=sub_config)
maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
return _return_command_with_state_update(result, runtime.tool_call_id)
async def atask(
description: Annotated[
str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
],
subagent_type: Annotated[
str,
"The type of subagent to use. Must be one of the available agent types listed in the tool description.",
],
runtime: ToolRuntime,
) -> str | Command:
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
return (
f"We cannot invoke subagent {subagent_type} because it does not exist, "
f"the only allowed types are {allowed_types}"
)
if not runtime.tool_call_id:
raise ValueError("Tool call ID is required for subagent invocation")
subagent, subagent_state = _validate_and_prepare_state(
subagent_type, description, runtime
)
sub_config = subagent_invoke_config(runtime)
# Resume bridge — see ``task`` above.
pending_id: str | None = None
pending_value: Any = None
aget_state = getattr(subagent, "aget_state", None)
if callable(aget_state):
try:
snapshot = await aget_state(sub_config)
pending_id, pending_value = get_first_pending_subagent_interrupt(
snapshot
)
except Exception:
if has_surfsense_resume(runtime):
logger.exception(
"Subagent %r aget_state raised with resume queued; re-raising.",
subagent_type,
)
raise
logger.debug(
"Subagent aget_state failed; falling back to fresh ainvoke",
exc_info=True,
)
if pending_value is not None:
resume_value = consume_surfsense_resume(runtime)
if resume_value is None:
raise RuntimeError(
f"Subagent {subagent_type!r} has a pending interrupt but no "
"surfsense_resume_value on config; resume bridge is broken."
)
expected = hitlrequest_action_count(pending_value)
resume_value = fan_out_decisions_to_match(resume_value, expected)
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
else:
result = await subagent.ainvoke(subagent_state, config=sub_config)
await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
return _return_command_with_state_update(result, runtime.tool_call_id)
return StructuredTool.from_function(
name="task",
func=task,
coroutine=atask,
description=description,
)