multi_agent_chat/middleware: parallel task tests and full bridge coverage

This commit is contained in:
CREDO23 2026-05-13 19:57:57 +02:00
parent 6fb011c95c
commit 1001f56206
2 changed files with 443 additions and 19 deletions

View file

@ -3,15 +3,24 @@
from __future__ import annotations from __future__ import annotations
import ast import ast
import asyncio
from types import SimpleNamespace
import pytest import pytest
from langchain.tools import ToolRuntime from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph from langgraph.graph import END, START, StateGraph
from langgraph.types import Command, interrupt from langgraph.types import Command, interrupt
from typing_extensions import TypedDict from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
subagent_invoke_config,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config, build_task_tool_with_parent_config,
) )
@ -24,8 +33,6 @@ class _SubagentState(TypedDict, total=False):
def _build_single_interrupt_subagent(): def _build_single_interrupt_subagent():
def approve_node(state): def approve_node(state):
from langchain_core.messages import AIMessage
decision = interrupt( decision = interrupt(
{ {
"action_requests": [ "action_requests": [
@ -50,17 +57,27 @@ def _build_single_interrupt_subagent():
return graph.compile(checkpointer=InMemorySaver()) return graph.compile(checkpointer=InMemorySaver())
def _make_runtime(config: dict) -> ToolRuntime: def _make_runtime(config: dict, *, tool_call_id: str = "parent-tcid-1") -> ToolRuntime:
return ToolRuntime( return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]}, state={"messages": [HumanMessage(content="seed")]},
context=None, context=None,
config=config, config=config,
stream_writer=None, stream_writer=None,
tool_call_id="parent-tcid-1", tool_call_id=tool_call_id,
store=None, store=None,
) )
def _prime_subagent_at_runtime_thread(subagent, runtime: ToolRuntime) -> dict:
"""Build the per-call ``RunnableConfig`` the production ``task`` tool will use.
Mirrors what the ``task`` tool does on first invocation so test fixtures
can prime the subagent's pending interrupt at the same checkpoint slot
(per-call ``thread_id``) the bridge looks at on resume.
"""
return subagent_invoke_config(runtime)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_resume_bridge_dispatches_decision_into_pending_subagent(): async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"""Side-channel decision must reach the subagent's pending interrupt verbatim.""" """Side-channel decision must reach the subagent's pending interrupt verbatim."""
@ -79,16 +96,17 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"configurable": {"thread_id": "shared-thread"}, "configurable": {"thread_id": "shared-thread"},
"recursion_limit": 100, "recursion_limit": 100,
} }
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) runtime = _make_runtime(parent_config)
snap = await subagent.aget_state(parent_config) sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
snap = await subagent.aget_state(sub_config)
assert snap.tasks and snap.tasks[0].interrupts, ( assert snap.tasks and snap.tasks[0].interrupts, (
"fixture broken: subagent should be paused on its interrupt" "fixture broken: subagent should be paused on its interrupt"
) )
parent_config["configurable"]["surfsense_resume_value"] = { parent_config["configurable"]["surfsense_resume_value"] = {
"decisions": ["APPROVED"] runtime.tool_call_id: {"decisions": ["APPROVED"]}
} }
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine( result = await task_tool.coroutine(
description="please approve", description="please approve",
@ -101,7 +119,7 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
assert update["decision_text"] == repr({"decisions": ["APPROVED"]}) assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
assert "surfsense_resume_value" not in parent_config["configurable"] assert "surfsense_resume_value" not in parent_config["configurable"]
final = await subagent.aget_state(parent_config) final = await subagent.aget_state(sub_config)
assert not final.tasks or all(not t.interrupts for t in final.tasks) assert not final.tasks or all(not t.interrupts for t in final.tasks)
@ -123,11 +141,11 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
"configurable": {"thread_id": "guard-thread"}, "configurable": {"thread_id": "guard-thread"},
"recursion_limit": 100, "recursion_limit": 100,
} }
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
runtime = _make_runtime(parent_config) runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
snap = await subagent.aget_state(sub_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
with pytest.raises(RuntimeError, match="resume bridge is broken"): with pytest.raises(RuntimeError, match="resume bridge is broken"):
await task_tool.coroutine( await task_tool.coroutine(
@ -139,8 +157,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
def _build_bundle_subagent(): def _build_bundle_subagent():
def bundle_node(state): def bundle_node(state):
from langchain_core.messages import AIMessage
decision = interrupt( decision = interrupt(
{ {
"action_requests": [ "action_requests": [
@ -181,7 +197,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
"configurable": {"thread_id": "bundle-thread"}, "configurable": {"thread_id": "bundle-thread"},
"recursion_limit": 100, "recursion_limit": 100,
} }
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
decisions_payload = { decisions_payload = {
"decisions": [ "decisions": [
@ -190,8 +208,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
{"type": "reject", "args": {"message": "no thanks"}}, {"type": "reject", "args": {"message": "no thanks"}},
] ]
} }
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload parent_config["configurable"]["surfsense_resume_value"] = {
runtime = _make_runtime(parent_config) runtime.tool_call_id: decisions_payload
}
result = await task_tool.coroutine( result = await task_tool.coroutine(
description="run bundle", description="run bundle",
@ -206,3 +225,186 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
assert received["decisions"][1]["type"] == "edit" assert received["decisions"][1]["type"] == "edit"
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}} assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
assert received["decisions"][2]["type"] == "reject" assert received["decisions"][2]["type"] == "reject"
@pytest.mark.asyncio
async def test_parallel_atask_routes_each_decision_to_its_own_subagent():
"""Two ``atask`` calls with distinct ``tool_call_id``s must each get their own decision.
With per-call ``thread_id`` isolation and per-call resume keying, A's
decision must reach A's pending interrupt and B's must reach B's. They
must NOT cross-contaminate even though they share ``configurable``.
"""
subagent_a = _build_single_interrupt_subagent()
subagent_b = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver_a",
"description": "approves A",
"runnable": subagent_a,
},
{
"name": "approver_b",
"description": "approves B",
"runnable": subagent_b,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "parallel-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
await subagent_a.ainvoke(
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
)
await subagent_b.ainvoke(
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
)
parent_config["configurable"]["surfsense_resume_value"] = {
"tcid-A": {"decisions": ["DECISION-A"]},
"tcid-B": {"decisions": ["DECISION-B"]},
}
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="please approve A",
subagent_type="approver_a",
runtime=runtime_a,
),
task_tool.coroutine(
description="please approve B",
subagent_type="approver_b",
runtime=runtime_b,
),
)
assert isinstance(result_a, Command)
assert isinstance(result_b, Command)
assert result_a.update["decision_text"] == repr({"decisions": ["DECISION-A"]})
assert result_b.update["decision_text"] == repr({"decisions": ["DECISION-B"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
@pytest.mark.asyncio
async def test_full_resume_routing_glue_for_two_paused_subagents():
"""End-to-end: extractor + slicer + bridge correctly route a flat decisions list.
This simulates exactly what ``stream_resume_chat`` will do on resume:
given a paused parent state with two pending interrupts (one per
subagent) and a flat ``decisions`` list, build the per-tool-call dict
via ``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call``,
then resume the bridge concurrently and verify each subagent received
only its own slice.
"""
subagent_a = _build_bundle_subagent()
subagent_b = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "bundler",
"description": "three-action bundle",
"runnable": subagent_a,
},
{
"name": "approver",
"description": "single approval",
"runnable": subagent_b,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "glue-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-bundler")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-approver")
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
await subagent_a.ainvoke(
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
)
await subagent_b.ainvoke(
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
)
# Synthetic parent state mirroring what the parent's pregel would have
# bundled: one Interrupt per subagent, value carrying tool_call_id +
# action_requests (exactly the shape ``propagation.wrap_with_tool_call_id``
# produces).
parent_interrupts = (
SimpleNamespace(
id="i-bundler",
value={
"action_requests": [
{"name": "create_a", "args": {}, "description": ""},
{"name": "create_b", "args": {}, "description": ""},
{"name": "create_c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
"tool_call_id": "tcid-bundler",
},
),
SimpleNamespace(
id="i-approver",
value={
"action_requests": [
{"name": "approve", "args": {}, "description": ""}
],
"review_configs": [{}],
"tool_call_id": "tcid-approver",
},
),
)
parent_state = SimpleNamespace(interrupts=parent_interrupts)
flat_decisions = [
{"type": "approve"},
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
{"type": "reject", "args": {"message": "no thanks"}},
{"type": "approve"},
]
pending = collect_pending_tool_calls(parent_state)
assert pending == [("tcid-bundler", 3), ("tcid-approver", 1)]
routed = slice_decisions_by_tool_call(flat_decisions, pending)
parent_config["configurable"]["surfsense_resume_value"] = routed
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="run bundle",
subagent_type="bundler",
runtime=runtime_a,
),
task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime_b,
),
)
assert isinstance(result_a, Command)
assert isinstance(result_b, Command)
received_a = ast.literal_eval(result_a.update["decision_text"])
assert received_a == {"decisions": flat_decisions[0:3]}
assert result_b.update["decision_text"] == repr(
{"decisions": flat_decisions[3:4]}
)
assert "surfsense_resume_value" not in parent_config["configurable"]

View file

@ -0,0 +1,222 @@
"""Behavioural guarantees for parallel ``task`` tool calls (non-HITL cases).
The HITL bridge tests in ``test_hitl_bridge.py`` cover the parallel-interrupt
flow. This file covers the *normal* parallel paths (no interrupts) and the
failure-isolation guarantee together they pin the behaviour we promise the
user about ``asyncio.gather`` over two ``atask`` coroutines.
"""
from __future__ import annotations
import asyncio
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: list
def _build_success_subagent(reply: str):
"""A subagent that completes immediately with ``reply``, never interrupts."""
def node(_state):
return {"messages": [AIMessage(content=reply)]}
g = StateGraph(_SubState)
g.add_node("only", node)
g.add_edge(START, "only")
g.add_edge("only", END)
return g.compile(checkpointer=InMemorySaver())
def _build_failing_subagent(exc: Exception):
"""A subagent whose only node raises ``exc`` — simulates a tool-level failure."""
def node(_state):
raise exc
g = StateGraph(_SubState)
g.add_node("only", node)
g.add_edge(START, "only")
g.add_edge("only", END)
return g.compile(checkpointer=InMemorySaver())
def _make_runtime(parent_config: dict, *, tool_call_id: str) -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=parent_config,
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
def _tool_message_text(cmd: Command, *, expected_tcid: str) -> str:
"""Return the ToolMessage content the task tool produced for ``expected_tcid``."""
assert isinstance(cmd, Command), f"expected Command, got {type(cmd).__name__}"
messages = cmd.update["messages"]
assert len(messages) == 1, f"expected 1 ToolMessage, got {len(messages)}"
msg = messages[0]
assert isinstance(msg, ToolMessage)
assert msg.tool_call_id == expected_tcid
return msg.content
@pytest.mark.asyncio
async def test_two_parallel_atasks_to_different_subagents_both_succeed():
"""Normal happy-path: two distinct subagents complete in parallel without interrupting."""
subagent_a = _build_success_subagent("A is done")
subagent_b = _build_success_subagent("B is done")
task_tool = build_task_tool_with_parent_config(
[
{"name": "alpha", "description": "alpha agent", "runnable": subagent_a},
{"name": "beta", "description": "beta agent", "runnable": subagent_b},
]
)
parent_config: dict = {
"configurable": {"thread_id": "ok-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="do A",
subagent_type="alpha",
runtime=runtime_a,
),
task_tool.coroutine(
description="do B",
subagent_type="beta",
runtime=runtime_b,
),
)
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "A is done"
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "B is done"
@pytest.mark.asyncio
async def test_two_parallel_atasks_same_subagent_type_different_tool_call_ids():
"""Per-call ``thread_id`` isolation: same compiled subagent invoked twice in parallel.
Both calls share the same ``InMemorySaver`` instance but are namespaced by
distinct ``tool_call_id``s, so checkpoints land in disjoint thread slots.
"""
shared_subagent = _build_success_subagent("ok")
task_tool = build_task_tool_with_parent_config(
[
{"name": "approver", "description": "shared approver", "runnable": shared_subagent},
]
)
parent_config: dict = {
"configurable": {"thread_id": "shared-subagent-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="first request",
subagent_type="approver",
runtime=runtime_a,
),
task_tool.coroutine(
description="second request",
subagent_type="approver",
runtime=runtime_b,
),
)
# Both calls succeed and produce ToolMessages keyed by their own tool_call_id.
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "ok"
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "ok"
# Verify checkpoint isolation: each call's state lives at its own thread_id.
state_a = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-A"}}
)
state_b = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-B"}}
)
assert state_a.values["messages"][-1].content == "ok"
assert state_b.values["messages"][-1].content == "ok"
# The parent's own thread_id slot is untouched by either subagent.
state_parent = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread"}}
)
assert state_parent.values == {} or state_parent.values.get("messages") in (None, [])
@pytest.mark.asyncio
async def test_one_atask_failure_does_not_corrupt_sibling_atask():
"""Failure isolation: a sibling's exception must not poison the surviving atask's state.
Note: in production, langgraph's pregel runner cancels siblings when any
parallel task raises a non-``GraphBubbleUp`` exception (see
``_should_stop_others`` in ``langgraph/pregel/_runner.py``). At our layer
that policy is invisible what we *can* guarantee is that the two atask
coroutines have disjoint state, so the surviving one returns a valid
Command even when its sibling explodes.
"""
failing_subagent = _build_failing_subagent(ValueError("boom"))
surviving_subagent = _build_success_subagent("still here")
task_tool = build_task_tool_with_parent_config(
[
{"name": "broken", "description": "always fails", "runnable": failing_subagent},
{"name": "healthy", "description": "always succeeds", "runnable": surviving_subagent},
]
)
parent_config: dict = {
"configurable": {"thread_id": "iso-thread"},
"recursion_limit": 100,
}
runtime_fail = _make_runtime(parent_config, tool_call_id="tcid-fail")
runtime_ok = _make_runtime(parent_config, tool_call_id="tcid-ok")
results = await asyncio.gather(
task_tool.coroutine(
description="will explode",
subagent_type="broken",
runtime=runtime_fail,
),
task_tool.coroutine(
description="will work",
subagent_type="healthy",
runtime=runtime_ok,
),
return_exceptions=True,
)
fail_result, ok_result = results
assert isinstance(fail_result, Exception), (
f"expected the broken subagent to raise, got {fail_result!r}"
)
# ValueError gets wrapped in langgraph's internal exception types — the
# important guarantee is "this path errored", not the specific class.
assert "boom" in str(fail_result) or isinstance(fail_result, ValueError)
assert _tool_message_text(ok_result, expected_tcid="tcid-ok") == "still here"
# Configurable side-channel must not have been corrupted by the failure.
assert "surfsense_resume_value" not in parent_config["configurable"]