multi_agent_chat/middleware: parallel task tests and full bridge coverage

This commit is contained in:
CREDO23 2026-05-13 19:57:57 +02:00
parent 6fb011c95c
commit 1001f56206
2 changed files with 443 additions and 19 deletions

View file

@ -3,15 +3,24 @@
from __future__ import annotations
import ast
import asyncio
from types import SimpleNamespace
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
subagent_invoke_config,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
@ -24,8 +33,6 @@ class _SubagentState(TypedDict, total=False):
def _build_single_interrupt_subagent():
def approve_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
@ -50,17 +57,27 @@ def _build_single_interrupt_subagent():
return graph.compile(checkpointer=InMemorySaver())
def _make_runtime(config: dict) -> ToolRuntime:
def _make_runtime(config: dict, *, tool_call_id: str = "parent-tcid-1") -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=config,
stream_writer=None,
tool_call_id="parent-tcid-1",
tool_call_id=tool_call_id,
store=None,
)
def _prime_subagent_at_runtime_thread(subagent, runtime: ToolRuntime) -> dict:
"""Build the per-call ``RunnableConfig`` the production ``task`` tool will use.
Mirrors what the ``task`` tool does on first invocation so test fixtures
can prime the subagent's pending interrupt at the same checkpoint slot
(per-call ``thread_id``) the bridge looks at on resume.
"""
return subagent_invoke_config(runtime)
@pytest.mark.asyncio
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
@ -79,16 +96,17 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"configurable": {"thread_id": "shared-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
snap = await subagent.aget_state(sub_config)
assert snap.tasks and snap.tasks[0].interrupts, (
"fixture broken: subagent should be paused on its interrupt"
)
parent_config["configurable"]["surfsense_resume_value"] = {
"decisions": ["APPROVED"]
runtime.tool_call_id: {"decisions": ["APPROVED"]}
}
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="please approve",
@ -101,7 +119,7 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
final = await subagent.aget_state(parent_config)
final = await subagent.aget_state(sub_config)
assert not final.tasks or all(not t.interrupts for t in final.tasks)
@ -123,11 +141,11 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
"configurable": {"thread_id": "guard-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
snap = await subagent.aget_state(sub_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
with pytest.raises(RuntimeError, match="resume bridge is broken"):
await task_tool.coroutine(
@ -139,8 +157,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
def _build_bundle_subagent():
def bundle_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
@ -181,7 +197,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
"configurable": {"thread_id": "bundle-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
decisions_payload = {
"decisions": [
@ -190,8 +208,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
{"type": "reject", "args": {"message": "no thanks"}},
]
}
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
runtime = _make_runtime(parent_config)
parent_config["configurable"]["surfsense_resume_value"] = {
runtime.tool_call_id: decisions_payload
}
result = await task_tool.coroutine(
description="run bundle",
@ -206,3 +225,186 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
assert received["decisions"][1]["type"] == "edit"
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
assert received["decisions"][2]["type"] == "reject"
@pytest.mark.asyncio
async def test_parallel_atask_routes_each_decision_to_its_own_subagent():
"""Two ``atask`` calls with distinct ``tool_call_id``s must each get their own decision.
With per-call ``thread_id`` isolation and per-call resume keying, A's
decision must reach A's pending interrupt and B's must reach B's. They
must NOT cross-contaminate even though they share ``configurable``.
"""
subagent_a = _build_single_interrupt_subagent()
subagent_b = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver_a",
"description": "approves A",
"runnable": subagent_a,
},
{
"name": "approver_b",
"description": "approves B",
"runnable": subagent_b,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "parallel-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
await subagent_a.ainvoke(
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
)
await subagent_b.ainvoke(
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
)
parent_config["configurable"]["surfsense_resume_value"] = {
"tcid-A": {"decisions": ["DECISION-A"]},
"tcid-B": {"decisions": ["DECISION-B"]},
}
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="please approve A",
subagent_type="approver_a",
runtime=runtime_a,
),
task_tool.coroutine(
description="please approve B",
subagent_type="approver_b",
runtime=runtime_b,
),
)
assert isinstance(result_a, Command)
assert isinstance(result_b, Command)
assert result_a.update["decision_text"] == repr({"decisions": ["DECISION-A"]})
assert result_b.update["decision_text"] == repr({"decisions": ["DECISION-B"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
@pytest.mark.asyncio
async def test_full_resume_routing_glue_for_two_paused_subagents():
"""End-to-end: extractor + slicer + bridge correctly route a flat decisions list.
This simulates exactly what ``stream_resume_chat`` will do on resume:
given a paused parent state with two pending interrupts (one per
subagent) and a flat ``decisions`` list, build the per-tool-call dict
via ``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call``,
then resume the bridge concurrently and verify each subagent received
only its own slice.
"""
subagent_a = _build_bundle_subagent()
subagent_b = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "bundler",
"description": "three-action bundle",
"runnable": subagent_a,
},
{
"name": "approver",
"description": "single approval",
"runnable": subagent_b,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "glue-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-bundler")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-approver")
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
await subagent_a.ainvoke(
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
)
await subagent_b.ainvoke(
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
)
# Synthetic parent state mirroring what the parent's pregel would have
# bundled: one Interrupt per subagent, value carrying tool_call_id +
# action_requests (exactly the shape ``propagation.wrap_with_tool_call_id``
# produces).
parent_interrupts = (
SimpleNamespace(
id="i-bundler",
value={
"action_requests": [
{"name": "create_a", "args": {}, "description": ""},
{"name": "create_b", "args": {}, "description": ""},
{"name": "create_c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
"tool_call_id": "tcid-bundler",
},
),
SimpleNamespace(
id="i-approver",
value={
"action_requests": [
{"name": "approve", "args": {}, "description": ""}
],
"review_configs": [{}],
"tool_call_id": "tcid-approver",
},
),
)
parent_state = SimpleNamespace(interrupts=parent_interrupts)
flat_decisions = [
{"type": "approve"},
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
{"type": "reject", "args": {"message": "no thanks"}},
{"type": "approve"},
]
pending = collect_pending_tool_calls(parent_state)
assert pending == [("tcid-bundler", 3), ("tcid-approver", 1)]
routed = slice_decisions_by_tool_call(flat_decisions, pending)
parent_config["configurable"]["surfsense_resume_value"] = routed
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="run bundle",
subagent_type="bundler",
runtime=runtime_a,
),
task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime_b,
),
)
assert isinstance(result_a, Command)
assert isinstance(result_b, Command)
received_a = ast.literal_eval(result_a.update["decision_text"])
assert received_a == {"decisions": flat_decisions[0:3]}
assert result_b.update["decision_text"] == repr(
{"decisions": flat_decisions[3:4]}
)
assert "surfsense_resume_value" not in parent_config["configurable"]

View file

@ -0,0 +1,222 @@
"""Behavioural guarantees for parallel ``task`` tool calls (non-HITL cases).
The HITL bridge tests in ``test_hitl_bridge.py`` cover the parallel-interrupt
flow. This file covers the *normal* parallel paths (no interrupts) and the
failure-isolation guarantee together they pin the behaviour we promise the
user about ``asyncio.gather`` over two ``atask`` coroutines.
"""
from __future__ import annotations
import asyncio
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: list
def _build_success_subagent(reply: str):
"""A subagent that completes immediately with ``reply``, never interrupts."""
def node(_state):
return {"messages": [AIMessage(content=reply)]}
g = StateGraph(_SubState)
g.add_node("only", node)
g.add_edge(START, "only")
g.add_edge("only", END)
return g.compile(checkpointer=InMemorySaver())
def _build_failing_subagent(exc: Exception):
"""A subagent whose only node raises ``exc`` — simulates a tool-level failure."""
def node(_state):
raise exc
g = StateGraph(_SubState)
g.add_node("only", node)
g.add_edge(START, "only")
g.add_edge("only", END)
return g.compile(checkpointer=InMemorySaver())
def _make_runtime(parent_config: dict, *, tool_call_id: str) -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=parent_config,
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
def _tool_message_text(cmd: Command, *, expected_tcid: str) -> str:
"""Return the ToolMessage content the task tool produced for ``expected_tcid``."""
assert isinstance(cmd, Command), f"expected Command, got {type(cmd).__name__}"
messages = cmd.update["messages"]
assert len(messages) == 1, f"expected 1 ToolMessage, got {len(messages)}"
msg = messages[0]
assert isinstance(msg, ToolMessage)
assert msg.tool_call_id == expected_tcid
return msg.content
@pytest.mark.asyncio
async def test_two_parallel_atasks_to_different_subagents_both_succeed():
"""Normal happy-path: two distinct subagents complete in parallel without interrupting."""
subagent_a = _build_success_subagent("A is done")
subagent_b = _build_success_subagent("B is done")
task_tool = build_task_tool_with_parent_config(
[
{"name": "alpha", "description": "alpha agent", "runnable": subagent_a},
{"name": "beta", "description": "beta agent", "runnable": subagent_b},
]
)
parent_config: dict = {
"configurable": {"thread_id": "ok-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="do A",
subagent_type="alpha",
runtime=runtime_a,
),
task_tool.coroutine(
description="do B",
subagent_type="beta",
runtime=runtime_b,
),
)
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "A is done"
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "B is done"
@pytest.mark.asyncio
async def test_two_parallel_atasks_same_subagent_type_different_tool_call_ids():
"""Per-call ``thread_id`` isolation: same compiled subagent invoked twice in parallel.
Both calls share the same ``InMemorySaver`` instance but are namespaced by
distinct ``tool_call_id``s, so checkpoints land in disjoint thread slots.
"""
shared_subagent = _build_success_subagent("ok")
task_tool = build_task_tool_with_parent_config(
[
{"name": "approver", "description": "shared approver", "runnable": shared_subagent},
]
)
parent_config: dict = {
"configurable": {"thread_id": "shared-subagent-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="first request",
subagent_type="approver",
runtime=runtime_a,
),
task_tool.coroutine(
description="second request",
subagent_type="approver",
runtime=runtime_b,
),
)
# Both calls succeed and produce ToolMessages keyed by their own tool_call_id.
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "ok"
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "ok"
# Verify checkpoint isolation: each call's state lives at its own thread_id.
state_a = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-A"}}
)
state_b = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-B"}}
)
assert state_a.values["messages"][-1].content == "ok"
assert state_b.values["messages"][-1].content == "ok"
# The parent's own thread_id slot is untouched by either subagent.
state_parent = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread"}}
)
assert state_parent.values == {} or state_parent.values.get("messages") in (None, [])
@pytest.mark.asyncio
async def test_one_atask_failure_does_not_corrupt_sibling_atask():
"""Failure isolation: a sibling's exception must not poison the surviving atask's state.
Note: in production, langgraph's pregel runner cancels siblings when any
parallel task raises a non-``GraphBubbleUp`` exception (see
``_should_stop_others`` in ``langgraph/pregel/_runner.py``). At our layer
that policy is invisible what we *can* guarantee is that the two atask
coroutines have disjoint state, so the surviving one returns a valid
Command even when its sibling explodes.
"""
failing_subagent = _build_failing_subagent(ValueError("boom"))
surviving_subagent = _build_success_subagent("still here")
task_tool = build_task_tool_with_parent_config(
[
{"name": "broken", "description": "always fails", "runnable": failing_subagent},
{"name": "healthy", "description": "always succeeds", "runnable": surviving_subagent},
]
)
parent_config: dict = {
"configurable": {"thread_id": "iso-thread"},
"recursion_limit": 100,
}
runtime_fail = _make_runtime(parent_config, tool_call_id="tcid-fail")
runtime_ok = _make_runtime(parent_config, tool_call_id="tcid-ok")
results = await asyncio.gather(
task_tool.coroutine(
description="will explode",
subagent_type="broken",
runtime=runtime_fail,
),
task_tool.coroutine(
description="will work",
subagent_type="healthy",
runtime=runtime_ok,
),
return_exceptions=True,
)
fail_result, ok_result = results
assert isinstance(fail_result, Exception), (
f"expected the broken subagent to raise, got {fail_result!r}"
)
# ValueError gets wrapped in langgraph's internal exception types — the
# important guarantee is "this path errored", not the specific class.
assert "boom" in str(fail_result) or isinstance(fail_result, ValueError)
assert _tool_message_text(ok_result, expected_tcid="tcid-ok") == "still here"
# Configurable side-channel must not have been corrupted by the failure.
assert "surfsense_resume_value" not in parent_config["configurable"]