mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer
This commit is contained in:
parent
246dae40a8
commit
fc2c5b6445
5 changed files with 513 additions and 30 deletions
|
|
@ -21,7 +21,17 @@ _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
|||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
||||
|
||||
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
||||
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
||||
The same call across the resume cycle keeps reading from the same snapshot
|
||||
(``tool_call_id`` is stable per LLM-emitted call).
|
||||
|
||||
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
||||
"""
|
||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||
current_limit = merged.get("recursion_limit")
|
||||
try:
|
||||
|
|
@ -30,43 +40,68 @@ def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
|||
current_int = 0
|
||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
||||
parent_thread_id = configurable.get("thread_id")
|
||||
per_call_suffix = f"task:{runtime.tool_call_id}"
|
||||
configurable["thread_id"] = (
|
||||
f"{parent_thread_id}::{per_call_suffix}"
|
||||
if parent_thread_id
|
||||
else per_call_suffix
|
||||
)
|
||||
merged["configurable"] = configurable
|
||||
return merged
|
||||
|
||||
|
||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||
"""Pop the resume payload; siblings share ``configurable`` by reference."""
|
||||
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
||||
|
||||
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
|
||||
so parallel sibling subagents (each with their own ``tool_call_id``) read
|
||||
only their own decision and never race on a shared scalar.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
return configurable.pop("surfsense_resume_value", None)
|
||||
by_tcid = configurable.get("surfsense_resume_value")
|
||||
if not isinstance(by_tcid, dict):
|
||||
return None
|
||||
payload = by_tcid.pop(runtime.tool_call_id, None)
|
||||
if not by_tcid:
|
||||
configurable.pop("surfsense_resume_value", None)
|
||||
return payload
|
||||
|
||||
|
||||
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
||||
"""True iff a resume payload is queued on this runtime (non-destructive)."""
|
||||
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return False
|
||||
return "surfsense_resume_value" in configurable
|
||||
by_tcid = configurable.get("surfsense_resume_value")
|
||||
if not isinstance(by_tcid, dict):
|
||||
return False
|
||||
return runtime.tool_call_id in by_tcid
|
||||
|
||||
|
||||
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
||||
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
||||
|
||||
``stream_resume_chat`` wakes the main agent with
|
||||
``Command(resume={"decisions": [...]})`` so the propagated
|
||||
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
|
||||
parent task's ``null_resume`` pending write, which only gets consumed
|
||||
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
|
||||
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
|
||||
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
|
||||
``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up
|
||||
the parent's still-live decisions — mismatching against a different number
|
||||
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
|
||||
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
|
||||
propagated parent-level interrupt can return. langgraph stores that
|
||||
payload as the parent task's ``null_resume`` pending write. The ``task``
|
||||
tool then forwards this turn's slice into the subagent via its own
|
||||
``Command(resume=...)``. While the subagent is mid-execution, any *new*
|
||||
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
|
||||
approve/reject) walks ``subagent_scratchpad → parent_scratchpad.get_null_resume``
|
||||
and picks up the parent's still-live decisions — mismatching against a
|
||||
different number of hanging tool calls and crashing
|
||||
``HumanInTheLoopMiddleware``.
|
||||
|
||||
Draining the write here closes that cross-graph leak so subagent
|
||||
interrupts pause cleanly and re-propagate as a fresh approval card.
|
||||
interrupts pause cleanly and bubble back up as a fresh approval card.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
|
||||
|
||||
The frontend submits decisions in the same order the SSE stream emitted
|
||||
approval cards. When multiple parallel subagents are paused, the backend uses
|
||||
this module to:
|
||||
|
||||
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
|
||||
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
|
||||
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
|
||||
inside ``task_tool``'s catch-and-stamp block when a subagent's
|
||||
``GraphInterrupt`` bubbles up through ``[a]task``.
|
||||
2. Slice the flat ``decisions`` list against that ordered pending list to
|
||||
produce the dict shape expected by ``consume_surfsense_resume``.
|
||||
|
||||
Both helpers are pure: callers own the state and the input decisions; we
|
||||
return new structures and never mutate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def slice_decisions_by_tool_call(
|
||||
decisions: list[dict[str, Any]],
|
||||
pending: Iterable[tuple[str, int]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
|
||||
|
||||
Args:
|
||||
decisions: Flat list of decisions in the order the SSE stream rendered
|
||||
them.
|
||||
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
|
||||
order. The slicer consumes ``decisions`` left-to-right.
|
||||
|
||||
Returns:
|
||||
Per-``tool_call_id`` payload dict ready to be written to
|
||||
``configurable["surfsense_resume_value"]``.
|
||||
|
||||
Raises:
|
||||
ValueError: When the total expected action count differs from the
|
||||
number of decisions provided. We fail loud rather than silently
|
||||
dropping or padding so a frontend/backend contract drift surfaces
|
||||
immediately.
|
||||
"""
|
||||
pending_list = list(pending)
|
||||
expected = sum(count for _, count in pending_list)
|
||||
if expected != len(decisions):
|
||||
raise ValueError(
|
||||
f"Decision count mismatch: pending tool calls expect "
|
||||
f"{expected} actions but received {len(decisions)} decisions."
|
||||
)
|
||||
|
||||
routed: dict[str, dict[str, Any]] = {}
|
||||
cursor = 0
|
||||
for tool_call_id, action_count in pending_list:
|
||||
routed[tool_call_id] = {
|
||||
"decisions": decisions[cursor : cursor + action_count]
|
||||
}
|
||||
cursor += action_count
|
||||
return routed
|
||||
|
||||
|
||||
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
|
||||
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
|
||||
|
||||
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
|
||||
paused subagent's propagated interrupt). Each interrupt value carries the
|
||||
``tool_call_id`` that the parent's ``task`` tool was processing — see
|
||||
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
|
||||
``except GraphInterrupt`` chokepoint.
|
||||
|
||||
Order is preserved from ``state.interrupts``, which is the order the SSE
|
||||
stream emitted approval cards. The frontend submits decisions in that
|
||||
same order, so the slicer can consume them left-to-right.
|
||||
|
||||
Interrupts without a ``tool_call_id`` are skipped — they were not
|
||||
produced by our task-routing layer (e.g. parent-side HITL middleware on
|
||||
a different tool); ``stream_resume_chat`` is not responsible for routing
|
||||
those.
|
||||
|
||||
Args:
|
||||
state: A langgraph ``StateSnapshot`` (or any object with an
|
||||
``interrupts`` attribute).
|
||||
|
||||
Returns:
|
||||
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
|
||||
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
|
||||
scalar-style ``interrupt("...")`` values that were wrapped as
|
||||
``{"value": ..., "tool_call_id": ...}``.
|
||||
|
||||
Raises:
|
||||
ValueError: When an interrupt value carries a ``tool_call_id`` but
|
||||
the action count cannot be determined (contract bug — every
|
||||
propagated value should be either a HITL bundle or a wrapped
|
||||
scalar).
|
||||
"""
|
||||
pending: list[tuple[str, int]] = []
|
||||
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
|
||||
value = getattr(interrupt_obj, "value", None)
|
||||
if not isinstance(value, dict):
|
||||
logger.warning(
|
||||
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
|
||||
idx,
|
||||
type(value).__name__,
|
||||
)
|
||||
continue
|
||||
tool_call_id = value.get("tool_call_id")
|
||||
if not isinstance(tool_call_id, str):
|
||||
# Should not happen post-stamping; flag loudly if a regression
|
||||
# ever lets an unstamped value reach the parent state.
|
||||
logger.warning(
|
||||
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
|
||||
idx,
|
||||
sorted(value.keys()),
|
||||
)
|
||||
continue
|
||||
|
||||
action_requests = value.get("action_requests")
|
||||
if isinstance(action_requests, list):
|
||||
pending.append((tool_call_id, len(action_requests)))
|
||||
continue
|
||||
if "value" in value:
|
||||
pending.append((tool_call_id, 1))
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Interrupt for tool_call_id={tool_call_id!r} has no "
|
||||
"``action_requests`` list and is not a wrapped scalar value; "
|
||||
"cannot determine action count for resume routing."
|
||||
)
|
||||
|
||||
return pending
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
|
||||
|
||||
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
|
||||
SSE stream emitted approval cards. When multiple parallel subagents are paused,
|
||||
the backend slices that flat list into per-``tool_call_id`` payloads so each
|
||||
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
|
||||
|
||||
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` — which is
|
||||
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
|
||||
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
|
||||
through ``[a]task`` — to build the ordered ``pending`` list the slicer needs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestSliceDecisionsByToolCall:
|
||||
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
|
||||
decisions = [
|
||||
{"type": "approve"},
|
||||
{"type": "edit", "edited_action": {"name": "edited-b1"}},
|
||||
{"type": "reject"},
|
||||
{"type": "approve"},
|
||||
{"type": "approve"},
|
||||
]
|
||||
pending = [
|
||||
("tcid-A", 3),
|
||||
("tcid-B", 2),
|
||||
]
|
||||
|
||||
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
assert routed == {
|
||||
"tcid-A": {"decisions": decisions[0:3]},
|
||||
"tcid-B": {"decisions": decisions[3:5]},
|
||||
}
|
||||
|
||||
def test_raises_when_decision_count_less_than_total_actions(self):
|
||||
decisions = [{"type": "approve"}, {"type": "approve"}]
|
||||
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||
|
||||
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
def test_raises_when_decision_count_greater_than_total_actions(self):
|
||||
decisions = [{"type": "approve"}] * 6
|
||||
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||
|
||||
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
|
||||
slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
def test_handles_single_pending_tool_call(self):
|
||||
decisions = [{"type": "approve"}, {"type": "reject"}]
|
||||
pending = [("tcid-only", 2)]
|
||||
|
||||
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||
|
||||
assert routed == {"tcid-only": {"decisions": decisions}}
|
||||
|
||||
def test_returns_empty_dict_for_no_pending(self):
|
||||
routed = slice_decisions_by_tool_call([], [])
|
||||
|
||||
assert routed == {}
|
||||
|
||||
|
||||
def _interrupt_with(tool_call_id: str, action_count: int):
|
||||
return SimpleNamespace(
|
||||
id=f"i-{tool_call_id}",
|
||||
value={
|
||||
"action_requests": [{"name": "n", "args": {}}] * action_count,
|
||||
"review_configs": [{}] * action_count,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestCollectPendingToolCalls:
|
||||
def test_single_pending_returns_one_pair(self):
|
||||
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
|
||||
|
||||
def test_multiple_pending_preserves_state_order(self):
|
||||
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
_interrupt_with("tcid-A", 2),
|
||||
_interrupt_with("tcid-B", 3),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
|
||||
|
||||
def test_empty_when_no_interrupts(self):
|
||||
state = SimpleNamespace(interrupts=())
|
||||
|
||||
assert collect_pending_tool_calls(state) == []
|
||||
|
||||
def test_skips_interrupts_without_tool_call_id(self):
|
||||
"""Defensive: interrupts not produced by our propagation layer are ignored.
|
||||
|
||||
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
|
||||
interrupts (e.g. parent-side HITL middleware on a different tool) are
|
||||
not the slicer's responsibility.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
_interrupt_with("tcid-A", 2),
|
||||
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
|
||||
_interrupt_with("tcid-B", 1),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
|
||||
|
||||
def test_handles_scalar_value_interrupt(self):
|
||||
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
|
||||
|
||||
These have no ``action_requests`` — count them as a single action so
|
||||
the frontend submits exactly one decision per such interrupt.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"value": "approve?", "tool_call_id": "tcid-A"},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
|
||||
|
||||
def test_raises_when_interrupt_value_missing_action_count_keys(self):
|
||||
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
|
||||
state = SimpleNamespace(
|
||||
interrupts=(
|
||||
SimpleNamespace(
|
||||
id="i-A",
|
||||
value={"tool_call_id": "tcid-A", "weird_shape": True},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="action_requests"):
|
||||
collect_pending_tool_calls(state)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Resume side-channel must be read exactly once per turn."""
|
||||
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
|
|||
)
|
||||
|
||||
|
||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||
def _runtime_with_config(
|
||||
config: dict, *, tool_call_id: str = "tcid-test"
|
||||
) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="tcid-test",
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeSurfsenseResume:
|
||||
def test_pops_value_on_first_call(self):
|
||||
def test_pops_only_entry_matching_runtime_tool_call_id(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {
|
||||
"tcid-A": {"decisions": ["approve"]},
|
||||
"tcid-B": {"decisions": ["reject"]},
|
||||
}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||
|
||||
def test_second_call_returns_none(self):
|
||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
def test_popping_one_entry_leaves_siblings_untouched(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {
|
||||
"tcid-A": {"decisions": ["approve"]},
|
||||
"tcid-B": {"decisions": ["reject"]},
|
||||
}
|
||||
}
|
||||
runtime_a = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
consume_surfsense_resume(runtime_a)
|
||||
|
||||
assert configurable["surfsense_resume_value"] == {
|
||||
"tcid-B": {"decisions": ["reject"]}
|
||||
}
|
||||
|
||||
def test_returns_none_when_no_entry_for_this_tool_call(self):
|
||||
runtime = _runtime_with_config(
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
def test_returns_none_when_no_payload_queued(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
|
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
|
|||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
def test_drops_empty_dict_after_last_entry_consumed(self):
|
||||
configurable = {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
|
||||
class TestHasSurfsenseResume:
|
||||
def test_true_when_payload_queued(self):
|
||||
def test_true_when_entry_for_this_tool_call_present(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is True
|
||||
|
||||
def test_false_when_entry_for_other_tool_call_only(self):
|
||||
runtime = _runtime_with_config(
|
||||
{
|
||||
"configurable": {
|
||||
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||
}
|
||||
},
|
||||
tool_call_id="tcid-A",
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is False
|
||||
|
||||
def test_does_not_consume_payload(self):
|
||||
configurable = {"surfsense_resume_value": "approve"}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
configurable = {
|
||||
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||
}
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||
)
|
||||
|
||||
has_surfsense_resume(runtime)
|
||||
|
||||
assert configurable == {"surfsense_resume_value": "approve"}
|
||||
assert configurable["surfsense_resume_value"] == {
|
||||
"tcid-A": {"decisions": ["approve"]}
|
||||
}
|
||||
|
||||
def test_false_when_payload_absent(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
"""Per-call ``thread_id`` derivation for nested subagent invocations.
|
||||
|
||||
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
|
||||
checkpoint slots so their nested pregel runs do not stomp on each other or on
|
||||
the parent's checkpoint state. The slot key is derived from the runtime's
|
||||
``tool_call_id`` so the same call across the resume cycle keeps reading from
|
||||
the same snapshot.
|
||||
|
||||
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
|
||||
is the primary checkpoint key and is free-form, so it's the right primitive.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
subagent_invoke_config,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config or {},
|
||||
stream_writer=None,
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestSubagentInvokeThreadId:
|
||||
def test_sets_per_call_thread_id_under_parent(self):
|
||||
runtime = _runtime(
|
||||
tool_call_id="tcid-A",
|
||||
config={"configurable": {"thread_id": "t1"}},
|
||||
)
|
||||
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
|
||||
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
|
||||
|
||||
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
|
||||
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
|
||||
runtime = _runtime(
|
||||
tool_call_id="tcid-inner",
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": "t1::task:tcid-outer",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
|
||||
assert (
|
||||
sub_config["configurable"]["thread_id"]
|
||||
== "t1::task:tcid-outer::task:tcid-inner"
|
||||
)
|
||||
|
||||
def test_different_tool_call_ids_produce_different_thread_ids(self):
|
||||
config = {"configurable": {"thread_id": "t1"}}
|
||||
rt_a = _runtime(tool_call_id="tcid-A", config=config)
|
||||
rt_b = _runtime(tool_call_id="tcid-B", config=config)
|
||||
|
||||
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
|
||||
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
|
||||
|
||||
assert tid_a != tid_b
|
||||
|
||||
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
|
||||
"""Resume bridge needs to find the snapshot it primed earlier."""
|
||||
config = {"configurable": {"thread_id": "t1"}}
|
||||
rt_first = _runtime(tool_call_id="tcid-A", config=config)
|
||||
rt_second = _runtime(tool_call_id="tcid-A", config=config)
|
||||
|
||||
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
|
||||
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
|
||||
|
||||
assert tid_first == tid_second
|
||||
|
||||
def test_does_not_mutate_caller_config(self):
|
||||
"""Repeated calls must not accumulate suffixes onto the parent's config."""
|
||||
original_thread_id = "t1"
|
||||
config = {"configurable": {"thread_id": original_thread_id}}
|
||||
runtime = _runtime(tool_call_id="tcid-A", config=config)
|
||||
|
||||
subagent_invoke_config(runtime)
|
||||
subagent_invoke_config(runtime)
|
||||
|
||||
assert config["configurable"]["thread_id"] == original_thread_id
|
||||
Loading…
Add table
Add a link
Reference in a new issue