mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer
This commit is contained in:
parent
246dae40a8
commit
fc2c5b6445
5 changed files with 513 additions and 30 deletions
|
|
@ -21,7 +21,17 @@ _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
||||||
|
|
||||||
|
|
||||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
||||||
|
|
||||||
|
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
||||||
|
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
||||||
|
The same call across the resume cycle keeps reading from the same snapshot
|
||||||
|
(``tool_call_id`` is stable per LLM-emitted call).
|
||||||
|
|
||||||
|
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||||
|
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||||
|
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
||||||
|
"""
|
||||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||||
current_limit = merged.get("recursion_limit")
|
current_limit = merged.get("recursion_limit")
|
||||||
try:
|
try:
|
||||||
|
|
@ -30,43 +40,68 @@ def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||||
current_int = 0
|
current_int = 0
|
||||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||||
|
|
||||||
|
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
||||||
|
parent_thread_id = configurable.get("thread_id")
|
||||||
|
per_call_suffix = f"task:{runtime.tool_call_id}"
|
||||||
|
configurable["thread_id"] = (
|
||||||
|
f"{parent_thread_id}::{per_call_suffix}"
|
||||||
|
if parent_thread_id
|
||||||
|
else per_call_suffix
|
||||||
|
)
|
||||||
|
merged["configurable"] = configurable
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||||
"""Pop the resume payload; siblings share ``configurable`` by reference."""
|
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
||||||
|
|
||||||
|
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
|
||||||
|
so parallel sibling subagents (each with their own ``tool_call_id``) read
|
||||||
|
only their own decision and never race on a shared scalar.
|
||||||
|
"""
|
||||||
cfg = runtime.config or {}
|
cfg = runtime.config or {}
|
||||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||||
if not isinstance(configurable, dict):
|
if not isinstance(configurable, dict):
|
||||||
return None
|
return None
|
||||||
return configurable.pop("surfsense_resume_value", None)
|
by_tcid = configurable.get("surfsense_resume_value")
|
||||||
|
if not isinstance(by_tcid, dict):
|
||||||
|
return None
|
||||||
|
payload = by_tcid.pop(runtime.tool_call_id, None)
|
||||||
|
if not by_tcid:
|
||||||
|
configurable.pop("surfsense_resume_value", None)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
||||||
"""True iff a resume payload is queued on this runtime (non-destructive)."""
|
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
|
||||||
cfg = runtime.config or {}
|
cfg = runtime.config or {}
|
||||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||||
if not isinstance(configurable, dict):
|
if not isinstance(configurable, dict):
|
||||||
return False
|
return False
|
||||||
return "surfsense_resume_value" in configurable
|
by_tcid = configurable.get("surfsense_resume_value")
|
||||||
|
if not isinstance(by_tcid, dict):
|
||||||
|
return False
|
||||||
|
return runtime.tool_call_id in by_tcid
|
||||||
|
|
||||||
|
|
||||||
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
||||||
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
||||||
|
|
||||||
``stream_resume_chat`` wakes the main agent with
|
``stream_resume_chat`` wakes the main agent with
|
||||||
``Command(resume={"decisions": [...]})`` so the propagated
|
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
|
||||||
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
|
propagated parent-level interrupt can return. langgraph stores that
|
||||||
parent task's ``null_resume`` pending write, which only gets consumed
|
payload as the parent task's ``null_resume`` pending write. The ``task``
|
||||||
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
|
tool then forwards this turn's slice into the subagent via its own
|
||||||
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
|
``Command(resume=...)``. While the subagent is mid-execution, any *new*
|
||||||
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
|
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
|
||||||
``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up
|
approve/reject) walks ``subagent_scratchpad → parent_scratchpad.get_null_resume``
|
||||||
the parent's still-live decisions — mismatching against a different number
|
and picks up the parent's still-live decisions — mismatching against a
|
||||||
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
|
different number of hanging tool calls and crashing
|
||||||
|
``HumanInTheLoopMiddleware``.
|
||||||
|
|
||||||
Draining the write here closes that cross-graph leak so subagent
|
Draining the write here closes that cross-graph leak so subagent
|
||||||
interrupts pause cleanly and re-propagate as a fresh approval card.
|
interrupts pause cleanly and bubble back up as a fresh approval card.
|
||||||
"""
|
"""
|
||||||
cfg = runtime.config or {}
|
cfg = runtime.config or {}
|
||||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,137 @@
|
||||||
|
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
|
||||||
|
|
||||||
|
The frontend submits decisions in the same order the SSE stream emitted
|
||||||
|
approval cards. When multiple parallel subagents are paused, the backend uses
|
||||||
|
this module to:
|
||||||
|
|
||||||
|
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
|
||||||
|
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
|
||||||
|
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
|
||||||
|
inside ``task_tool``'s catch-and-stamp block when a subagent's
|
||||||
|
``GraphInterrupt`` bubbles up through ``[a]task``.
|
||||||
|
2. Slice the flat ``decisions`` list against that ordered pending list to
|
||||||
|
produce the dict shape expected by ``consume_surfsense_resume``.
|
||||||
|
|
||||||
|
Both helpers are pure: callers own the state and the input decisions; we
|
||||||
|
return new structures and never mutate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def slice_decisions_by_tool_call(
|
||||||
|
decisions: list[dict[str, Any]],
|
||||||
|
pending: Iterable[tuple[str, int]],
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decisions: Flat list of decisions in the order the SSE stream rendered
|
||||||
|
them.
|
||||||
|
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
|
||||||
|
order. The slicer consumes ``decisions`` left-to-right.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Per-``tool_call_id`` payload dict ready to be written to
|
||||||
|
``configurable["surfsense_resume_value"]``.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When the total expected action count differs from the
|
||||||
|
number of decisions provided. We fail loud rather than silently
|
||||||
|
dropping or padding so a frontend/backend contract drift surfaces
|
||||||
|
immediately.
|
||||||
|
"""
|
||||||
|
pending_list = list(pending)
|
||||||
|
expected = sum(count for _, count in pending_list)
|
||||||
|
if expected != len(decisions):
|
||||||
|
raise ValueError(
|
||||||
|
f"Decision count mismatch: pending tool calls expect "
|
||||||
|
f"{expected} actions but received {len(decisions)} decisions."
|
||||||
|
)
|
||||||
|
|
||||||
|
routed: dict[str, dict[str, Any]] = {}
|
||||||
|
cursor = 0
|
||||||
|
for tool_call_id, action_count in pending_list:
|
||||||
|
routed[tool_call_id] = {
|
||||||
|
"decisions": decisions[cursor : cursor + action_count]
|
||||||
|
}
|
||||||
|
cursor += action_count
|
||||||
|
return routed
|
||||||
|
|
||||||
|
|
||||||
|
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
|
||||||
|
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
|
||||||
|
|
||||||
|
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
|
||||||
|
paused subagent's propagated interrupt). Each interrupt value carries the
|
||||||
|
``tool_call_id`` that the parent's ``task`` tool was processing — see
|
||||||
|
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
|
||||||
|
``except GraphInterrupt`` chokepoint.
|
||||||
|
|
||||||
|
Order is preserved from ``state.interrupts``, which is the order the SSE
|
||||||
|
stream emitted approval cards. The frontend submits decisions in that
|
||||||
|
same order, so the slicer can consume them left-to-right.
|
||||||
|
|
||||||
|
Interrupts without a ``tool_call_id`` are skipped — they were not
|
||||||
|
produced by our task-routing layer (e.g. parent-side HITL middleware on
|
||||||
|
a different tool); ``stream_resume_chat`` is not responsible for routing
|
||||||
|
those.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: A langgraph ``StateSnapshot`` (or any object with an
|
||||||
|
``interrupts`` attribute).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
|
||||||
|
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
|
||||||
|
scalar-style ``interrupt("...")`` values that were wrapped as
|
||||||
|
``{"value": ..., "tool_call_id": ...}``.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When an interrupt value carries a ``tool_call_id`` but
|
||||||
|
the action count cannot be determined (contract bug — every
|
||||||
|
propagated value should be either a HITL bundle or a wrapped
|
||||||
|
scalar).
|
||||||
|
"""
|
||||||
|
pending: list[tuple[str, int]] = []
|
||||||
|
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
|
||||||
|
value = getattr(interrupt_obj, "value", None)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
logger.warning(
|
||||||
|
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
|
||||||
|
idx,
|
||||||
|
type(value).__name__,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
tool_call_id = value.get("tool_call_id")
|
||||||
|
if not isinstance(tool_call_id, str):
|
||||||
|
# Should not happen post-stamping; flag loudly if a regression
|
||||||
|
# ever lets an unstamped value reach the parent state.
|
||||||
|
logger.warning(
|
||||||
|
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
|
||||||
|
idx,
|
||||||
|
sorted(value.keys()),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
action_requests = value.get("action_requests")
|
||||||
|
if isinstance(action_requests, list):
|
||||||
|
pending.append((tool_call_id, len(action_requests)))
|
||||||
|
continue
|
||||||
|
if "value" in value:
|
||||||
|
pending.append((tool_call_id, 1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Interrupt for tool_call_id={tool_call_id!r} has no "
|
||||||
|
"``action_requests`` list and is not a wrapped scalar value; "
|
||||||
|
"cannot determine action count for resume routing."
|
||||||
|
)
|
||||||
|
|
||||||
|
return pending
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
|
||||||
|
|
||||||
|
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
|
||||||
|
SSE stream emitted approval cards. When multiple parallel subagents are paused,
|
||||||
|
the backend slices that flat list into per-``tool_call_id`` payloads so each
|
||||||
|
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
|
||||||
|
|
||||||
|
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` — which is
|
||||||
|
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
|
||||||
|
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
|
||||||
|
through ``[a]task`` — to build the ordered ``pending`` list the slicer needs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||||
|
collect_pending_tool_calls,
|
||||||
|
slice_decisions_by_tool_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSliceDecisionsByToolCall:
|
||||||
|
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
|
||||||
|
decisions = [
|
||||||
|
{"type": "approve"},
|
||||||
|
{"type": "edit", "edited_action": {"name": "edited-b1"}},
|
||||||
|
{"type": "reject"},
|
||||||
|
{"type": "approve"},
|
||||||
|
{"type": "approve"},
|
||||||
|
]
|
||||||
|
pending = [
|
||||||
|
("tcid-A", 3),
|
||||||
|
("tcid-B", 2),
|
||||||
|
]
|
||||||
|
|
||||||
|
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||||
|
|
||||||
|
assert routed == {
|
||||||
|
"tcid-A": {"decisions": decisions[0:3]},
|
||||||
|
"tcid-B": {"decisions": decisions[3:5]},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_raises_when_decision_count_less_than_total_actions(self):
|
||||||
|
decisions = [{"type": "approve"}, {"type": "approve"}]
|
||||||
|
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
|
||||||
|
slice_decisions_by_tool_call(decisions, pending)
|
||||||
|
|
||||||
|
def test_raises_when_decision_count_greater_than_total_actions(self):
|
||||||
|
decisions = [{"type": "approve"}] * 6
|
||||||
|
pending = [("tcid-A", 3), ("tcid-B", 2)]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
|
||||||
|
slice_decisions_by_tool_call(decisions, pending)
|
||||||
|
|
||||||
|
def test_handles_single_pending_tool_call(self):
|
||||||
|
decisions = [{"type": "approve"}, {"type": "reject"}]
|
||||||
|
pending = [("tcid-only", 2)]
|
||||||
|
|
||||||
|
routed = slice_decisions_by_tool_call(decisions, pending)
|
||||||
|
|
||||||
|
assert routed == {"tcid-only": {"decisions": decisions}}
|
||||||
|
|
||||||
|
def test_returns_empty_dict_for_no_pending(self):
|
||||||
|
routed = slice_decisions_by_tool_call([], [])
|
||||||
|
|
||||||
|
assert routed == {}
|
||||||
|
|
||||||
|
|
||||||
|
def _interrupt_with(tool_call_id: str, action_count: int):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=f"i-{tool_call_id}",
|
||||||
|
value={
|
||||||
|
"action_requests": [{"name": "n", "args": {}}] * action_count,
|
||||||
|
"review_configs": [{}] * action_count,
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectPendingToolCalls:
|
||||||
|
def test_single_pending_returns_one_pair(self):
|
||||||
|
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
|
||||||
|
|
||||||
|
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
|
||||||
|
|
||||||
|
def test_multiple_pending_preserves_state_order(self):
|
||||||
|
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
|
||||||
|
state = SimpleNamespace(
|
||||||
|
interrupts=(
|
||||||
|
_interrupt_with("tcid-A", 2),
|
||||||
|
_interrupt_with("tcid-B", 3),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
|
||||||
|
|
||||||
|
def test_empty_when_no_interrupts(self):
|
||||||
|
state = SimpleNamespace(interrupts=())
|
||||||
|
|
||||||
|
assert collect_pending_tool_calls(state) == []
|
||||||
|
|
||||||
|
def test_skips_interrupts_without_tool_call_id(self):
|
||||||
|
"""Defensive: interrupts not produced by our propagation layer are ignored.
|
||||||
|
|
||||||
|
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
|
||||||
|
interrupts (e.g. parent-side HITL middleware on a different tool) are
|
||||||
|
not the slicer's responsibility.
|
||||||
|
"""
|
||||||
|
state = SimpleNamespace(
|
||||||
|
interrupts=(
|
||||||
|
_interrupt_with("tcid-A", 2),
|
||||||
|
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
|
||||||
|
_interrupt_with("tcid-B", 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
|
||||||
|
|
||||||
|
def test_handles_scalar_value_interrupt(self):
|
||||||
|
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
|
||||||
|
|
||||||
|
These have no ``action_requests`` — count them as a single action so
|
||||||
|
the frontend submits exactly one decision per such interrupt.
|
||||||
|
"""
|
||||||
|
state = SimpleNamespace(
|
||||||
|
interrupts=(
|
||||||
|
SimpleNamespace(
|
||||||
|
id="i-A",
|
||||||
|
value={"value": "approve?", "tool_call_id": "tcid-A"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
|
||||||
|
|
||||||
|
def test_raises_when_interrupt_value_missing_action_count_keys(self):
|
||||||
|
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
|
||||||
|
state = SimpleNamespace(
|
||||||
|
interrupts=(
|
||||||
|
SimpleNamespace(
|
||||||
|
id="i-A",
|
||||||
|
value={"tool_call_id": "tcid-A", "weird_shape": True},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="action_requests"):
|
||||||
|
collect_pending_tool_calls(state)
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Resume side-channel must be read exactly once per turn."""
|
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
def _runtime_with_config(
|
||||||
|
config: dict, *, tool_call_id: str = "tcid-test"
|
||||||
|
) -> ToolRuntime:
|
||||||
return ToolRuntime(
|
return ToolRuntime(
|
||||||
state=None,
|
state=None,
|
||||||
context=None,
|
context=None,
|
||||||
config=config,
|
config=config,
|
||||||
stream_writer=None,
|
stream_writer=None,
|
||||||
tool_call_id="tcid-test",
|
tool_call_id=tool_call_id,
|
||||||
store=None,
|
store=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestConsumeSurfsenseResume:
|
class TestConsumeSurfsenseResume:
|
||||||
def test_pops_value_on_first_call(self):
|
def test_pops_only_entry_matching_runtime_tool_call_id(self):
|
||||||
|
configurable = {
|
||||||
|
"surfsense_resume_value": {
|
||||||
|
"tcid-A": {"decisions": ["approve"]},
|
||||||
|
"tcid-B": {"decisions": ["reject"]},
|
||||||
|
}
|
||||||
|
}
|
||||||
runtime = _runtime_with_config(
|
runtime = _runtime_with_config(
|
||||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||||
|
|
||||||
def test_second_call_returns_none(self):
|
def test_popping_one_entry_leaves_siblings_untouched(self):
|
||||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
configurable = {
|
||||||
runtime = _runtime_with_config({"configurable": configurable})
|
"surfsense_resume_value": {
|
||||||
|
"tcid-A": {"decisions": ["approve"]},
|
||||||
|
"tcid-B": {"decisions": ["reject"]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
runtime_a = _runtime_with_config(
|
||||||
|
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||||
|
)
|
||||||
|
|
||||||
consume_surfsense_resume(runtime)
|
consume_surfsense_resume(runtime_a)
|
||||||
|
|
||||||
|
assert configurable["surfsense_resume_value"] == {
|
||||||
|
"tcid-B": {"decisions": ["reject"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_returns_none_when_no_entry_for_this_tool_call(self):
|
||||||
|
runtime = _runtime_with_config(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tool_call_id="tcid-A",
|
||||||
|
)
|
||||||
|
|
||||||
assert consume_surfsense_resume(runtime) is None
|
assert consume_surfsense_resume(runtime) is None
|
||||||
assert "surfsense_resume_value" not in configurable
|
|
||||||
|
|
||||||
def test_returns_none_when_no_payload_queued(self):
|
def test_returns_none_when_no_payload_queued(self):
|
||||||
runtime = _runtime_with_config({"configurable": {}})
|
runtime = _runtime_with_config({"configurable": {}})
|
||||||
|
|
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
|
||||||
|
|
||||||
assert consume_surfsense_resume(runtime) is None
|
assert consume_surfsense_resume(runtime) is None
|
||||||
|
|
||||||
|
def test_drops_empty_dict_after_last_entry_consumed(self):
|
||||||
|
configurable = {
|
||||||
|
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||||
|
}
|
||||||
|
runtime = _runtime_with_config(
|
||||||
|
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||||
|
)
|
||||||
|
|
||||||
|
consume_surfsense_resume(runtime)
|
||||||
|
|
||||||
|
assert "surfsense_resume_value" not in configurable
|
||||||
|
|
||||||
|
|
||||||
class TestHasSurfsenseResume:
|
class TestHasSurfsenseResume:
|
||||||
def test_true_when_payload_queued(self):
|
def test_true_when_entry_for_this_tool_call_present(self):
|
||||||
runtime = _runtime_with_config(
|
runtime = _runtime_with_config(
|
||||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
{
|
||||||
|
"configurable": {
|
||||||
|
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tool_call_id="tcid-A",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert has_surfsense_resume(runtime) is True
|
assert has_surfsense_resume(runtime) is True
|
||||||
|
|
||||||
|
def test_false_when_entry_for_other_tool_call_only(self):
|
||||||
|
runtime = _runtime_with_config(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tool_call_id="tcid-A",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert has_surfsense_resume(runtime) is False
|
||||||
|
|
||||||
def test_does_not_consume_payload(self):
|
def test_does_not_consume_payload(self):
|
||||||
configurable = {"surfsense_resume_value": "approve"}
|
configurable = {
|
||||||
runtime = _runtime_with_config({"configurable": configurable})
|
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
|
||||||
|
}
|
||||||
|
runtime = _runtime_with_config(
|
||||||
|
{"configurable": configurable}, tool_call_id="tcid-A"
|
||||||
|
)
|
||||||
|
|
||||||
has_surfsense_resume(runtime)
|
has_surfsense_resume(runtime)
|
||||||
|
|
||||||
assert configurable == {"surfsense_resume_value": "approve"}
|
assert configurable["surfsense_resume_value"] == {
|
||||||
|
"tcid-A": {"decisions": ["approve"]}
|
||||||
|
}
|
||||||
|
|
||||||
def test_false_when_payload_absent(self):
|
def test_false_when_payload_absent(self):
|
||||||
runtime = _runtime_with_config({"configurable": {}})
|
runtime = _runtime_with_config({"configurable": {}})
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""Per-call ``thread_id`` derivation for nested subagent invocations.
|
||||||
|
|
||||||
|
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
|
||||||
|
checkpoint slots so their nested pregel runs do not stomp on each other or on
|
||||||
|
the parent's checkpoint state. The slot key is derived from the runtime's
|
||||||
|
``tool_call_id`` so the same call across the resume cycle keeps reading from
|
||||||
|
the same snapshot.
|
||||||
|
|
||||||
|
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||||
|
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||||
|
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
|
||||||
|
is the primary checkpoint key and is free-form, so it's the right primitive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||||
|
subagent_invoke_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
|
||||||
|
return ToolRuntime(
|
||||||
|
state=None,
|
||||||
|
context=None,
|
||||||
|
config=config or {},
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentInvokeThreadId:
|
||||||
|
def test_sets_per_call_thread_id_under_parent(self):
|
||||||
|
runtime = _runtime(
|
||||||
|
tool_call_id="tcid-A",
|
||||||
|
config={"configurable": {"thread_id": "t1"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
sub_config = subagent_invoke_config(runtime)
|
||||||
|
|
||||||
|
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
|
||||||
|
|
||||||
|
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
|
||||||
|
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
|
||||||
|
runtime = _runtime(
|
||||||
|
tool_call_id="tcid-inner",
|
||||||
|
config={
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": "t1::task:tcid-outer",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
sub_config = subagent_invoke_config(runtime)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sub_config["configurable"]["thread_id"]
|
||||||
|
== "t1::task:tcid-outer::task:tcid-inner"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_different_tool_call_ids_produce_different_thread_ids(self):
|
||||||
|
config = {"configurable": {"thread_id": "t1"}}
|
||||||
|
rt_a = _runtime(tool_call_id="tcid-A", config=config)
|
||||||
|
rt_b = _runtime(tool_call_id="tcid-B", config=config)
|
||||||
|
|
||||||
|
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
|
||||||
|
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
|
||||||
|
|
||||||
|
assert tid_a != tid_b
|
||||||
|
|
||||||
|
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
|
||||||
|
"""Resume bridge needs to find the snapshot it primed earlier."""
|
||||||
|
config = {"configurable": {"thread_id": "t1"}}
|
||||||
|
rt_first = _runtime(tool_call_id="tcid-A", config=config)
|
||||||
|
rt_second = _runtime(tool_call_id="tcid-A", config=config)
|
||||||
|
|
||||||
|
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
|
||||||
|
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
|
||||||
|
|
||||||
|
assert tid_first == tid_second
|
||||||
|
|
||||||
|
def test_does_not_mutate_caller_config(self):
|
||||||
|
"""Repeated calls must not accumulate suffixes onto the parent's config."""
|
||||||
|
original_thread_id = "t1"
|
||||||
|
config = {"configurable": {"thread_id": original_thread_id}}
|
||||||
|
runtime = _runtime(tool_call_id="tcid-A", config=config)
|
||||||
|
|
||||||
|
subagent_invoke_config(runtime)
|
||||||
|
subagent_invoke_config(runtime)
|
||||||
|
|
||||||
|
assert config["configurable"]["thread_id"] == original_thread_id
|
||||||
Loading…
Add table
Add a link
Reference in a new issue