mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Add SurfSenseCheckpointedSubAgentMiddleware to bridge HITL into deepagents subagents.
This commit is contained in:
parent
4fd3c4fb27
commit
acd2fdda8a
7 changed files with 551 additions and 0 deletions
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""SubAgent ``task`` tool wiring required for HITL inside subagents.
|
||||||
|
|
||||||
|
Replaces upstream ``SubAgentMiddleware`` to:
|
||||||
|
|
||||||
|
- share the parent's checkpointer with each subagent,
|
||||||
|
- forward ``runtime.config`` (thread_id, recursion_limit, …) into nested invokes,
|
||||||
|
- bridge ``Command(resume=...)`` from the parent into the subagent via the
|
||||||
|
``config["configurable"]["surfsense_resume_value"]`` side-channel,
|
||||||
|
- target the resume at the captured interrupt id so a follow-up
|
||||||
|
``HumanInTheLoopMiddleware.after_model`` does not consume the same payload,
|
||||||
|
- re-raise any new subagent interrupt at the parent so the SSE stream surfaces it.
|
||||||
|
|
||||||
|
Module layout
|
||||||
|
-------------
|
||||||
|
|
||||||
|
- ``constants`` — shared keys / limits.
|
||||||
|
- ``config`` — RunnableConfig + side-channel resume read.
|
||||||
|
- ``resume`` — pending-interrupt detection, fan-out, ``Command(resume=...)`` builder.
|
||||||
|
- ``propagation`` — re-raise pending subagent interrupts at the parent.
|
||||||
|
- ``task_tool`` — the ``task`` tool factory (sync + async).
|
||||||
|
- ``middleware`` — :class:`SurfSenseCheckpointedSubAgentMiddleware` itself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .middleware import SurfSenseCheckpointedSubAgentMiddleware
|
||||||
|
|
||||||
|
__all__ = ["SurfSenseCheckpointedSubAgentMiddleware"]
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""RunnableConfig wiring for nested subagent invocations.
|
||||||
|
|
||||||
|
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
|
||||||
|
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
|
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||||
|
|
||||||
|
|
||||||
|
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||||
|
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` to the parent's budget."""
|
||||||
|
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||||
|
current_limit = merged.get("recursion_limit")
|
||||||
|
try:
|
||||||
|
current_int = int(current_limit) if current_limit is not None else 0
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
current_int = 0
|
||||||
|
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||||
|
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def extract_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||||
|
"""Resume payload stashed by ``stream_resume_chat``; ``None`` on a first-time call."""
|
||||||
|
cfg = runtime.config or {}
|
||||||
|
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||||
|
if not isinstance(configurable, dict):
|
||||||
|
return None
|
||||||
|
return configurable.get("surfsense_resume_value")
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""Constants shared by the checkpointed subagent middleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
|
||||||
|
EXCLUDED_STATE_KEYS = frozenset(
|
||||||
|
{
|
||||||
|
"messages",
|
||||||
|
"todos",
|
||||||
|
"structured_response",
|
||||||
|
"skills_metadata",
|
||||||
|
"memory_contents",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match the parent graph's budget; the LangGraph default of 25 trips on
|
||||||
|
# multi-step subagent runs.
|
||||||
|
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""SubAgent middleware that compiles each subagent against the parent checkpointer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
||||||
|
from deepagents.middleware.subagents import (
|
||||||
|
TASK_SYSTEM_PROMPT,
|
||||||
|
CompiledSubAgent,
|
||||||
|
SubAgent,
|
||||||
|
SubAgentMiddleware,
|
||||||
|
)
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import HumanInTheLoopMiddleware
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from .task_tool import build_task_tool_with_parent_config
|
||||||
|
|
||||||
|
|
||||||
|
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||||
|
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
checkpointer: Checkpointer,
|
||||||
|
backend: BackendProtocol | BackendFactory,
|
||||||
|
subagents: list[SubAgent | CompiledSubAgent],
|
||||||
|
system_prompt: str | None = TASK_SYSTEM_PROMPT,
|
||||||
|
task_description: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._surf_checkpointer = checkpointer
|
||||||
|
super(SubAgentMiddleware, self).__init__()
|
||||||
|
if not subagents:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one subagent must be specified when using the new API"
|
||||||
|
)
|
||||||
|
self._backend = backend
|
||||||
|
self._subagents = subagents
|
||||||
|
subagent_specs = self._surf_compile_subagent_graphs()
|
||||||
|
task_tool = build_task_tool_with_parent_config(subagent_specs, task_description)
|
||||||
|
if system_prompt and subagent_specs:
|
||||||
|
agents_desc = "\n".join(
|
||||||
|
f"- {s['name']}: {s['description']}" for s in subagent_specs
|
||||||
|
)
|
||||||
|
self.system_prompt = (
|
||||||
|
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.tools = [task_tool]
|
||||||
|
|
||||||
|
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
||||||
|
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
||||||
|
specs: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for spec in self._subagents:
|
||||||
|
if "runnable" in spec:
|
||||||
|
compiled = cast(CompiledSubAgent, spec)
|
||||||
|
specs.append(
|
||||||
|
{
|
||||||
|
"name": compiled["name"],
|
||||||
|
"description": compiled["description"],
|
||||||
|
"runnable": compiled["runnable"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "model" not in spec:
|
||||||
|
msg = f"SubAgent '{spec['name']}' must specify 'model'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if "tools" not in spec:
|
||||||
|
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
model = spec["model"]
|
||||||
|
if isinstance(model, str):
|
||||||
|
model = init_chat_model(model)
|
||||||
|
|
||||||
|
middleware: list[Any] = list(spec.get("middleware", []))
|
||||||
|
|
||||||
|
interrupt_on = spec.get("interrupt_on")
|
||||||
|
if interrupt_on:
|
||||||
|
middleware.append(HumanInTheLoopMiddleware(interrupt_on=interrupt_on))
|
||||||
|
|
||||||
|
specs.append(
|
||||||
|
{
|
||||||
|
"name": spec["name"],
|
||||||
|
"description": spec["description"],
|
||||||
|
"runnable": create_agent(
|
||||||
|
model,
|
||||||
|
system_prompt=spec["system_prompt"],
|
||||||
|
tools=spec["tools"],
|
||||||
|
middleware=middleware,
|
||||||
|
name=spec["name"],
|
||||||
|
checkpointer=self._surf_checkpointer,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return specs
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""Re-raise still-pending subagent interrupts at the parent graph level.
|
||||||
|
|
||||||
|
After ``subagent.[a]invoke(Command(resume=...))`` returns, the subagent may
|
||||||
|
still hold a pending interrupt (e.g. the LLM produced a follow-up tool call
|
||||||
|
that fired a fresh ``interrupt()``). The parent's pregel cannot see that
|
||||||
|
interrupt because it lives in a separate compiled graph; we re-raise it here
|
||||||
|
so the parent's SSE stream surfaces it as the next approval card.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langgraph.types import interrupt as _lg_interrupt
|
||||||
|
|
||||||
|
from .resume import get_first_pending_subagent_interrupt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_propagate_subagent_interrupt(
|
||||||
|
subagent: Runnable,
|
||||||
|
sub_config: dict[str, Any],
|
||||||
|
subagent_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Re-raise a still-pending subagent interrupt at the parent so the SSE stream surfaces it."""
|
||||||
|
get_state_sync = getattr(subagent, "get_state", None)
|
||||||
|
if not callable(get_state_sync):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
snapshot = get_state_sync(sub_config)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logger.debug(
|
||||||
|
"Subagent get_state failed during re-interrupt check",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
||||||
|
if pending_value is None:
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
|
||||||
|
subagent_type,
|
||||||
|
)
|
||||||
|
_lg_interrupt(pending_value)
|
||||||
|
|
||||||
|
|
||||||
|
async def amaybe_propagate_subagent_interrupt(
|
||||||
|
subagent: Runnable,
|
||||||
|
sub_config: dict[str, Any],
|
||||||
|
subagent_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Async counterpart of :func:`maybe_propagate_subagent_interrupt`."""
|
||||||
|
aget_state = getattr(subagent, "aget_state", None)
|
||||||
|
if not callable(aget_state):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
snapshot = await aget_state(sub_config)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logger.debug(
|
||||||
|
"Subagent aget_state failed during re-interrupt check",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
_pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
||||||
|
if pending_value is None:
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"Re-raising subagent %r interrupt to parent (multi-step HITL)",
|
||||||
|
subagent_type,
|
||||||
|
)
|
||||||
|
_lg_interrupt(pending_value)
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
"""Resume-payload shaping and pending-interrupt detection for subagents.
|
||||||
|
|
||||||
|
Splits the work of "given a state snapshot and a parent-stashed resume value,
|
||||||
|
produce the right ``Command(resume=...)`` for the subagent" into pure helpers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
|
||||||
|
def hitlrequest_action_count(pending_value: Any) -> int:
|
||||||
|
"""Bundle size for a LangChain ``HITLRequest`` payload; ``0`` for non-bundle interrupts."""
|
||||||
|
if not isinstance(pending_value, dict):
|
||||||
|
return 0
|
||||||
|
actions = pending_value.get("action_requests")
|
||||||
|
if isinstance(actions, list):
|
||||||
|
return len(actions)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def fan_out_decisions_to_match(resume_value: Any, expected_count: int) -> Any:
|
||||||
|
"""Pad a single-decision resume to N entries so an ``action_requests=N`` bundle accepts it."""
|
||||||
|
if expected_count <= 1:
|
||||||
|
return resume_value
|
||||||
|
if not isinstance(resume_value, dict):
|
||||||
|
return resume_value
|
||||||
|
decisions = resume_value.get("decisions")
|
||||||
|
if not isinstance(decisions, list) or len(decisions) >= expected_count:
|
||||||
|
return resume_value
|
||||||
|
if not decisions:
|
||||||
|
return resume_value
|
||||||
|
padded = list(decisions) + [decisions[-1]] * (expected_count - len(decisions))
|
||||||
|
return {**resume_value, "decisions": padded}
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_pending_subagent_interrupt(state: Any) -> tuple[str | None, Any]:
|
||||||
|
"""First pending ``(interrupt_id, value)`` in the snapshot, else ``(None, None)``.
|
||||||
|
|
||||||
|
The ``id`` lets the caller target ``Command(resume={id: value})`` so the
|
||||||
|
payload is not broadcast to a later fresh interrupt in the same run.
|
||||||
|
"""
|
||||||
|
if state is None:
|
||||||
|
return None, None
|
||||||
|
for it in getattr(state, "interrupts", None) or ():
|
||||||
|
value = getattr(it, "value", None)
|
||||||
|
interrupt_id = getattr(it, "id", None)
|
||||||
|
if value is not None:
|
||||||
|
return (
|
||||||
|
interrupt_id if isinstance(interrupt_id, str) else None,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
for sub_task in getattr(state, "tasks", None) or ():
|
||||||
|
for it in getattr(sub_task, "interrupts", None) or ():
|
||||||
|
value = getattr(it, "value", None)
|
||||||
|
interrupt_id = getattr(it, "id", None)
|
||||||
|
if value is not None:
|
||||||
|
return (
|
||||||
|
interrupt_id if isinstance(interrupt_id, str) else None,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def build_resume_command(resume_value: Any, pending_id: str | None) -> Command:
|
||||||
|
"""``Command(resume={id: value})`` when ``id`` is known, else fall back to scalar."""
|
||||||
|
if pending_id is None:
|
||||||
|
return Command(resume=resume_value)
|
||||||
|
return Command(resume={pending_id: resume_value})
|
||||||
|
|
@ -0,0 +1,224 @@
|
||||||
|
"""Build the ``task`` tool that invokes subagents with HITL bridging.
|
||||||
|
|
||||||
|
The tool's body is the only place where the parent and the subagent meet at
|
||||||
|
runtime: it reads the parent's stashed resume value, decides whether to send
|
||||||
|
fresh state or a targeted ``Command(resume=...)`` to the subagent, then
|
||||||
|
re-raises any new pending interrupt back to the parent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
||||||
|
from langchain.tools import BaseTool, ToolRuntime
|
||||||
|
from langchain_core.messages import HumanMessage, ToolMessage
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from .config import extract_surfsense_resume, subagent_invoke_config
|
||||||
|
from .constants import EXCLUDED_STATE_KEYS
|
||||||
|
from .propagation import (
|
||||||
|
amaybe_propagate_subagent_interrupt,
|
||||||
|
maybe_propagate_subagent_interrupt,
|
||||||
|
)
|
||||||
|
from .resume import (
|
||||||
|
build_resume_command,
|
||||||
|
fan_out_decisions_to_match,
|
||||||
|
hitlrequest_action_count,
|
||||||
|
get_first_pending_subagent_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def build_task_tool_with_parent_config(
|
||||||
|
subagents: list[dict[str, Any]],
|
||||||
|
task_description: str | None = None,
|
||||||
|
) -> BaseTool:
|
||||||
|
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging."""
|
||||||
|
subagent_graphs: dict[str, Runnable] = {
|
||||||
|
spec["name"]: spec["runnable"] for spec in subagents
|
||||||
|
}
|
||||||
|
subagent_description_str = "\n".join(
|
||||||
|
f"- {s['name']}: {s['description']}" for s in subagents
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_description is None:
|
||||||
|
description = TASK_TOOL_DESCRIPTION.format(available_agents=subagent_description_str)
|
||||||
|
elif "{available_agents}" in task_description:
|
||||||
|
description = task_description.format(available_agents=subagent_description_str)
|
||||||
|
else:
|
||||||
|
description = task_description
|
||||||
|
|
||||||
|
def _return_command_with_state_update(result: dict, tool_call_id: str) -> Command:
|
||||||
|
if "messages" not in result:
|
||||||
|
msg = (
|
||||||
|
"CompiledSubAgent must return a state containing a 'messages' key. "
|
||||||
|
"Custom StateGraphs used with CompiledSubAgent should include 'messages' "
|
||||||
|
"in their state schema to communicate results back to the main agent."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
|
||||||
|
message_text = (
|
||||||
|
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
|
||||||
|
)
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
**state_update,
|
||||||
|
"messages": [ToolMessage(message_text, tool_call_id=tool_call_id)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_and_prepare_state(
|
||||||
|
subagent_type: str, description: str, runtime: ToolRuntime
|
||||||
|
) -> tuple[Runnable, dict]:
|
||||||
|
subagent = subagent_graphs[subagent_type]
|
||||||
|
subagent_state = {
|
||||||
|
k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS
|
||||||
|
}
|
||||||
|
subagent_state["messages"] = [HumanMessage(content=description)]
|
||||||
|
return subagent, subagent_state
|
||||||
|
|
||||||
|
def task(
|
||||||
|
description: Annotated[
|
||||||
|
str,
|
||||||
|
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
|
||||||
|
],
|
||||||
|
subagent_type: Annotated[
|
||||||
|
str,
|
||||||
|
"The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime,
|
||||||
|
) -> str | Command:
|
||||||
|
if subagent_type not in subagent_graphs:
|
||||||
|
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||||
|
return (
|
||||||
|
f"We cannot invoke subagent {subagent_type} because it does not exist, "
|
||||||
|
f"the only allowed types are {allowed_types}"
|
||||||
|
)
|
||||||
|
if not runtime.tool_call_id:
|
||||||
|
raise ValueError("Tool call ID is required for subagent invocation")
|
||||||
|
subagent, subagent_state = _validate_and_prepare_state(
|
||||||
|
subagent_type, description, runtime
|
||||||
|
)
|
||||||
|
sub_config = subagent_invoke_config(runtime)
|
||||||
|
|
||||||
|
# Resume bridge: forward the parent's stashed decision into the
|
||||||
|
# subagent's pending ``interrupt()``, targeted by id.
|
||||||
|
pending_id: str | None = None
|
||||||
|
pending_value: Any = None
|
||||||
|
get_state = getattr(subagent, "get_state", None)
|
||||||
|
if callable(get_state):
|
||||||
|
try:
|
||||||
|
snapshot = get_state(sub_config)
|
||||||
|
pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logger.debug(
|
||||||
|
"Subagent get_state failed; falling back to fresh invoke",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pending_value is not None:
|
||||||
|
resume_value = extract_surfsense_resume(runtime)
|
||||||
|
if resume_value is not None:
|
||||||
|
expected = hitlrequest_action_count(pending_value)
|
||||||
|
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||||
|
logger.info(
|
||||||
|
"Forwarding surfsense_resume_value into subagent %r "
|
||||||
|
"(action_requests=%d, targeted_id=%s)",
|
||||||
|
subagent_type,
|
||||||
|
expected,
|
||||||
|
pending_id is not None,
|
||||||
|
)
|
||||||
|
result = subagent.invoke(
|
||||||
|
build_resume_command(resume_value, pending_id),
|
||||||
|
config=sub_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Subagent %r has pending interrupt but no surfsense_resume_value "
|
||||||
|
"on config — replaying with fresh state (interrupt will re-fire).",
|
||||||
|
subagent_type,
|
||||||
|
)
|
||||||
|
result = subagent.invoke(subagent_state, config=sub_config)
|
||||||
|
else:
|
||||||
|
result = subagent.invoke(subagent_state, config=sub_config)
|
||||||
|
maybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
|
||||||
|
return _return_command_with_state_update(result, runtime.tool_call_id)
|
||||||
|
|
||||||
|
async def atask(
|
||||||
|
description: Annotated[
|
||||||
|
str,
|
||||||
|
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
|
||||||
|
],
|
||||||
|
subagent_type: Annotated[
|
||||||
|
str,
|
||||||
|
"The type of subagent to use. Must be one of the available agent types listed in the tool description.", # noqa: E501
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime,
|
||||||
|
) -> str | Command:
|
||||||
|
if subagent_type not in subagent_graphs:
|
||||||
|
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||||
|
return (
|
||||||
|
f"We cannot invoke subagent {subagent_type} because it does not exist, "
|
||||||
|
f"the only allowed types are {allowed_types}"
|
||||||
|
)
|
||||||
|
if not runtime.tool_call_id:
|
||||||
|
raise ValueError("Tool call ID is required for subagent invocation")
|
||||||
|
subagent, subagent_state = _validate_and_prepare_state(
|
||||||
|
subagent_type, description, runtime
|
||||||
|
)
|
||||||
|
sub_config = subagent_invoke_config(runtime)
|
||||||
|
|
||||||
|
# Resume bridge — see ``task`` above.
|
||||||
|
pending_id: str | None = None
|
||||||
|
pending_value: Any = None
|
||||||
|
aget_state = getattr(subagent, "aget_state", None)
|
||||||
|
if callable(aget_state):
|
||||||
|
try:
|
||||||
|
snapshot = await aget_state(sub_config)
|
||||||
|
pending_id, pending_value = get_first_pending_subagent_interrupt(snapshot)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logger.debug(
|
||||||
|
"Subagent aget_state failed; falling back to fresh ainvoke",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pending_value is not None:
|
||||||
|
resume_value = extract_surfsense_resume(runtime)
|
||||||
|
if resume_value is not None:
|
||||||
|
expected = hitlrequest_action_count(pending_value)
|
||||||
|
resume_value = fan_out_decisions_to_match(resume_value, expected)
|
||||||
|
logger.info(
|
||||||
|
"Forwarding surfsense_resume_value into subagent %r "
|
||||||
|
"(action_requests=%d, targeted_id=%s)",
|
||||||
|
subagent_type,
|
||||||
|
expected,
|
||||||
|
pending_id is not None,
|
||||||
|
)
|
||||||
|
result = await subagent.ainvoke(
|
||||||
|
build_resume_command(resume_value, pending_id),
|
||||||
|
config=sub_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Subagent %r has pending interrupt but no surfsense_resume_value "
|
||||||
|
"on config — replaying with fresh state (interrupt will re-fire).",
|
||||||
|
subagent_type,
|
||||||
|
)
|
||||||
|
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||||
|
else:
|
||||||
|
result = await subagent.ainvoke(subagent_state, config=sub_config)
|
||||||
|
await amaybe_propagate_subagent_interrupt(subagent, sub_config, subagent_type)
|
||||||
|
return _return_command_with_state_update(result, runtime.tool_call_id)
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="task",
|
||||||
|
func=task,
|
||||||
|
coroutine=atask,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue