mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: stamp tool_call_id on subagent interrupts at task chokepoint
This commit is contained in:
parent
fc2c5b6445
commit
e27883e88c
2 changed files with 84 additions and 82 deletions
|
|
@ -1,74 +1,38 @@
|
|||
"""Re-raise still-pending subagent interrupts at the parent graph level.
|
||||
"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value.
|
||||
|
||||
After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may
|
||||
still hold a pending interrupt (e.g. the LLM produced a follow-up tool call
|
||||
that fired a fresh ``interrupt()``). The parent's pregel cannot see that
|
||||
interrupt because it lives in a separate compiled graph; we re-raise it here
|
||||
so the parent's SSE stream surfaces it as the next approval card.
|
||||
When a subagent (compiled as a langgraph subgraph and invoked from a parent
|
||||
tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph
|
||||
raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's
|
||||
``task`` tool catches that exception, stamps ``tool_call_id`` onto each
|
||||
``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a
|
||||
fresh ``GraphInterrupt`` whose values carry that stamp.
|
||||
|
||||
``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]``
|
||||
to route a flat ``decisions`` list back to the right paused subagent — without
|
||||
the stamp, parallel HITL across siblings would collapse into an ambiguous
|
||||
bucket and resume would fail.
|
||||
|
||||
This module hosts only the stamping helper; the catch/re-raise lives in
|
||||
``task_tool.py`` since that's the single chokepoint where the raw exception
|
||||
is in our hands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import Runnable
|
||||
from langgraph.types import interrupt as _lg_interrupt
|
||||
|
||||
from .resume import get_first_pending_subagent_interrupt
|
||||
def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]:
|
||||
"""Return a value dict that always carries the parent's ``tool_call_id``.
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Dict values are shallow-copied with ``tool_call_id`` stamped on top, so
|
||||
any value the subagent may already carry under that key (from a deeper
|
||||
HITL level) is overwritten — the parent's call id is the only one
|
||||
``stream_resume_chat`` correlates against.
|
||||
|
||||
|
||||
def maybe_propagate_subagent_interrupt(
|
||||
subagent: Runnable,
|
||||
sub_config: dict[str, Any],
|
||||
subagent_type: str,
|
||||
) -> None:
|
||||
"""Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it."""
|
||||
get_state_sync = getattr(subagent, "get_state", None)
|
||||
if not callable(get_state_sync):
|
||||
return
|
||||
try:
|
||||
snapshot = get_state_sync(sub_config)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"Subagent get_state failed during re-interrupt check",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
||||
if pending_value is None:
|
||||
return
|
||||
logger.info(
|
||||
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
|
||||
subagent_type,
|
||||
)
|
||||
_lg_interrupt(pending_value)
|
||||
|
||||
|
||||
async def amaybe_propagate_subagent_interrupt(
|
||||
subagent: Runnable,
|
||||
sub_config: dict[str, Any],
|
||||
subagent_type: str,
|
||||
) -> None:
|
||||
"""Async counterpart of :func:`maybe_propagate_subagent_interrupt`."""
|
||||
aget_state = getattr(subagent, "aget_state", None)
|
||||
if not callable(aget_state):
|
||||
return
|
||||
try:
|
||||
snapshot = await aget_state(sub_config)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"Subagent aget_state failed during re-interrupt check",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
||||
if pending_value is None:
|
||||
return
|
||||
logger.info(
|
||||
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
|
||||
subagent_type,
|
||||
)
|
||||
_lg_interrupt(pending_value)
|
||||
Non-dict values are wrapped as ``{"value": <original>, "tool_call_id": ...}``
|
||||
so simple ``interrupt("approve?")`` patterns still propagate cleanly.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {**value, "tool_call_id": tool_call_id}
|
||||
return {"value": value, "tool_call_id": tool_call_id}
|
||||
|
|
|
|||
|
|
@ -9,14 +9,15 @@ re-raises any new pending interrupt back to the parent.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, NoReturn
|
||||
|
||||
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
||||
from langchain.tools import BaseTool, ToolRuntime
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.types import Command
|
||||
from langgraph.errors import GraphInterrupt
|
||||
from langgraph.types import Command, Interrupt
|
||||
|
||||
from .config import (
|
||||
consume_surfsense_resume,
|
||||
|
|
@ -25,10 +26,7 @@ from .config import (
|
|||
subagent_invoke_config,
|
||||
)
|
||||
from .constants import EXCLUDED_STATE_KEYS
|
||||
from .propagation import (
|
||||
amaybe_propagate_subagent_interrupt,
|
||||
maybe_propagate_subagent_interrupt,
|
||||
)
|
||||
from .propagation import wrap_with_tool_call_id
|
||||
from .resume import (
|
||||
build_resume_command,
|
||||
fan_out_decisions_to_match,
|
||||
|
|
@ -39,6 +37,31 @@ from .resume import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _reraise_stamped_subagent_interrupt(
|
||||
gi: GraphInterrupt, tool_call_id: str
|
||||
) -> NoReturn:
|
||||
"""Stamp ``tool_call_id`` onto each pending interrupt value and re-raise.
|
||||
|
||||
See :mod:`...propagation` for why this stamp is required for resume routing.
|
||||
Chained via ``from gi`` so tracebacks point at the subagent's original
|
||||
``interrupt(...)`` site.
|
||||
"""
|
||||
interrupts = gi.args[0] if gi.args else ()
|
||||
stamped = tuple(
|
||||
Interrupt(
|
||||
value=wrap_with_tool_call_id(i.value, tool_call_id),
|
||||
id=i.id,
|
||||
)
|
||||
for i in interrupts
|
||||
)
|
||||
logger.info(
|
||||
"[hitl_route] stamped %d subagent interrupt(s) with tool_call_id=%s",
|
||||
len(stamped),
|
||||
tool_call_id,
|
||||
)
|
||||
raise GraphInterrupt(stamped) from gi
|
||||
|
||||
|
||||
def build_task_tool_with_parent_config(
|
||||
subagents: list[dict[str, Any]],
|
||||
task_description: str | None = None,
|
||||
|
|
@ -161,13 +184,18 @@ def build_task_tool_with_parent_config(
|
|||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
result = subagent.invoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
)
|
||||
try:
|
||||
result = subagent.invoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
)
|
||||
except GraphInterrupt as gi:
|
||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||
else:
|
||||
result = subagent.invoke(subagent_state, config=sub_config)
|
||||
maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
|
||||
try:
|
||||
result = subagent.invoke(subagent_state, config=sub_config)
|
||||
except GraphInterrupt as gi:
|
||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||
return _return_command_with_state_update(result, runtime.tool_call_id)
|
||||
|
||||
async def atask(
|
||||
|
|
@ -181,6 +209,11 @@ def build_task_tool_with_parent_config(
|
|||
],
|
||||
runtime: ToolRuntime,
|
||||
) -> str | Command:
|
||||
logger.info(
|
||||
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
|
||||
subagent_type,
|
||||
runtime.tool_call_id,
|
||||
)
|
||||
if subagent_type not in subagent_graphs:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||
return (
|
||||
|
|
@ -228,13 +261,18 @@ def build_task_tool_with_parent_config(
|
|||
# Prevent the parent's resume payload from leaking into subagent
|
||||
# interrupts via langgraph's parent_scratchpad fallback.
|
||||
drain_parent_null_resume(runtime)
|
||||
result = await subagent.ainvoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
)
|
||||
try:
|
||||
result = await subagent.ainvoke(
|
||||
build_resume_command(resume_value, pending_id),
|
||||
config=sub_config,
|
||||
)
|
||||
except GraphInterrupt as gi:
|
||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||
else:
|
||||
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||
await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
|
||||
try:
|
||||
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||
except GraphInterrupt as gi:
|
||||
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||
return _return_command_with_state_update(result, runtime.tool_call_id)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue