mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/backend-tests
This commit is contained in:
commit
8de7d86d56
603 changed files with 45074 additions and 4695 deletions
|
|
@ -3,15 +3,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
subagent_invoke_config,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
|
@ -24,8 +33,6 @@ class _SubagentState(TypedDict, total=False):
|
|||
|
||||
def _build_single_interrupt_subagent():
|
||||
def approve_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
|
|
@ -50,17 +57,27 @@ def _build_single_interrupt_subagent():
|
|||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
def _make_runtime(config: dict) -> ToolRuntime:
|
||||
def _make_runtime(config: dict, *, tool_call_id: str = "parent-tcid-1") -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state={"messages": [HumanMessage(content="seed")]},
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="parent-tcid-1",
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
def _prime_subagent_at_runtime_thread(subagent, runtime: ToolRuntime) -> dict:
|
||||
"""Build the per-call ``RunnableConfig`` the production ``task`` tool will use.
|
||||
|
||||
Mirrors what the ``task`` tool does on first invocation so test fixtures
|
||||
can prime the subagent's pending interrupt at the same checkpoint slot
|
||||
(per-call ``thread_id``) the bridge looks at on resume.
|
||||
"""
|
||||
return subagent_invoke_config(runtime)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
|
||||
|
|
@ -79,16 +96,17 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
|||
"configurable": {"thread_id": "shared-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
runtime = _make_runtime(parent_config)
|
||||
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
|
||||
snap = await subagent.aget_state(sub_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, (
|
||||
"fixture broken: subagent should be paused on its interrupt"
|
||||
)
|
||||
|
||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||
"decisions": ["APPROVED"]
|
||||
runtime.tool_call_id: {"decisions": ["APPROVED"]}
|
||||
}
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="please approve",
|
||||
|
|
@ -101,7 +119,7 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
|||
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
||||
final = await subagent.aget_state(parent_config)
|
||||
final = await subagent.aget_state(sub_config)
|
||||
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
||||
|
||||
|
||||
|
|
@ -123,11 +141,11 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
|||
"configurable": {"thread_id": "guard-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||
|
||||
runtime = _make_runtime(parent_config)
|
||||
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
|
||||
snap = await subagent.aget_state(sub_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||
|
||||
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
||||
await task_tool.coroutine(
|
||||
|
|
@ -139,8 +157,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
|||
|
||||
def _build_bundle_subagent():
|
||||
def bundle_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
|
|
@ -181,7 +197,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
|
|||
"configurable": {"thread_id": "bundle-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
runtime = _make_runtime(parent_config)
|
||||
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
|
||||
|
||||
decisions_payload = {
|
||||
"decisions": [
|
||||
|
|
@ -190,8 +208,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
|
|||
{"type": "reject", "args": {"message": "no thanks"}},
|
||||
]
|
||||
}
|
||||
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
|
||||
runtime = _make_runtime(parent_config)
|
||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||
runtime.tool_call_id: decisions_payload
|
||||
}
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="run bundle",
|
||||
|
|
@ -206,3 +225,182 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
|
|||
assert received["decisions"][1]["type"] == "edit"
|
||||
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
||||
assert received["decisions"][2]["type"] == "reject"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_atask_routes_each_decision_to_its_own_subagent():
|
||||
"""Two ``atask`` calls with distinct ``tool_call_id``s must each get their own decision.
|
||||
|
||||
With per-call ``thread_id`` isolation and per-call resume keying, A's
|
||||
decision must reach A's pending interrupt and B's must reach B's. They
|
||||
must NOT cross-contaminate even though they share ``configurable``.
|
||||
"""
|
||||
subagent_a = _build_single_interrupt_subagent()
|
||||
subagent_b = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver_a",
|
||||
"description": "approves A",
|
||||
"runnable": subagent_a,
|
||||
},
|
||||
{
|
||||
"name": "approver_b",
|
||||
"description": "approves B",
|
||||
"runnable": subagent_b,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "parallel-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
|
||||
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
|
||||
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
|
||||
|
||||
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
|
||||
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
|
||||
|
||||
await subagent_a.ainvoke(
|
||||
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
|
||||
)
|
||||
await subagent_b.ainvoke(
|
||||
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
|
||||
)
|
||||
|
||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||
"tcid-A": {"decisions": ["DECISION-A"]},
|
||||
"tcid-B": {"decisions": ["DECISION-B"]},
|
||||
}
|
||||
|
||||
result_a, result_b = await asyncio.gather(
|
||||
task_tool.coroutine(
|
||||
description="please approve A",
|
||||
subagent_type="approver_a",
|
||||
runtime=runtime_a,
|
||||
),
|
||||
task_tool.coroutine(
|
||||
description="please approve B",
|
||||
subagent_type="approver_b",
|
||||
runtime=runtime_b,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(result_a, Command)
|
||||
assert isinstance(result_b, Command)
|
||||
assert result_a.update["decision_text"] == repr({"decisions": ["DECISION-A"]})
|
||||
assert result_b.update["decision_text"] == repr({"decisions": ["DECISION-B"]})
|
||||
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_resume_routing_glue_for_two_paused_subagents():
|
||||
"""End-to-end: extractor + slicer + bridge correctly route a flat decisions list.
|
||||
|
||||
This simulates exactly what ``stream_resume_chat`` will do on resume:
|
||||
given a paused parent state with two pending interrupts (one per
|
||||
subagent) and a flat ``decisions`` list, build the per-tool-call dict
|
||||
via ``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call``,
|
||||
then resume the bridge concurrently and verify each subagent received
|
||||
only its own slice.
|
||||
"""
|
||||
subagent_a = _build_bundle_subagent()
|
||||
subagent_b = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "bundler",
|
||||
"description": "three-action bundle",
|
||||
"runnable": subagent_a,
|
||||
},
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "single approval",
|
||||
"runnable": subagent_b,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "glue-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
|
||||
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-bundler")
|
||||
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-approver")
|
||||
|
||||
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
|
||||
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
|
||||
|
||||
await subagent_a.ainvoke(
|
||||
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
|
||||
)
|
||||
await subagent_b.ainvoke(
|
||||
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
|
||||
)
|
||||
|
||||
# Synthetic parent state mirroring what the parent's pregel would have
|
||||
# bundled: one Interrupt per subagent, value carrying tool_call_id +
|
||||
# action_requests (exactly the shape ``propagation.wrap_with_tool_call_id``
|
||||
# produces).
|
||||
parent_interrupts = (
|
||||
SimpleNamespace(
|
||||
id="i-bundler",
|
||||
value={
|
||||
"action_requests": [
|
||||
{"name": "create_a", "args": {}, "description": ""},
|
||||
{"name": "create_b", "args": {}, "description": ""},
|
||||
{"name": "create_c", "args": {}, "description": ""},
|
||||
],
|
||||
"review_configs": [{}, {}, {}],
|
||||
"tool_call_id": "tcid-bundler",
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
id="i-approver",
|
||||
value={
|
||||
"action_requests": [{"name": "approve", "args": {}, "description": ""}],
|
||||
"review_configs": [{}],
|
||||
"tool_call_id": "tcid-approver",
|
||||
},
|
||||
),
|
||||
)
|
||||
parent_state = SimpleNamespace(interrupts=parent_interrupts)
|
||||
|
||||
flat_decisions = [
|
||||
{"type": "approve"},
|
||||
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
|
||||
{"type": "reject", "args": {"message": "no thanks"}},
|
||||
{"type": "approve"},
|
||||
]
|
||||
|
||||
pending = collect_pending_tool_calls(parent_state)
|
||||
assert pending == [("tcid-bundler", 3), ("tcid-approver", 1)]
|
||||
|
||||
routed = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
parent_config["configurable"]["surfsense_resume_value"] = routed
|
||||
|
||||
result_a, result_b = await asyncio.gather(
|
||||
task_tool.coroutine(
|
||||
description="run bundle",
|
||||
subagent_type="bundler",
|
||||
runtime=runtime_a,
|
||||
),
|
||||
task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime_b,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(result_a, Command)
|
||||
assert isinstance(result_b, Command)
|
||||
|
||||
received_a = ast.literal_eval(result_a.update["decision_text"])
|
||||
assert received_a == {"decisions": flat_decisions[0:3]}
|
||||
assert result_b.update["decision_text"] == repr({"decisions": flat_decisions[3:4]})
|
||||
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,259 @@
|
|||
"""Real-graph contract: heterogeneous decisions route correctly across parallel subagents.
|
||||
|
||||
The simple "approve everything" parallel test (see
|
||||
``test_parallel_resume_command_keying``) proves the routing wires up at all,
|
||||
but it doesn't exercise the actual production user flow: rejecting one card
|
||||
while approving another, or editing one action's args before approving the
|
||||
rest. Those are the decisions ``HumanInTheLoopMiddleware`` differentiates on,
|
||||
and they're exactly where a slicer/router bug silently mis-applies a reject
|
||||
to the wrong subagent.
|
||||
|
||||
This module pins:
|
||||
|
||||
1. **Order preservation** across the slice boundary — flat decisions enter
|
||||
in the order the SSE stream rendered cards; each subagent must receive
|
||||
only its slice in the original order.
|
||||
2. **Per-decision metadata pass-through** — ``message`` and ``edited_action``
|
||||
payloads must reach the subagent intact (not just the ``type`` discriminator).
|
||||
3. **Off-by-one-sensitive bundle sizes** — both paused subagents have action
|
||||
counts ``> 1`` (``2`` and ``3``). With those sizes a buggy
|
||||
``cursor += 1`` slicer (instead of ``cursor += action_count``) produces a
|
||||
different B-slice from the correct one, so this test catches the most
|
||||
common refactor mistake. A ``(1, 2)`` configuration would silently pass
|
||||
such a bug because ``+= 1`` and ``+= count`` are arithmetically identical
|
||||
when ``count == 1``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_capturing_subagent(checkpointer: InMemorySaver, *, action_count: int):
|
||||
"""Subagent that pauses with an N-action bundle and on resume records what it received.
|
||||
|
||||
The recorded ``AIMessage`` content is the JSON-serialized resume payload, so
|
||||
the assertions can inspect exactly which decisions reached this subagent
|
||||
(vs. its sibling) — including the ``message`` and ``edited_action``
|
||||
metadata, not just the ``type``.
|
||||
"""
|
||||
|
||||
def hitl_node(_state):
|
||||
decision_payload = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{
|
||||
"name": f"act_{i}",
|
||||
"args": {"i": i},
|
||||
"description": f"action {i}",
|
||||
}
|
||||
for i in range(action_count)
|
||||
],
|
||||
"review_configs": [
|
||||
{
|
||||
"action_name": f"act_{i}",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
for i in range(action_count)
|
||||
],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(content=json.dumps(decision_payload, sort_keys=True))
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("hitl", hitl_node)
|
||||
g.add_edge(START, "hitl")
|
||||
g.add_edge("hitl", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_dispatching_two_subagents(
|
||||
task_tool, *, dispatches: list[dict[str, str]], checkpointer
|
||||
):
|
||||
"""Parent that fans out to ``len(dispatches)`` parallel ``task`` tool calls.
|
||||
|
||||
Each entry in ``dispatches`` is ``{"tcid": ..., "subtype": ..., "desc": ...}``
|
||||
so different parallel branches can target different subagent types — the
|
||||
actual production scenario (Linear + Jira, etc.).
|
||||
"""
|
||||
|
||||
def fanout(_state) -> list[Send]:
|
||||
return [Send("call_task", d) for d in dispatches]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_intact():
|
||||
"""Mixed approve/reject/edit decisions across two parallel subagents.
|
||||
|
||||
Setup chosen so the slicer's cursor arithmetic is sensitive to off-by-one
|
||||
refactors:
|
||||
- Sub-A pauses with a 2-action bundle (``act_0``, ``act_1``).
|
||||
- Sub-B pauses with a 3-action bundle (``act_0``, ``act_1``, ``act_2``).
|
||||
- Parent ends up with 2 pending interrupts (one per subagent).
|
||||
|
||||
With both counts ``> 1``, a buggy ``cursor += 1`` (instead of
|
||||
``cursor += action_count``) produces a different B-slice from the correct
|
||||
one, so the assertions catch it. A ``(1, 2)`` configuration would not
|
||||
because ``+= 1`` and ``+= count`` are arithmetically identical when
|
||||
``count == 1``.
|
||||
|
||||
The frontend submits a flat
|
||||
``[A_approve, A_reject, B_edit, B_approve, B_reject]`` list with distinct
|
||||
``message`` and ``edited_action`` payloads; our slicer must split into
|
||||
``{tcid_A: [A_approve, A_reject], tcid_B: [B_edit, B_approve, B_reject]}``
|
||||
and the bridge must forward each subagent's slice intact — including all
|
||||
metadata, in original order.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
sub_a = _build_capturing_subagent(checkpointer, action_count=2)
|
||||
sub_b = _build_capturing_subagent(checkpointer, action_count=3)
|
||||
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "agent-a", "description": "first", "runnable": sub_a},
|
||||
{"name": "agent-b", "description": "second", "runnable": sub_b},
|
||||
]
|
||||
)
|
||||
|
||||
parent = _parent_dispatching_two_subagents(
|
||||
task_tool,
|
||||
dispatches=[
|
||||
{"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"},
|
||||
{"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"},
|
||||
],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "het-decisions-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused_state = await parent.aget_state(config)
|
||||
assert len(paused_state.interrupts) == 2, (
|
||||
f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}"
|
||||
)
|
||||
|
||||
pending = collect_pending_tool_calls(paused_state)
|
||||
pending_by_tcid = dict(pending)
|
||||
assert pending_by_tcid == {"tcid-A": 2, "tcid-B": 3}, (
|
||||
f"REGRESSION: action-count accounting wrong; got {pending_by_tcid!r}"
|
||||
)
|
||||
|
||||
a_approve = {"type": "approve"}
|
||||
a_reject = {"type": "reject", "message": "A[1] looks redundant"}
|
||||
b_edit = {
|
||||
"type": "edit",
|
||||
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
|
||||
}
|
||||
b_approve = {"type": "approve"}
|
||||
b_reject = {"type": "reject", "message": "B[2] needs more context"}
|
||||
flat_decisions = [a_approve, a_reject, b_edit, b_approve, b_reject]
|
||||
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
|
||||
assert by_tool_call_id == {
|
||||
"tcid-A": {"decisions": [a_approve, a_reject]},
|
||||
"tcid-B": {"decisions": [b_edit, b_approve, b_reject]},
|
||||
}, f"REGRESSION: slicer mis-routed decisions: {by_tool_call_id!r}"
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
||||
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final_state = await parent.aget_state(config)
|
||||
assert not final_state.interrupts, (
|
||||
f"REGRESSION: leftover pending interrupts after resume: {final_state.interrupts!r}"
|
||||
)
|
||||
|
||||
payloads: list[dict] = []
|
||||
for msg in final_state.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
payloads.append(json.loads(content))
|
||||
|
||||
expected_a = {"decisions": [a_approve, a_reject]}
|
||||
expected_b = {"decisions": [b_edit, b_approve, b_reject]}
|
||||
|
||||
assert expected_a in payloads, (
|
||||
f"REGRESSION: sub-A did not receive its 2-decision slice in original order; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
assert expected_b in payloads, (
|
||||
f"REGRESSION: sub-B did not receive its 3-decision slice in original order; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decision_count_mismatch_fails_loud_before_dispatch():
|
||||
"""The slicer must refuse a flat list whose total != sum(action_counts).
|
||||
|
||||
Otherwise a frontend/backend contract drift would silently send a
|
||||
truncated/padded slice to one of the subagents — the worst possible
|
||||
failure mode (mis-applied reject on a long-lived ticket).
|
||||
"""
|
||||
pending = [("tcid-A", 1), ("tcid-B", 2)]
|
||||
decisions = [{"type": "approve"}, {"type": "approve"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Decision count mismatch"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
"""Real-graph contract: one parallel branch completes while a sibling pauses with HITL.
|
||||
|
||||
The two existing parallel-routing tests
|
||||
(``test_parallel_resume_command_keying`` and
|
||||
``test_parallel_heterogeneous_decisions``) both pause **all** branches
|
||||
simultaneously. That's the easy case — every dispatched ``task`` call has a
|
||||
matching pending interrupt, and the routing helpers see a uniform shape.
|
||||
|
||||
Production rarely matches that uniform shape. The orchestrator typically
|
||||
delegates "create a Linear ticket and summarize the user's recent activity":
|
||||
one branch needs HITL, the other returns its result and exits. At the pause
|
||||
moment::
|
||||
|
||||
state.values["messages"] += [ToolMessage(from-A)] # A merged in
|
||||
state.interrupts = [Interrupt(value-from-B)] # B alone is pending
|
||||
|
||||
So ``len(state.interrupts) < num_dispatched_tasks``. The slicer and
|
||||
``build_lg_resume_map`` must:
|
||||
|
||||
1. **Key off ``state.interrupts``, never off the originally dispatched tcids.**
|
||||
A flat decisions list of length 1 must route only to B; if anything tries
|
||||
to look up A in the resume map, langgraph rejects an unknown
|
||||
``Interrupt.id``.
|
||||
2. **Leave A's contributions intact across resume.** A's ToolMessage was
|
||||
committed at the pause; resuming the paused branch must not re-run A nor
|
||||
drop its message.
|
||||
3. **Drain the single pending interrupt.** Final ``state.interrupts`` is
|
||||
empty regardless of whether sibling branches were paused.
|
||||
|
||||
The langgraph semantics this test relies on were verified empirically in the
|
||||
exploratory probe before this test was authored.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
_QUICK_MARKER = "quick-subagent-finished-without-pausing"
|
||||
|
||||
|
||||
def _build_quick_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that completes synchronously without firing any interrupt."""
|
||||
|
||||
def quick_node(_state):
|
||||
return {"messages": [AIMessage(content=_QUICK_MARKER)]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("quick", quick_node)
|
||||
g.add_edge(START, "quick")
|
||||
g.add_edge("quick", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_pausing_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses with a single-action HITL bundle and records its resume payload."""
|
||||
|
||||
def hitl_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "act_0", "args": {"i": 0}, "description": ""}
|
||||
],
|
||||
"review_configs": [
|
||||
{
|
||||
"action_name": "act_0",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("hitl", hitl_node)
|
||||
g.add_edge(START, "hitl")
|
||||
g.add_edge("hitl", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_with_two_branches(task_tool, *, dispatches, checkpointer):
|
||||
def fanout(_state) -> list[Send]:
|
||||
return [Send("call_task", d) for d in dispatches]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _quick_marker_count(state) -> int:
|
||||
"""How many messages anywhere in parent state contain the quick subagent's marker."""
|
||||
n = 0
|
||||
for msg in state.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str) and _QUICK_MARKER in content:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_pause_routes_only_to_paused_branch_without_rerunning_completed_one():
|
||||
"""One branch completes synchronously; the other pauses with HITL — resume routes only to B.
|
||||
|
||||
Setup:
|
||||
- Sub-A (``quick``): no interrupt, finishes immediately, writes a marker
|
||||
message to parent state.
|
||||
- Sub-B (``pausing``): interrupts with a 1-action HITL bundle.
|
||||
|
||||
At pause, parent state has A's marker already merged in and exactly one
|
||||
pending interrupt (B's). Resume sends a 1-element flat decisions list;
|
||||
the routing helpers must not look up A in the resume map (would explode
|
||||
with an unknown ``Interrupt.id``) and must not re-invoke A on resume
|
||||
(would duplicate the marker).
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
quick_sub = _build_quick_subagent(checkpointer)
|
||||
pausing_sub = _build_pausing_subagent(checkpointer)
|
||||
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "quick-agent", "description": "instant", "runnable": quick_sub},
|
||||
{
|
||||
"name": "pausing-agent",
|
||||
"description": "needs review",
|
||||
"runnable": pausing_sub,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
parent = _parent_with_two_branches(
|
||||
task_tool,
|
||||
dispatches=[
|
||||
{"tcid": "tcid-A", "subtype": "quick-agent", "desc": "do A fast"},
|
||||
{
|
||||
"tcid": "tcid-B",
|
||||
"subtype": "pausing-agent",
|
||||
"desc": "needs approval",
|
||||
},
|
||||
],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "partial-pause-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused = await parent.aget_state(config)
|
||||
|
||||
assert len(paused.interrupts) == 1, (
|
||||
f"REGRESSION: expected exactly 1 pending interrupt (sub-B alone), "
|
||||
f"got {len(paused.interrupts)}"
|
||||
)
|
||||
|
||||
pending = collect_pending_tool_calls(paused)
|
||||
assert pending == [("tcid-B", 1)], (
|
||||
f"REGRESSION: pending list contains stale tcids; got {pending!r}"
|
||||
)
|
||||
|
||||
pre_resume_marker_count = _quick_marker_count(paused)
|
||||
assert pre_resume_marker_count == 1, (
|
||||
f"REGRESSION: sub-A's contribution missing or duplicated at pause "
|
||||
f"(found {pre_resume_marker_count}, expected 1)"
|
||||
)
|
||||
|
||||
flat_decisions = [{"type": "approve"}]
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
assert by_tool_call_id == {"tcid-B": {"decisions": [{"type": "approve"}]}}, (
|
||||
f"REGRESSION: slicer routed to a non-pending tcid: {by_tool_call_id!r}"
|
||||
)
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
|
||||
|
||||
assert set(lg_resume_map.keys()) == {paused.interrupts[0].id}, (
|
||||
f"REGRESSION: resume map keyed by an unknown Interrupt.id "
|
||||
f"(would crash langgraph): {lg_resume_map!r}"
|
||||
)
|
||||
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final = await parent.aget_state(config)
|
||||
assert not final.interrupts, (
|
||||
f"REGRESSION: pending interrupts after resume: {final.interrupts!r}"
|
||||
)
|
||||
|
||||
post_resume_marker_count = _quick_marker_count(final)
|
||||
assert post_resume_marker_count == 1, (
|
||||
f"REGRESSION: sub-A re-ran on resume (marker count went "
|
||||
f"{pre_resume_marker_count} → {post_resume_marker_count}); "
|
||||
f"resume must touch only the paused branch."
|
||||
)
|
||||
|
||||
payloads: list[dict] = []
|
||||
for msg in final.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
payloads.append(json.loads(content))
|
||||
|
||||
assert {"decisions": [{"type": "approve"}]} in payloads, (
|
||||
f"REGRESSION: sub-B did not receive its single approve on resume; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
"""Real-graph contract: all-reject decisions route correctly across parallel subagents.
|
||||
|
||||
Heterogeneous routing is covered by ``test_parallel_heterogeneous_decisions``.
|
||||
This module pins the narrower edge case where **every** card on **every**
|
||||
paused subagent is rejected.
|
||||
|
||||
Why a separate pin:
|
||||
|
||||
1. **No approval-bias in the slicer.** A future "if no approvals, short-circuit
|
||||
resume" optimization would be tempting (skips a langgraph round-trip) and
|
||||
would also silently break this scenario. Pin it.
|
||||
2. **``message`` metadata pass-through across a run of rejects.** The reject
|
||||
``message`` is the user-visible reason ("looks suspicious", "duplicate",
|
||||
etc.). Losing it would silently swallow user intent — the worst UX
|
||||
failure mode for HITL. Heterogeneous covers one reject; here we verify a
|
||||
sequence of rejects survives the slicer + bridge with distinct messages
|
||||
intact and in order.
|
||||
3. **All branches complete with no leftover pending.** Even when nothing was
|
||||
approved, the parent must drain every paused subagent so the SSE stream
|
||||
can close cleanly. A bug that left one ``Interrupt.id`` un-keyed would
|
||||
strand the conversation in "pending" forever.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_recording_subagent(checkpointer: InMemorySaver, *, action_count: int):
|
||||
"""Subagent that pauses with ``action_count`` actions and records its resume payload.
|
||||
|
||||
The recorded ``AIMessage`` content is the JSON-serialized payload, so the
|
||||
test can match each subagent's slice by content.
|
||||
"""
|
||||
|
||||
def hitl_node(_state):
|
||||
decision_payload = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": f"act_{i}", "args": {"i": i}, "description": ""}
|
||||
for i in range(action_count)
|
||||
],
|
||||
"review_configs": [
|
||||
{
|
||||
"action_name": f"act_{i}",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
for i in range(action_count)
|
||||
],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(content=json.dumps(decision_payload, sort_keys=True))
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("hitl", hitl_node)
|
||||
g.add_edge(START, "hitl")
|
||||
g.add_edge("hitl", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_two_branches(task_tool, *, dispatches, checkpointer):
|
||||
def fanout(_state) -> list[Send]:
|
||||
return [Send("call_task", d) for d in dispatches]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_reject_decisions_route_to_each_subagent_with_messages_intact():
|
||||
"""All cards rejected across two parallel subagents — order and messages preserved.
|
||||
|
||||
Setup mirrors a real "user reviews two parallel ticket creations and
|
||||
rejects everything with distinct reasons":
|
||||
|
||||
- Sub-A pauses with 2 actions.
|
||||
- Sub-B pauses with 1 action.
|
||||
- Flat decisions: 3 rejects, each with a unique ``message``.
|
||||
|
||||
Asserts each subagent receives only its slice, in original order,
|
||||
with every ``message`` intact and no ``edited_action`` fields fabricated.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
sub_a = _build_recording_subagent(checkpointer, action_count=2)
|
||||
sub_b = _build_recording_subagent(checkpointer, action_count=1)
|
||||
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "agent-a", "description": "first", "runnable": sub_a},
|
||||
{"name": "agent-b", "description": "second", "runnable": sub_b},
|
||||
]
|
||||
)
|
||||
|
||||
parent = _parent_two_branches(
|
||||
task_tool,
|
||||
dispatches=[
|
||||
{"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"},
|
||||
{"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"},
|
||||
],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "all-reject-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused_state = await parent.aget_state(config)
|
||||
assert len(paused_state.interrupts) == 2, (
|
||||
f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}"
|
||||
)
|
||||
|
||||
a_reject_0 = {"type": "reject", "message": "A[0] looks suspicious"}
|
||||
a_reject_1 = {"type": "reject", "message": "A[1] duplicates A[0]"}
|
||||
b_reject_0 = {"type": "reject", "message": "B[0] needs more context"}
|
||||
flat_decisions = [a_reject_0, a_reject_1, b_reject_0]
|
||||
|
||||
pending = collect_pending_tool_calls(paused_state)
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
|
||||
assert by_tool_call_id == {
|
||||
"tcid-A": {"decisions": [a_reject_0, a_reject_1]},
|
||||
"tcid-B": {"decisions": [b_reject_0]},
|
||||
}, f"REGRESSION: slicer mis-routed all-reject decisions: {by_tool_call_id!r}"
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
||||
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final_state = await parent.aget_state(config)
|
||||
assert not final_state.interrupts, (
|
||||
f"REGRESSION: leftover pending interrupts after all-reject resume: "
|
||||
f"{final_state.interrupts!r}"
|
||||
)
|
||||
|
||||
payloads: list[dict] = []
|
||||
for msg in final_state.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
payloads.append(json.loads(content))
|
||||
|
||||
expected_a = {"decisions": [a_reject_0, a_reject_1]}
|
||||
expected_b = {"decisions": [b_reject_0]}
|
||||
|
||||
assert expected_a in payloads, (
|
||||
f"REGRESSION: sub-A did not receive its 2-reject slice in order; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
assert expected_b in payloads, (
|
||||
f"REGRESSION: sub-B did not receive its single reject; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
|
||||
for p in payloads:
|
||||
for d in p.get("decisions", []):
|
||||
assert "edited_action" not in d, (
|
||||
f"REGRESSION: spurious ``edited_action`` on a reject — "
|
||||
f"slicer/bridge mutated payload: {d!r}"
|
||||
)
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
"""Real-graph contract: parallel resume must key ``Command(resume=...)`` by ``Interrupt.id``.
|
||||
|
||||
When the parent state has multiple pending interrupts, langgraph rejects a
|
||||
scalar ``Command(resume=v)`` with::
|
||||
|
||||
RuntimeError: When there are multiple pending interrupts, you must specify
|
||||
the interrupt id when resuming.
|
||||
|
||||
The fix is to map each ``Interrupt.id`` from ``state.interrupts`` to the
|
||||
per-subagent slice — orthogonal to our ``tool_call_id``-keyed
|
||||
``surfsense_resume_value`` side-channel (different consumer: langgraph's
|
||||
pregel vs. our subagent bridge).
|
||||
|
||||
This test reproduces the production failure with a real two-task parallel
|
||||
``Send`` parent graph, exercises the full resume cycle, and asserts both
|
||||
subagents complete cleanly with their per-subagent slice intact.
|
||||
|
||||
Bundle sizes are chosen heterogeneous (``2`` and ``3``) so the assertions
|
||||
also catch slicer arithmetic regressions (e.g., ``cursor += 1`` instead of
|
||||
``cursor += action_count``). A symmetric ``(1, 1)`` configuration would
|
||||
silently pass such a bug because the slices would coincide.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
# ``add_messages`` reducer matches production agent state shape and is
|
||||
# required when two parallel ``Send`` branches both write to ``messages``
|
||||
# in the same superstep (post-resume both subagents return their own
|
||||
# ``{"messages": [...]}``). Without a reducer langgraph raises
|
||||
# ``InvalidUpdateError: At key 'messages': Can receive only one value``.
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_pausing_subagent(checkpointer: InMemorySaver, *, action_count: int):
|
||||
"""Subagent that pauses with an ``action_count``-action HITL bundle.
|
||||
|
||||
On resume it captures the decision payload as a JSON-serialized
|
||||
``AIMessage`` content so the test can inspect exactly which slice
|
||||
reached this subagent — the strongest assertion against slicer
|
||||
routing regressions.
|
||||
"""
|
||||
|
||||
def approve_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": f"act_{i}", "args": {"i": i}, "description": ""}
|
||||
for i in range(action_count)
|
||||
],
|
||||
"review_configs": [{} for _ in range(action_count)],
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("approve", approve_node)
|
||||
g.add_edge(START, "approve")
|
||||
g.add_edge("approve", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
|
||||
):
|
||||
def fanout_edge(_state) -> list[Send]:
|
||||
return [
|
||||
Send(
|
||||
"call_task",
|
||||
{"tcid": tool_call_id_a, "desc": "approve A", "subtype": "agent-a"},
|
||||
),
|
||||
Send(
|
||||
"call_task",
|
||||
{"tcid": tool_call_id_b, "desc": "approve B", "subtype": "agent-b"},
|
||||
),
|
||||
]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_two_subagents_task_tool(checkpointer: InMemorySaver):
|
||||
"""Register two subagents under distinct names with heterogeneous bundle sizes.
|
||||
|
||||
Sub-A: 2-action bundle. Sub-B: 3-action bundle. Both ``> 1`` so the slice
|
||||
arithmetic is sensitive to off-by-one mistakes.
|
||||
"""
|
||||
sub_a = _build_pausing_subagent(checkpointer, action_count=2)
|
||||
sub_b = _build_pausing_subagent(checkpointer, action_count=3)
|
||||
return build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "agent-a", "description": "first", "runnable": sub_a},
|
||||
{"name": "agent-b", "description": "second", "runnable": sub_b},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error():
|
||||
"""Confirm the production failure mode: scalar resume on multi-pending state explodes.
|
||||
|
||||
This is a contract pin: if langgraph relaxes the requirement in a future
|
||||
release, this test starts passing and we know we can simplify
|
||||
``stream_resume_chat``. Until then, the keyed form is mandatory.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
task_tool = _build_two_subagents_task_tool(checkpointer)
|
||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool,
|
||||
tool_call_id_a="parent-tcid-A",
|
||||
tool_call_id_b="parent-tcid-B",
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "parallel-resume-scalar"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="multiple pending interrupts"):
|
||||
await parent.ainvoke(Command(resume={"decisions": ["A"]}), config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subagents():
|
||||
"""Production-shape resume: builds the langgraph-keyed map and resumes both subagents.
|
||||
|
||||
Mirrors what ``stream_resume_chat`` does: collects pending interrupts,
|
||||
slices the flat decisions list by ``tool_call_id``, builds the
|
||||
``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes.
|
||||
|
||||
Post-conditions checked:
|
||||
1. The langgraph-keyed map has exactly one entry per pending interrupt
|
||||
id (``str`` keys, count matches).
|
||||
2. Both subagents complete with no leftover pending interrupts.
|
||||
3. **Each subagent receives its exact slice in the original order** —
|
||||
this catches slicer arithmetic regressions (e.g., ``cursor += 1``)
|
||||
that wouldn't surface by checking only "no leftover pending".
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
task_tool = _build_two_subagents_task_tool(checkpointer)
|
||||
tcid_a = "parent-tcid-A"
|
||||
tcid_b = "parent-tcid-B"
|
||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool,
|
||||
tool_call_id_a=tcid_a,
|
||||
tool_call_id_b=tcid_b,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "parallel-resume-keyed"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused_state = await parent.aget_state(config)
|
||||
assert len(paused_state.interrupts) == 2, (
|
||||
"fixture broken: expected 2 paused subagents"
|
||||
)
|
||||
|
||||
pending = collect_pending_tool_calls(paused_state)
|
||||
assert dict(pending) == {tcid_a: 2, tcid_b: 3}, (
|
||||
f"fixture broken: heterogeneous bundle sizes not detected; got {pending!r}"
|
||||
)
|
||||
|
||||
a_d0 = {"type": "approve"}
|
||||
a_d1 = {"type": "reject", "message": "A[1] is redundant"}
|
||||
b_d0 = {
|
||||
"type": "edit",
|
||||
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
|
||||
}
|
||||
b_d1 = {"type": "approve"}
|
||||
b_d2 = {"type": "reject", "message": "B[2] needs more context"}
|
||||
flat_decisions = [a_d0, a_d1, b_d0, b_d1, b_d2]
|
||||
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
|
||||
|
||||
assert len(lg_resume_map) == 2, (
|
||||
f"expected one entry per pending interrupt id, got {lg_resume_map!r}"
|
||||
)
|
||||
assert all(isinstance(k, str) for k in lg_resume_map), (
|
||||
f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}"
|
||||
)
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final_state = await parent.aget_state(config)
|
||||
assert not final_state.interrupts, (
|
||||
f"expected no leftover pending interrupts after resume, got "
|
||||
f"{final_state.interrupts!r}"
|
||||
)
|
||||
|
||||
payloads: list[dict] = []
|
||||
for msg in final_state.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
payloads.append(json.loads(content))
|
||||
|
||||
expected_a = {"decisions": [a_d0, a_d1]}
|
||||
expected_b = {"decisions": [b_d0, b_d1, b_d2]}
|
||||
assert expected_a in payloads, (
|
||||
f"REGRESSION: sub-A did not receive its 2-decision slice in order; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
assert expected_b in payloads, (
|
||||
f"REGRESSION: sub-B did not receive its 3-decision slice in order; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
|
||||
|
||||
def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps():
|
||||
"""Unstamped interrupts can't be routed; we don't fabricate keys for them.
|
||||
|
||||
If a regression lets an unstamped interrupt reach the parent state, the
|
||||
empty map propagates to the call site and surfaces as a clear count
|
||||
mismatch instead of a silent mis-route.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
fake_interrupt = SimpleNamespace(id="i-foreign", value={"action_requests": [{}]})
|
||||
state = SimpleNamespace(interrupts=(fake_interrupt,))
|
||||
|
||||
assert build_lg_resume_map(state, {"some-tcid": {"decisions": ["x"]}}) == {}
|
||||
|
||||
|
||||
def test_build_lg_resume_map_skips_interrupts_without_corresponding_slice():
|
||||
"""Skip rather than silently mis-route when the slice and interrupts disagree.
|
||||
|
||||
Only emit a resume entry when both an interrupt id and a tool_call_id
|
||||
slice are present; a mismatch indicates upstream contract drift and
|
||||
should not be papered over.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"action_requests": [{}], "tool_call_id": "tcid-A"},
|
||||
),
|
||||
SimpleNamespace(
|
||||
id="i-B",
|
||||
value={"action_requests": [{}], "tool_call_id": "tcid-B"},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
out = build_lg_resume_map(state, {"tcid-A": {"decisions": ["only-A"]}})
|
||||
assert out == {"i-A": {"decisions": ["only-A"]}}
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
"""Real-graph parallel HITL across both approval kinds — the keystone regression.
|
||||
|
||||
Pre-fix bug: the parallel-HITL routing layer (``collect_pending_tool_calls``
|
||||
+ ``slice_decisions_by_tool_call`` + ``build_lg_resume_map``) only
|
||||
recognized middleware-gated approvals (LC HITL shape from
|
||||
``HumanInTheLoopMiddleware``). Self-gated approvals from
|
||||
``request_approval`` and middleware-gated permission asks from
|
||||
``PermissionMiddleware`` both used the SurfSense-specific
|
||||
``{type, action, context}`` shape, so when the orchestrator dispatched
|
||||
two parallel ``task`` calls — one self-gated, one middleware-gated — only
|
||||
one interrupt was visible to the routing layer and resume crashed with
|
||||
``Decision count mismatch``.
|
||||
|
||||
This test fans out two real subagents via ``Send``: one calls
|
||||
``request_approval`` (self-gated), the other calls
|
||||
``request_permission_decision`` (middleware-gated). Both pause; the routing
|
||||
layer must see TWO LC HITL interrupts, slice the decisions by
|
||||
``tool_call_id``, key by ``Interrupt.id``, and resume both branches with
|
||||
their per-slice payload.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
|
||||
request_permission_decision,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
# ``add_messages`` is mandatory: parallel ``Send`` branches both append
|
||||
# to ``messages`` in the same superstep; without a reducer langgraph
|
||||
# raises ``InvalidUpdateError``.
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
def _build_self_gated_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses via ``request_approval`` (self-gated path)."""
|
||||
|
||||
def gate_node(_state):
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com"},
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "self_gated",
|
||||
"decision_type": result.decision_type,
|
||||
"params": result.params,
|
||||
"rejected": result.rejected,
|
||||
},
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("gate", gate_node)
|
||||
g.add_edge(START, "gate")
|
||||
g.add_edge("gate", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_middleware_gated_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses via ``request_permission_decision`` (middleware-gated path)."""
|
||||
|
||||
def perm_node(_state):
|
||||
decision = request_permission_decision(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/file"},
|
||||
patterns=["rm/*"],
|
||||
rules=[Rule(permission="rm", pattern="*", action="ask")],
|
||||
emit_interrupt=True,
|
||||
)
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=json.dumps(
|
||||
{"kind": "middleware_gated", "decision": decision},
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("perm", perm_node)
|
||||
g.add_edge(START, "perm")
|
||||
g.add_edge("perm", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_mixed_task_tool(checkpointer: InMemorySaver):
|
||||
"""Two subagents, one per approval kind, registered under distinct names."""
|
||||
return build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "self-gated-agent",
|
||||
"description": "uses request_approval",
|
||||
"runnable": _build_self_gated_subagent(checkpointer),
|
||||
},
|
||||
{
|
||||
"name": "middleware-gated-agent",
|
||||
"description": "uses request_permission_decision",
|
||||
"runnable": _build_middleware_gated_subagent(checkpointer),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _parent_dispatching_one_of_each(
|
||||
task_tool, *, tcid_self: str, tcid_mw: str, checkpointer
|
||||
):
|
||||
def fanout_edge(_state) -> list[Send]:
|
||||
return [
|
||||
Send(
|
||||
"call_task",
|
||||
{
|
||||
"tcid": tcid_self,
|
||||
"desc": "approve email",
|
||||
"subtype": "self-gated-agent",
|
||||
},
|
||||
),
|
||||
Send(
|
||||
"call_task",
|
||||
{
|
||||
"tcid": tcid_mw,
|
||||
"desc": "approve rm",
|
||||
"subtype": "middleware-gated-agent",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_self_gated_and_middleware_gated_route_and_resume_cleanly():
|
||||
"""Both interrupt kinds must reach the routing layer in LC HITL shape and resume independently."""
|
||||
checkpointer = InMemorySaver()
|
||||
task_tool = _build_mixed_task_tool(checkpointer)
|
||||
tcid_self = "tcid-self-gated"
|
||||
tcid_mw = "tcid-middleware-gated"
|
||||
parent = _parent_dispatching_one_of_each(
|
||||
task_tool,
|
||||
tcid_self=tcid_self,
|
||||
tcid_mw=tcid_mw,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "mixed-parallel"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused = await parent.aget_state(config)
|
||||
assert len(paused.interrupts) == 2, (
|
||||
"fixture broken: expected one paused interrupt per approval kind"
|
||||
)
|
||||
|
||||
# Both interrupts must speak the same wire shape — the whole point of
|
||||
# the unification. If either one regresses to the legacy SurfSense shape
|
||||
# ``collect_pending_tool_calls`` would silently skip it and the count
|
||||
# below would be 1.
|
||||
pending = collect_pending_tool_calls(paused)
|
||||
assert dict(pending) == {tcid_self: 1, tcid_mw: 1}, (
|
||||
f"REGRESSION: not all interrupt kinds reached the routing layer; "
|
||||
f"got {pending!r}"
|
||||
)
|
||||
|
||||
# Verify the actual wire payloads carry the LC HITL standard fields
|
||||
# (extra defensive assertion against partial regressions where one
|
||||
# path stamps tool_call_id but reverts the body shape).
|
||||
interrupt_types = {i.value.get("interrupt_type") for i in paused.interrupts}
|
||||
assert interrupt_types == {"gmail_email_send", "permission_ask"}
|
||||
|
||||
# Resume order: same order the SSE stream would emit (interrupts list).
|
||||
decision_self = {"type": "approve"}
|
||||
decision_mw = {"type": "approve_always"}
|
||||
flat_decisions = [
|
||||
# Match `pending` order.
|
||||
decision_self if pending[0][0] == tcid_self else decision_mw,
|
||||
decision_mw if pending[0][0] == tcid_self else decision_self,
|
||||
]
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
|
||||
assert len(lg_resume_map) == 2
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final = await parent.aget_state(config)
|
||||
assert not final.interrupts, (
|
||||
f"expected both branches resumed, but state still has interrupts: "
|
||||
f"{final.interrupts!r}"
|
||||
)
|
||||
|
||||
# Each subagent must have received its own slice — verify by inspecting
|
||||
# the JSON-serialized result messages.
|
||||
payloads: list[dict] = []
|
||||
for msg in final.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
payloads.append(json.loads(content))
|
||||
|
||||
self_payloads = [p for p in payloads if p.get("kind") == "self_gated"]
|
||||
mw_payloads = [p for p in payloads if p.get("kind") == "middleware_gated"]
|
||||
assert len(self_payloads) == 1, (
|
||||
f"self-gated subagent did not complete; payloads: {payloads!r}"
|
||||
)
|
||||
assert len(mw_payloads) == 1, (
|
||||
f"middleware-gated subagent did not complete; payloads: {payloads!r}"
|
||||
)
|
||||
|
||||
# Self-gated approve → HITLResult(decision_type="approve", rejected=False).
|
||||
assert self_payloads[0]["decision_type"] == "approve"
|
||||
assert self_payloads[0]["rejected"] is False
|
||||
|
||||
# Middleware-gated approve_always → canonical permission shape unchanged.
|
||||
assert mw_payloads[0]["decision"] == {"decision_type": "approve_always"}
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
"""Behavioural guarantees for parallel ``task`` tool calls (non-HITL cases).
|
||||
|
||||
The HITL bridge tests in ``test_hitl_bridge.py`` cover the parallel-interrupt
|
||||
flow. This file covers the *normal* parallel paths (no interrupts) and the
|
||||
failure-isolation guarantee — together they pin the behaviour we promise the
|
||||
user about ``asyncio.gather`` over two ``atask`` coroutines.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
def _build_success_subagent(reply: str):
|
||||
"""A subagent that completes immediately with ``reply``, never interrupts."""
|
||||
|
||||
def node(_state):
|
||||
return {"messages": [AIMessage(content=reply)]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("only", node)
|
||||
g.add_edge(START, "only")
|
||||
g.add_edge("only", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
def _build_failing_subagent(exc: Exception):
|
||||
"""A subagent whose only node raises ``exc`` — simulates a tool-level failure."""
|
||||
|
||||
def node(_state):
|
||||
raise exc
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("only", node)
|
||||
g.add_edge(START, "only")
|
||||
g.add_edge("only", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
def _make_runtime(parent_config: dict, *, tool_call_id: str) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state={"messages": [HumanMessage(content="seed")]},
|
||||
context=None,
|
||||
config=parent_config,
|
||||
stream_writer=None,
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
def _tool_message_text(cmd: Command, *, expected_tcid: str) -> str:
|
||||
"""Return the ToolMessage content the task tool produced for ``expected_tcid``."""
|
||||
assert isinstance(cmd, Command), f"expected Command, got {type(cmd).__name__}"
|
||||
messages = cmd.update["messages"]
|
||||
assert len(messages) == 1, f"expected 1 ToolMessage, got {len(messages)}"
|
||||
msg = messages[0]
|
||||
assert isinstance(msg, ToolMessage)
|
||||
assert msg.tool_call_id == expected_tcid
|
||||
return msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_parallel_atasks_to_different_subagents_both_succeed():
|
||||
"""Normal happy-path: two distinct subagents complete in parallel without interrupting."""
|
||||
subagent_a = _build_success_subagent("A is done")
|
||||
subagent_b = _build_success_subagent("B is done")
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "alpha", "description": "alpha agent", "runnable": subagent_a},
|
||||
{"name": "beta", "description": "beta agent", "runnable": subagent_b},
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "ok-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
|
||||
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
|
||||
|
||||
result_a, result_b = await asyncio.gather(
|
||||
task_tool.coroutine(
|
||||
description="do A",
|
||||
subagent_type="alpha",
|
||||
runtime=runtime_a,
|
||||
),
|
||||
task_tool.coroutine(
|
||||
description="do B",
|
||||
subagent_type="beta",
|
||||
runtime=runtime_b,
|
||||
),
|
||||
)
|
||||
|
||||
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "A is done"
|
||||
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "B is done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_parallel_atasks_same_subagent_type_different_tool_call_ids():
|
||||
"""Per-call ``thread_id`` isolation: same compiled subagent invoked twice in parallel.
|
||||
|
||||
Both calls share the same ``InMemorySaver`` instance but are namespaced by
|
||||
distinct ``tool_call_id``s, so checkpoints land in disjoint thread slots.
|
||||
"""
|
||||
shared_subagent = _build_success_subagent("ok")
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "shared approver",
|
||||
"runnable": shared_subagent,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "shared-subagent-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
|
||||
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
|
||||
|
||||
result_a, result_b = await asyncio.gather(
|
||||
task_tool.coroutine(
|
||||
description="first request",
|
||||
subagent_type="approver",
|
||||
runtime=runtime_a,
|
||||
),
|
||||
task_tool.coroutine(
|
||||
description="second request",
|
||||
subagent_type="approver",
|
||||
runtime=runtime_b,
|
||||
),
|
||||
)
|
||||
|
||||
# Both calls succeed and produce ToolMessages keyed by their own tool_call_id.
|
||||
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "ok"
|
||||
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "ok"
|
||||
|
||||
# Verify checkpoint isolation: each call's state lives at its own thread_id.
|
||||
state_a = await shared_subagent.aget_state(
|
||||
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-A"}}
|
||||
)
|
||||
state_b = await shared_subagent.aget_state(
|
||||
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-B"}}
|
||||
)
|
||||
assert state_a.values["messages"][-1].content == "ok"
|
||||
assert state_b.values["messages"][-1].content == "ok"
|
||||
|
||||
# The parent's own thread_id slot is untouched by either subagent.
|
||||
state_parent = await shared_subagent.aget_state(
|
||||
{"configurable": {"thread_id": "shared-subagent-thread"}}
|
||||
)
|
||||
assert state_parent.values == {} or state_parent.values.get("messages") in (
|
||||
None,
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_one_atask_failure_does_not_corrupt_sibling_atask():
|
||||
"""Failure isolation: a sibling's exception must not poison the surviving atask's state.
|
||||
|
||||
Note: in production, langgraph's pregel runner cancels siblings when any
|
||||
parallel task raises a non-``GraphBubbleUp`` exception (see
|
||||
``_should_stop_others`` in ``langgraph/pregel/_runner.py``). At our layer
|
||||
that policy is invisible — what we *can* guarantee is that the two atask
|
||||
coroutines have disjoint state, so the surviving one returns a valid
|
||||
Command even when its sibling explodes.
|
||||
"""
|
||||
failing_subagent = _build_failing_subagent(ValueError("boom"))
|
||||
surviving_subagent = _build_success_subagent("still here")
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "broken",
|
||||
"description": "always fails",
|
||||
"runnable": failing_subagent,
|
||||
},
|
||||
{
|
||||
"name": "healthy",
|
||||
"description": "always succeeds",
|
||||
"runnable": surviving_subagent,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "iso-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
runtime_fail = _make_runtime(parent_config, tool_call_id="tcid-fail")
|
||||
runtime_ok = _make_runtime(parent_config, tool_call_id="tcid-ok")
|
||||
|
||||
results = await asyncio.gather(
|
||||
task_tool.coroutine(
|
||||
description="will explode",
|
||||
subagent_type="broken",
|
||||
runtime=runtime_fail,
|
||||
),
|
||||
task_tool.coroutine(
|
||||
description="will work",
|
||||
subagent_type="healthy",
|
||||
runtime=runtime_ok,
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
fail_result, ok_result = results
|
||||
|
||||
assert isinstance(fail_result, Exception), (
|
||||
f"expected the broken subagent to raise, got {fail_result!r}"
|
||||
)
|
||||
# ValueError gets wrapped in langgraph's internal exception types — the
|
||||
# important guarantee is "this path errored", not the specific class.
|
||||
assert "boom" in str(fail_result) or isinstance(fail_result, ValueError)
|
||||
|
||||
assert _tool_message_text(ok_result, expected_tcid="tcid-ok") == "still here"
|
||||
|
||||
# Configurable side-channel must not have been corrupted by the failure.
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
|
||||
|
||||
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
|
||||
SSE stream emitted approval cards. When multiple parallel subagents are paused,
|
||||
the backend slices that flat list into per-``tool_call_id`` payloads so each
|
||||
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
|
||||
|
||||
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` — which is
|
||||
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
|
||||
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
|
||||
through ``[a]task`` — to build the ordered ``pending`` list the slicer needs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestSliceDecisionsByToolCall:
|
||||
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
|
||||
decisions = [
|
||||
{"type": "approve"},
|
||||
{"type": "edit", "edited_action": {"name": "edited-b1"}},
|
||||
{"type": "reject"},
|
||||
{"type": "approve"},
|
||||
{"type": "approve"},
|
||||
]
|
||||
pending = [
|
||||
("tcid-A", 3),
|
||||
("tcid-B", 2),
|
||||
]
|
||||
|
||||
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
assert routed == {
|
||||
"tcid-A": {"decisions": decisions[0:3]},
|
||||
"tcid-B": {"decisions": decisions[3:5]},
|
||||
}
|
||||
|
||||
def test_raises_when_decision_count_less_than_total_actions(self):
|
||||
decisions = [{"type": "approve"}, {"type": "approve"}]
|
||||
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||
|
||||
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
def test_raises_when_decision_count_greater_than_total_actions(self):
|
||||
decisions = [{"type": "approve"}] * 6
|
||||
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||
|
||||
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
def test_handles_single_pending_tool_call(self):
|
||||
decisions = [{"type": "approve"}, {"type": "reject"}]
|
||||
pending = [("tcid-only", 2)]
|
||||
|
||||
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
assert routed == {"tcid-only": {"decisions": decisions}}
|
||||
|
||||
def test_returns_empty_dict_for_no_pending(self):
|
||||
routed = slice_decisions_by_tool_call([], [])
|
||||
|
||||
assert routed == {}
|
||||
|
||||
|
||||
def _interrupt_with(tool_call_id: str, action_count: int):
|
||||
return SimpleNamespace(
|
||||
id=f"i-{tool_call_id}",
|
||||
value={
|
||||
"action_requests": [{"name": "n", "args": {}}] * action_count,
|
||||
"review_configs": [{}] * action_count,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestCollectPendingToolCalls:
|
||||
def test_single_pending_returns_one_pair(self):
|
||||
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
|
||||
|
||||
def test_multiple_pending_preserves_state_order(self):
|
||||
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
_interrupt_with("tcid-A", 2),
|
||||
_interrupt_with("tcid-B", 3),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
|
||||
|
||||
def test_empty_when_no_interrupts(self):
|
||||
state = SimpleNamespace(interrupts=())
|
||||
|
||||
assert collect_pending_tool_calls(state) == []
|
||||
|
||||
def test_skips_interrupts_without_tool_call_id(self):
|
||||
"""Defensive: interrupts not produced by our propagation layer are ignored.
|
||||
|
||||
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
|
||||
interrupts (e.g. parent-side HITL middleware on a different tool) are
|
||||
not the slicer's responsibility.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
_interrupt_with("tcid-A", 2),
|
||||
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
|
||||
_interrupt_with("tcid-B", 1),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
|
||||
|
||||
def test_handles_scalar_value_interrupt(self):
|
||||
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
|
||||
|
||||
These have no ``action_requests`` — count them as a single action so
|
||||
the frontend submits exactly one decision per such interrupt.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"value": "approve?", "tool_call_id": "tcid-A"},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
|
||||
|
||||
def test_raises_when_interrupt_value_missing_action_count_keys(self):
|
||||
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"tool_call_id": "tcid-A", "weird_shape": True},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="action_requests"):
|
||||
collect_pending_tool_calls(state)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Resume side-channel must be read exactly once per turn."""
|
||||
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
|
|||
)
|
||||
|
||||
|
||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||
def _runtime_with_config(
|
||||
config: dict, *, tool_call_id: str = "tcid-test"
|
||||
) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="tcid-test",
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeSurfsenseResume:
|
||||
def test_pops_value_on_first_call(self):
|
||||
def test_pops_only_entry_matching_runtime_tool_call_id(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {
|
||||
"tcid-A": {"decisions": ["approve"]},
|
||||
"tcid-B": {"decisions": ["reject"]},
|
||||
}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||
|
||||
def test_second_call_returns_none(self):
|
||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
def test_popping_one_entry_leaves_siblings_untouched(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {
|
||||
"tcid-A": {"decisions": ["approve"]},
|
||||
"tcid-B": {"decisions": ["reject"]},
|
||||
}
|
||||
}
|
||||
runtime_a = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
consume_surfsense_resume(runtime_a)
|
||||
|
||||
assert configurable["surfsense_resume_value"] == {
|
||||
"tcid-B": {"decisions": ["reject"]}
|
||||
}
|
||||
|
||||
def test_returns_none_when_no_entry_for_this_tool_call(self):
|
||||
runtime = _runtime_with_config(
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
def test_returns_none_when_no_payload_queued(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
|
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
|
|||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
def test_drops_empty_dict_after_last_entry_consumed(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
|
||||
class TestHasSurfsenseResume:
|
||||
def test_true_when_payload_queued(self):
|
||||
def test_true_when_entry_for_this_tool_call_present(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is True
|
||||
|
||||
def test_false_when_entry_for_other_tool_call_only(self):
|
||||
runtime = _runtime_with_config(
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is False
|
||||
|
||||
def test_does_not_consume_payload(self):
|
||||
configurable = {"surfsense_resume_value": "approve"}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
configurable = {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
has_surfsense_resume(runtime)
|
||||
|
||||
assert configurable == {"surfsense_resume_value": "approve"}
|
||||
assert configurable["surfsense_resume_value"] == {
|
||||
"tcid-A": {"decisions": ["approve"]}
|
||||
}
|
||||
|
||||
def test_false_when_payload_absent(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,284 @@
|
|||
"""Production-shape regression tests for ``tool_call_id`` stamping on subagent interrupts.
|
||||
|
||||
The production bug we're pinning here: when the orchestrator dispatches one or
|
||||
more ``task`` tool calls and the targeted subagents hit a HITL ``interrupt(...)``,
|
||||
the parent's persisted ``state.interrupts`` must carry the parent's
|
||||
``tool_call_id`` on each interrupt value. Without that stamp,
|
||||
``stream_resume_chat`` cannot route a flat ``decisions`` list back to the right
|
||||
paused subagent and resume fails with ``Decision count mismatch``.
|
||||
|
||||
The tests in this module:
|
||||
|
||||
- Build a **real** ``StateGraph`` subagent that calls real ``interrupt(...)``
|
||||
(no MagicMock, no patch of langgraph internals — those are exactly the kind
|
||||
of fakes that hid this bug).
|
||||
- Invoke the ``task`` tool from **inside a parent pregel** (via a tiny parent
|
||||
``StateGraph`` node) so the subagent invocation happens in the
|
||||
production-shape "subgraph called from a parent tool node" context.
|
||||
- Assert on ``parent.state.interrupts[*].value["tool_call_id"]`` — the
|
||||
observable that ``stream_resume_chat`` reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _S(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
def _build_single_interrupt_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that fires one HITL-bundle-shaped interrupt and waits for a decision."""
|
||||
|
||||
def approve_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "do_thing", "args": {"x": 1}, "description": ""}
|
||||
],
|
||||
"review_configs": [{}],
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
||||
|
||||
g = StateGraph(_S)
|
||||
g.add_node("approve", approve_node)
|
||||
g.add_edge(START, "approve")
|
||||
g.add_edge("approve", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_bundle_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that fires one interrupt carrying a 3-action bundle."""
|
||||
|
||||
def bundle_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "a", "args": {}, "description": ""},
|
||||
{"name": "b", "args": {}, "description": ""},
|
||||
{"name": "c", "args": {}, "description": ""},
|
||||
],
|
||||
"review_configs": [{}, {}, {}],
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"bundle:{decision}")]}
|
||||
|
||||
g = StateGraph(_S)
|
||||
g.add_node("bundle", bundle_node)
|
||||
g.add_edge(START, "bundle")
|
||||
g.add_edge("bundle", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_graph_calling_task(task_tool, *, tool_call_id: str, checkpointer):
|
||||
"""A tiny parent graph whose only node invokes ``task_tool`` from inside the pregel runtime.
|
||||
|
||||
This is the minimal reproduction of production's "subagent invoked from
|
||||
inside a parent tool node" context — the *only* context where langgraph
|
||||
treats the subagent as a subgraph and routes its interrupts back to the
|
||||
parent's checkpoint.
|
||||
"""
|
||||
|
||||
async def call_task(state, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=rt,
|
||||
)
|
||||
|
||||
g = StateGraph(_S)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_edge(START, "call_task")
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: list
|
||||
tcid: str
|
||||
desc: str
|
||||
|
||||
|
||||
def _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
|
||||
):
|
||||
"""A parent graph that dispatches two ``task`` calls as parallel pregel
|
||||
tasks via :class:`~langgraph.types.Send`.
|
||||
|
||||
This mirrors the production dispatch mechanism: when the orchestrator's
|
||||
LLM emits two ``task`` tool calls in one turn, langchain's tool node
|
||||
fans them out as parallel pregel tasks (the same primitive as ``Send``)
|
||||
so each tool call gets its own pregel task that can raise
|
||||
``GraphInterrupt`` independently — and pregel collects *all* of them
|
||||
into the parent's snapshot at the end of the superstep.
|
||||
"""
|
||||
|
||||
def fanout_edge(_state) -> list[Send]:
|
||||
return [
|
||||
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
|
||||
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
|
||||
]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type="approver", runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_interrupt_values(snapshot) -> list[dict]:
|
||||
"""Extract ``state.interrupts[*].value`` for assertions."""
|
||||
return [i.value for i in (snapshot.interrupts or ())]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_subagent_interrupt_stamps_parent_tool_call_id():
|
||||
"""A single paused subagent must surface to the parent with ``tool_call_id`` stamped.
|
||||
|
||||
Production bug regression: was producing
|
||||
``value={"action_requests": [...], "review_configs": [...]}`` (no
|
||||
``tool_call_id``), causing ``stream_resume_chat`` to skip the interrupt
|
||||
and raise ``Decision count mismatch``.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
subagent = _build_single_interrupt_subagent(checkpointer)
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||
)
|
||||
parent = _parent_graph_calling_task(
|
||||
task_tool, tool_call_id="parent-tcid-A", checkpointer=checkpointer
|
||||
)
|
||||
|
||||
parent_config = {
|
||||
"configurable": {"thread_id": "parent-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
|
||||
snap = await parent.aget_state(parent_config)
|
||||
values = _parent_interrupt_values(snap)
|
||||
assert len(values) == 1, (
|
||||
f"expected exactly 1 parent interrupt, got {len(values)}: {values!r}"
|
||||
)
|
||||
value = values[0]
|
||||
assert isinstance(value, dict)
|
||||
assert value.get("tool_call_id") == "parent-tcid-A", (
|
||||
f"REGRESSION: parent interrupt missing/wrong tool_call_id stamp. "
|
||||
f"Expected 'parent-tcid-A', got {value.get('tool_call_id')!r}. "
|
||||
f"Keys present: {sorted(value.keys())}"
|
||||
)
|
||||
# The original HITL payload must still be intact alongside the stamp.
|
||||
assert value.get("action_requests") == [
|
||||
{"name": "do_thing", "args": {"x": 1}, "description": ""}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_parallel_subagents_each_stamp_their_own_tool_call_id():
|
||||
"""Two ``task`` calls dispatched in parallel must each carry their own ``tool_call_id``.
|
||||
|
||||
This is the actual production scenario (Linear + Jira ticket creation):
|
||||
two parallel ``task`` tool calls, both subagents hit HITL, parent must
|
||||
end up with two interrupts whose ``tool_call_id``s match the two
|
||||
distinct parent-level ``tool_call_id``s the LLM emitted.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
subagent = _build_single_interrupt_subagent(checkpointer)
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||
)
|
||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool,
|
||||
tool_call_id_a="parent-tcid-A",
|
||||
tool_call_id_b="parent-tcid-B",
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
parent_config = {
|
||||
"configurable": {"thread_id": "parent-thread-parallel"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
|
||||
snap = await parent.aget_state(parent_config)
|
||||
values = _parent_interrupt_values(snap)
|
||||
assert len(values) == 2, (
|
||||
f"expected 2 parent interrupts (one per parallel task call), "
|
||||
f"got {len(values)}: {values!r}"
|
||||
)
|
||||
stamps = {v.get("tool_call_id") for v in values}
|
||||
assert stamps == {"parent-tcid-A", "parent-tcid-B"}, (
|
||||
f"REGRESSION: parallel parent interrupts missing/wrong tool_call_id stamps. "
|
||||
f"Expected {{'parent-tcid-A', 'parent-tcid-B'}}, got {stamps!r}. "
|
||||
f"Values: {values!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bundle_subagent_interrupt_stamps_tool_call_id_preserving_actions():
|
||||
"""A subagent emitting a multi-action bundle must surface stamped, with all actions intact.
|
||||
|
||||
The bundle shape (``action_requests=[3 items]``) drives the
|
||||
``slice_decisions_by_tool_call`` accounting in ``stream_resume_chat`` —
|
||||
if either the stamp or the action count is lost, resume routing
|
||||
miscounts and crashes.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
subagent = _build_bundle_subagent(checkpointer)
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||
)
|
||||
parent = _parent_graph_calling_task(
|
||||
task_tool, tool_call_id="parent-tcid-bundle", checkpointer=checkpointer
|
||||
)
|
||||
|
||||
parent_config = {
|
||||
"configurable": {"thread_id": "parent-thread-bundle"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
|
||||
snap = await parent.aget_state(parent_config)
|
||||
values = _parent_interrupt_values(snap)
|
||||
assert len(values) == 1
|
||||
value = values[0]
|
||||
assert value.get("tool_call_id") == "parent-tcid-bundle"
|
||||
assert isinstance(value.get("action_requests"), list)
|
||||
assert len(value["action_requests"]) == 3, (
|
||||
f"REGRESSION: bundle action_requests count changed during stamping; "
|
||||
f"got {len(value['action_requests'])} actions: {value['action_requests']!r}"
|
||||
)
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""Per-call ``thread_id`` derivation for nested subagent invocations.
|
||||
|
||||
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
|
||||
checkpoint slots so their nested pregel runs do not stomp on each other or on
|
||||
the parent's checkpoint state. The slot key is derived from the runtime's
|
||||
``tool_call_id`` so the same call across the resume cycle keeps reading from
|
||||
the same snapshot.
|
||||
|
||||
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
|
||||
is the primary checkpoint key and is free-form, so it's the right primitive.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
subagent_invoke_config,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config or {},
|
||||
stream_writer=None,
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestSubagentInvokeThreadId:
|
||||
def test_sets_per_call_thread_id_under_parent(self):
|
||||
runtime = _runtime(
|
||||
tool_call_id="tcid-A",
|
||||
config={"configurable": {"thread_id": "t1"}},
|
||||
)
|
||||
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
|
||||
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
|
||||
|
||||
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
|
||||
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
|
||||
runtime = _runtime(
|
||||
tool_call_id="tcid-inner",
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": "t1::task:tcid-outer",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
|
||||
assert (
|
||||
sub_config["configurable"]["thread_id"]
|
||||
== "t1::task:tcid-outer::task:tcid-inner"
|
||||
)
|
||||
|
||||
def test_different_tool_call_ids_produce_different_thread_ids(self):
|
||||
config = {"configurable": {"thread_id": "t1"}}
|
||||
rt_a = _runtime(tool_call_id="tcid-A", config=config)
|
||||
rt_b = _runtime(tool_call_id="tcid-B", config=config)
|
||||
|
||||
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
|
||||
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
|
||||
|
||||
assert tid_a != tid_b
|
||||
|
||||
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
|
||||
"""Resume bridge needs to find the snapshot it primed earlier."""
|
||||
config = {"configurable": {"thread_id": "t1"}}
|
||||
rt_first = _runtime(tool_call_id="tcid-A", config=config)
|
||||
rt_second = _runtime(tool_call_id="tcid-A", config=config)
|
||||
|
||||
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
|
||||
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
|
||||
|
||||
assert tid_first == tid_second
|
||||
|
||||
def test_does_not_mutate_caller_config(self):
|
||||
"""Repeated calls must not accumulate suffixes onto the parent's config."""
|
||||
original_thread_id = "t1"
|
||||
config = {"configurable": {"thread_id": original_thread_id}}
|
||||
runtime = _runtime(tool_call_id="tcid-A", config=config)
|
||||
|
||||
subagent_invoke_config(runtime)
|
||||
subagent_invoke_config(runtime)
|
||||
|
||||
assert config["configurable"]["thread_id"] == original_thread_id
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
"""Regression: ``request_permission_decision`` must emit the unified LC HITL wire shape.
|
||||
|
||||
Same bug class as :mod:`test_lc_hitl_wire` for self-gated approvals: the
|
||||
permission middleware previously fired the SurfSense-specific
|
||||
``{type, action, context}`` shape, which the parallel-HITL routing layer
|
||||
does not recognize. Standardizing on LC HITL keeps every approval kind on
|
||||
one routing path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
|
||||
request_permission_decision,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: list
|
||||
final_decision: dict
|
||||
|
||||
|
||||
def _build_graph_calling_request_permission_decision(checkpointer: InMemorySaver):
|
||||
"""Real graph whose only node delegates to the permission ask primitive."""
|
||||
|
||||
def perm_node(_state):
|
||||
decision = request_permission_decision(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/file"},
|
||||
patterns=["rm/*"],
|
||||
rules=[Rule(permission="rm", pattern="*", action="ask")],
|
||||
emit_interrupt=True,
|
||||
)
|
||||
return {"final_decision": decision}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("perm", perm_node)
|
||||
g.add_edge(START, "perm")
|
||||
g.add_edge("perm", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_ask_payload_uses_lc_hitl_shape():
|
||||
"""The permission middleware now speaks the langchain HITL standard shape."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-wire"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
value = snap.interrupts[0].value
|
||||
|
||||
assert value.get("action_requests") == [
|
||||
{"name": "rm", "args": {"path": "/tmp/file"}}
|
||||
], f"REGRESSION: permission ask reverted to legacy shape; got {value!r}"
|
||||
review = value.get("review_configs")
|
||||
assert isinstance(review, list) and len(review) == 1
|
||||
palette = review[0]["allowed_decisions"]
|
||||
# Native tool (no ``tool=`` argument): the palette must include the
|
||||
# once/reject/edit triad. ``approve_always`` is gated on MCP-ness and
|
||||
# therefore *omitted* here — palette content per tool kind is
|
||||
# exercised in ``test_permission_ask_mcp_context``.
|
||||
assert "approve" in palette and "reject" in palette and "edit" in palette
|
||||
assert value.get("interrupt_type") == "permission_ask"
|
||||
# SurfSense context rides through verbatim for FE explainability.
|
||||
assert value["context"]["patterns"] == ["rm/*"]
|
||||
assert value["context"]["always"] == ["rm/*"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_approve_envelope_returns_once_decision():
|
||||
"""``approve`` from the LC envelope projects to permission-domain ``once``."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-once"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {"decision_type": "once"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_approve_always_envelope_projects_unchanged():
|
||||
"""``approve_always`` reply must project unchanged so the middleware can promote the rule."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-approve-always"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve_always"}]}), config
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {"decision_type": "approve_always"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_reject_and_feedback_carries_feedback_through():
|
||||
"""Reject feedback must survive normalization for ``CorrectedError`` to fire downstream."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_permission_decision(checkpointer)
|
||||
config = {"configurable": {"thread_id": "perm-reject"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(
|
||||
resume={"decisions": [{"type": "reject", "feedback": "use the trash bin"}]}
|
||||
),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision") == {
|
||||
"decision_type": "reject",
|
||||
"feedback": "use the trash bin",
|
||||
}
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
"""Permission-ask payload surfaces tool metadata for the FE card."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.payload import (
|
||||
build_permission_ask_payload,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
class _NoArgs(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
async def _noop(**_kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _ask_rule(tool_name: str) -> Rule:
|
||||
return Rule(permission=tool_name, pattern="*", action="ask")
|
||||
|
||||
|
||||
def _make_mcp_tool(*, name: str, connector_id: int, connector_name: str):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Run {name} via MCP.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={
|
||||
"mcp_connector_id": connector_id,
|
||||
"mcp_connector_name": connector_name,
|
||||
"mcp_transport": "http",
|
||||
"hitl": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_payload_surfaces_mcp_fields_from_tool():
|
||||
tool = _make_mcp_tool(
|
||||
name="linear_create_issue", connector_id=42, connector_name="Linear (acme)"
|
||||
)
|
||||
payload = build_permission_ask_payload(
|
||||
tool_name=tool.name,
|
||||
args={"title": "bug"},
|
||||
patterns=[tool.name],
|
||||
rules=[_ask_rule(tool.name)],
|
||||
tool=tool,
|
||||
)
|
||||
ctx = payload["context"]
|
||||
assert ctx["mcp_connector_id"] == 42
|
||||
assert ctx["mcp_server"] == "Linear (acme)"
|
||||
assert ctx["tool_description"] == "Run linear_create_issue via MCP."
|
||||
|
||||
|
||||
def test_payload_omits_tool_fields_when_tool_is_none():
|
||||
payload = build_permission_ask_payload(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/x"},
|
||||
patterns=["rm"],
|
||||
rules=[_ask_rule("rm")],
|
||||
tool=None,
|
||||
)
|
||||
ctx = payload["context"]
|
||||
assert "mcp_connector_id" not in ctx
|
||||
assert "mcp_server" not in ctx
|
||||
assert "tool_description" not in ctx
|
||||
|
||||
|
||||
def test_palette_includes_approve_always_for_mcp_tool():
|
||||
"""Saving to the connector's trusted-tools list is only possible for MCP tools."""
|
||||
tool = _make_mcp_tool(
|
||||
name="linear_create_issue", connector_id=42, connector_name="Linear"
|
||||
)
|
||||
palette = build_permission_ask_payload(
|
||||
tool_name=tool.name,
|
||||
args={},
|
||||
patterns=[tool.name],
|
||||
rules=[_ask_rule(tool.name)],
|
||||
tool=tool,
|
||||
)["review_configs"][0]["allowed_decisions"]
|
||||
assert "approve_always" in palette
|
||||
|
||||
|
||||
def test_palette_excludes_approve_always_for_native_tool():
|
||||
"""Native tools have no place to persist trust, so don't offer the button."""
|
||||
native = StructuredTool(
|
||||
name="rm",
|
||||
description="Remove a file.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={"hitl": True},
|
||||
)
|
||||
palette = build_permission_ask_payload(
|
||||
tool_name=native.name,
|
||||
args={"path": "/tmp/x"},
|
||||
patterns=[native.name],
|
||||
rules=[_ask_rule(native.name)],
|
||||
tool=native,
|
||||
)["review_configs"][0]["allowed_decisions"]
|
||||
assert "approve_always" not in palette
|
||||
assert palette == ["approve", "reject", "edit"]
|
||||
|
||||
|
||||
def test_palette_excludes_approve_always_when_tool_is_none():
|
||||
"""Without a tool object the middleware can't tell — fall back to the safe triad."""
|
||||
palette = build_permission_ask_payload(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp/x"},
|
||||
patterns=["rm"],
|
||||
rules=[_ask_rule("rm")],
|
||||
tool=None,
|
||||
)["review_configs"][0]["allowed_decisions"]
|
||||
assert palette == ["approve", "reject", "edit"]
|
||||
|
||||
|
||||
def test_payload_omits_falsy_mcp_metadata_fields():
|
||||
tool = StructuredTool(
|
||||
name="anon_tool",
|
||||
description="",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={"mcp_connector_id": None, "mcp_connector_name": ""},
|
||||
)
|
||||
ctx = build_permission_ask_payload(
|
||||
tool_name=tool.name,
|
||||
args={},
|
||||
patterns=[tool.name],
|
||||
rules=[_ask_rule(tool.name)],
|
||||
tool=tool,
|
||||
)["context"]
|
||||
assert "mcp_connector_id" not in ctx
|
||||
assert "mcp_server" not in ctx
|
||||
assert "tool_description" not in ctx
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
|
||||
def _node(_state: _State) -> dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_name,
|
||||
"args": args,
|
||||
"id": call_id,
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
return _node
|
||||
|
||||
|
||||
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
|
||||
def after(state: _State) -> dict[str, Any] | None:
|
||||
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
|
||||
g.add_node("permission", after)
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_decorates_interrupt_with_mcp_tool_metadata():
|
||||
tool = _make_mcp_tool(
|
||||
name="linear_create_issue", connector_id=7, connector_name="Linear"
|
||||
)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[
|
||||
Ruleset(origin="linear", rules=[_ask_rule(tool.name)]),
|
||||
],
|
||||
tools=[tool],
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _compile_graph_with(pm, tool.name, {"title": "bug"}, "call-1")
|
||||
config = {"configurable": {"thread_id": "linear-ask"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
ctx = snap.interrupts[0].value["context"]
|
||||
assert ctx["mcp_connector_id"] == 7
|
||||
assert ctx["mcp_server"] == "Linear"
|
||||
assert ctx["tool_description"] == "Run linear_create_issue via MCP."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_without_tool_index_still_asks_without_tool_fields():
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule("rm")])],
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _compile_graph_with(pm, "rm", {"path": "/tmp/foo"}, "call-rm")
|
||||
config = {"configurable": {"thread_id": "kb-rm"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1
|
||||
ctx = snap.interrupts[0].value["context"]
|
||||
assert "mcp_connector_id" not in ctx
|
||||
assert "mcp_server" not in ctx
|
||||
assert "tool_description" not in ctx
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``.
|
||||
|
||||
The KB unification swap (legacy ``interrupt_on`` map → KB-owned ``Ruleset``
|
||||
threaded through ``build_permission_mw(subagent_rulesets=...)``) must
|
||||
produce *exactly one* interrupt per destructive FS call, in LC HITL
|
||||
shape, even when ``enable_permission`` is False — destructive ops always
|
||||
ask.
|
||||
|
||||
We exercise the production factory and a real ``PermissionMiddleware`` on a
|
||||
real ``StateGraph`` so the test catches regressions in factory gating,
|
||||
ruleset layering, and interrupt emission together.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
def _kb_style_ruleset() -> Ruleset:
|
||||
"""Mirror :data:`knowledge_base.agent.KB_RULESET` without importing it.
|
||||
|
||||
Importing the agent module pulls in deepagents and prompts; this test
|
||||
is about the factory + middleware contract, not KB wiring.
|
||||
"""
|
||||
return Ruleset(
|
||||
origin="knowledge_base",
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _build_graph_with_permission_middleware(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_rulesets: list[Ruleset] | None,
|
||||
checkpointer: InMemorySaver,
|
||||
):
|
||||
"""Compile a one-node graph that emits a tool call for ``rm`` and
|
||||
routes through the production ``PermissionMiddleware``.
|
||||
|
||||
The node returns an ``AIMessage`` with a tool call. The middleware's
|
||||
``after_model`` hook intercepts and (if a rule says ``ask``) raises
|
||||
a ``GraphInterrupt`` carrying the LC HITL payload.
|
||||
"""
|
||||
pm = build_permission_mw(flags=flags, subagent_rulesets=subagent_rulesets)
|
||||
|
||||
def node(_state: _State) -> dict[str, Any]:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "rm",
|
||||
"args": {"path": "/tmp/foo"},
|
||||
"id": "call-rm-1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
return {"messages": [msg]}
|
||||
|
||||
def after_node(state: _State) -> dict[str, Any] | None:
|
||||
if pm is None:
|
||||
return None
|
||||
return pm.after_model(state, None) # type: ignore[arg-type]
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", node)
|
||||
g.add_node("permission", after_node)
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=checkpointer), pm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_flag_off():
|
||||
"""KB ruleset: ``rm`` must ask once even with ``enable_permission=False``.
|
||||
|
||||
This is the keystone of the unification: the legacy ``interrupt_on``
|
||||
map fired regardless of ``enable_permission``, so the migrated rules
|
||||
must too. Otherwise users could opt out of "ask before rm".
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
checkpointer = InMemorySaver()
|
||||
graph, pm = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
subagent_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
assert pm is not None, "subagent rulesets must force the middleware on"
|
||||
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1, (
|
||||
f"REGRESSION: KB ruleset should raise exactly one interrupt; got "
|
||||
f"{[i.value for i in snap.interrupts]!r}"
|
||||
)
|
||||
payload = snap.interrupts[0].value
|
||||
requests = payload.get("action_requests")
|
||||
assert requests == [{"name": "rm", "args": {"path": "/tmp/foo"}}], (
|
||||
f"interrupt must carry the rm call in LC HITL shape; got {payload!r}"
|
||||
)
|
||||
assert payload.get("interrupt_type") == "permission_ask"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_ruleset_resume_with_approve_lets_rm_through():
|
||||
"""Resume with ``approve`` → call kept; the model continues normally."""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
checkpointer = InMemorySaver()
|
||||
graph, _ = _build_graph_with_permission_middleware(
|
||||
flags=flags,
|
||||
subagent_rulesets=[_kb_style_ruleset()],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.next == (), "graph must complete after approve"
|
||||
last_ai = next(
|
||||
(m for m in reversed(final.values["messages"]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
assert last_ai is not None
|
||||
assert [tc["name"] for tc in last_ai.tool_calls] == ["rm"], (
|
||||
"approved rm call must remain on the AIMessage so the tool can run"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_subagent_rulesets_with_permission_off_skips_middleware_entirely():
|
||||
"""No subagent rulesets + permission off → factory returns ``None`` (no engine).
|
||||
|
||||
The legacy gating is preserved when no caller asks for rules: nothing
|
||||
runs, nothing pauses.
|
||||
"""
|
||||
flags = AgentFeatureFlags(enable_permission=False)
|
||||
pm = build_permission_mw(flags=flags, subagent_rulesets=None)
|
||||
assert pm is None
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
"""``approve_always`` decisions for MCP tools are saved via the trusted-tool saver."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
|
||||
class _NoArgs(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
async def _noop(**_kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _ask_rule(tool_name: str) -> Rule:
|
||||
return Rule(permission=tool_name, pattern="*", action="ask")
|
||||
|
||||
|
||||
def _make_mcp_tool(*, name: str, connector_id: int):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Run {name} via MCP.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={
|
||||
"mcp_connector_id": connector_id,
|
||||
"mcp_connector_name": "Linear",
|
||||
"mcp_transport": "http",
|
||||
"hitl": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_native_tool(*, name: str):
|
||||
return StructuredTool(
|
||||
name=name,
|
||||
description=f"Native {name}.",
|
||||
coroutine=_noop,
|
||||
args_schema=_NoArgs,
|
||||
metadata={"hitl": True},
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
def _build_graph(pm, tool_name: str):
|
||||
def emit(_state: _State) -> dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_name,
|
||||
"args": {},
|
||||
"id": "call-1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("emit", emit)
|
||||
g.add_node("permission", pm.aafter_model) # type: ignore[arg-type]
|
||||
g.add_edge(START, "emit")
|
||||
g.add_edge("emit", "permission")
|
||||
g.add_edge("permission", END)
|
||||
return g.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_decision_saves_mcp_tool_via_callback():
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "approve-always-mcp"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve_always"}]}), config
|
||||
)
|
||||
|
||||
assert saved == [(7, "linear_create_issue")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_once_decision_does_not_save():
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "once-mcp"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
|
||||
|
||||
assert saved == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_decision_for_native_tool_skips_save():
|
||||
"""Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust."""
|
||||
saved: list[tuple[int, str]] = []
|
||||
|
||||
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
|
||||
saved.append((connector_id, tool_name))
|
||||
|
||||
tool = _make_native_tool(name="rm")
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=trusted_tool_saver,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "approve-always-native"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve_always"}]}), config
|
||||
)
|
||||
|
||||
assert saved == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_decision_with_no_saver_callback_is_a_noop():
|
||||
"""Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash."""
|
||||
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
|
||||
pm = build_permission_mw(
|
||||
flags=AgentFeatureFlags(enable_permission=False),
|
||||
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
|
||||
tools=[tool],
|
||||
trusted_tool_saver=None,
|
||||
)
|
||||
assert pm is not None
|
||||
|
||||
graph = _build_graph(pm, tool.name)
|
||||
config = {"configurable": {"thread_id": "anon-approve-always"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "approve_always"}]}), config
|
||||
)
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""Regression: ``request_approval`` must emit the unified LC HITL wire shape.
|
||||
|
||||
Before this fix, self-gated approvals fired the SurfSense-specific
|
||||
``{type, action, context}`` shape which the parallel-HITL routing layer
|
||||
(``collect_pending_tool_calls``) does not recognize. In a parallel HITL
|
||||
scenario where one subagent used self-gated approvals (e.g. Gmail send)
|
||||
and another used middleware-gated approvals (e.g. Linear via
|
||||
``HumanInTheLoopMiddleware``), the routing layer would silently skip the
|
||||
self-gated interrupt and crash on resume with ``Decision count mismatch``.
|
||||
|
||||
This test pins the wire contract by running ``request_approval`` inside a
|
||||
real ``StateGraph`` and asserting the paused parent observes the LC HITL
|
||||
shape (``action_requests``, ``review_configs``, ``interrupt_type``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
|
||||
|
||||
class _State(TypedDict, total=False):
|
||||
messages: list
|
||||
final_decision_type: str
|
||||
final_params: dict
|
||||
|
||||
|
||||
def _build_graph_calling_request_approval(checkpointer: InMemorySaver):
|
||||
"""A real graph whose only node delegates to ``request_approval``."""
|
||||
|
||||
def gate_node(_state):
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": "alice@example.com", "subject": "hi"},
|
||||
context={"account": "alice@gmail.com"},
|
||||
)
|
||||
return {
|
||||
"final_decision_type": result.decision_type,
|
||||
"final_params": result.params,
|
||||
}
|
||||
|
||||
g = StateGraph(_State)
|
||||
g.add_node("gate", gate_node)
|
||||
g.add_edge(START, "gate")
|
||||
g.add_edge("gate", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paused_interrupt_uses_lc_hitl_action_requests_shape():
|
||||
"""The paused interrupt must speak the langchain HITL standard shape."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-wire"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
snap = await graph.aget_state(config)
|
||||
assert len(snap.interrupts) == 1, (
|
||||
f"expected one paused interrupt, got {len(snap.interrupts)}"
|
||||
)
|
||||
value = snap.interrupts[0].value
|
||||
assert isinstance(value, dict)
|
||||
|
||||
# Standard LC HITL fields the routing layer reads.
|
||||
assert value.get("action_requests") == [
|
||||
{
|
||||
"name": "send_gmail_email",
|
||||
"args": {"to": "alice@example.com", "subject": "hi"},
|
||||
}
|
||||
], (
|
||||
"REGRESSION: self-gated approval reverted to legacy SurfSense shape; "
|
||||
f"got {value!r}"
|
||||
)
|
||||
assert value.get("review_configs") == [
|
||||
{
|
||||
"action_name": "send_gmail_email",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
]
|
||||
assert value.get("interrupt_type") == "gmail_email_send", (
|
||||
"FE card discriminator must travel as ``interrupt_type``."
|
||||
)
|
||||
assert value.get("context") == {"account": "alice@gmail.com"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_lc_envelope_returns_hitl_result_with_edited_args():
|
||||
"""Edit reply via the LC envelope must round-trip into ``HITLResult.params``."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-resume"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
edited = {"to": "alice@example.com", "subject": "EDITED"}
|
||||
await graph.ainvoke(
|
||||
Command(
|
||||
resume={
|
||||
"decisions": [
|
||||
{"type": "edit", "edited_action": {"args": {"subject": "EDITED"}}}
|
||||
]
|
||||
}
|
||||
),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision_type") == "edit"
|
||||
assert final.values.get("final_params") == edited
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_envelope_returns_rejected_hitl_result():
|
||||
"""Reject reply must surface as ``HITLResult.rejected=True`` without invoking the tool."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = _build_graph_calling_request_approval(checkpointer)
|
||||
config = {"configurable": {"thread_id": "self-gated-reject"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
await graph.ainvoke(
|
||||
Command(resume={"decisions": [{"type": "reject", "feedback": "no"}]}),
|
||||
config,
|
||||
)
|
||||
final = await graph.aget_state(config)
|
||||
assert final.values.get("final_decision_type") == "reject"
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""Unit contract for the unified LC HITL wire format.
|
||||
|
||||
Both the self-gated approval primitive (``request_approval``) and the
|
||||
middleware-gated permission ask (``PermissionMiddleware``) must serialize
|
||||
to the same wire shape so the parallel-HITL routing layer
|
||||
(``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call`` +
|
||||
``build_lg_resume_map``) sees one format.
|
||||
|
||||
These tests pin the shape:
|
||||
|
||||
- Builder always emits ``action_requests`` (1 entry) + ``review_configs``
|
||||
+ ``interrupt_type``; ``context`` rides through verbatim when present.
|
||||
- Parser tolerates the standard LC envelope, bare scalar strings, and
|
||||
unrecognized shapes (failing closed to ``reject``).
|
||||
- Edited args round-trip through both nested (``edited_action.args``) and
|
||||
flat (``args``) shapes without inventing values for the empty case.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.wire import (
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_EDIT,
|
||||
LC_DECISION_REJECT,
|
||||
SURFSENSE_DECISION_APPROVE_ALWAYS,
|
||||
build_lc_hitl_payload,
|
||||
parse_lc_envelope,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildLcHitlPayload:
|
||||
def test_minimal_payload_has_one_action_request_and_one_review_config(self):
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="send_email",
|
||||
args={"to": "x@y.z"},
|
||||
allowed_decisions=[LC_DECISION_APPROVE, LC_DECISION_REJECT],
|
||||
interrupt_type="gmail_email_send",
|
||||
)
|
||||
assert payload["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "x@y.z"}}
|
||||
]
|
||||
assert payload["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
"allowed_decisions": [LC_DECISION_APPROVE, LC_DECISION_REJECT],
|
||||
}
|
||||
]
|
||||
assert payload["interrupt_type"] == "gmail_email_send"
|
||||
assert "context" not in payload, "context must be omitted when not provided"
|
||||
|
||||
def test_none_args_normalized_to_empty_dict(self):
|
||||
"""FE expects a stable shape; ``None`` would crash card rendering."""
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="ping",
|
||||
args=None, # type: ignore[arg-type]
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="self_gated",
|
||||
)
|
||||
assert payload["action_requests"][0]["args"] == {}
|
||||
|
||||
def test_description_attached_only_when_provided(self):
|
||||
with_desc = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="x",
|
||||
description="please review",
|
||||
)
|
||||
without = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=[LC_DECISION_APPROVE],
|
||||
interrupt_type="x",
|
||||
)
|
||||
assert with_desc["action_requests"][0]["description"] == "please review"
|
||||
assert "description" not in without["action_requests"][0]
|
||||
|
||||
def test_context_passed_through_verbatim(self):
|
||||
ctx = {"patterns": ["rm/*"], "rules": [], "always": ["rm/*"]}
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="rm",
|
||||
args={"path": "/tmp"},
|
||||
allowed_decisions=[
|
||||
LC_DECISION_APPROVE,
|
||||
LC_DECISION_REJECT,
|
||||
SURFSENSE_DECISION_APPROVE_ALWAYS,
|
||||
],
|
||||
interrupt_type="permission_ask",
|
||||
context=ctx,
|
||||
)
|
||||
assert payload["context"] == ctx
|
||||
|
||||
def test_allowed_decisions_list_is_copied_not_aliased(self):
|
||||
"""A caller mutating their original list must not corrupt the payload."""
|
||||
decisions = [LC_DECISION_APPROVE]
|
||||
payload = build_lc_hitl_payload(
|
||||
tool_name="t",
|
||||
args={},
|
||||
allowed_decisions=decisions,
|
||||
interrupt_type="x",
|
||||
)
|
||||
decisions.append(LC_DECISION_REJECT)
|
||||
assert payload["review_configs"][0]["allowed_decisions"] == [
|
||||
LC_DECISION_APPROVE
|
||||
]
|
||||
|
||||
|
||||
class TestParseLcEnvelope:
|
||||
def test_standard_lc_envelope_returns_typed_decision(self):
|
||||
parsed = parse_lc_envelope({"decisions": [{"type": "approve"}]})
|
||||
assert parsed.decision_type == "approve"
|
||||
assert parsed.edited_args is None
|
||||
assert parsed.message is None
|
||||
|
||||
def test_bare_scalar_string_passes_through_lowercased(self):
|
||||
assert parse_lc_envelope("APPROVE_ALWAYS").decision_type == "approve_always"
|
||||
assert parse_lc_envelope("once").decision_type == "once"
|
||||
|
||||
def test_non_dict_non_string_collapses_to_reject(self):
|
||||
"""Failing closed: ambiguous input must never proceed."""
|
||||
assert parse_lc_envelope(42).decision_type == "reject"
|
||||
assert parse_lc_envelope(None).decision_type == "reject"
|
||||
assert parse_lc_envelope(["bogus"]).decision_type == "reject"
|
||||
|
||||
def test_missing_decision_type_collapses_to_reject(self):
|
||||
assert parse_lc_envelope({"decisions": [{}]}).decision_type == "reject"
|
||||
assert parse_lc_envelope({"foo": "bar"}).decision_type == "reject"
|
||||
|
||||
def test_edit_extracts_nested_args(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{
|
||||
"decisions": [
|
||||
{
|
||||
"type": LC_DECISION_EDIT,
|
||||
"edited_action": {"args": {"to": "edited@y.z"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
assert parsed.decision_type == "edit"
|
||||
assert parsed.edited_args == {"to": "edited@y.z"}
|
||||
|
||||
def test_edit_falls_back_to_flat_args(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "edit", "args": {"k": "v"}}]}
|
||||
)
|
||||
assert parsed.edited_args == {"k": "v"}
|
||||
|
||||
def test_edit_with_empty_args_yields_none_edited(self):
|
||||
"""Empty edited_args means "no edits" — caller treats as plain approve."""
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "edit", "edited_action": {"args": {}}}]}
|
||||
)
|
||||
assert parsed.edited_args is None
|
||||
|
||||
def test_message_picked_from_either_feedback_or_message_field(self):
|
||||
with_feedback = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "feedback": "no thanks"}]}
|
||||
)
|
||||
with_message = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "message": "no thanks"}]}
|
||||
)
|
||||
assert with_feedback.message == "no thanks"
|
||||
assert with_message.message == "no thanks"
|
||||
|
||||
def test_blank_message_treated_as_absent(self):
|
||||
parsed = parse_lc_envelope(
|
||||
{"decisions": [{"type": "reject", "message": " "}]}
|
||||
)
|
||||
assert parsed.message is None
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
|
||||
"""Subagent resilience contract: ``middleware_stack`` reaches the agent chain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -19,9 +19,14 @@ from langchain_core.language_models.fake_chat_models import (
|
|||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions.middleware.core import (
|
||||
PermissionMiddleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset, evaluate
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
|
|
@ -67,20 +72,23 @@ class _AlwaysFailingChatModel(BaseChatModel):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_recovers_when_primary_llm_fails():
|
||||
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
|
||||
"""Fallback in ``middleware_stack`` must finish the turn when primary raises."""
|
||||
primary = _AlwaysFailingChatModel()
|
||||
fallback = FakeMessagesListChatModel(
|
||||
responses=[AIMessage(content="recovered via fallback")]
|
||||
)
|
||||
|
||||
spec = pack_subagent(
|
||||
result = pack_subagent(
|
||||
name="resilience_test",
|
||||
description="test subagent",
|
||||
system_prompt="be helpful",
|
||||
tools=[],
|
||||
ruleset=Ruleset(origin="resilience_test", rules=[]),
|
||||
dependencies={"flags": AgentFeatureFlags()},
|
||||
model=primary,
|
||||
extra_middleware=[ModelFallbackMiddleware(fallback)],
|
||||
middleware_stack={"fallback": ModelFallbackMiddleware(fallback)},
|
||||
)
|
||||
spec = result.spec
|
||||
|
||||
agent = create_agent(
|
||||
model=spec["model"],
|
||||
|
|
@ -94,3 +102,142 @@ async def test_subagent_recovers_when_primary_llm_fails():
|
|||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered via fallback"
|
||||
|
||||
|
||||
def _extract_permission_mw(spec) -> PermissionMiddleware:
|
||||
"""Find the lone PermissionMiddleware in a subagent's middleware list."""
|
||||
matches = [m for m in spec["middleware"] if isinstance(m, PermissionMiddleware)]
|
||||
assert len(matches) == 1, "expected exactly one PermissionMiddleware"
|
||||
return matches[0]
|
||||
|
||||
|
||||
def test_user_allowlist_overrides_coded_ask_via_last_match_wins():
|
||||
"""User ``allow`` rules promoted via "Always Allow" must beat coded ``ask`` rules."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
user_allowlist = Ruleset(
|
||||
origin="user_allowlist:connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test connector",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"connector": user_allowlist},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "allow", (
|
||||
f"user_allowlist must override coded ask; got {decided!r}"
|
||||
)
|
||||
|
||||
|
||||
def test_coded_ask_stays_when_user_allowlist_unrelated():
|
||||
"""User ``allow`` rules for OTHER tools must not leak into asked-tools."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="delete_issue", pattern="*", action="ask")],
|
||||
)
|
||||
user_allowlist = Ruleset(
|
||||
origin="user_allowlist:connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"connector": user_allowlist},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("delete_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
||||
|
||||
def test_missing_user_allowlist_keeps_coded_behaviour():
|
||||
"""``dependencies`` without ``user_allowlist_by_subagent`` is the common case."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={"flags": AgentFeatureFlags()},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
||||
|
||||
def test_user_allowlist_for_different_subagent_does_not_leak():
|
||||
"""User trust for ``linear`` must not affect a ``jira`` subagent compile."""
|
||||
coded = Ruleset(
|
||||
origin="jira",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
linear_allowlist = Ruleset(
|
||||
origin="user_allowlist:linear",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
|
||||
)
|
||||
|
||||
result = pack_subagent(
|
||||
name="jira",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"linear": linear_allowlist},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
||||
|
||||
def test_empty_user_allowlist_is_tolerated():
|
||||
"""An empty ``Ruleset`` (no rules) must not flip evaluation to allow-everything."""
|
||||
coded = Ruleset(
|
||||
origin="connector",
|
||||
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
|
||||
)
|
||||
empty = Ruleset(origin="user_allowlist:connector", rules=[])
|
||||
|
||||
result = pack_subagent(
|
||||
name="connector",
|
||||
description="test",
|
||||
system_prompt="x",
|
||||
tools=[],
|
||||
ruleset=coded,
|
||||
dependencies={
|
||||
"flags": AgentFeatureFlags(),
|
||||
"user_allowlist_by_subagent": {"connector": empty},
|
||||
},
|
||||
)
|
||||
|
||||
mw = _extract_permission_mw(result.spec)
|
||||
decided = evaluate("save_issue", "*", *mw._static_rulesets)
|
||||
assert decided.action == "ask"
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ class TestAsk:
|
|||
# No new rule persisted
|
||||
assert mw._runtime_ruleset.rules == []
|
||||
|
||||
def test_always_persists_runtime_rule(self) -> None:
|
||||
def test_approve_always_persists_runtime_rule(self) -> None:
|
||||
mw = PermissionMiddleware(rulesets=[])
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment]
|
||||
mw._raise_interrupt = lambda **kw: {"decision_type": "approve_always"} # type: ignore[assignment]
|
||||
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
|
||||
out = mw.after_model(state, _FakeRuntime())
|
||||
assert out is None # call kept
|
||||
|
|
|
|||
|
|
@ -741,6 +741,366 @@ async def test_extract_image_falls_back_to_document_without_vision_llm(
|
|||
assert result.content_type == "document"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document path with vision LLM: per-image descriptions are appended
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fake_extraction_result(*descriptions):
|
||||
from app.etl_pipeline.picture_describer import (
|
||||
PictureDescription,
|
||||
PictureExtractionResult,
|
||||
)
|
||||
|
||||
return PictureExtractionResult(
|
||||
descriptions=[
|
||||
PictureDescription(
|
||||
page_number=d["page"],
|
||||
ordinal_in_page=d.get("ordinal", 0),
|
||||
name=d["name"],
|
||||
sha256=d.get("sha", "deadbeef"),
|
||||
description=d["desc"],
|
||||
)
|
||||
for d in descriptions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def test_extract_pdf_with_vision_llm_inlines_image_blocks(tmp_path, mocker):
|
||||
"""A PDF with an `<!-- image -->` placeholder + caption gets the
|
||||
block spliced inline (no orphaned ``## Image Content`` section).
|
||||
|
||||
This is the headline scenario for the medxpertqa benchmark: the
|
||||
image content lives in the same chunk as the surrounding case text
|
||||
so retrieval pulls the question, image, and answer options together.
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {
|
||||
"content": (
|
||||
"# MedXpertQA-MM MM-130\n\n"
|
||||
"## Clinical case\n\nA 44-year-old man...\n\n"
|
||||
"<!-- image -->\nImage: MM-130-a.jpeg\n\n"
|
||||
"## Answer choices\n\nA) ...\n"
|
||||
)
|
||||
}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
extraction = _fake_extraction_result(
|
||||
{
|
||||
"page": 1,
|
||||
"name": "Im0",
|
||||
"desc": "Axial CT showing a large cystic mass.",
|
||||
}
|
||||
)
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=mocker.AsyncMock(return_value=extraction),
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
result = await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
md = result.markdown_content
|
||||
# The placeholder + caption are gone, replaced by a horizontal-
|
||||
# rule-delimited section with the captioned filename.
|
||||
assert "<!-- image -->" not in md
|
||||
assert "Image: MM-130-a.jpeg" not in md
|
||||
assert "**Embedded image:** `MM-130-a.jpeg`" in md
|
||||
assert "**Visual description:**" in md
|
||||
assert "Axial CT showing a large cystic mass." in md
|
||||
# No OCR section -- our fake_extraction_result has no ocr_text,
|
||||
# and the format omits the section when there's no text to show.
|
||||
assert "**OCR text:**" not in md
|
||||
# No raw HTML / XML tags or blockquote wrapping leak.
|
||||
assert "<image" not in md
|
||||
assert "> **Embedded image:**" not in md
|
||||
# No appended section -- everything went inline.
|
||||
assert "## Image Content" not in md
|
||||
# Surrounding case text + answer options are preserved.
|
||||
assert "A 44-year-old man..." in md
|
||||
assert "## Answer choices" in md
|
||||
assert "A) ..." in md
|
||||
|
||||
|
||||
async def test_extract_pdf_with_vision_llm_appends_when_no_marker(tmp_path, mocker):
|
||||
"""When parser markdown has no image markers, descriptions get appended.
|
||||
|
||||
This is the fallback path for parsers that drop image placeholders
|
||||
entirely. The image content still ends up in the markdown -- just
|
||||
in a clearly-labeled section rather than inline.
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {
|
||||
"content": "# Parsed PDF text\n\nNo image markers anywhere.\n"
|
||||
}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
extraction = _fake_extraction_result(
|
||||
{"page": 1, "name": "Im0", "desc": "An image description."}
|
||||
)
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=mocker.AsyncMock(return_value=extraction),
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
result = await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
md = result.markdown_content
|
||||
assert "# Parsed PDF text" in md
|
||||
assert "## Image Content (vision-LLM extracted)" in md
|
||||
assert "**Embedded image:** `Im0`" in md
|
||||
assert "An image description." in md
|
||||
|
||||
|
||||
async def test_extract_pdf_without_vision_llm_skips_picture_descriptions(
|
||||
tmp_path, mocker
|
||||
):
|
||||
"""No vision LLM -> parser markdown returned as-is."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
describe_mock = mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=mocker.AsyncMock(),
|
||||
)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Parsed PDF text"
|
||||
assert "<image" not in result.markdown_content
|
||||
describe_mock.assert_not_called()
|
||||
|
||||
|
||||
async def test_extract_pdf_with_vision_llm_swallows_describe_failure(tmp_path, mocker):
|
||||
"""A pypdf or vision LLM blow-up never fails the document upload."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=mocker.AsyncMock(side_effect=RuntimeError("pypdf exploded")),
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
result = await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Parsed PDF text"
|
||||
assert result.etl_service == "DOCLING"
|
||||
|
||||
|
||||
async def test_extract_pdf_with_vision_llm_no_images_returns_parser_text(
|
||||
tmp_path, mocker
|
||||
):
|
||||
"""Vision-LLM-enabled PDF with zero extracted images is unchanged."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "# Just text, no images"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
empty = _fake_extraction_result()
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=mocker.AsyncMock(return_value=empty),
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
result = await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Just text, no images"
|
||||
assert "<image" not in result.markdown_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-image OCR runner: wiring + behaviour
|
||||
#
|
||||
# When extracting a PDF with a vision LLM, the ETL service must ALSO
|
||||
# pass an ``ocr_runner`` to picture_describer. The runner is a closure
|
||||
# that re-feeds each extracted image through a vision-LLM-less
|
||||
# EtlPipelineService -- i.e. the same OCR engine that handles
|
||||
# standalone image uploads (Docling/Azure DI/LlamaCloud) gets a crack
|
||||
# at each embedded image, with the text attached to the inline block.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extract_pdf_passes_ocr_runner_to_describe_pictures(tmp_path, mocker):
|
||||
"""The ETL service must wire an ocr_runner kwarg to describe_pictures."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
describe_mock = mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=mocker.AsyncMock(return_value=_fake_extraction_result()),
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
describe_mock.assert_awaited_once()
|
||||
_, kwargs = describe_mock.await_args
|
||||
assert "ocr_runner" in kwargs
|
||||
assert callable(kwargs["ocr_runner"])
|
||||
|
||||
|
||||
async def test_extract_pdf_ocr_runner_invokes_document_parser_on_image(
|
||||
tmp_path, mocker
|
||||
):
|
||||
"""The OCR runner closure should re-extract each image via the parser.
|
||||
|
||||
We capture the runner that the ETL service passes to
|
||||
describe_pictures, invoke it with a fake image path, and assert
|
||||
that Docling was called with that image. This proves the closure
|
||||
is wired to a vision-LLM-less sub-pipeline (otherwise it would
|
||||
recurse into the vision LLM and never hit the OCR engine).
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
image_file = tmp_path / "Im0.png"
|
||||
image_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "Slice 24 / 60 L R"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def capture_runner(*args, **kwargs):
|
||||
captured["runner"] = kwargs["ocr_runner"]
|
||||
return _fake_extraction_result()
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=capture_runner,
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
runner = captured["runner"]
|
||||
ocr_text = await runner(str(image_file), "Im0.png")
|
||||
|
||||
assert ocr_text == "Slice 24 / 60 L R"
|
||||
# Docling was invoked twice in total: once for the PDF, once for
|
||||
# the image we re-fed via the runner.
|
||||
assert fake_docling.process_document.await_count == 2
|
||||
|
||||
|
||||
async def test_extract_pdf_ocr_runner_returns_empty_on_unsupported_image(
|
||||
tmp_path, mocker
|
||||
):
|
||||
"""Unsupported image format → runner returns empty string, doesn't raise.
|
||||
|
||||
Common case: a PDF embeds a JPEG2000 or CCITT-TIFF image that
|
||||
Docling can't load. We don't want an unsupported format on ONE
|
||||
embedded image to spoil the whole PDF extraction; the runner
|
||||
should swallow the EtlUnsupportedFileError and return "" so the
|
||||
image gets a description but no OCR tag.
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
weird_image = tmp_path / "Im0.jp2" # JPEG2000, unlikely to be supported
|
||||
weird_image.write_bytes(b"\x00\x00\x00\x0cjP" + b"\x00" * 50)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def capture_runner(*args, **kwargs):
|
||||
captured["runner"] = kwargs["ocr_runner"]
|
||||
return _fake_extraction_result()
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.picture_describer.describe_pictures",
|
||||
new=capture_runner,
|
||||
)
|
||||
|
||||
fake_llm = mocker.MagicMock()
|
||||
await EtlPipelineService(vision_llm=fake_llm).extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
runner = captured["runner"]
|
||||
ocr_text = await runner(str(weird_image), "Im0.jp2")
|
||||
|
||||
assert ocr_text == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Processing Mode enum tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,972 @@
|
|||
"""Unit tests for the picture_describer module.
|
||||
|
||||
Covers:
|
||||
|
||||
- :func:`describe_pictures` -- the PDF image walker + per-image vision
|
||||
LLM call (structured output split into ``ocr_text`` and
|
||||
``description``);
|
||||
- :func:`inject_descriptions_inline` -- in-place replacement of image
|
||||
placeholders / captions in the parser markdown;
|
||||
- :func:`merge_descriptions_into_markdown` -- the top-level helper
|
||||
that inlines what it can and appends what it can't;
|
||||
- :func:`render_appended_section` -- the appended-fallback renderer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.etl_pipeline.picture_describer import (
|
||||
PictureDescription,
|
||||
PictureExtractionResult,
|
||||
describe_pictures,
|
||||
inject_descriptions_inline,
|
||||
merge_descriptions_into_markdown,
|
||||
render_appended_section,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_image_obj(name: str, data: bytes):
|
||||
"""Mimic pypdf's ImageFile object shape for the bits we use."""
|
||||
img = MagicMock()
|
||||
img.name = name
|
||||
img.data = data
|
||||
return img
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# describe_pictures: short-circuits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_describe_pictures_no_op_for_non_pdf(tmp_path):
|
||||
"""Non-PDF files are silently no-op'd; we don't try to extract images."""
|
||||
docx_file = tmp_path / "report.docx"
|
||||
docx_file.write_bytes(b"PK fake docx")
|
||||
|
||||
fake_llm = AsyncMock()
|
||||
result = await describe_pictures(str(docx_file), "report.docx", fake_llm)
|
||||
|
||||
assert result.descriptions == []
|
||||
assert result.skipped_too_large == 0
|
||||
fake_llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
async def test_describe_pictures_no_op_when_vision_llm_is_none(tmp_path):
|
||||
"""If the caller didn't provide a vision LLM, we no-op even for PDFs."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", None)
|
||||
assert result.descriptions == []
|
||||
|
||||
|
||||
async def test_describe_pictures_no_op_for_pdf_with_no_images(tmp_path, mocker):
|
||||
"""A PDF that pypdf can open but contains zero images returns empty."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[]), MagicMock(images=[])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
fake_llm = AsyncMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert result.descriptions == []
|
||||
fake_llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# describe_pictures: happy paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_describe_pictures_runs_vision_llm_per_image(tmp_path, mocker):
|
||||
"""Every eligible image gets exactly one description-only vision call."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img_a = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
img_b = _make_image_obj("Im1.png", b"\x89PNG\r\n\x1a\n" + b"\xcd" * 2000)
|
||||
page1 = MagicMock(images=[img_a])
|
||||
page2 = MagicMock(images=[img_b])
|
||||
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [page1, page2]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
parse_mock = mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(side_effect=["Description A", "Description B"]),
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert len(result.descriptions) == 2
|
||||
by_name = {d.name: d.description for d in result.descriptions}
|
||||
assert by_name == {"Im0.jpeg": "Description A", "Im1.png": "Description B"}
|
||||
assert all(d.page_number in (1, 2) for d in result.descriptions)
|
||||
assert parse_mock.await_count == 2
|
||||
|
||||
|
||||
async def test_describe_pictures_dedups_by_hash(tmp_path, mocker):
|
||||
"""An image that appears N times in the PDF is described once."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
payload = b"\x89PNG\r\n\x1a\n" + b"\x42" * 2000
|
||||
img = _make_image_obj("logo.png", payload)
|
||||
page1 = MagicMock(images=[img])
|
||||
page2 = MagicMock(images=[_make_image_obj("logo.png", payload)])
|
||||
page3 = MagicMock(images=[_make_image_obj("logo.png", payload)])
|
||||
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [page1, page2, page3]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
parse_mock = mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(return_value="Logo desc"),
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.skipped_duplicate == 2
|
||||
assert parse_mock.await_count == 1
|
||||
|
||||
|
||||
async def test_describe_pictures_skips_too_small_images(tmp_path, mocker):
|
||||
"""Sub-1KB images (tracking pixels, dots, etc.) are skipped."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
tiny = _make_image_obj("dot.png", b"\x89PNG\r\n\x1a\n")
|
||||
big = _make_image_obj("ct.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 3000)
|
||||
page = MagicMock(images=[tiny, big])
|
||||
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [page]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
parse_mock = mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(return_value="CT scan"),
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].name == "ct.jpeg"
|
||||
assert result.skipped_too_small == 1
|
||||
assert parse_mock.await_count == 1
|
||||
|
||||
|
||||
async def test_describe_pictures_skips_too_large_images(tmp_path, mocker):
|
||||
"""Images larger than the vision LLM's per-image cap are skipped."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
huge = _make_image_obj("huge.jpeg", b"\xff" * (6 * 1024 * 1024))
|
||||
ok = _make_image_obj("ok.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
page = MagicMock(images=[huge, ok])
|
||||
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [page]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
parse_mock = mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(return_value="OK image"),
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].name == "ok.jpeg"
|
||||
assert result.skipped_too_large == 1
|
||||
assert parse_mock.await_count == 1
|
||||
|
||||
|
||||
async def test_describe_pictures_swallows_per_image_failure(tmp_path, mocker):
|
||||
"""A vision LLM failure on one image must not kill the whole document."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img_a = _make_image_obj("a.jpeg", b"\xff\xd8" + b"\xab" * 2000)
|
||||
img_b = _make_image_obj("b.jpeg", b"\xff\xd8" + b"\xcd" * 2000)
|
||||
page = MagicMock(images=[img_a, img_b])
|
||||
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [page]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(side_effect=[RuntimeError("vision blew up"), "Success"]),
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].description == "Success"
|
||||
assert result.failed == 1
|
||||
|
||||
|
||||
async def test_describe_pictures_handles_pypdf_open_failure(tmp_path, mocker):
|
||||
"""A malformed PDF that pypdf can't open returns an empty result."""
|
||||
pdf_file = tmp_path / "broken.pdf"
|
||||
pdf_file.write_bytes(b"not a pdf")
|
||||
|
||||
mocker.patch("pypdf.PdfReader", side_effect=ValueError("EOF marker not found"))
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "broken.pdf", fake_llm)
|
||||
assert result.descriptions == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# inject_descriptions_inline: replacement patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _desc(name="Im0", description="A CT scan."):
|
||||
return PictureDescription(
|
||||
page_number=1,
|
||||
ordinal_in_page=0,
|
||||
name=name,
|
||||
sha256="aa",
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
def test_inject_no_op_when_no_descriptions():
|
||||
markdown = "# Title\n\nbody text\n"
|
||||
result = PictureExtractionResult()
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
assert out == markdown
|
||||
assert n == 0
|
||||
|
||||
|
||||
def test_inject_replaces_placeholder_with_caption():
|
||||
"""`<!-- image -->` + `Image: <name>` together becomes one block.
|
||||
|
||||
This is the most common medxpertqa case: our renderer puts a caption
|
||||
line right below the embedded JPEG, and Docling preserves both.
|
||||
"""
|
||||
markdown = (
|
||||
"# Case\n\n"
|
||||
"Clinical text...\n\n"
|
||||
"<!-- image -->\nImage: MM-130-a.jpeg\n\n"
|
||||
"Answer choices: A) ...\n"
|
||||
)
|
||||
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
assert "<!-- image -->" not in out
|
||||
assert "Image: MM-130-a.jpeg" not in out # caption consumed
|
||||
# New format: horizontal-rule-delimited section with "Embedded
|
||||
# image:" anchor and named "Visual description:" section. No
|
||||
# blockquote wrapping -- nested blocks (lists, code, tables) inside
|
||||
# a blockquote are silently dropped by Streamdown / remark.
|
||||
assert "**Embedded image:** `MM-130-a.jpeg`" in out
|
||||
assert "**Visual description:**" in out
|
||||
assert "A CT scan." in out
|
||||
# Block is delimited by horizontal rules so it stands out from
|
||||
# surrounding paragraphs.
|
||||
assert "\n---\n" in out
|
||||
# No OCR section -- this fixture has no ocr_text on its descriptions.
|
||||
assert "**OCR text:**" not in out
|
||||
# No raw HTML tags / blockquote prefixes leak.
|
||||
assert "<image" not in out
|
||||
assert "</image>" not in out
|
||||
assert "> **Embedded image:**" not in out # we no longer wrap in `>`
|
||||
# Surrounding context is preserved.
|
||||
assert "Clinical text..." in out
|
||||
assert "Answer choices: A) ..." in out
|
||||
|
||||
|
||||
def test_inject_uses_pypdf_name_when_no_caption():
|
||||
"""`<!-- image -->` alone uses the pypdf-given name as the attribute."""
|
||||
markdown = "# Case\n\n<!-- image -->\n\nMore text\n"
|
||||
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
assert "**Embedded image:** `Im0`" in out
|
||||
|
||||
|
||||
def test_inject_replaces_bare_caption():
|
||||
"""A bare `Image: <name>` line (no placeholder) still gets replaced."""
|
||||
markdown = "# Case\n\nText...\nImage: scan.jpeg\nMore text\n"
|
||||
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
assert "**Embedded image:** `scan.jpeg`" in out
|
||||
assert "Image: scan.jpeg" not in out
|
||||
|
||||
|
||||
def test_inject_handles_multiple_images_in_order():
|
||||
"""Two placeholders + two descriptions: each consumed in document order."""
|
||||
markdown = (
|
||||
"Page 1\n\n<!-- image -->\nImage: a.jpeg\n\n"
|
||||
"Between\n\n<!-- image -->\nImage: b.jpeg\n\nEnd\n"
|
||||
)
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[
|
||||
PictureDescription(
|
||||
page_number=1,
|
||||
ordinal_in_page=0,
|
||||
name="Im0",
|
||||
sha256="aa",
|
||||
description="Desc A",
|
||||
),
|
||||
PictureDescription(
|
||||
page_number=2,
|
||||
ordinal_in_page=0,
|
||||
name="Im1",
|
||||
sha256="bb",
|
||||
description="Desc B",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 2
|
||||
assert "**Embedded image:** `a.jpeg`" in out
|
||||
assert "**Embedded image:** `b.jpeg`" in out
|
||||
assert out.index("a.jpeg") < out.index("b.jpeg")
|
||||
assert "Desc A" in out and "Desc B" in out
|
||||
|
||||
|
||||
def test_inject_returns_remaining_count_when_more_descriptions_than_markers():
|
||||
"""Three descriptions, one marker -> only one inlined, two leftover."""
|
||||
markdown = "Just one <!-- image --> here.\n"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[
|
||||
_desc(name="Im0", description="First"),
|
||||
_desc(name="Im1", description="Second"),
|
||||
_desc(name="Im2", description="Third"),
|
||||
]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
assert "**Embedded image:** `Im0`" in out
|
||||
assert "**Embedded image:** `Im1`" not in out
|
||||
|
||||
|
||||
def test_inject_returns_zero_when_no_markers_present():
|
||||
"""Markdown with no image markers at all returns the input unchanged."""
|
||||
markdown = "# Title\n\nJust text. No images mentioned at all.\n"
|
||||
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 0
|
||||
assert out == markdown
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_appended_section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_render_appended_empty_when_nothing_passed():
|
||||
assert render_appended_section([]) == ""
|
||||
|
||||
|
||||
def test_render_appended_renders_each_image_as_block():
|
||||
descriptions = [
|
||||
_desc(name="MM-130-a.jpeg", description="CT scan"),
|
||||
_desc(name="MM-130-b.jpeg", description="Bar chart"),
|
||||
]
|
||||
rendered = render_appended_section(descriptions)
|
||||
assert "## Image Content (vision-LLM extracted)" in rendered
|
||||
assert "**Embedded image:** `MM-130-a.jpeg`" in rendered
|
||||
assert "CT scan" in rendered
|
||||
assert "**Embedded image:** `MM-130-b.jpeg`" in rendered
|
||||
assert "Bar chart" in rendered
|
||||
# Each image block is delimited by horizontal rules.
|
||||
assert rendered.count("\n---\n") >= 2
|
||||
# No raw HTML / XML / blockquote prefixes.
|
||||
assert "<image" not in rendered
|
||||
assert "> **Embedded image:**" not in rendered
|
||||
assert "**OCR text:**" not in rendered
|
||||
|
||||
|
||||
def test_render_appended_includes_skip_notes():
|
||||
descriptions = [_desc()]
|
||||
skip_result = PictureExtractionResult(
|
||||
descriptions=descriptions,
|
||||
skipped_too_small=2,
|
||||
skipped_too_large=1,
|
||||
skipped_duplicate=3,
|
||||
failed=1,
|
||||
)
|
||||
rendered = render_appended_section(descriptions, skip_notes=skip_result)
|
||||
assert "_Note:" in rendered
|
||||
assert "2 too small" in rendered
|
||||
assert "1 too large" in rendered
|
||||
assert "3 duplicate" in rendered
|
||||
assert "1 failed" in rendered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# merge_descriptions_into_markdown: top-level
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_merge_inlines_when_marker_present():
|
||||
markdown = "Text...\n\n<!-- image -->\nImage: scan.jpeg\n\nMore text\n"
|
||||
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
|
||||
|
||||
out = merge_descriptions_into_markdown(markdown, result)
|
||||
|
||||
assert "**Embedded image:** `scan.jpeg`" in out
|
||||
# Nothing leaked into an appended section -- we should NOT see the
|
||||
# appended-section heading because everything went inline.
|
||||
assert "## Image Content" not in out
|
||||
|
||||
|
||||
def test_merge_appends_when_no_marker_present():
|
||||
"""Zero markers means everything goes into an appended section."""
|
||||
markdown = "Pure text doc, no image markers.\n"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc(name="Im0", description="An image desc.")]
|
||||
)
|
||||
|
||||
out = merge_descriptions_into_markdown(markdown, result)
|
||||
|
||||
assert "Pure text doc" in out
|
||||
assert "## Image Content (vision-LLM extracted)" in out
|
||||
assert "**Embedded image:** `Im0`" in out
|
||||
|
||||
|
||||
def test_merge_appends_leftovers_with_distinct_heading():
|
||||
"""One marker, two descriptions -> one inline, second appended under
|
||||
a heading that signals it's a leftover.
|
||||
"""
|
||||
markdown = "Text\n\n<!-- image -->\nImage: a.jpeg\n\nEnd\n"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[
|
||||
_desc(name="Im0", description="First"),
|
||||
_desc(name="Im1", description="Second"),
|
||||
]
|
||||
)
|
||||
|
||||
out = merge_descriptions_into_markdown(markdown, result)
|
||||
|
||||
assert "**Embedded image:** `a.jpeg`" in out # inlined
|
||||
assert "## Image Content (additional, no inline marker found)" in out
|
||||
assert "**Embedded image:** `Im1`" in out # appended
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# describe_pictures: ocr_runner integration
|
||||
#
|
||||
# These tests cover the per-image OCR side-channel: when the caller
|
||||
# supplies an ``ocr_runner`` callable, each extracted image is sent
|
||||
# both to the vision LLM (visual description) and to the OCR runner
|
||||
# (text-in-image), in parallel. The OCR text -- if any -- is recorded
|
||||
# on the PictureDescription and rendered in the inline block.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_describe_pictures_calls_ocr_runner_per_image(tmp_path, mocker):
|
||||
"""When an ocr_runner is provided, it's invoked once per eligible image."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img_a = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
img_b = _make_image_obj("Im1.png", b"\x89PNG\r\n\x1a\n" + b"\xcd" * 2000)
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[img_a, img_b])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(side_effect=["Visual A", "Visual B"]),
|
||||
)
|
||||
ocr_runner = AsyncMock(side_effect=["OCR text A", "OCR text B"])
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(
|
||||
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
|
||||
)
|
||||
|
||||
assert ocr_runner.await_count == 2
|
||||
by_name = {d.name: d.ocr_text for d in result.descriptions}
|
||||
assert by_name == {"Im0.jpeg": "OCR text A", "Im1.png": "OCR text B"}
|
||||
|
||||
|
||||
async def test_describe_pictures_runs_vision_and_ocr_in_parallel(tmp_path, mocker):
|
||||
"""Vision LLM and OCR run concurrently per image, not sequentially.
|
||||
|
||||
We verify this by recording call timestamps: if both finish within
|
||||
a small window relative to the per-call sleep, they ran in parallel.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[img])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
sleep_each = 0.05 # 50ms per call
|
||||
|
||||
async def slow_vision(*args, **kwargs):
|
||||
await asyncio.sleep(sleep_each)
|
||||
return "Visual"
|
||||
|
||||
async def slow_ocr(*args, **kwargs):
|
||||
await asyncio.sleep(sleep_each)
|
||||
return "OCR"
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=slow_vision,
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
started = time.perf_counter()
|
||||
result = await describe_pictures(
|
||||
str(pdf_file), "report.pdf", fake_llm, ocr_runner=slow_ocr
|
||||
)
|
||||
elapsed = time.perf_counter() - started
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].ocr_text == "OCR"
|
||||
# Sequential would be ~2*sleep_each. Parallel is ~1*sleep_each + overhead.
|
||||
# Be generous with the bound so we're not flaky on slow CI.
|
||||
assert elapsed < 1.5 * sleep_each, (
|
||||
f"vision+OCR appear to be sequential (took {elapsed:.3f}s)"
|
||||
)
|
||||
|
||||
|
||||
async def test_describe_pictures_treats_empty_ocr_as_none(tmp_path, mocker):
|
||||
"""Empty / whitespace-only OCR result is normalised to None.
|
||||
|
||||
This means the rendered image block won't carry an empty
|
||||
"OCR text" section for images that contain no text at all
|
||||
(e.g. a clean radiograph).
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img = _make_image_obj("scan.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[img])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(return_value="A radiograph."),
|
||||
)
|
||||
ocr_runner = AsyncMock(return_value=" \n \n")
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(
|
||||
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
|
||||
)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].ocr_text is None
|
||||
|
||||
|
||||
async def test_describe_pictures_swallows_ocr_runner_failure(tmp_path, mocker):
|
||||
"""An OCR runner exception must not kill the description for that image.
|
||||
|
||||
OCR is supplementary; the vision LLM's description is the primary
|
||||
payload. If OCR blows up we drop the OCR field for that image and
|
||||
keep the description.
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img = _make_image_obj("scan.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[img])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(return_value="A radiograph."),
|
||||
)
|
||||
ocr_runner = AsyncMock(side_effect=RuntimeError("OCR backend down"))
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(
|
||||
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
|
||||
)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].description == "A radiograph."
|
||||
assert result.descriptions[0].ocr_text is None
|
||||
assert result.failed == 0 # the IMAGE didn't fail; only its OCR did
|
||||
|
||||
|
||||
async def test_describe_pictures_vision_failure_with_ocr_runner_skips_image(
|
||||
tmp_path, mocker
|
||||
):
|
||||
"""If the vision LLM fails, the image is skipped even if OCR succeeded.
|
||||
|
||||
The inline block's primary purpose is the visual description; an
|
||||
OCR-only block would be misleading (it'd look like the vision
|
||||
pipeline ran when it didn't), so we treat vision failure as image
|
||||
failure regardless of OCR outcome.
|
||||
"""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img = _make_image_obj("scan.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[img])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(side_effect=RuntimeError("vision blew up")),
|
||||
)
|
||||
ocr_runner = AsyncMock(return_value="OCR text")
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(
|
||||
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
|
||||
)
|
||||
|
||||
assert result.descriptions == []
|
||||
assert result.failed == 1
|
||||
|
||||
|
||||
async def test_describe_pictures_no_ocr_runner_keeps_ocr_text_none(tmp_path, mocker):
|
||||
"""Backward compat: omitting ocr_runner produces description-only blocks."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
img = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
|
||||
fake_reader = MagicMock()
|
||||
fake_reader.pages = [MagicMock(images=[img])]
|
||||
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
|
||||
|
||||
mocker.patch(
|
||||
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
|
||||
new=AsyncMock(return_value="Visual"),
|
||||
)
|
||||
|
||||
fake_llm = MagicMock()
|
||||
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
|
||||
|
||||
assert len(result.descriptions) == 1
|
||||
assert result.descriptions[0].ocr_text is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rendering: "OCR text" section appears iff PictureDescription.ocr_text is set
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _desc_with_ocr(name="Im0", description="A CT scan.", ocr_text="L R 10mm"):
|
||||
return PictureDescription(
|
||||
page_number=1,
|
||||
ordinal_in_page=0,
|
||||
name=name,
|
||||
sha256="aa",
|
||||
description=description,
|
||||
ocr_text=ocr_text,
|
||||
)
|
||||
|
||||
|
||||
def test_inject_renders_ocr_section_when_ocr_text_present():
|
||||
markdown = "Text\n\n<!-- image -->\nImage: scan.jpeg\n\nMore\n"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc_with_ocr(name="Im0", ocr_text="L R 10mm")]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
assert "**Embedded image:** `scan.jpeg`" in out
|
||||
assert "**OCR text:**" in out
|
||||
assert "L R 10mm" in out
|
||||
# OCR section comes before the visual description (literal text
|
||||
# first, interpretation second).
|
||||
assert out.index("**OCR text:**") < out.index("**Visual description:**")
|
||||
# Critical: no nested-block constructs (fenced code, blockquote)
|
||||
# that previous formats relied on -- both broke in Streamdown /
|
||||
# PlateJS by escaping their container and dropping content.
|
||||
assert "```" not in out
|
||||
assert "> **" not in out
|
||||
|
||||
|
||||
def test_inject_renders_multiline_ocr_with_hard_breaks():
|
||||
"""Multi-line OCR uses trailing-two-spaces hard breaks so each
|
||||
line renders on its own row, without needing a fragile fenced
|
||||
code block or blockquote wrapper."""
|
||||
markdown = "Text\n\n<!-- image -->\nImage: scan.jpeg\n\nMore\n"
|
||||
ocr_multi = "Slice 24 / 60\nL\nR\n10 mm"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc_with_ocr(name="Im0", ocr_text=ocr_multi)]
|
||||
)
|
||||
|
||||
out, _ = inject_descriptions_inline(markdown, result)
|
||||
|
||||
# Every OCR line is present.
|
||||
for line in ("Slice 24 / 60", "L", "R", "10 mm"):
|
||||
assert line in out
|
||||
# Non-last OCR lines get the trailing two-space hard break.
|
||||
assert "Slice 24 / 60 \n" in out
|
||||
assert "\nL \n" in out
|
||||
assert "\nR \n" in out
|
||||
# Last OCR line must NOT carry the two-space hard break (no stray <br>).
|
||||
assert "10 mm \n" not in out
|
||||
assert "10 mm\n" in out
|
||||
|
||||
|
||||
def test_render_appended_renders_ocr_section_when_ocr_text_present():
|
||||
descriptions = [
|
||||
_desc_with_ocr(
|
||||
name="MM-130-a.jpeg",
|
||||
description="Axial CT.",
|
||||
ocr_text="Slice 24 / 60",
|
||||
),
|
||||
]
|
||||
rendered = render_appended_section(descriptions)
|
||||
|
||||
assert "**OCR text:**" in rendered
|
||||
assert "Slice 24 / 60" in rendered
|
||||
assert "Axial CT." in rendered
|
||||
|
||||
|
||||
def test_render_omits_ocr_section_when_ocr_text_is_none():
|
||||
descriptions = [_desc(name="Im0", description="A clean radiograph.")]
|
||||
rendered = render_appended_section(descriptions)
|
||||
|
||||
assert "**Embedded image:** `Im0`" in rendered
|
||||
assert "**OCR text:**" not in rendered
|
||||
assert "**Visual description:**" in rendered
|
||||
# No raw HTML / blockquote prefixes.
|
||||
assert "<image" not in rendered
|
||||
assert "> **" not in rendered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# inject_descriptions_inline: <figure> blocks (layout-aware parsers)
|
||||
#
|
||||
# Azure Document Intelligence's ``prebuilt-layout`` and LlamaCloud
|
||||
# premium both emit ``<figure>...</figure>`` blocks that already contain
|
||||
# the parser's own OCR of the figure (chart bar values, axis labels,
|
||||
# inline ``<figcaption>``, embedded ``<table>`` for tabular figures).
|
||||
# That parser-side content is useful for retrieval on its own, so we
|
||||
# PRESERVE the figure verbatim and append our vision-LLM block
|
||||
# immediately after rather than substituting for it.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_inject_appends_block_after_figure_preserving_parser_content():
|
||||
"""Figure block stays intact; vision-LLM block goes right after it."""
|
||||
markdown = (
|
||||
"Some narrative text.\n\n"
|
||||
"<figure>\n\n"
|
||||
"Republican\n68\nDemocrat\n30\n"
|
||||
"\n</figure>\n\n"
|
||||
"Following paragraph.\n"
|
||||
)
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc(name="Im0", description="Bar chart of party ID.")]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
# Original figure is preserved verbatim -- the parser's OCR'd
|
||||
# numbers must still be searchable.
|
||||
assert "<figure>" in out
|
||||
assert "</figure>" in out
|
||||
assert "Republican" in out and "68" in out
|
||||
# Our vision-LLM block follows the figure, not before / inside it.
|
||||
assert "**Embedded image:** `Im0`" in out
|
||||
assert "Bar chart of party ID." in out
|
||||
figure_close = out.index("</figure>")
|
||||
embedded_at = out.index("**Embedded image:** `Im0`")
|
||||
assert figure_close < embedded_at, "block must be appended AFTER </figure>"
|
||||
# Surrounding narrative is preserved.
|
||||
assert "Some narrative text." in out
|
||||
assert "Following paragraph." in out
|
||||
|
||||
|
||||
def test_inject_handles_multiple_figures_in_document_order():
|
||||
"""N figures + N descriptions: each pair lands in the right place."""
|
||||
markdown = (
|
||||
"Page 1\n\n<figure>\nChart A bars\n</figure>\n\n"
|
||||
"Between\n\n<figure>\nChart B bars\n</figure>\n\n"
|
||||
"End.\n"
|
||||
)
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[
|
||||
PictureDescription(
|
||||
page_number=1,
|
||||
ordinal_in_page=0,
|
||||
name="Im0",
|
||||
sha256="aa",
|
||||
description="Description of chart A.",
|
||||
),
|
||||
PictureDescription(
|
||||
page_number=2,
|
||||
ordinal_in_page=0,
|
||||
name="Im1",
|
||||
sha256="bb",
|
||||
description="Description of chart B.",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 2
|
||||
# Both figures preserved; both descriptions inlined; order matches.
|
||||
assert out.count("<figure>") == 2
|
||||
assert out.count("</figure>") == 2
|
||||
assert "Description of chart A." in out
|
||||
assert "Description of chart B." in out
|
||||
assert out.index("Description of chart A.") < out.index("Description of chart B.")
|
||||
# Each description appears AFTER its corresponding </figure>.
|
||||
first_close = out.index("</figure>")
|
||||
assert first_close < out.index("Description of chart A.")
|
||||
second_close = out.index("</figure>", first_close + 1)
|
||||
assert second_close < out.index("Description of chart B.")
|
||||
|
||||
|
||||
def test_inject_figures_with_attributes_and_nested_tags():
|
||||
"""``<figure>`` with attributes and nested tags is matched and preserved."""
|
||||
markdown = (
|
||||
'<figure id="fig-3" class="chart">\n'
|
||||
"<figcaption>Source: Pew Research</figcaption>\n"
|
||||
"<table><tr><td>Republican</td><td>57</td></tr></table>\n"
|
||||
"</figure>\n"
|
||||
)
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc(name="Im0", description="Survey table.")]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
# All nested HTML is preserved (chunking will pick it up).
|
||||
assert 'id="fig-3"' in out
|
||||
assert "<figcaption>Source: Pew Research</figcaption>" in out
|
||||
assert "<table>" in out and "Republican" in out and "57" in out
|
||||
# Our block sits after the closing tag.
|
||||
assert out.index("</figure>") < out.index("**Embedded image:** `Im0`")
|
||||
|
||||
|
||||
def test_inject_figures_more_descriptions_than_figures_returns_remaining():
|
||||
"""Three descriptions, one figure -> one inlined, two left for caller."""
|
||||
markdown = "Text.\n<figure>\nbar values\n</figure>\nMore.\n"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[
|
||||
_desc(name="Im0", description="First desc."),
|
||||
_desc(name="Im1", description="Second desc."),
|
||||
_desc(name="Im2", description="Third desc."),
|
||||
]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
assert "First desc." in out
|
||||
# Leftovers are the caller's job; inject_descriptions_inline does
|
||||
# not append them on its own.
|
||||
assert "Second desc." not in out
|
||||
assert "Third desc." not in out
|
||||
|
||||
|
||||
def test_inject_figures_more_figures_than_descriptions_leaves_extras_untouched():
|
||||
"""Two figures, one description -> first figure enriched, second left raw."""
|
||||
markdown = (
|
||||
"<figure>\nfigure 1 content\n</figure>\n<figure>\nfigure 2 content\n</figure>\n"
|
||||
)
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc(name="Im0", description="Only description.")]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 1
|
||||
# Both figures still present; only the first one was enriched.
|
||||
assert out.count("<figure>") == 2
|
||||
assert "Only description." in out
|
||||
# Second figure has no embedded-image block immediately after it.
|
||||
second_open = out.index("<figure>", out.index("<figure>") + 1)
|
||||
second_close = out.index("</figure>", second_open)
|
||||
after_second = out[second_close:]
|
||||
assert "**Embedded image:**" not in after_second
|
||||
|
||||
|
||||
def test_merge_inlines_at_figure_boundary():
|
||||
"""Top-level helper does the right thing with figures (no leftover section)."""
|
||||
markdown = "Lead.\n<figure>\nbars\n</figure>\nTrailer.\n"
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[_desc(name="Im0", description="Bar chart.")]
|
||||
)
|
||||
|
||||
out = merge_descriptions_into_markdown(markdown, result)
|
||||
|
||||
# Inline succeeded -> no appended-section heading.
|
||||
assert "## Image Content" not in out
|
||||
assert "Bar chart." in out
|
||||
assert "<figure>" in out and "</figure>" in out
|
||||
|
||||
|
||||
def test_inject_figures_then_falls_through_to_docling_marker():
|
||||
"""Mixed-marker doc: figure consumed first, then Docling placeholder.
|
||||
|
||||
Defensive -- single docs are usually one parser's output, but if a
|
||||
pipeline ever stitches two parsers' markdowns together the inliner
|
||||
should still place each description.
|
||||
"""
|
||||
markdown = (
|
||||
"<figure>\nChart bars: 50, 40, 30\n</figure>\n\n"
|
||||
"Later in the doc:\n\n"
|
||||
"<!-- image -->\nImage: scan.jpeg\n\n"
|
||||
"End.\n"
|
||||
)
|
||||
result = PictureExtractionResult(
|
||||
descriptions=[
|
||||
_desc(name="Im0", description="Chart description."),
|
||||
_desc(name="Im1", description="Scan description."),
|
||||
]
|
||||
)
|
||||
|
||||
out, n = inject_descriptions_inline(markdown, result)
|
||||
|
||||
assert n == 2
|
||||
# Figure preserved + augmented.
|
||||
assert "<figure>" in out and "Chart bars: 50, 40, 30" in out
|
||||
assert "Chart description." in out
|
||||
# Docling placeholder + caption replaced.
|
||||
assert "<!-- image -->" not in out
|
||||
assert "Image: scan.jpeg" not in out
|
||||
assert "**Embedded image:** `scan.jpeg`" in out
|
||||
assert "Scan description." in out
|
||||
146
surfsense_backend/tests/unit/etl_pipeline/test_vision_llm.py
Normal file
146
surfsense_backend/tests/unit/etl_pipeline/test_vision_llm.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""Unit tests for the vision_llm parser helpers.
|
||||
|
||||
Two helpers exist:
|
||||
|
||||
- :func:`parse_with_vision_llm` -- single-shot for standalone image
|
||||
uploads (.png/.jpg/etc). Returns combined markdown (description +
|
||||
verbatim OCR mixed) since the image *is* the document.
|
||||
- :func:`parse_image_for_description` -- per-image-in-PDF call. Returns
|
||||
visual description only; OCR is the ETL service's job.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_with_vision_llm: legacy single-shot path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_parse_with_vision_llm_returns_combined_markdown(tmp_path):
|
||||
"""Standalone image uploads still go through the combined-markdown path."""
|
||||
from app.etl_pipeline.parsers.vision_llm import parse_with_vision_llm
|
||||
|
||||
img = tmp_path / "scan.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = "# A scan of something."
|
||||
fake_llm = AsyncMock()
|
||||
fake_llm.ainvoke.return_value = fake_response
|
||||
|
||||
out = await parse_with_vision_llm(str(img), "scan.png", fake_llm)
|
||||
assert out == "# A scan of something."
|
||||
fake_llm.ainvoke.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_parse_with_vision_llm_rejects_empty_response(tmp_path):
|
||||
"""An empty model response raises rather than silently returning blanks."""
|
||||
from app.etl_pipeline.parsers.vision_llm import parse_with_vision_llm
|
||||
|
||||
img = tmp_path / "scan.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = ""
|
||||
fake_llm = AsyncMock()
|
||||
fake_llm.ainvoke.return_value = fake_response
|
||||
|
||||
with pytest.raises(ValueError, match="empty content"):
|
||||
await parse_with_vision_llm(str(img), "scan.png", fake_llm)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_image_for_description: per-image-in-PDF, description only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_parse_image_for_description_returns_description(tmp_path):
|
||||
"""Description-only path returns the model's markdown unchanged."""
|
||||
from app.etl_pipeline.parsers.vision_llm import parse_image_for_description
|
||||
|
||||
img = tmp_path / "scan.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = "Axial CT showing a large cystic mass."
|
||||
fake_llm = AsyncMock()
|
||||
fake_llm.ainvoke.return_value = fake_response
|
||||
|
||||
out = await parse_image_for_description(str(img), "scan.png", fake_llm)
|
||||
assert out == "Axial CT showing a large cystic mass."
|
||||
|
||||
|
||||
async def test_parse_image_for_description_uses_description_only_prompt(tmp_path):
|
||||
"""The prompt explicitly tells the model NOT to transcribe text.
|
||||
|
||||
This is the contract that lets us drop OCR from the response: the
|
||||
ETL pipeline already has the text (from page-level OCR), so asking
|
||||
the vision LLM for it would be redundant cost.
|
||||
"""
|
||||
from app.etl_pipeline.parsers.vision_llm import parse_image_for_description
|
||||
|
||||
img = tmp_path / "scan.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = "A description"
|
||||
fake_llm = AsyncMock()
|
||||
fake_llm.ainvoke.return_value = fake_response
|
||||
|
||||
await parse_image_for_description(str(img), "scan.png", fake_llm)
|
||||
|
||||
# The prompt is the first text part of the message we sent.
|
||||
sent_messages = fake_llm.ainvoke.call_args.args[0]
|
||||
prompt_text = sent_messages[0].content[0]["text"].lower()
|
||||
assert "describe what this image visually depicts" in prompt_text
|
||||
assert "do not transcribe text" in prompt_text
|
||||
|
||||
|
||||
async def test_parse_image_for_description_rejects_empty(tmp_path):
|
||||
"""Empty response surfaces as ValueError so the caller can skip the image."""
|
||||
from app.etl_pipeline.parsers.vision_llm import parse_image_for_description
|
||||
|
||||
img = tmp_path / "scan.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = " " # whitespace-only counts as empty
|
||||
fake_llm = AsyncMock()
|
||||
fake_llm.ainvoke.return_value = fake_response
|
||||
|
||||
with pytest.raises(ValueError, match="empty content"):
|
||||
await parse_image_for_description(str(img), "scan.png", fake_llm)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image size + extension validation (shared by both paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_image_to_data_url_rejects_oversized(tmp_path):
|
||||
"""Images larger than 5 MB raise before any LLM call is made."""
|
||||
from app.etl_pipeline.parsers.vision_llm import _image_to_data_url
|
||||
|
||||
big = tmp_path / "huge.png"
|
||||
big.write_bytes(b"\x89PNG" + b"\x00" * (6 * 1024 * 1024))
|
||||
|
||||
with pytest.raises(ValueError, match="Image too large"):
|
||||
_image_to_data_url(str(big))
|
||||
|
||||
|
||||
def test_image_to_data_url_rejects_unsupported_extension(tmp_path):
|
||||
"""Unknown extensions raise rather than guessing a MIME type."""
|
||||
from app.etl_pipeline.parsers.vision_llm import _image_to_data_url
|
||||
|
||||
weird = tmp_path / "scan.xyz"
|
||||
weird.write_bytes(b"\x00" * 100)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported image extension"):
|
||||
_image_to_data_url(str(weird))
|
||||
|
|
@ -39,8 +39,9 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
):
|
||||
"""index() runs the chunker and embed_texts via asyncio.to_thread, not blocking the loop.
|
||||
|
||||
The default (non-code) path uses ``chunk_text_hybrid`` so Markdown tables stay
|
||||
intact (see issue #1334); ``chunk_text`` is reserved for the code-chunker branch.
|
||||
Routing between ``chunk_text`` (code path) and ``chunk_text_hybrid`` (default
|
||||
path, see issue #1334) is verified separately in
|
||||
``test_non_code_documents_use_hybrid_chunker``.
|
||||
"""
|
||||
to_thread_calls = []
|
||||
original_to_thread = asyncio.to_thread
|
||||
|
|
@ -86,11 +87,64 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
|
||||
await pipeline.index(document, connector_doc, llm=MagicMock())
|
||||
|
||||
assert "chunk_text_hybrid" in to_thread_calls
|
||||
# Either chunker entry point satisfies the "chunking runs off the event
|
||||
# loop" contract this test guards. Routing between the two is verified
|
||||
# in test_non_code_documents_use_hybrid_chunker.
|
||||
assert {"chunk_text", "chunk_text_hybrid"} & set(to_thread_calls)
|
||||
assert "embed_texts" in to_thread_calls
|
||||
assert document.status == DocumentStatus.ready()
|
||||
|
||||
|
||||
async def test_non_code_documents_use_hybrid_chunker(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""Non-code documents route through ``chunk_text_hybrid`` (issue #1334).
|
||||
|
||||
The hybrid chunker preserves Markdown table integrity by avoiding splits
|
||||
mid-row. Only documents flagged with ``should_use_code_chunker=True``
|
||||
should take the ``chunk_text`` path.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
AsyncMock(return_value="Summary."),
|
||||
)
|
||||
mock_chunk_hybrid = MagicMock(return_value=["chunk1"])
|
||||
mock_chunk_hybrid.__name__ = "chunk_text_hybrid"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid",
|
||||
mock_chunk_hybrid,
|
||||
)
|
||||
mock_chunk_code = MagicMock(return_value=["chunk1"])
|
||||
mock_chunk_code.__name__ = "chunk_text"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock_chunk_code,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
|
||||
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document",
|
||||
MagicMock(),
|
||||
)
|
||||
|
||||
connector_doc = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id="msg-1",
|
||||
search_space_id=1,
|
||||
should_use_code_chunker=False,
|
||||
)
|
||||
document = MagicMock(spec=Document)
|
||||
document.id = 1
|
||||
document.status = DocumentStatus.pending()
|
||||
|
||||
await pipeline.index(document, connector_doc, llm=MagicMock())
|
||||
|
||||
mock_chunk_hybrid.assert_called_once()
|
||||
mock_chunk_code.assert_not_called()
|
||||
|
||||
|
||||
def _mock_session_factory(orm_docs_by_id):
|
||||
"""Replace get_celery_session_maker with a two-level callable.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,209 @@
|
|||
"""Real-graph contract: ``all_interrupt_values`` surfaces every pending interrupt.
|
||||
|
||||
The chat-stream emit loop must yield one ``data-interrupt-request`` SSE frame
|
||||
per paused subagent, in the same order ``state.interrupts`` reports them —
|
||||
that's also the order the resume slicer consumes decisions. These tests pin
|
||||
that contract against a **real** paused parent graph built via
|
||||
:class:`~langgraph.types.Send` fan-out (no synthetic state mocks).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
all_interrupt_values,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: list
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: list
|
||||
tcid: str
|
||||
desc: str
|
||||
|
||||
|
||||
def _build_pausing_subagent(checkpointer: InMemorySaver):
|
||||
def approve_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "do_thing", "args": {"x": 1}, "description": ""}
|
||||
],
|
||||
"review_configs": [{}],
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("approve", approve_node)
|
||||
g.add_edge(START, "approve")
|
||||
g.add_edge("approve", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
|
||||
):
|
||||
def fanout_edge(_state) -> list[Send]:
|
||||
return [
|
||||
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
|
||||
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
|
||||
]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type="approver", runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_every_pending_interrupt_for_two_paused_subagents():
|
||||
"""Two parallel subagents -> ``all_interrupt_values`` returns 2 dicts."""
|
||||
checkpointer = InMemorySaver()
|
||||
subagent = _build_pausing_subagent(checkpointer)
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||
)
|
||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool,
|
||||
tool_call_id_a="parent-tcid-A",
|
||||
tool_call_id_b="parent-tcid-B",
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
parent_config = {
|
||||
"configurable": {"thread_id": "all-iv-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
state = await parent.aget_state(parent_config)
|
||||
|
||||
values = all_interrupt_values(state)
|
||||
|
||||
assert isinstance(values, list)
|
||||
assert len(values) == 2, (
|
||||
f"REGRESSION: expected one value per pending subagent, got "
|
||||
f"{len(values)}: {values!r}"
|
||||
)
|
||||
stamps = [v.get("tool_call_id") for v in values]
|
||||
assert sorted(stamps) == ["parent-tcid-A", "parent-tcid-B"]
|
||||
for v in values:
|
||||
assert isinstance(v.get("action_requests"), list)
|
||||
assert len(v["action_requests"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_state_interrupts_traversal_order():
|
||||
"""Order returned by inspector must match ``state.interrupts`` order.
|
||||
|
||||
The resume slicer consumes decisions left-to-right against
|
||||
``collect_pending_tool_calls(state)`` which walks ``state.interrupts``
|
||||
in iteration order — so the inspector (which drives the *emit* order)
|
||||
must agree with that traversal or the slice and the wire fall out of sync.
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
subagent = _build_pausing_subagent(checkpointer)
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||
)
|
||||
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||
task_tool,
|
||||
tool_call_id_a="parent-tcid-A",
|
||||
tool_call_id_b="parent-tcid-B",
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
parent_config = {
|
||||
"configurable": {"thread_id": "order-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
state = await parent.aget_state(parent_config)
|
||||
|
||||
inspector_order = [v["tool_call_id"] for v in all_interrupt_values(state)]
|
||||
state_order = [
|
||||
i.value["tool_call_id"]
|
||||
for i in state.interrupts
|
||||
if isinstance(getattr(i, "value", None), dict) and "tool_call_id" in i.value
|
||||
]
|
||||
|
||||
assert inspector_order == state_order, (
|
||||
f"inspector order {inspector_order!r} diverged from state.interrupts "
|
||||
f"order {state_order!r}; the resume slicer would mis-route decisions."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_list_when_nothing_paused():
|
||||
"""A graph that completes normally produces no interrupts to surface."""
|
||||
|
||||
def done_node(_state):
|
||||
return {"messages": [AIMessage(content="done")]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("done", done_node)
|
||||
g.add_edge(START, "done")
|
||||
g.add_edge("done", END)
|
||||
graph = g.compile(checkpointer=InMemorySaver())
|
||||
config = {"configurable": {"thread_id": "no-pause-thread"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
state = await graph.aget_state(config)
|
||||
|
||||
assert all_interrupt_values(state) == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_paused_subagent_returns_a_list_of_one():
|
||||
"""Single-pause case must still return a list (not unwrap to a dict)."""
|
||||
|
||||
def approve_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [{"name": "x", "args": {}, "description": ""}],
|
||||
"review_configs": [{}],
|
||||
"tool_call_id": "lonely-tcid",
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("approve", approve_node)
|
||||
g.add_edge(START, "approve")
|
||||
g.add_edge("approve", END)
|
||||
graph = g.compile(checkpointer=InMemorySaver())
|
||||
config = {"configurable": {"thread_id": "single-thread"}}
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
state = await graph.aget_state(config)
|
||||
|
||||
values = all_interrupt_values(state)
|
||||
|
||||
assert isinstance(values, list)
|
||||
assert len(values) == 1
|
||||
assert values[0].get("tool_call_id") == "lonely-tcid"
|
||||
|
|
@ -23,7 +23,6 @@ from app.tasks.chat.stream_new_chat import (
|
|||
_emit_stream_terminal_error as old_emit_terminal_error,
|
||||
_extract_chunk_parts as old_extract_chunk_parts,
|
||||
_extract_resolved_file_path as old_extract_resolved_file_path,
|
||||
_first_interrupt_value as old_first_interrupt_value,
|
||||
_tool_output_has_error as old_tool_output_has_error,
|
||||
_tool_output_to_text as old_tool_output_to_text,
|
||||
)
|
||||
|
|
@ -36,9 +35,6 @@ from app.tasks.chat.streaming.errors.emitter import (
|
|||
from app.tasks.chat.streaming.helpers.chunk_parts import (
|
||||
extract_chunk_parts as new_extract_chunk_parts,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
first_interrupt_value as new_first_interrupt_value,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.tool_output import (
|
||||
extract_resolved_file_path as new_extract_resolved_file_path,
|
||||
tool_output_has_error as new_tool_output_has_error,
|
||||
|
|
@ -105,52 +101,6 @@ def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
|
|||
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
|
||||
|
||||
|
||||
# ---------------------------------------------------------- interrupt inspector
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Interrupt:
|
||||
value: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Task:
|
||||
interrupts: tuple[Any, ...] = ()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _State:
|
||||
tasks: tuple[Any, ...] = ()
|
||||
interrupts: tuple[Any, ...] = ()
|
||||
|
||||
|
||||
_INTERRUPT_CASES: list[Any] = [
|
||||
_State(),
|
||||
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
|
||||
# Multiple tasks: must return the FIRST one in iteration order.
|
||||
_State(
|
||||
tasks=(
|
||||
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
|
||||
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
|
||||
)
|
||||
),
|
||||
# Empty task interrupts -> falls back to root state.interrupts.
|
||||
_State(
|
||||
tasks=(_Task(interrupts=()),),
|
||||
interrupts=(_Interrupt(value={"name": "root"}),),
|
||||
),
|
||||
# Interrupts as plain dicts (not wrapper objects).
|
||||
_State(interrupts=({"value": {"name": "dict_root"}},)),
|
||||
# A defective task whose `.interrupts` raises - must be tolerated.
|
||||
_State(tasks=(object(),)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
|
||||
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
|
||||
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
|
||||
|
||||
|
||||
# ----------------------------------------------------------- error classifier
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,171 @@
|
|||
"""Pin: thinking-step IDs must be globally unique within a thread.
|
||||
|
||||
The frontend rehydrates ``currentThinkingSteps`` from the prior assistant
|
||||
message when starting a resume. If two consecutive resume turns emit step IDs
|
||||
that overlap (e.g. both produce ``thinking-resume-1`` because each invocation
|
||||
constructs a fresh :class:`AgentEventRelayState` with
|
||||
``thinking_step_counter=0``), React renders sibling timeline rows with the
|
||||
same key — the warning the user reported in production.
|
||||
|
||||
The contract this module pins: each ``_stream_agent_events`` invocation must
|
||||
receive a ``step_prefix`` that is unique within the thread (we salt with the
|
||||
per-turn ``turn_id``), so the resulting step IDs across consecutive turns
|
||||
are always disjoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_resume_step_prefix,
|
||||
_stream_agent_events,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeChunk:
|
||||
content: Any = ""
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class _FakeAgentState:
|
||||
def __init__(self) -> None:
|
||||
self.values: dict[str, Any] = {}
|
||||
self.tasks: list[Any] = []
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||
self._events = events
|
||||
self._state = _FakeAgentState()
|
||||
|
||||
async def astream_events( # type: ignore[no-untyped-def]
|
||||
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
del config, version
|
||||
for ev in self._events:
|
||||
yield ev
|
||||
|
||||
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
|
||||
return self._state
|
||||
|
||||
|
||||
def _tool_start(*, name: str, run_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"event": "on_tool_start",
|
||||
"name": name,
|
||||
"run_id": run_id,
|
||||
"data": {"input": {}},
|
||||
}
|
||||
|
||||
|
||||
async def _drain_step_ids(
|
||||
events: list[dict[str, Any]], *, step_prefix: str
|
||||
) -> set[str]:
|
||||
"""Run ``_stream_agent_events`` once and return every emitted thinking-step ID."""
|
||||
agent = _FakeAgent(events)
|
||||
service = VercelStreamingService()
|
||||
result = StreamResult()
|
||||
config = {"configurable": {"thread_id": "regression-thread"}}
|
||||
|
||||
sse_lines: list[str] = []
|
||||
async for sse in _stream_agent_events(
|
||||
agent, config, {}, service, result, step_prefix=step_prefix
|
||||
):
|
||||
sse_lines.append(sse)
|
||||
|
||||
ids: set[str] = set()
|
||||
for line in sse_lines:
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
body = line[len("data: ") :].rstrip("\n")
|
||||
if not body or body == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if payload.get("type") != "data-thinking-step":
|
||||
continue
|
||||
step_id = (payload.get("data") or {}).get("id")
|
||||
if isinstance(step_id, str):
|
||||
ids.add(step_id)
|
||||
return ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consecutive_invocations_with_same_prefix_produce_overlapping_ids():
|
||||
"""Pin the bug: identical ``step_prefix`` across two turns reuses ``-1``, ``-2``…
|
||||
|
||||
This is what production was doing for resume — every resume invocation
|
||||
passed ``step_prefix='thinking-resume'`` and the relay state's counter
|
||||
restarted at 0. Two scrollback timelines built from such turns then
|
||||
presented React with siblings keyed by the same ``thinking-resume-1``.
|
||||
"""
|
||||
events = [
|
||||
_tool_start(name="t1", run_id="run-A-1"),
|
||||
_tool_start(name="t2", run_id="run-A-2"),
|
||||
]
|
||||
|
||||
ids_turn_one = await _drain_step_ids(events, step_prefix="thinking-resume")
|
||||
ids_turn_two = await _drain_step_ids(events, step_prefix="thinking-resume")
|
||||
|
||||
assert ids_turn_one == ids_turn_two != set(), (
|
||||
"fixture broken: expected non-empty overlapping ids when prefix is reused"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_turn_salted_prefix_yields_disjoint_step_ids_across_turns():
|
||||
"""The fix: salting the prefix with the per-turn ``turn_id`` makes IDs disjoint.
|
||||
|
||||
Two consecutive resume calls in the same thread feed two different
|
||||
``turn_id``s into the prefix, so the resulting step IDs cannot collide
|
||||
no matter how many times the FE rehydrates from earlier assistant
|
||||
messages — which is the precondition for the React duplicate-key warning.
|
||||
"""
|
||||
events = [
|
||||
_tool_start(name="t1", run_id="run-A-1"),
|
||||
_tool_start(name="t2", run_id="run-A-2"),
|
||||
]
|
||||
|
||||
ids_turn_one = await _drain_step_ids(
|
||||
events, step_prefix="thinking-resume-104:1778698228472"
|
||||
)
|
||||
ids_turn_two = await _drain_step_ids(
|
||||
events, step_prefix="thinking-resume-104:1778698244022"
|
||||
)
|
||||
|
||||
assert ids_turn_one and ids_turn_two, "fixture broken: expected non-empty id sets"
|
||||
assert ids_turn_one.isdisjoint(ids_turn_two), (
|
||||
f"REGRESSION: per-turn-salted prefixes produced overlapping step IDs: "
|
||||
f"{ids_turn_one & ids_turn_two!r}"
|
||||
)
|
||||
|
||||
|
||||
def test_resume_step_prefix_helper_includes_turn_id_verbatim():
|
||||
"""Production call-site pin: ``stream_resume_chat`` builds the prefix via
|
||||
this helper. Reverting it back to a hardcoded ``'thinking-resume'`` would
|
||||
silently re-introduce the duplicate-key React warning across consecutive
|
||||
resumes — this test fails first instead.
|
||||
"""
|
||||
a = _resume_step_prefix("104:1778698228472")
|
||||
b = _resume_step_prefix("104:1778698244022")
|
||||
|
||||
assert a.startswith("thinking-resume-"), (
|
||||
f"prefix shape changed; the FE log filters and the timeline contract "
|
||||
f"expect the ``thinking-resume-`` head to remain stable: got {a!r}"
|
||||
)
|
||||
assert "104:1778698228472" in a and "104:1778698244022" in b
|
||||
assert a != b
|
||||
Loading…
Add table
Add a link
Reference in a new issue