mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/middleware: per-call thread_id, tcid-keyed resume, decisions slicer
This commit is contained in:
parent
246dae40a8
commit
fc2c5b6445
5 changed files with 513 additions and 30 deletions
|
|
@ -21,7 +21,17 @@ _LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
|||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
||||
|
||||
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
||||
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
||||
The same call across the resume cycle keeps reading from the same snapshot
|
||||
(``tool_call_id`` is stable per LLM-emitted call).
|
||||
|
||||
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
||||
"""
|
||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||
current_limit = merged.get("recursion_limit")
|
||||
try:
|
||||
|
|
@ -30,43 +40,68 @@ def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
|||
current_int = 0
|
||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
||||
parent_thread_id = configurable.get("thread_id")
|
||||
per_call_suffix = f"task:{runtime.tool_call_id}"
|
||||
configurable["thread_id"] = (
|
||||
f"{parent_thread_id}::{per_call_suffix}"
|
||||
if parent_thread_id
|
||||
else per_call_suffix
|
||||
)
|
||||
merged["configurable"] = configurable
|
||||
return merged
|
||||
|
||||
|
||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||
"""Pop the resume payload; siblings share ``configurable`` by reference."""
|
||||
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
||||
|
||||
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
|
||||
so parallel sibling subagents (each with their own ``tool_call_id``) read
|
||||
only their own decision and never race on a shared scalar.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
return configurable.pop("surfsense_resume_value", None)
|
||||
by_tcid = configurable.get("surfsense_resume_value")
|
||||
if not isinstance(by_tcid, dict):
|
||||
return None
|
||||
payload = by_tcid.pop(runtime.tool_call_id, None)
|
||||
if not by_tcid:
|
||||
configurable.pop("surfsense_resume_value", None)
|
||||
return payload
|
||||
|
||||
|
||||
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
||||
"""True iff a resume payload is queued on this runtime (non-destructive)."""
|
||||
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return False
|
||||
return "surfsense_resume_value" in configurable
|
||||
by_tcid = configurable.get("surfsense_resume_value")
|
||||
if not isinstance(by_tcid, dict):
|
||||
return False
|
||||
return runtime.tool_call_id in by_tcid
|
||||
|
||||
|
||||
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
||||
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
||||
|
||||
``stream_resume_chat`` wakes the main agent with
|
||||
``Command(resume={"decisions": [...]})`` so the propagated
|
||||
``_lg_interrupt(...)`` can return. langgraph stores that payload as the
|
||||
parent task's ``null_resume`` pending write, which only gets consumed
|
||||
*after* ``subagent.[a]invoke`` returns (when the post-call propagation
|
||||
re-fires). While the subagent is mid-execution, any *new* ``interrupt()``
|
||||
inside it (e.g. a follow-up tool call after a mixed approve/reject) walks
|
||||
``subagent_scratchpad → parent_scratchpad.get_null_resume`` and picks up
|
||||
the parent's still-live decisions — mismatching against a different number
|
||||
of hanging tool calls and crashing ``HumanInTheLoopMiddleware``.
|
||||
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
|
||||
propagated parent-level interrupt can return. langgraph stores that
|
||||
payload as the parent task's ``null_resume`` pending write. The ``task``
|
||||
tool then forwards this turn's slice into the subagent via its own
|
||||
``Command(resume=...)``. While the subagent is mid-execution, any *new*
|
||||
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
|
||||
approve/reject) walks ``subagent_scratchpad → parent_scratchpad.get_null_resume``
|
||||
and picks up the parent's still-live decisions — mismatching against a
|
||||
different number of hanging tool calls and crashing
|
||||
``HumanInTheLoopMiddleware``.
|
||||
|
||||
Draining the write here closes that cross-graph leak so subagent
|
||||
interrupts pause cleanly and re-propagate as a fresh approval card.
|
||||
interrupts pause cleanly and bubble back up as a fresh approval card.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
|
||||
|
||||
The frontend submits decisions in the same order the SSE stream emitted
|
||||
approval cards. When multiple parallel subagents are paused, the backend uses
|
||||
this module to:
|
||||
|
||||
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
|
||||
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
|
||||
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
|
||||
inside ``task_tool``'s catch-and-stamp block when a subagent's
|
||||
``GraphInterrupt`` bubbles up through ``[a]task``.
|
||||
2. Slice the flat ``decisions`` list against that ordered pending list to
|
||||
produce the dict shape expected by ``consume_surfsense_resume``.
|
||||
|
||||
Both helpers are pure: callers own the state and the input decisions; we
|
||||
return new structures and never mutate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def slice_decisions_by_tool_call(
|
||||
decisions: list[dict[str, Any]],
|
||||
pending: Iterable[tuple[str, int]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
|
||||
|
||||
Args:
|
||||
decisions: Flat list of decisions in the order the SSE stream rendered
|
||||
them.
|
||||
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
|
||||
order. The slicer consumes ``decisions`` left-to-right.
|
||||
|
||||
Returns:
|
||||
Per-``tool_call_id`` payload dict ready to be written to
|
||||
``configurable["surfsense_resume_value"]``.
|
||||
|
||||
Raises:
|
||||
ValueError: When the total expected action count differs from the
|
||||
number of decisions provided. We fail loud rather than silently
|
||||
dropping or padding so a frontend/backend contract drift surfaces
|
||||
immediately.
|
||||
"""
|
||||
pending_list = list(pending)
|
||||
expected = sum(count for _, count in pending_list)
|
||||
if expected != len(decisions):
|
||||
raise ValueError(
|
||||
f"Decision count mismatch: pending tool calls expect "
|
||||
f"{expected} actions but received {len(decisions)} decisions."
|
||||
)
|
||||
|
||||
routed: dict[str, dict[str, Any]] = {}
|
||||
cursor = 0
|
||||
for tool_call_id, action_count in pending_list:
|
||||
routed[tool_call_id] = {
|
||||
"decisions": decisions[cursor : cursor + action_count]
|
||||
}
|
||||
cursor += action_count
|
||||
return routed
|
||||
|
||||
|
||||
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
|
||||
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
|
||||
|
||||
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
|
||||
paused subagent's propagated interrupt). Each interrupt value carries the
|
||||
``tool_call_id`` that the parent's ``task`` tool was processing — see
|
||||
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
|
||||
``except GraphInterrupt`` chokepoint.
|
||||
|
||||
Order is preserved from ``state.interrupts``, which is the order the SSE
|
||||
stream emitted approval cards. The frontend submits decisions in that
|
||||
same order, so the slicer can consume them left-to-right.
|
||||
|
||||
Interrupts without a ``tool_call_id`` are skipped — they were not
|
||||
produced by our task-routing layer (e.g. parent-side HITL middleware on
|
||||
a different tool); ``stream_resume_chat`` is not responsible for routing
|
||||
those.
|
||||
|
||||
Args:
|
||||
state: A langgraph ``StateSnapshot`` (or any object with an
|
||||
``interrupts`` attribute).
|
||||
|
||||
Returns:
|
||||
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
|
||||
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
|
||||
scalar-style ``interrupt("...")`` values that were wrapped as
|
||||
``{"value": ..., "tool_call_id": ...}``.
|
||||
|
||||
Raises:
|
||||
ValueError: When an interrupt value carries a ``tool_call_id`` but
|
||||
the action count cannot be determined (contract bug — every
|
||||
propagated value should be either a HITL bundle or a wrapped
|
||||
scalar).
|
||||
"""
|
||||
pending: list[tuple[str, int]] = []
|
||||
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
|
||||
value = getattr(interrupt_obj, "value", None)
|
||||
if not isinstance(value, dict):
|
||||
logger.warning(
|
||||
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
|
||||
idx,
|
||||
type(value).__name__,
|
||||
)
|
||||
continue
|
||||
tool_call_id = value.get("tool_call_id")
|
||||
if not isinstance(tool_call_id, str):
|
||||
# Should not happen post-stamping; flag loudly if a regression
|
||||
# ever lets an unstamped value reach the parent state.
|
||||
logger.warning(
|
||||
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
|
||||
idx,
|
||||
sorted(value.keys()),
|
||||
)
|
||||
continue
|
||||
|
||||
action_requests = value.get("action_requests")
|
||||
if isinstance(action_requests, list):
|
||||
pending.append((tool_call_id, len(action_requests)))
|
||||
continue
|
||||
if "value" in value:
|
||||
pending.append((tool_call_id, 1))
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Interrupt for tool_call_id={tool_call_id!r} has no "
|
||||
"``action_requests`` list and is not a wrapped scalar value; "
|
||||
"cannot determine action count for resume routing."
|
||||
)
|
||||
|
||||
return pending
|
||||
Loading…
Add table
Add a link
Reference in a new issue