multi_agent_chat/middleware: stamp tool_call_id on subagent interrupts at task chokepoint

This commit is contained in:
CREDO23 2026-05-13 19:57:02 +02:00
parent fc2c5b6445
commit e27883e88c
2 changed files with 84 additions and 82 deletions

View file

@ -1,74 +1,38 @@
"""Re-raise still-pending subagent interrupts at the parent graph level.
"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value.
After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may
still hold a pending interrupt (e.g. the LLM produced a follow-up tool call
that fired a fresh ``interrupt()``). The parent's pregel cannot see that
interrupt because it lives in a separate compiled graph; we re-raise it here
so the parent's SSE stream surfaces it as the next approval card.
When a subagent (compiled as a langgraph subgraph and invoked from a parent
tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph
raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's
``task`` tool catches that exception, stamps ``tool_call_id`` onto each
``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a
fresh ``GraphInterrupt`` whose values carry that stamp.
``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]``
to route a flat ``decisions`` list back to the right paused subagent without
the stamp, parallel HITL across siblings would collapse into an ambiguous
bucket and resume would fail.
This module hosts only the stamping helper; the catch/re-raise lives in
``task_tool.py`` since that's the single chokepoint where the raw exception
is in our hands.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.runnables import Runnable
from langgraph.types import interrupt as _lg_interrupt
from .resume import get_first_pending_subagent_interrupt
def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]:
"""Return a value dict that always carries the parent's ``tool_call_id``.
logger = logging.getLogger(__name__)
Dict values are shallow-copied with ``tool_call_id`` stamped on top, so
any value the subagent may already carry under that key (from a deeper
HITL level) is overwritten the parent's call id is the only one
``stream_resume_chat`` correlates against.
def maybe_propagate_subagent_interrupt(
subagent: Runnable,
sub_config: dict[str, Any],
subagent_type: str,
) -> None:
"""Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it."""
get_state_sync = getattr(subagent, "get_state", None)
if not callable(get_state_sync):
return
try:
snapshot = get_state_sync(sub_config)
except Exception: # pragma: no cover - defensive
logger.debug(
"Subagent get_state failed during re-interrupt check",
exc_info=True,
)
return
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
if pending_value is None:
return
logger.info(
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
subagent_type,
)
_lg_interrupt(pending_value)
async def amaybe_propagate_subagent_interrupt(
subagent: Runnable,
sub_config: dict[str, Any],
subagent_type: str,
) -> None:
"""Async counterpart of :func:`maybe_propagate_subagent_interrupt`."""
aget_state = getattr(subagent, "aget_state", None)
if not callable(aget_state):
return
try:
snapshot = await aget_state(sub_config)
except Exception: # pragma: no cover - defensive
logger.debug(
"Subagent aget_state failed during re-interrupt check",
exc_info=True,
)
return
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
if pending_value is None:
return
logger.info(
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
subagent_type,
)
_lg_interrupt(pending_value)
Non-dict values are wrapped as ``{"value": <original>, "tool_call_id": ...}``
so simple ``interrupt("approve?")`` patterns still propagate cleanly.
"""
if isinstance(value, dict):
return {**value, "tool_call_id": tool_call_id}
return {"value": value, "tool_call_id": tool_call_id}

View file

@ -9,14 +9,15 @@ re-raises any new pending interrupt back to the parent.
from __future__ import annotations
import logging
from typing import Annotated, Any
from typing import Annotated, Any, NoReturn
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
from langchain.tools import BaseTool, ToolRuntime
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import StructuredTool
from langgraph.types import Command
from langgraph.errors import GraphInterrupt
from langgraph.types import Command, Interrupt
from .config import (
consume_surfsense_resume,
@ -25,10 +26,7 @@ from .config import (
subagent_invoke_config,
)
from .constants import EXCLUDED_STATE_KEYS
from .propagation import (
amaybe_propagate_subagent_interrupt,
maybe_propagate_subagent_interrupt,
)
from .propagation import wrap_with_tool_call_id
from .resume import (
build_resume_command,
fan_out_decisions_to_match,
@ -39,6 +37,31 @@ from .resume import (
logger = logging.getLogger(__name__)
def _reraise_stamped_subagent_interrupt(
gi: GraphInterrupt, tool_call_id: str
) -> NoReturn:
"""Stamp ``tool_call_id`` onto each pending interrupt value and re-raise.
See :mod:`...propagation` for why this stamp is required for resume routing.
Chained via ``from gi`` so tracebacks point at the subagent's original
``interrupt(...)`` site.
"""
interrupts = gi.args[0] if gi.args else ()
stamped = tuple(
Interrupt(
value=wrap_with_tool_call_id(i.value, tool_call_id),
id=i.id,
)
for i in interrupts
)
logger.info(
"[hitl_route] stamped %d subagent interrupt(s) with tool_call_id=%s",
len(stamped),
tool_call_id,
)
raise GraphInterrupt(stamped) from gi
def build_task_tool_with_parent_config(
subagents: list[dict[str, Any]],
task_description: str | None = None,
@ -161,13 +184,18 @@ def build_task_tool_with_parent_config(
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
result = subagent.invoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
try:
result = subagent.invoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
except GraphInterrupt as gi:
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
else:
result = subagent.invoke(subagent_state, config=sub_config)
maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
try:
result = subagent.invoke(subagent_state, config=sub_config)
except GraphInterrupt as gi:
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
return _return_command_with_state_update(result, runtime.tool_call_id)
async def atask(
@ -181,6 +209,11 @@ def build_task_tool_with_parent_config(
],
runtime: ToolRuntime,
) -> str | Command:
logger.info(
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
subagent_type,
runtime.tool_call_id,
)
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
return (
@ -228,13 +261,18 @@ def build_task_tool_with_parent_config(
# Prevent the parent's resume payload from leaking into subagent
# interrupts via langgraph's parent_scratchpad fallback.
drain_parent_null_resume(runtime)
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
try:
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
)
except GraphInterrupt as gi:
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
else:
result = await subagent.ainvoke(subagent_state, config=sub_config)
await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
try:
result = await subagent.ainvoke(subagent_state, config=sub_config)
except GraphInterrupt as gi:
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
return _return_command_with_state_update(result, runtime.tool_call_id)
return StructuredTool.from_function(