mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer
This commit is contained in:
parent
246dae40a8
commit
fc2c5b6445
5 changed files with 513 additions and 30 deletions
|
|
@ -0,0 +1,154 @@
|
|||
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
|
||||
|
||||
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
|
||||
SSE stream emitted approval cards. When multiple parallel subagents are paused,
|
||||
the backend slices that flat list into per-``tool_call_id`` payloads so each
|
||||
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
|
||||
|
||||
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` — which is
|
||||
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
|
||||
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
|
||||
through ``[a]task`` — to build the ordered ``pending`` list the slicer needs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestSliceDecisionsByToolCall:
|
||||
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
|
||||
decisions = [
|
||||
{"type": "approve"},
|
||||
{"type": "edit", "edited_action": {"name": "edited-b1"}},
|
||||
{"type": "reject"},
|
||||
{"type": "approve"},
|
||||
{"type": "approve"},
|
||||
]
|
||||
pending = [
|
||||
("tcid-A", 3),
|
||||
("tcid-B", 2),
|
||||
]
|
||||
|
||||
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
assert routed == {
|
||||
"tcid-A": {"decisions": decisions[0:3]},
|
||||
"tcid-B": {"decisions": decisions[3:5]},
|
||||
}
|
||||
|
||||
def test_raises_when_decision_count_less_than_total_actions(self):
|
||||
decisions = [{"type": "approve"}, {"type": "approve"}]
|
||||
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||
|
||||
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
def test_raises_when_decision_count_greater_than_total_actions(self):
|
||||
decisions = [{"type": "approve"}] * 6
|
||||
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||
|
||||
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
def test_handles_single_pending_tool_call(self):
|
||||
decisions = [{"type": "approve"}, {"type": "reject"}]
|
||||
pending = [("tcid-only", 2)]
|
||||
|
||||
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
assert routed == {"tcid-only": {"decisions": decisions}}
|
||||
|
||||
def test_returns_empty_dict_for_no_pending(self):
|
||||
routed = slice_decisions_by_tool_call([], [])
|
||||
|
||||
assert routed == {}
|
||||
|
||||
|
||||
def _interrupt_with(tool_call_id: str, action_count: int):
|
||||
return SimpleNamespace(
|
||||
id=f"i-{tool_call_id}",
|
||||
value={
|
||||
"action_requests": [{"name": "n", "args": {}}] * action_count,
|
||||
"review_configs": [{}] * action_count,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestCollectPendingToolCalls:
|
||||
def test_single_pending_returns_one_pair(self):
|
||||
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
|
||||
|
||||
def test_multiple_pending_preserves_state_order(self):
|
||||
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
_interrupt_with("tcid-A", 2),
|
||||
_interrupt_with("tcid-B", 3),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
|
||||
|
||||
def test_empty_when_no_interrupts(self):
|
||||
state = SimpleNamespace(interrupts=())
|
||||
|
||||
assert collect_pending_tool_calls(state) == []
|
||||
|
||||
def test_skips_interrupts_without_tool_call_id(self):
|
||||
"""Defensive: interrupts not produced by our propagation layer are ignored.
|
||||
|
||||
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
|
||||
interrupts (e.g. parent-side HITL middleware on a different tool) are
|
||||
not the slicer's responsibility.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
_interrupt_with("tcid-A", 2),
|
||||
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
|
||||
_interrupt_with("tcid-B", 1),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
|
||||
|
||||
def test_handles_scalar_value_interrupt(self):
|
||||
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
|
||||
|
||||
These have no ``action_requests`` — count them as a single action so
|
||||
the frontend submits exactly one decision per such interrupt.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"value": "approve?", "tool_call_id": "tcid-A"},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
|
||||
|
||||
def test_raises_when_interrupt_value_missing_action_count_keys(self):
|
||||
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"tool_call_id": "tcid-A", "weird_shape": True},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="action_requests"):
|
||||
collect_pending_tool_calls(state)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Resume side-channel must be read exactly once per turn."""
|
||||
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
|
|||
)
|
||||
|
||||
|
||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||
def _runtime_with_config(
|
||||
config: dict, *, tool_call_id: str = "tcid-test"
|
||||
) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="tcid-test",
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeSurfsenseResume:
|
||||
def test_pops_value_on_first_call(self):
|
||||
def test_pops_only_entry_matching_runtime_tool_call_id(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {
|
||||
"tcid-A": {"decisions": ["approve"]},
|
||||
"tcid-B": {"decisions": ["reject"]},
|
||||
}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||
|
||||
def test_second_call_returns_none(self):
|
||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
def test_popping_one_entry_leaves_siblings_untouched(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {
|
||||
"tcid-A": {"decisions": ["approve"]},
|
||||
"tcid-B": {"decisions": ["reject"]},
|
||||
}
|
||||
}
|
||||
runtime_a = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
consume_surfsense_resume(runtime_a)
|
||||
|
||||
assert configurable["surfsense_resume_value"] == {
|
||||
"tcid-B": {"decisions": ["reject"]}
|
||||
}
|
||||
|
||||
def test_returns_none_when_no_entry_for_this_tool_call(self):
|
||||
runtime = _runtime_with_config(
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
def test_returns_none_when_no_payload_queued(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
|
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
|
|||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
def test_drops_empty_dict_after_last_entry_consumed(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
|
||||
class TestHasSurfsenseResume:
|
||||
def test_true_when_payload_queued(self):
|
||||
def test_true_when_entry_for_this_tool_call_present(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is True
|
||||
|
||||
def test_false_when_entry_for_other_tool_call_only(self):
|
||||
runtime = _runtime_with_config(
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is False
|
||||
|
||||
def test_does_not_consume_payload(self):
|
||||
configurable = {"surfsense_resume_value": "approve"}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
configurable = {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
has_surfsense_resume(runtime)
|
||||
|
||||
assert configurable == {"surfsense_resume_value": "approve"}
|
||||
assert configurable["surfsense_resume_value"] == {
|
||||
"tcid-A": {"decisions": ["approve"]}
|
||||
}
|
||||
|
||||
def test_false_when_payload_absent(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
"""Per-call ``thread_id`` derivation for nested subagent invocations.
|
||||
|
||||
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
|
||||
checkpoint slots so their nested pregel runs do not stomp on each other or on
|
||||
the parent's checkpoint state. The slot key is derived from the runtime's
|
||||
``tool_call_id`` so the same call across the resume cycle keeps reading from
|
||||
the same snapshot.
|
||||
|
||||
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
|
||||
is the primary checkpoint key and is free-form, so it's the right primitive.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
subagent_invoke_config,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config or {},
|
||||
stream_writer=None,
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestSubagentInvokeThreadId:
|
||||
def test_sets_per_call_thread_id_under_parent(self):
|
||||
runtime = _runtime(
|
||||
tool_call_id="tcid-A",
|
||||
config={"configurable": {"thread_id": "t1"}},
|
||||
)
|
||||
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
|
||||
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
|
||||
|
||||
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
|
||||
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
|
||||
runtime = _runtime(
|
||||
tool_call_id="tcid-inner",
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": "t1::task:tcid-outer",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
|
||||
assert (
|
||||
sub_config["configurable"]["thread_id"]
|
||||
== "t1::task:tcid-outer::task:tcid-inner"
|
||||
)
|
||||
|
||||
def test_different_tool_call_ids_produce_different_thread_ids(self):
|
||||
config = {"configurable": {"thread_id": "t1"}}
|
||||
rt_a = _runtime(tool_call_id="tcid-A", config=config)
|
||||
rt_b = _runtime(tool_call_id="tcid-B", config=config)
|
||||
|
||||
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
|
||||
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
|
||||
|
||||
assert tid_a != tid_b
|
||||
|
||||
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
|
||||
"""Resume bridge needs to find the snapshot it primed earlier."""
|
||||
config = {"configurable": {"thread_id": "t1"}}
|
||||
rt_first = _runtime(tool_call_id="tcid-A", config=config)
|
||||
rt_second = _runtime(tool_call_id="tcid-A", config=config)
|
||||
|
||||
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
|
||||
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
|
||||
|
||||
assert tid_first == tid_second
|
||||
|
||||
def test_does_not_mutate_caller_config(self):
|
||||
"""Repeated calls must not accumulate suffixes onto the parent's config."""
|
||||
original_thread_id = "t1"
|
||||
config = {"configurable": {"thread_id": original_thread_id}}
|
||||
runtime = _runtime(tool_call_id="tcid-A", config=config)
|
||||
|
||||
subagent_invoke_config(runtime)
|
||||
subagent_invoke_config(runtime)
|
||||
|
||||
assert config["configurable"]["thread_id"] == original_thread_id
|
||||
Loading…
Add table
Add a link
Reference in a new issue