mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: stamp tool_call_id on subagent interrupts at task chokepoint
This commit is contained in:
parent
fc2c5b6445
commit
e27883e88c
2 changed files with 84 additions and 82 deletions
|
|
@ -1,74 +1,38 @@
|
||||||
"""Re-raise still-pending subagent interrupts at the parent graph level.
|
"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value.
|
||||||
|
|
||||||
After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may
|
When a subagent (compiled as a langgraph subgraph and invoked from a parent
|
||||||
still hold a pending interrupt (e.g. the LLM produced a follow-up tool call
|
tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph
|
||||||
that fired a fresh ``interrupt()``). The parent's pregel cannot see that
|
raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's
|
||||||
interrupt because it lives in a separate compiled graph; we re-raise it here
|
``task`` tool catches that exception, stamps ``tool_call_id`` onto each
|
||||||
so the parent's SSE stream surfaces it as the next approval card.
|
``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a
|
||||||
|
fresh ``GraphInterrupt`` whose values carry that stamp.
|
||||||
|
|
||||||
|
``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]``
|
||||||
|
to route a flat ``decisions`` list back to the right paused subagent — without
|
||||||
|
the stamp, parallel HITL across siblings would collapse into an ambiguous
|
||||||
|
bucket and resume would fail.
|
||||||
|
|
||||||
|
This module hosts only the stamping helper; the catch/re-raise lives in
|
||||||
|
``task_tool.py`` since that's the single chokepoint where the raw exception
|
||||||
|
is in our hands.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.runnables import Runnable
|
|
||||||
from langgraph.types import interrupt as _lg_interrupt
|
|
||||||
|
|
||||||
from .resume import get_first_pending_subagent_interrupt
|
def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]:
|
||||||
|
"""Return a value dict that always carries the parent's ``tool_call_id``.
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
Dict values are shallow-copied with ``tool_call_id`` stamped on top, so
|
||||||
|
any value the subagent may already carry under that key (from a deeper
|
||||||
|
HITL level) is overwritten — the parent's call id is the only one
|
||||||
|
``stream_resume_chat`` correlates against.
|
||||||
|
|
||||||
|
Non-dict values are wrapped as ``{"value": <original>, "tool_call_id": ...}``
|
||||||
def maybe_propagate_subagent_interrupt(
|
so simple ``interrupt("approve?")`` patterns still propagate cleanly.
|
||||||
subagent: Runnable,
|
"""
|
||||||
sub_config: dict[str, Any],
|
if isinstance(value, dict):
|
||||||
subagent_type: str,
|
return {**value, "tool_call_id": tool_call_id}
|
||||||
) -> None:
|
return {"value": value, "tool_call_id": tool_call_id}
|
||||||
"""Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it."""
|
|
||||||
get_state_sync = getattr(subagent, "get_state", None)
|
|
||||||
if not callable(get_state_sync):
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
snapshot = get_state_sync(sub_config)
|
|
||||||
except Exception: # pragma: no cover - defensive
|
|
||||||
logger.debug(
|
|
||||||
"Subagent get_state failed during re-interrupt check",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
|
||||||
if pending_value is None:
|
|
||||||
return
|
|
||||||
logger.info(
|
|
||||||
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
|
|
||||||
subagent_type,
|
|
||||||
)
|
|
||||||
_lg_interrupt(pending_value)
|
|
||||||
|
|
||||||
|
|
||||||
async def amaybe_propagate_subagent_interrupt(
|
|
||||||
subagent: Runnable,
|
|
||||||
sub_config: dict[str, Any],
|
|
||||||
subagent_type: str,
|
|
||||||
) -> None:
|
|
||||||
"""Async counterpart of :func:`maybe_propagate_subagent_interrupt`."""
|
|
||||||
aget_state = getattr(subagent, "aget_state", None)
|
|
||||||
if not callable(aget_state):
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
snapshot = await aget_state(sub_config)
|
|
||||||
except Exception: # pragma: no cover - defensive
|
|
||||||
logger.debug(
|
|
||||||
"Subagent aget_state failed during re-interrupt check",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
|
||||||
if pending_value is None:
|
|
||||||
return
|
|
||||||
logger.info(
|
|
||||||
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
|
|
||||||
subagent_type,
|
|
||||||
)
|
|
||||||
_lg_interrupt(pending_value)
|
|
||||||
|
|
|
||||||
|
|
@ -9,14 +9,15 @@ re-raises any new pending interrupt back to the parent.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any, NoReturn
|
||||||
|
|
||||||
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
||||||
from langchain.tools import BaseTool, ToolRuntime
|
from langchain.tools import BaseTool, ToolRuntime
|
||||||
from langchain_core.messages import HumanMessage, ToolMessage
|
from langchain_core.messages import HumanMessage, ToolMessage
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from langgraph.types import Command
|
from langgraph.errors import GraphInterrupt
|
||||||
|
from langgraph.types import Command, Interrupt
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
consume_surfsense_resume,
|
consume_surfsense_resume,
|
||||||
|
|
@ -25,10 +26,7 @@ from .config import (
|
||||||
subagent_invoke_config,
|
subagent_invoke_config,
|
||||||
)
|
)
|
||||||
from .constants import EXCLUDED_STATE_KEYS
|
from .constants import EXCLUDED_STATE_KEYS
|
||||||
from .propagation import (
|
from .propagation import wrap_with_tool_call_id
|
||||||
amaybe_propagate_subagent_interrupt,
|
|
||||||
maybe_propagate_subagent_interrupt,
|
|
||||||
)
|
|
||||||
from .resume import (
|
from .resume import (
|
||||||
build_resume_command,
|
build_resume_command,
|
||||||
fan_out_decisions_to_match,
|
fan_out_decisions_to_match,
|
||||||
|
|
@ -39,6 +37,31 @@ from .resume import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _reraise_stamped_subagent_interrupt(
|
||||||
|
gi: GraphInterrupt, tool_call_id: str
|
||||||
|
) -> NoReturn:
|
||||||
|
"""Stamp ``tool_call_id`` onto each pending interrupt value and re-raise.
|
||||||
|
|
||||||
|
See :mod:`...propagation` for why this stamp is required for resume routing.
|
||||||
|
Chained via ``from gi`` so tracebacks point at the subagent's original
|
||||||
|
``interrupt(...)`` site.
|
||||||
|
"""
|
||||||
|
interrupts = gi.args[0] if gi.args else ()
|
||||||
|
stamped = tuple(
|
||||||
|
Interrupt(
|
||||||
|
value=wrap_with_tool_call_id(i.value, tool_call_id),
|
||||||
|
id=i.id,
|
||||||
|
)
|
||||||
|
for i in interrupts
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[hitl_route] stamped %d subagent interrupt(s) with tool_call_id=%s",
|
||||||
|
len(stamped),
|
||||||
|
tool_call_id,
|
||||||
|
)
|
||||||
|
raise GraphInterrupt(stamped) from gi
|
||||||
|
|
||||||
|
|
||||||
def build_task_tool_with_parent_config(
|
def build_task_tool_with_parent_config(
|
||||||
subagents: list[dict[str, Any]],
|
subagents: list[dict[str, Any]],
|
||||||
task_description: str | None = None,
|
task_description: str | None = None,
|
||||||
|
|
@ -161,13 +184,18 @@ def build_task_tool_with_parent_config(
|
||||||
# Prevent the parent's resume payload from leaking into subagent
|
# Prevent the parent's resume payload from leaking into subagent
|
||||||
# interrupts via langgraph's parent_scratchpad fallback.
|
# interrupts via langgraph's parent_scratchpad fallback.
|
||||||
drain_parent_null_resume(runtime)
|
drain_parent_null_resume(runtime)
|
||||||
result = subagent.invoke(
|
try:
|
||||||
build_resume_command(resume_value, pending_id),
|
result = subagent.invoke(
|
||||||
config=sub_config,
|
build_resume_command(resume_value, pending_id),
|
||||||
)
|
config=sub_config,
|
||||||
|
)
|
||||||
|
except GraphInterrupt as gi:
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
else:
|
else:
|
||||||
result = subagent.invoke(subagent_state, config=sub_config)
|
try:
|
||||||
maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
|
result = subagent.invoke(subagent_state, config=sub_config)
|
||||||
|
except GraphInterrupt as gi:
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
return _return_command_with_state_update(result, runtime.tool_call_id)
|
return _return_command_with_state_update(result, runtime.tool_call_id)
|
||||||
|
|
||||||
async def atask(
|
async def atask(
|
||||||
|
|
@ -181,6 +209,11 @@ def build_task_tool_with_parent_config(
|
||||||
],
|
],
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
) -> str | Command:
|
) -> str | Command:
|
||||||
|
logger.info(
|
||||||
|
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
|
||||||
|
subagent_type,
|
||||||
|
runtime.tool_call_id,
|
||||||
|
)
|
||||||
if subagent_type not in subagent_graphs:
|
if subagent_type not in subagent_graphs:
|
||||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||||
return (
|
return (
|
||||||
|
|
@ -228,13 +261,18 @@ def build_task_tool_with_parent_config(
|
||||||
# Prevent the parent's resume payload from leaking into subagent
|
# Prevent the parent's resume payload from leaking into subagent
|
||||||
# interrupts via langgraph's parent_scratchpad fallback.
|
# interrupts via langgraph's parent_scratchpad fallback.
|
||||||
drain_parent_null_resume(runtime)
|
drain_parent_null_resume(runtime)
|
||||||
result = await subagent.ainvoke(
|
try:
|
||||||
build_resume_command(resume_value, pending_id),
|
result = await subagent.ainvoke(
|
||||||
config=sub_config,
|
build_resume_command(resume_value, pending_id),
|
||||||
)
|
config=sub_config,
|
||||||
|
)
|
||||||
|
except GraphInterrupt as gi:
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
else:
|
else:
|
||||||
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
try:
|
||||||
await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
|
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||||
|
except GraphInterrupt as gi:
|
||||||
|
_reraise_stamped_subagent_interrupt(gi, runtime.tool_call_id)
|
||||||
return _return_command_with_state_update(result, runtime.tool_call_id)
|
return _return_command_with_state_update(result, runtime.tool_call_id)
|
||||||
|
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue