diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py index dbc2c9c00..48eabbd7c 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_hitl_bridge.py @@ -3,15 +3,24 @@ from __future__ import annotations import ast +import asyncio +from types import SimpleNamespace import pytest from langchain.tools import ToolRuntime -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import END, START, StateGraph from langgraph.types import Command, interrupt from typing_extensions import TypedDict +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import ( + subagent_invoke_config, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( build_task_tool_with_parent_config, ) @@ -24,8 +33,6 @@ class _SubagentState(TypedDict, total=False): def _build_single_interrupt_subagent(): def approve_node(state): - from langchain_core.messages import AIMessage - decision = interrupt( { "action_requests": [ @@ -50,17 +57,27 @@ def _build_single_interrupt_subagent(): return graph.compile(checkpointer=InMemorySaver()) -def _make_runtime(config: dict) -> ToolRuntime: +def _make_runtime(config: dict, *, tool_call_id: str = "parent-tcid-1") -> ToolRuntime: return ToolRuntime( state={"messages": [HumanMessage(content="seed")]}, context=None, config=config, stream_writer=None, - tool_call_id="parent-tcid-1", + tool_call_id=tool_call_id, store=None, ) +def _prime_subagent_at_runtime_thread(subagent, runtime: ToolRuntime) -> dict: + """Build the per-call ``RunnableConfig`` the production ``task`` tool will use. + + Mirrors what the ``task`` tool does on first invocation so test fixtures + can prime the subagent's pending interrupt at the same checkpoint slot + (per-call ``thread_id``) the bridge looks at on resume. + """ + return subagent_invoke_config(runtime) + + @pytest.mark.asyncio async def test_resume_bridge_dispatches_decision_into_pending_subagent(): """Side-channel decision must reach the subagent's pending interrupt verbatim.""" @@ -79,16 +96,17 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent(): "configurable": {"thread_id": "shared-thread"}, "recursion_limit": 100, } - await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) - snap = await subagent.aget_state(parent_config) + runtime = _make_runtime(parent_config) + sub_config = _prime_subagent_at_runtime_thread(subagent, runtime) + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config) + snap = await subagent.aget_state(sub_config) assert snap.tasks and snap.tasks[0].interrupts, ( "fixture broken: subagent should be paused on its interrupt" ) parent_config["configurable"]["surfsense_resume_value"] = { - "decisions": ["APPROVED"] + runtime.tool_call_id: {"decisions": ["APPROVED"]} } - runtime = _make_runtime(parent_config) result = await task_tool.coroutine( description="please approve", @@ -101,7 +119,7 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent(): assert update["decision_text"] == repr({"decisions": ["APPROVED"]}) assert "surfsense_resume_value" not in parent_config["configurable"] - final = await subagent.aget_state(parent_config) + final = await subagent.aget_state(sub_config) assert not final.tasks or all(not t.interrupts for t in final.tasks) @@ -123,11 +141,11 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error(): "configurable": {"thread_id": "guard-thread"}, "recursion_limit": 100, } - await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) - snap = await subagent.aget_state(parent_config) - assert snap.tasks and snap.tasks[0].interrupts, "fixture broken" - runtime = _make_runtime(parent_config) + sub_config = _prime_subagent_at_runtime_thread(subagent, runtime) + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config) + snap = await subagent.aget_state(sub_config) + assert snap.tasks and snap.tasks[0].interrupts, "fixture broken" with pytest.raises(RuntimeError, match="resume bridge is broken"): await task_tool.coroutine( @@ -139,8 +157,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error(): def _build_bundle_subagent(): def bundle_node(state): - from langchain_core.messages import AIMessage - decision = interrupt( { "action_requests": [ @@ -181,7 +197,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order(): "configurable": {"thread_id": "bundle-thread"}, "recursion_limit": 100, } - await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config) + runtime = _make_runtime(parent_config) + sub_config = _prime_subagent_at_runtime_thread(subagent, runtime) + await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config) decisions_payload = { "decisions": [ @@ -190,8 +208,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order(): {"type": "reject", "args": {"message": "no thanks"}}, ] } - parent_config["configurable"]["surfsense_resume_value"] = decisions_payload - runtime = _make_runtime(parent_config) + parent_config["configurable"]["surfsense_resume_value"] = { + runtime.tool_call_id: decisions_payload + } result = await task_tool.coroutine( description="run bundle", @@ -206,3 +225,186 @@ async def test_bundle_three_mixed_decisions_arrive_in_order(): assert received["decisions"][1]["type"] == "edit" assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}} assert received["decisions"][2]["type"] == "reject" + + +@pytest.mark.asyncio +async def test_parallel_atask_routes_each_decision_to_its_own_subagent(): + """Two ``atask`` calls with distinct ``tool_call_id``s must each get their own decision. + + With per-call ``thread_id`` isolation and per-call resume keying, A's + decision must reach A's pending interrupt and B's must reach B's. They + must NOT cross-contaminate even though they share ``configurable``. + """ + subagent_a = _build_single_interrupt_subagent() + subagent_b = _build_single_interrupt_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "approver_a", + "description": "approves A", + "runnable": subagent_a, + }, + { + "name": "approver_b", + "description": "approves B", + "runnable": subagent_b, + }, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "parallel-thread"}, + "recursion_limit": 100, + } + + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B") + + sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a) + sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b) + + await subagent_a.ainvoke( + {"messages": [HumanMessage(content="seed-A")]}, sub_config_a + ) + await subagent_b.ainvoke( + {"messages": [HumanMessage(content="seed-B")]}, sub_config_b + ) + + parent_config["configurable"]["surfsense_resume_value"] = { + "tcid-A": {"decisions": ["DECISION-A"]}, + "tcid-B": {"decisions": ["DECISION-B"]}, + } + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="please approve A", + subagent_type="approver_a", + runtime=runtime_a, + ), + task_tool.coroutine( + description="please approve B", + subagent_type="approver_b", + runtime=runtime_b, + ), + ) + + assert isinstance(result_a, Command) + assert isinstance(result_b, Command) + assert result_a.update["decision_text"] == repr({"decisions": ["DECISION-A"]}) + assert result_b.update["decision_text"] == repr({"decisions": ["DECISION-B"]}) + + assert "surfsense_resume_value" not in parent_config["configurable"] + + +@pytest.mark.asyncio +async def test_full_resume_routing_glue_for_two_paused_subagents(): + """End-to-end: extractor + slicer + bridge correctly route a flat decisions list. + + This simulates exactly what ``stream_resume_chat`` will do on resume: + given a paused parent state with two pending interrupts (one per + subagent) and a flat ``decisions`` list, build the per-tool-call dict + via ``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call``, + then resume the bridge concurrently and verify each subagent received + only its own slice. + """ + subagent_a = _build_bundle_subagent() + subagent_b = _build_single_interrupt_subagent() + task_tool = build_task_tool_with_parent_config( + [ + { + "name": "bundler", + "description": "three-action bundle", + "runnable": subagent_a, + }, + { + "name": "approver", + "description": "single approval", + "runnable": subagent_b, + }, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "glue-thread"}, + "recursion_limit": 100, + } + + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-bundler") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-approver") + + sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a) + sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b) + + await subagent_a.ainvoke( + {"messages": [HumanMessage(content="seed-A")]}, sub_config_a + ) + await subagent_b.ainvoke( + {"messages": [HumanMessage(content="seed-B")]}, sub_config_b + ) + + # Synthetic parent state mirroring what the parent's pregel would have + # bundled: one Interrupt per subagent, value carrying tool_call_id + + # action_requests (exactly the shape ``propagation.wrap_with_tool_call_id`` + # produces). + parent_interrupts = ( + SimpleNamespace( + id="i-bundler", + value={ + "action_requests": [ + {"name": "create_a", "args": {}, "description": ""}, + {"name": "create_b", "args": {}, "description": ""}, + {"name": "create_c", "args": {}, "description": ""}, + ], + "review_configs": [{}, {}, {}], + "tool_call_id": "tcid-bundler", + }, + ), + SimpleNamespace( + id="i-approver", + value={ + "action_requests": [ + {"name": "approve", "args": {}, "description": ""} + ], + "review_configs": [{}], + "tool_call_id": "tcid-approver", + }, + ), + ) + parent_state = SimpleNamespace(interrupts=parent_interrupts) + + flat_decisions = [ + {"type": "approve"}, + {"type": "edit", "args": {"args": {"name": "edited-b"}}}, + {"type": "reject", "args": {"message": "no thanks"}}, + {"type": "approve"}, + ] + + pending = collect_pending_tool_calls(parent_state) + assert pending == [("tcid-bundler", 3), ("tcid-approver", 1)] + + routed = slice_decisions_by_tool_call(flat_decisions, pending) + parent_config["configurable"]["surfsense_resume_value"] = routed + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="run bundle", + subagent_type="bundler", + runtime=runtime_a, + ), + task_tool.coroutine( + description="please approve", + subagent_type="approver", + runtime=runtime_b, + ), + ) + + assert isinstance(result_a, Command) + assert isinstance(result_b, Command) + + received_a = ast.literal_eval(result_a.update["decision_text"]) + assert received_a == {"decisions": flat_decisions[0:3]} + assert result_b.update["decision_text"] == repr( + {"decisions": flat_decisions[3:4]} + ) + + assert "surfsense_resume_value" not in parent_config["configurable"] diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py new file mode 100644 index 000000000..9c067fc57 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_tasks.py @@ -0,0 +1,222 @@ +"""Behavioural guarantees for parallel ``task`` tool calls (non-HITL cases). + +The HITL bridge tests in ``test_hitl_bridge.py`` cover the parallel-interrupt +flow. This file covers the *normal* parallel paths (no interrupts) and the +failure-isolation guarantee — together they pin the behaviour we promise the +user about ``asyncio.gather`` over two ``atask`` coroutines. +""" + +from __future__ import annotations + +import asyncio + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: list + + +def _build_success_subagent(reply: str): + """A subagent that completes immediately with ``reply``, never interrupts.""" + + def node(_state): + return {"messages": [AIMessage(content=reply)]} + + g = StateGraph(_SubState) + g.add_node("only", node) + g.add_edge(START, "only") + g.add_edge("only", END) + return g.compile(checkpointer=InMemorySaver()) + + +def _build_failing_subagent(exc: Exception): + """A subagent whose only node raises ``exc`` — simulates a tool-level failure.""" + + def node(_state): + raise exc + + g = StateGraph(_SubState) + g.add_node("only", node) + g.add_edge(START, "only") + g.add_edge("only", END) + return g.compile(checkpointer=InMemorySaver()) + + +def _make_runtime(parent_config: dict, *, tool_call_id: str) -> ToolRuntime: + return ToolRuntime( + state={"messages": [HumanMessage(content="seed")]}, + context=None, + config=parent_config, + stream_writer=None, + tool_call_id=tool_call_id, + store=None, + ) + + +def _tool_message_text(cmd: Command, *, expected_tcid: str) -> str: + """Return the ToolMessage content the task tool produced for ``expected_tcid``.""" + assert isinstance(cmd, Command), f"expected Command, got {type(cmd).__name__}" + messages = cmd.update["messages"] + assert len(messages) == 1, f"expected 1 ToolMessage, got {len(messages)}" + msg = messages[0] + assert isinstance(msg, ToolMessage) + assert msg.tool_call_id == expected_tcid + return msg.content + + +@pytest.mark.asyncio +async def test_two_parallel_atasks_to_different_subagents_both_succeed(): + """Normal happy-path: two distinct subagents complete in parallel without interrupting.""" + subagent_a = _build_success_subagent("A is done") + subagent_b = _build_success_subagent("B is done") + task_tool = build_task_tool_with_parent_config( + [ + {"name": "alpha", "description": "alpha agent", "runnable": subagent_a}, + {"name": "beta", "description": "beta agent", "runnable": subagent_b}, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "ok-thread"}, + "recursion_limit": 100, + } + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B") + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="do A", + subagent_type="alpha", + runtime=runtime_a, + ), + task_tool.coroutine( + description="do B", + subagent_type="beta", + runtime=runtime_b, + ), + ) + + assert _tool_message_text(result_a, expected_tcid="tcid-A") == "A is done" + assert _tool_message_text(result_b, expected_tcid="tcid-B") == "B is done" + + +@pytest.mark.asyncio +async def test_two_parallel_atasks_same_subagent_type_different_tool_call_ids(): + """Per-call ``thread_id`` isolation: same compiled subagent invoked twice in parallel. + + Both calls share the same ``InMemorySaver`` instance but are namespaced by + distinct ``tool_call_id``s, so checkpoints land in disjoint thread slots. + """ + shared_subagent = _build_success_subagent("ok") + task_tool = build_task_tool_with_parent_config( + [ + {"name": "approver", "description": "shared approver", "runnable": shared_subagent}, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "shared-subagent-thread"}, + "recursion_limit": 100, + } + runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A") + runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B") + + result_a, result_b = await asyncio.gather( + task_tool.coroutine( + description="first request", + subagent_type="approver", + runtime=runtime_a, + ), + task_tool.coroutine( + description="second request", + subagent_type="approver", + runtime=runtime_b, + ), + ) + + # Both calls succeed and produce ToolMessages keyed by their own tool_call_id. + assert _tool_message_text(result_a, expected_tcid="tcid-A") == "ok" + assert _tool_message_text(result_b, expected_tcid="tcid-B") == "ok" + + # Verify checkpoint isolation: each call's state lives at its own thread_id. + state_a = await shared_subagent.aget_state( + {"configurable": {"thread_id": "shared-subagent-thread::task:tcid-A"}} + ) + state_b = await shared_subagent.aget_state( + {"configurable": {"thread_id": "shared-subagent-thread::task:tcid-B"}} + ) + assert state_a.values["messages"][-1].content == "ok" + assert state_b.values["messages"][-1].content == "ok" + + # The parent's own thread_id slot is untouched by either subagent. + state_parent = await shared_subagent.aget_state( + {"configurable": {"thread_id": "shared-subagent-thread"}} + ) + assert state_parent.values == {} or state_parent.values.get("messages") in (None, []) + + +@pytest.mark.asyncio +async def test_one_atask_failure_does_not_corrupt_sibling_atask(): + """Failure isolation: a sibling's exception must not poison the surviving atask's state. + + Note: in production, langgraph's pregel runner cancels siblings when any + parallel task raises a non-``GraphBubbleUp`` exception (see + ``_should_stop_others`` in ``langgraph/pregel/_runner.py``). At our layer + that policy is invisible — what we *can* guarantee is that the two atask + coroutines have disjoint state, so the surviving one returns a valid + Command even when its sibling explodes. + """ + failing_subagent = _build_failing_subagent(ValueError("boom")) + surviving_subagent = _build_success_subagent("still here") + task_tool = build_task_tool_with_parent_config( + [ + {"name": "broken", "description": "always fails", "runnable": failing_subagent}, + {"name": "healthy", "description": "always succeeds", "runnable": surviving_subagent}, + ] + ) + + parent_config: dict = { + "configurable": {"thread_id": "iso-thread"}, + "recursion_limit": 100, + } + runtime_fail = _make_runtime(parent_config, tool_call_id="tcid-fail") + runtime_ok = _make_runtime(parent_config, tool_call_id="tcid-ok") + + results = await asyncio.gather( + task_tool.coroutine( + description="will explode", + subagent_type="broken", + runtime=runtime_fail, + ), + task_tool.coroutine( + description="will work", + subagent_type="healthy", + runtime=runtime_ok, + ), + return_exceptions=True, + ) + + fail_result, ok_result = results + + assert isinstance(fail_result, Exception), ( + f"expected the broken subagent to raise, got {fail_result!r}" + ) + # ValueError gets wrapped in langgraph's internal exception types — the + # important guarantee is "this path errored", not the specific class. + assert "boom" in str(fail_result) or isinstance(fail_result, ValueError) + + assert _tool_message_text(ok_result, expected_tcid="tcid-ok") == "still here" + + # Configurable side-channel must not have been corrupted by the failure. + assert "surfsense_resume_value" not in parent_config["configurable"]