From e27883e88cb14c49d96d4b866656e1b2b104c6a4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 13 May 2026 19:57:02 +0200 Subject: [PATCH] multi_agent_chat/middleware: stamp tool_call_id on subagent interrupts at task chokepoint --- .../propagation.py | 92 ++++++------------- .../task_tool.py | 74 +++++++++++---- 2 files changed, 84 insertions(+), 82 deletions(-) diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py index 55aae7201..cfebe1fd9 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/propagation.py @@ -1,74 +1,38 @@ -"""Re-raise still-pending subagent interrupts at the parent graph level. +"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value. -After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may -still hold a pending interrupt (e.g. the LLM produced a follow-up tool call -that fired a fresh ``interrupt()``). The parent's pregel cannot see that -interrupt because it lives in a separate compiled graph; we re-raise it here -so the parent's SSE stream surfaces it as the next approval card. +When a subagent (compiled as a langgraph subgraph and invoked from a parent +tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph +raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's +``task`` tool catches that exception, stamps ``tool_call_id`` onto each +``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a +fresh ``GraphInterrupt`` whose values carry that stamp. + +``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]`` +to route a flat ``decisions`` list back to the right paused subagent — without +the stamp, parallel HITL across siblings would collapse into an ambiguous +bucket and resume would fail. + +This module hosts only the stamping helper; the catch/re-raise lives in +``task_tool.py`` since that's the single chokepoint where the raw exception +is in our hands. """ from __future__ import annotations -import logging from typing import Any -from langchain_core.runnables import Runnable -from langgraph.types import interrupt as _lg_interrupt -from .resume import get_first_pending_subagent_interrupt +def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]: + """Return a value dict that always carries the parent's ``tool_call_id``. -logger = logging.getLogger(__name__) + Dict values are shallow-copied with ``tool_call_id`` stamped on top, so + any value the subagent may already carry under that key (from a deeper + HITL level) is overwritten — the parent's call id is the only one + ``stream_resume_chat`` correlates against. - -def maybe_propagate_subagent_interrupt( - subagent: Runnable, - sub_config: dict[str, Any], - subagent_type: str, -) -> None: - """Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it.""" - get_state_sync = getattr(subagent, "get_state", None) - if not callable(get_state_sync): - return - try: - snapshot = get_state_sync(sub_config) - except Exception: # pragma: no cover - defensive - logger.debug( - "Subagent get_state failed during re-interrupt check", - exc_info=True, - ) - return - _pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) - if pending_value is None: - return - logger.info( - "Re-raising subagent %r interrupt to parent (multi-step HITL)", - subagent_type, - ) - _lg_interrupt(pending_value) - - -async def amaybe_propagate_subagent_interrupt( - subagent: Runnable, - sub_config: dict[str, Any], - subagent_type: str, -) -> None: - """Async counterpart of :func:`maybe_propagate_subagent_interrupt`.""" - aget_state = getattr(subagent, "aget_state", None) - if not callable(aget_state): - return - try: - snapshot = await aget_state(sub_config) - except Exception: # pragma: no cover - defensive - logger.debug( - "Subagent aget_state failed during re-interrupt check", - exc_info=True, - ) - return - _pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot) - if pending_value is None: - return - logger.info( - "Re-raising subagent %r interrupt to parent (multi-step HITL)", - subagent_type, - ) - _lg_interrupt(pending_value) + Non-dict values are wrapped as ``{"value": , "tool_call_id": ...}`` + so simple ``interrupt("approve?")`` patterns still propagate cleanly. + """ + if isinstance(value, dict): + return {**value, "tool_call_id": tool_call_id} + return {"value": value, "tool_call_id": tool_call_id} diff --git a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py index 7c0dd8624..f9b316e23 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py +++ b/surfsense_backend/app/agents/multi_agent_chat/middleware/main_agent/checkpointed_subagent_middleware/task_tool.py @@ -9,14 +9,15 @@ re-raises any new pending interrupt back to the parent. from __future__ import annotations import logging -from typing import Annotated, Any +from typing import Annotated, Any, NoReturn from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION from langchain.tools import BaseTool, ToolRuntime from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import StructuredTool -from langgraph.types import Command +from langgraph.errors import GraphInterrupt +from langgraph.types import Command, Interrupt from .config import ( consume_surfsense_resume, @@ -25,10 +26,7 @@ from .config import ( subagent_invoke_config, ) from .constants import EXCLUDED_STATE_KEYS -from .propagation import ( - amaybe_propagate_subagent_interrupt, - maybe_propagate_subagent_interrupt, -) +from .propagation import wrap_with_tool_call_id from .resume import ( build_resume_command, fan_out_decisions_to_match, @@ -39,6 +37,31 @@ from .resume import ( logger = logging.getLogger(__name__) +def _reraise_stamped_subagent_interrupt( + gi: GraphInterrupt, tool_call_id: str +) -> NoReturn: + """Stamp ``tool_call_id`` onto each pending interrupt value and re-raise. + + See :mod:`...propagation` for why this stamp is required for resume routing. + Chained via ``from gi`` so tracebacks point at the subagent's original + ``interrupt(...)`` site. + """ + interrupts = gi.args[0] if gi.args else () + stamped = tuple( + Interrupt( + value=wrap_with_tool_call_id(i.value, tool_call_id), + id=i.id, + ) + for i in interrupts + ) + logger.info( + "[hitl_route] stamped %d subagent interrupt(s) with tool_call_id=%s", + len(stamped), + tool_call_id, + ) + raise GraphInterrupt(stamped) from gi + + def build_task_tool_with_parent_config( subagents: list[dict[str, Any]], task_description: str | None = None, @@ -161,13 +184,18 @@ def build_task_tool_with_parent_config( # Prevent the parent's resume payload from leaking into subagent # interrupts via langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) - result = subagent.invoke( - build_resume_command(resume_value, pending_id), - config=sub_config, - ) + try: + result = subagent.invoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) else: - result = subagent.invoke(subagent_state, config=sub_config) - maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type) + try: + result = subagent.invoke(subagent_state, config=sub_config) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) return _return_command_with_state_update(result, runtime.tool_call_id) async def atask( @@ -181,6 +209,11 @@ def build_task_tool_with_parent_config( ], runtime: ToolRuntime, ) -> str | Command: + logger.info( + "[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s", + subagent_type, + runtime.tool_call_id, + ) if subagent_type not in subagent_graphs: allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs]) return ( @@ -228,13 +261,18 @@ def build_task_tool_with_parent_config( # Prevent the parent's resume payload from leaking into subagent # interrupts via langgraph's parent_scratchpad fallback. drain_parent_null_resume(runtime) - result = await subagent.ainvoke( - build_resume_command(resume_value, pending_id), - config=sub_config, - ) + try: + result = await subagent.ainvoke( + build_resume_command(resume_value, pending_id), + config=sub_config, + ) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) else: - result = await subagent.ainvoke(subagent_state, config=sub_config) - await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type) + try: + result = await subagent.ainvoke(subagent_state, config=sub_config) + except GraphInterrupt as gi: + _reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id) return _return_command_with_state_update(result, runtime.tool_call_id) return StructuredTool.from_function(