feat: enhance task management and timeout configurations in multi-agent chat

- Added new environment variables for controlling task execution limits, including `SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`, `SURFSENSE_TASK_BATCH_CONCURRENCY`, and `SURFSENSE_TASK_BATCH_MAX_SIZE`.
- Updated documentation to reflect new batch processing capabilities for `task` calls, allowing for concurrent execution of multiple subagent tasks.
- Improved error handling and receipt generation for deliverables, ensuring consistent feedback on task status.
- Refactored middleware to incorporate search space ID for better task management.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-27 14:58:10 -07:00
parent 820f541f08
commit 9d6e9b7e2d
66 changed files with 2561 additions and 380 deletions

View file

@ -2,6 +2,8 @@
from __future__ import annotations
import os
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
EXCLUDED_STATE_KEYS = frozenset(
{
@ -16,3 +18,72 @@ EXCLUDED_STATE_KEYS = frozenset(
# Match the parent graph's budget; the LangGraph default of 25 trips on
# multi-step subagent runs.
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
def _read_timeout_env(name: str, default: float) -> float:
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
Kept as a free function so the module-level constants stay constants
after import; tests can monkeypatch this and re-evaluate via
``importlib.reload`` if they need a different value mid-process.
"""
raw = os.environ.get(name)
if not raw:
return default
try:
value = float(raw)
except (TypeError, ValueError):
return default
return value if value > 0 else default
# Wall-clock budget for a single ``task(subagent, ...)`` invocation.
# Subagents that run hot (image generation with slow vendors, KB writes
# behind a sluggish embedder) can otherwise wedge the orchestrator until
# the next checkpoint heartbeat. ``0`` disables the timeout entirely.
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS: float = _read_timeout_env(
"SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS",
default=300.0,
)
def _read_int_env(name: str, default: int) -> int:
raw = os.environ.get(name)
if not raw:
return default
try:
value = int(raw)
except (TypeError, ValueError):
return default
return value if value > 0 else default
# Maximum number of children that ``task(..., tasks=[...])`` runs in
# parallel via ``asyncio.gather`` + ``Semaphore``. Bounded so a runaway
# fanout cannot starve unrelated subagents (each child still owns an
# LLM call + DB session). Set ``SURFSENSE_TASK_BATCH_CONCURRENCY=1`` to
# effectively serialise batches without changing the schema.
DEFAULT_SUBAGENT_BATCH_CONCURRENCY: int = _read_int_env(
"SURFSENSE_TASK_BATCH_CONCURRENCY",
default=3,
)
# Max number of children in a single batched ``task`` call. Hard upper
# bound is a safety net for prompt-injection / runaway loops; the orchestrator
# rarely needs more than a handful of concurrent specialists.
MAX_SUBAGENT_BATCH_SIZE: int = _read_int_env(
"SURFSENSE_TASK_BATCH_MAX_SIZE",
default=8,
)
# Soft threshold for per-turn cumulative ``task(...)`` invocations across
# **all** subagents. Once the sum of ``state['billable_calls']`` values
# crosses this number, the runtime appends a one-shot warning ToolMessage
# instructing the orchestrator to wrap up the turn. Tunable so heavy-research
# turns (which legitimately need 15+ specialist calls) don't trip the alarm
# in production. Set to ``0`` to disable the warning entirely.
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD: int = _read_int_env(
"SURFSENSE_SUBAGENT_BILLABLE_THRESHOLD",
default=15,
)

View file

@ -16,6 +16,9 @@ from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents.shared.spec import (
SURF_CONTEXT_HINT_PROVIDER_KEY,
)
from app.utils.perf import get_perf_logger
from .task_tool import build_task_tool_with_parent_config
@ -34,6 +37,7 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
subagents: list[SubAgent | CompiledSubAgent],
system_prompt: str | None = TASK_SYSTEM_PROMPT,
task_description: str | None = None,
search_space_id: int | None = None,
) -> None:
self._surf_checkpointer = checkpointer
super(SubAgentMiddleware, self).__init__()
@ -43,8 +47,17 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
)
self._backend = backend
self._subagents = subagents
# Search-space id is captured at build time (the orchestrator runs in
# exactly one search space for its lifetime). The spawn-paused kill
# switch keys on it so an operator can quarantine one workspace
# without affecting the rest of the deployment.
self._search_space_id = search_space_id
subagent_specs = self._surf_compile_subagent_graphs()
task_tool = build_task_tool_with_parent_config(subagent_specs, task_description)
task_tool = build_task_tool_with_parent_config(
subagent_specs,
task_description,
search_space_id=search_space_id,
)
if system_prompt and subagent_specs:
agents_desc = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagent_specs
@ -64,6 +77,10 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
for spec in self._subagents:
spec_start = time.perf_counter()
# Provider may be ``None`` (no hint), in which case task_tool
# skips the prepend step. We forward the key unconditionally so
# the registry shape is uniform.
hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY)
if "runnable" in spec:
compiled = cast(CompiledSubAgent, spec)
specs.append(
@ -71,6 +88,7 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"name": compiled["name"],
"description": compiled["description"],
"runnable": compiled["runnable"],
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
@ -108,6 +126,7 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"name": spec["name"],
"description": spec["description"],
"runnable": runnable,
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(

View file

@ -0,0 +1,84 @@
"""Per-search-space spawn-paused kill switch for the ``task`` boundary.
When operators see a runaway loop, a vendor outage, or a billing event
that requires immediate cessation of subagent traffic for a specific
workspace, they flip a Redis flag and the ``task`` tool short-circuits
without touching downstream services. The flag is **per-search-space**
so one tenant's incident never silences the rest of the deployment.
Flag key: ``surfsense:spawn_paused:{search_space_id}``
Flag value: any string-truthy value (we read presence, not contents).
TTL: set by whoever toggles the flag this module never expires
keys on its own, since "the flag is on" is itself the signal
that a human (or alert) needs to investigate.
The check is best-effort: Redis errors are logged but do not block the
``task`` invocation. Failing closed (block-on-redis-error) would let a
single Redis blip take the whole orchestrator offline; failing open
preserves availability and the alarm bells (rate-limits, cost spikes)
will surface the underlying outage.
"""
from __future__ import annotations
import contextlib
import logging
import os
from app.config import config
logger = logging.getLogger(__name__)
# Operators can disable the check entirely (e.g. local dev without Redis)
# by setting ``SURFSENSE_TASK_SPAWN_PAUSED_DISABLED=1``. Default is
# enabled so production never relies on flipping an opt-out flag.
_DISABLED = os.environ.get(
"SURFSENSE_TASK_SPAWN_PAUSED_DISABLED", ""
).strip().lower() in {
"1",
"true",
"yes",
"on",
}
def _flag_key(search_space_id: int) -> str:
return f"surfsense:spawn_paused:{search_space_id}"
async def is_spawn_paused(search_space_id: int | None) -> bool:
"""Return ``True`` iff the workspace's spawn-paused flag is set in Redis.
A ``None`` search-space (e.g. dev paths that did not plumb the id
through yet) bypasses the check. So does a Redis outage see module
docstring for the fail-open rationale.
"""
if _DISABLED or search_space_id is None:
return False
try:
# Local import keeps the cold-path import cheap and lets routes
# that never call ``task`` skip the redis dependency entirely.
import redis.asyncio as aioredis # type: ignore[import-not-found]
client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
try:
raw = await client.get(_flag_key(search_space_id))
finally:
# ``aclose()`` is the async-safe variant on redis-py >=5; fall back
# to ``close()`` for older clients pinned in tests.
close = getattr(client, "aclose", None) or getattr(client, "close", None)
if callable(close):
with contextlib.suppress(Exception):
await close() # type: ignore[misc]
return bool(raw)
except Exception:
logger.warning(
"spawn_paused check failed for search_space_id=%s; failing open.",
search_space_id,
exc_info=True,
)
return False
__all__ = ["is_spawn_paused"]

View file

@ -8,9 +8,12 @@ re-raises any new pending interrupt back to the parent.
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Annotated, Any, NoReturn
from collections.abc import Awaitable
from typing import Annotated, Any, NoReturn, TypeVar
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
from langchain.tools import BaseTool, ToolRuntime
@ -20,6 +23,10 @@ from langchain_core.tools import StructuredTool
from langgraph.errors import GraphInterrupt
from langgraph.types import Command, Interrupt
from app.agents.multi_agent_chat.subagents.shared.spec import (
SURF_CONTEXT_HINT_PROVIDER_KEY,
ContextHintProvider,
)
from app.observability import metrics as ot_metrics, otel as ot
from app.utils.perf import get_perf_logger
@ -29,7 +36,13 @@ from .config import (
has_surfsense_resume,
subagent_invoke_config,
)
from .constants import EXCLUDED_STATE_KEYS
from .constants import (
DEFAULT_SUBAGENT_BATCH_CONCURRENCY,
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD,
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS,
EXCLUDED_STATE_KEYS,
MAX_SUBAGENT_BATCH_SIZE,
)
from .propagation import wrap_with_tool_call_id
from .resume import (
build_resume_command,
@ -37,11 +50,70 @@ from .resume import (
get_first_pending_subagent_interrupt,
hitlrequest_action_count,
)
from .spawn_paused import is_spawn_paused
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
class SubagentInvokeTimeoutError(Exception):
"""Raised when ``subagent.ainvoke`` exceeds the configured wall-clock budget.
Carries the subagent name and the elapsed seconds so the caller can
synthesize a ToolMessage that the orchestrator can act on (re-route,
surface to the user, or retry with a smaller scope).
"""
def __init__(self, subagent_type: str, elapsed_seconds: float) -> None:
super().__init__(
f"subagent {subagent_type!r} exceeded "
f"{DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS:.0f}s budget "
f"(elapsed={elapsed_seconds:.1f}s)"
)
self.subagent_type = subagent_type
self.elapsed_seconds = elapsed_seconds
_T = TypeVar("_T")
async def _ainvoke_with_timeout[T](
coro: Awaitable[_T], *, subagent_type: str, started_at: float
) -> _T:
"""Apply :data:`DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS` to ``coro``.
A non-positive timeout disables the cap (configurable via the
``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` env var). On expiry the
underlying task is cancelled and :class:`SubagentInvokeTimeoutError` is
raised the caller wraps it into a synthetic ToolMessage so the
orchestrator can decide what to do.
"""
timeout = DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS
if timeout <= 0:
return await coro
try:
return await asyncio.wait_for(coro, timeout=timeout)
except TimeoutError as exc:
elapsed = time.perf_counter() - started_at
raise SubagentInvokeTimeoutError(subagent_type, elapsed) from exc
def _synthesize_timeout_command(
exc: SubagentInvokeTimeoutError, *, tool_call_id: str
) -> Command:
"""Turn a :class:`SubagentInvokeTimeoutError` into a ToolMessage the parent can read."""
content = (
f"Subagent {exc.subagent_type!r} timed out after "
f"{exc.elapsed_seconds:.1f}s (budget="
f"{DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS:.0f}s). "
"The work was cancelled. Treat as status=error; re-route with a "
"narrower scope or different specialist."
)
return Command(
update={"messages": [ToolMessage(content=content, tool_call_id=tool_call_id)]}
)
def _reraise_stamped_subagent_interrupt(
gi: GraphInterrupt, tool_call_id: str
) -> NoReturn:
@ -70,11 +142,24 @@ def _reraise_stamped_subagent_interrupt(
def build_task_tool_with_parent_config(
subagents: list[dict[str, Any]],
task_description: str | None = None,
*,
search_space_id: int | None = None,
) -> BaseTool:
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging."""
subagent_graphs: dict[str, Runnable] = {
spec["name"]: spec["runnable"] for spec in subagents
}
# Per-subagent context-hint providers (see ``SurfSenseSubagentSpec``).
# The mapping is sparse: only routes that opted in via ``pack_subagent``
# appear here, and the value is invoked once per ``task(...)`` call to
# generate a short string prepended to the subagent's first
# ``HumanMessage``. Failures are logged and swallowed — a broken hint
# provider must never prevent the underlying task from running.
subagent_hint_providers: dict[str, ContextHintProvider] = {
spec["name"]: provider
for spec in subagents
if (provider := spec.get(SURF_CONTEXT_HINT_PROVIDER_KEY)) is not None
}
subagent_description_str = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagents
)
@ -88,6 +173,120 @@ def build_task_tool_with_parent_config(
else:
description = task_description
def _billable_call_update(
subagent_type: str, runtime: ToolRuntime
) -> dict[str, Any]:
"""Build the per-call ``billable_calls`` delta + an optional warning.
The orchestrator's ``billable_calls`` map is summed by
:func:`_int_counter_merge_reducer`, so we always emit
``{subagent_type: 1}`` and let the reducer accumulate. If the
cumulative count *after* this call would cross the configured
threshold, we also slip a soft ``messages`` entry into the update
so the orchestrator can read it on its next step and self-limit.
Returning a plain ``dict`` (vs. an extra :class:`Command`) keeps
the helper composable with the existing single/batch return paths.
"""
delta: dict[str, Any] = {"billable_calls": {subagent_type: 1}}
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
if threshold <= 0:
return delta
prior = runtime.state.get("billable_calls") or {}
# ``prior`` may be a plain dict or a reducer-managed mapping; only
# int values are counted so a malformed checkpoint can't crash us.
prior_total = sum(v for v in prior.values() if isinstance(v, int))
new_total = prior_total + 1
if prior_total < threshold <= new_total:
warn = (
f"[budget warning] This turn has dispatched {new_total} "
f"subagent calls (soft cap = {threshold}). Wrap up the "
"user's request with what you have rather than launching "
"more specialists; surface a partial answer if needed."
)
delta["_billable_warn_text"] = warn
return delta
def _attach_billable(
cmd: Command, subagent_type: str, runtime: ToolRuntime
) -> Command:
"""Merge the per-call billable counter (and warning) into ``cmd``."""
delta = _billable_call_update(subagent_type, runtime)
warn_text = delta.pop("_billable_warn_text", None)
# ``cmd.update`` may be a dict or LangGraph ``UpdateDict``; defensively
# copy so we don't mutate state shared across other tool returns.
update = dict(getattr(cmd, "update", {}) or {})
for key, value in delta.items():
update[key] = value
if warn_text:
existing_msgs = list(update.get("messages") or [])
existing_msgs.append(
ToolMessage(content=warn_text, tool_call_id=runtime.tool_call_id)
)
update["messages"] = existing_msgs
return Command(update=update)
def _safe_message_text(msg: Any) -> str:
"""Pull text out of a BaseMessage without trusting the ``.text`` property.
``BaseMessage.text`` walks ``content_blocks`` and crashes with
``TypeError: 'NoneType' object is not iterable`` when ``content`` is
``None`` (common for tool-call AIMessages whose payload is purely
structured). ``getattr(msg, "text", None)`` does not catch this
because Python evaluates the property body before falling back to
the default. Read ``content`` directly and coerce defensively.
"""
try:
content = getattr(msg, "content", None)
except Exception:
content = None
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict):
block_text = block.get("text") or block.get("content")
if isinstance(block_text, str):
parts.append(block_text)
return " ".join(parts)
return str(content)
def _build_tool_trace(messages: list[Any]) -> list[dict[str, Any]]:
"""Compress the subagent's message stream into a compact tool trace.
Each entry is ``{"tool": <name>, "status": "ok"|"error", "preview":
<120 chars>}`` so the orchestrator can show "this is what your
specialist actually did" without dumping the full message stream
back through the prompt. The list is attached to the returned
ToolMessage's ``additional_kwargs`` (under ``"surf_tool_trace"``);
the LLM never sees it, but UI / observability code can pluck it
out of the checkpoint.
"""
trace: list[dict[str, Any]] = []
for msg in messages:
tool_name = getattr(msg, "name", None)
tool_call_id_attr = getattr(msg, "tool_call_id", None)
if not tool_name and not tool_call_id_attr:
# Only ToolMessages have either field; skip AIMessage /
# HumanMessage / SystemMessage frames.
continue
status = getattr(msg, "status", None) or "ok"
preview = _safe_message_text(msg).strip().replace("\n", " ")
if len(preview) > 120:
preview = preview[:117] + "..."
trace.append(
{
"tool": tool_name or "<unknown>",
"status": status,
"preview": preview,
}
)
return trace
def _return_command_with_state_update(result: dict, tool_call_id: str) -> Command:
if "messages" not in result:
msg = (
@ -106,15 +305,51 @@ def build_task_tool_with_parent_config(
"output to forward back to the user."
)
raise ValueError(msg)
last_text = getattr(messages[-1], "text", None) or ""
message_text = last_text.rstrip()
message_text = _safe_message_text(messages[-1]).rstrip()
# Tool-trace is purely observability — wrap defensively so a single
# malformed frame never bubbles up and kills the whole user turn.
try:
tool_trace = _build_tool_trace(messages)
except Exception:
logger.exception(
"Failed to build tool_trace for subagent return; "
"continuing without trace."
)
tool_trace = []
tool_msg = ToolMessage(message_text, tool_call_id=tool_call_id)
if tool_trace:
# ``additional_kwargs`` is a free-form dict on BaseMessage; using
# a ``surf_`` prefix avoids collision with provider-specific keys
# (e.g. Anthropic's ``cache_control``). The LLM doesn't see it;
# consumers (UI, observability) read it off the checkpoint.
tool_msg.additional_kwargs["surf_tool_trace"] = tool_trace
return Command(
update={
**state_update,
"messages": [ToolMessage(message_text, tool_call_id=tool_call_id)],
"messages": [tool_msg],
}
)
def _resolve_context_hint(
subagent_type: str, description: str, runtime: ToolRuntime
) -> str | None:
"""Run the per-subagent hint provider; swallow & log any exception."""
provider = subagent_hint_providers.get(subagent_type)
if provider is None:
return None
try:
hint = provider(runtime.state, description)
except Exception:
logger.exception(
"Context-hint provider for subagent %r raised; skipping hint.",
subagent_type,
)
return None
if not hint or not isinstance(hint, str):
return None
cleaned = hint.strip()
return cleaned or None
def _validate_and_prepare_state(
subagent_type: str, description: str, runtime: ToolRuntime
) -> tuple[Runnable, dict]:
@ -122,20 +357,308 @@ def build_task_tool_with_parent_config(
subagent_state = {
k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS
}
subagent_state["messages"] = [HumanMessage(content=description)]
hint = _resolve_context_hint(subagent_type, description, runtime)
if hint:
# Prepend as a tagged block so the subagent prompt can pattern-match
# on the section (and a future change can lift it into its own
# ``SystemMessage`` if needed).
payload = f"<context_hint>\n{hint}\n</context_hint>\n\n{description}"
else:
payload = description
subagent_state["messages"] = [HumanMessage(content=payload)]
return subagent, subagent_state
def _merge_batch_results(
results: list[tuple[int, str, dict | str, dict | None]],
runtime: ToolRuntime,
) -> Command:
"""Combine per-child results into one Command with a combined ToolMessage.
``results`` is a list of ``(task_index, subagent_type,
payload_or_error_text, child_state_update)`` tuples preserving the
input order so the orchestrator can map each block back to the task
it dispatched. State updates are merged by reducer for keys outside
:data:`EXCLUDED_STATE_KEYS`; everything else (``messages``, ``todos``,
etc.) is replaced by the synthesized aggregate ToolMessage. Every
child also contributes a ``billable_calls`` increment so cost
accounting matches single-mode dispatch.
"""
results.sort(key=lambda r: r[0])
merged_state: dict[str, Any] = {}
billable_delta: dict[str, int] = {}
message_blocks: list[str] = []
batch_trace: list[dict[str, Any]] = []
for task_index, subagent_type, payload, state_update in results:
billable_delta[subagent_type] = billable_delta.get(subagent_type, 0) + 1
if isinstance(payload, str):
# Pre-flight error or per-task exception text.
message_blocks.append(f"[task {task_index}] {payload}")
batch_trace.append(
{
"task_index": task_index,
"subagent_type": subagent_type,
"status": "error",
"tool_trace": [],
}
)
continue
messages = payload.get("messages") or []
last_text = _safe_message_text(messages[-1]).rstrip() if messages else ""
message_blocks.append(
f"[task {task_index}] {last_text or '<empty>'}"
)
try:
child_trace = _build_tool_trace(messages)
except Exception:
logger.exception(
"Failed to build tool_trace for batch task_index=%d; continuing.",
task_index,
)
child_trace = []
batch_trace.append(
{
"task_index": task_index,
"subagent_type": subagent_type,
"status": "ok",
"tool_trace": child_trace,
}
)
if state_update:
# Naive merge: later tasks win on scalar collisions; reducer-backed
# fields (``receipts``, ``files`` etc.) accumulate at apply time.
merged_state.update(state_update)
aggregate = "\n\n".join(message_blocks)
aggregate_msg = ToolMessage(
content=aggregate, tool_call_id=runtime.tool_call_id
)
if batch_trace:
aggregate_msg.additional_kwargs["surf_tool_trace"] = batch_trace
update: dict[str, Any] = {
**merged_state,
"billable_calls": billable_delta,
"messages": [aggregate_msg],
}
# Soft-cap warning: check the cumulative count after attribution.
threshold = DEFAULT_SUBAGENT_BILLABLE_THRESHOLD
if threshold > 0:
prior = runtime.state.get("billable_calls") or {}
prior_total = sum(v for v in prior.values() if isinstance(v, int))
new_total = prior_total + sum(billable_delta.values())
if prior_total < threshold <= new_total:
update["messages"].append(
ToolMessage(
content=(
f"[budget warning] This turn has dispatched "
f"{new_total} subagent calls (soft cap = "
f"{threshold}). Wrap up the user's request with "
"what you have rather than launching more "
"specialists; surface a partial answer if needed."
),
tool_call_id=runtime.tool_call_id,
)
)
return Command(update=update)
async def _ainvoke_one_batch_child(
*,
task_index: int,
subagent_type: str,
description: str,
runtime: ToolRuntime,
semaphore: asyncio.Semaphore,
) -> tuple[int, str, dict | str, dict | None]:
"""Run one child of a batched ``task`` call under the concurrency cap.
Errors are returned as plain text in slot 2 so a single child's
failure does not abort the whole batch. ``GraphInterrupt`` from a
batched child is currently treated as a hard failure for that child
only batched HITL is intentionally out of scope for the v1
rollout (see plan tier 2 item 4 risks).
"""
async with semaphore:
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
return (
task_index,
subagent_type,
(
f"Subagent {subagent_type!r} does not exist; "
f"allowed: {allowed_types}"
),
None,
)
subagent, subagent_state = _validate_and_prepare_state(
subagent_type, description, runtime
)
sub_config = subagent_invoke_config(runtime)
started_at = time.perf_counter()
try:
result = await _ainvoke_with_timeout(
subagent.ainvoke(subagent_state, config=sub_config),
subagent_type=subagent_type,
started_at=started_at,
)
except SubagentInvokeTimeoutError as exc:
logger.warning(
"Batch child %d (%s) timed out after %.1fs",
task_index,
subagent_type,
exc.elapsed_seconds,
)
return (task_index, subagent_type, str(exc), None)
except GraphInterrupt:
# Batched HITL is unsupported in v1 — surface as a failure
# for this child so the rest of the batch still completes.
logger.warning(
"Batch child %d (%s) raised GraphInterrupt; batched HITL "
"is not supported. Re-dispatch this task as a single "
"(non-batched) `task(...)` call to get the HITL prompt.",
task_index,
subagent_type,
)
return (
task_index,
subagent_type,
(
f"Subagent {subagent_type!r} needs human approval. "
"Re-dispatch this task as a single (non-batched) "
"`task(...)` call so the approval card can be shown."
),
None,
)
except Exception as exc:
logger.exception(
"Batch child %d (%s) raised: %s",
task_index,
subagent_type,
exc,
)
return (
task_index,
subagent_type,
f"Subagent {subagent_type!r} error: {exc}",
None,
)
child_state_update = {
k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS
}
return (task_index, subagent_type, result, child_state_update)
def _coerce_batch_arg(tasks: Any) -> list[dict] | str:
"""Rescue common LLM-side malformations of the ``tasks`` argument.
Some providers serialise an array argument as a JSON-encoded string,
and small models occasionally hand back a single ``{description,
subagent_type}`` dict instead of a one-element array. Both are
recovered here with a WARN log so the issue is visible in metrics
but the user's turn still completes; truly broken shapes return a
plain string that the caller surfaces as the tool error.
"""
if isinstance(tasks, list):
return tasks
if isinstance(tasks, dict):
logger.warning(
"task: `tasks` was a single dict; coercing to a 1-element list. "
"Orchestrators should send `tasks=[{...}]` directly."
)
return [tasks]
if isinstance(tasks, str):
stripped = tasks.strip()
if not stripped:
return "tasks: argument is empty."
try:
parsed = json.loads(stripped)
except json.JSONDecodeError as exc:
return (
f"tasks: argument is a string but not valid JSON ({exc.msg}). "
"Send a JSON array of `{description, subagent_type}` objects."
)
logger.warning(
"task: `tasks` was a JSON-encoded string; parsed to %s. "
"Orchestrators should send a JSON array directly.",
type(parsed).__name__,
)
return _coerce_batch_arg(parsed)
return (
f"tasks: unsupported type {type(tasks).__name__}; expected an array "
"of `{description, subagent_type}` objects."
)
async def _adispatch_batch(
tasks: list[dict], runtime: ToolRuntime
) -> Command | str:
"""Fan-out helper for the ``tasks`` array shape.
Bounded by :data:`MAX_SUBAGENT_BATCH_SIZE` and concurrency-capped
at :data:`DEFAULT_SUBAGENT_BATCH_CONCURRENCY`. Returns a single
:class:`Command` that the LLM sees as one ToolMessage per child,
prefixed with ``[task <index>]`` so it can map back to the input
order.
"""
if not tasks:
return "tasks: array is empty; nothing to dispatch."
if len(tasks) > MAX_SUBAGENT_BATCH_SIZE:
return (
f"tasks: too many children ({len(tasks)}); "
f"max is {MAX_SUBAGENT_BATCH_SIZE}. Split the batch."
)
normalized: list[tuple[int, str, str]] = []
for idx, item in enumerate(tasks):
if not isinstance(item, dict):
return (
f"tasks[{idx}]: must be an object with description+subagent_type."
)
description = item.get("description")
subagent_type = item.get("subagent_type")
if not isinstance(description, str) or not description.strip():
return f"tasks[{idx}]: missing or empty 'description'."
if not isinstance(subagent_type, str) or not subagent_type.strip():
return f"tasks[{idx}]: missing or empty 'subagent_type'."
normalized.append((idx, subagent_type.strip(), description))
semaphore = asyncio.Semaphore(DEFAULT_SUBAGENT_BATCH_CONCURRENCY)
coros = [
_ainvoke_one_batch_child(
task_index=idx,
subagent_type=subagent_type,
description=description,
runtime=runtime,
semaphore=semaphore,
)
for idx, subagent_type, description in normalized
]
results = await asyncio.gather(*coros)
return _merge_batch_results(list(results), runtime)
def task(
description: Annotated[
str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
],
str | None,
"Single-mode: a detailed task description for the subagent. Required unless `tasks` is provided.",
] = None,
subagent_type: Annotated[
str,
"The type of subagent to use. Must be one of the available agent types listed in the tool description.",
],
runtime: ToolRuntime,
str | None,
"Single-mode: the type of subagent to use. Required unless `tasks` is provided.",
] = None,
runtime: ToolRuntime = None, # type: ignore[assignment]
tasks: Annotated[
list[dict] | None,
(
"Batch-mode: array of `{description, subagent_type}` objects. "
"Synchronous path does not support batch mode; orchestrators "
"must use the async event loop to fan out."
),
] = None,
) -> str | Command:
if tasks is not None:
return (
"task: batch mode (`tasks=[...]`) is only supported on the async "
"path. SurfSense orchestrators always run in an event loop, so "
"this should never fire — file a bug if you see it."
)
if not description or not subagent_type:
return (
"task: must provide either single-mode (`description`+`subagent_type`) "
"or batch-mode (`tasks`)."
)
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
return (
@ -284,16 +807,65 @@ def build_task_tool_with_parent_config(
async def atask(
description: Annotated[
str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.",
],
str | None,
"Single-mode: a detailed task description for the subagent. Required unless `tasks` is provided.",
] = None,
subagent_type: Annotated[
str,
"The type of subagent to use. Must be one of the available agent types listed in the tool description.",
],
runtime: ToolRuntime,
str | None,
"Single-mode: the type of subagent to use. Required unless `tasks` is provided.",
] = None,
runtime: ToolRuntime = None, # type: ignore[assignment]
tasks: Annotated[
list[dict] | None,
(
"Batch-mode: array of `{description, subagent_type}` objects "
"to fan out concurrently (max "
f"{MAX_SUBAGENT_BATCH_SIZE}, concurrency "
f"{DEFAULT_SUBAGENT_BATCH_CONCURRENCY}). Mutually exclusive "
"with single-mode args. Batched children do not support "
"human-in-the-loop interrupts; re-dispatch as single mode "
"if a child needs approval."
),
] = None,
) -> str | Command:
atask_start = time.perf_counter()
# Kill switch: when ops flips the spawn-paused flag for this
# workspace, every ``task(...)`` invocation (single- or batch-mode)
# short-circuits with a clear ToolMessage so the orchestrator can
# tell the user what happened and stop hammering downstream APIs.
if await is_spawn_paused(search_space_id):
logger.warning(
"[hitl_route] atask SPAWN_PAUSED: search_space_id=%s tool_call_id=%s",
search_space_id,
runtime.tool_call_id,
)
return (
"task: subagent dispatch is currently paused for this workspace. "
"Acknowledge to the user that delegation is temporarily disabled "
"(ops kill switch); do not retry until the pause is lifted."
)
if tasks is not None:
if description or subagent_type:
return (
"task: cannot combine `tasks` with `description`/`subagent_type`. "
"Use either single-mode (description+subagent_type) or batch-mode (tasks)."
)
if not runtime.tool_call_id:
raise ValueError("Tool call ID is required for subagent invocation")
coerced = _coerce_batch_arg(tasks)
if isinstance(coerced, str):
return coerced
logger.info(
"[hitl_route] atask BATCH ENTRY: size=%d tool_call_id=%s",
len(coerced),
runtime.tool_call_id,
)
return await _adispatch_batch(coerced, runtime)
if not description or not subagent_type:
return (
"task: must provide either single-mode (`description`+`subagent_type`) "
"or batch-mode (`tasks`)."
)
logger.info(
"[hitl_route] atask ENTRY: subagent_type=%r tool_call_id=%s",
subagent_type,
@ -358,11 +930,37 @@ def build_task_tool_with_parent_config(
subagent_type=subagent_type, path=invoke_path
) as sp:
try:
result = await subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
result = await _ainvoke_with_timeout(
subagent.ainvoke(
build_resume_command(resume_value, pending_id),
config=sub_config,
),
subagent_type=subagent_type,
started_at=ainvoke_start,
)
sp.set_attribute("subagent.outcome", ainvoke_outcome)
except SubagentInvokeTimeoutError as exc:
ainvoke_outcome = "timeout"
sp.set_attribute("subagent.outcome", ainvoke_outcome)
ot_metrics.record_subagent_invoke_duration(
(time.perf_counter() - ainvoke_start) * 1000,
subagent_type=subagent_type,
path=invoke_path,
outcome=ainvoke_outcome,
)
ot_metrics.record_subagent_invoke_outcome(
subagent_type=subagent_type,
path=invoke_path,
outcome=ainvoke_outcome,
)
logger.warning(
"Subagent %r ainvoke (resume) timed out after %.1fs",
subagent_type,
exc.elapsed_seconds,
)
return _synthesize_timeout_command(
exc, tool_call_id=runtime.tool_call_id
)
except GraphInterrupt as gi:
ainvoke_outcome = "interrupted"
sp.set_attribute("subagent.outcome", ainvoke_outcome)
@ -408,10 +1006,34 @@ def build_task_tool_with_parent_config(
subagent_type=subagent_type, path=invoke_path
) as sp:
try:
result = await subagent.ainvoke(
subagent_state, config=sub_config
result = await _ainvoke_with_timeout(
subagent.ainvoke(subagent_state, config=sub_config),
subagent_type=subagent_type,
started_at=ainvoke_start,
)
sp.set_attribute("subagent.outcome", ainvoke_outcome)
except SubagentInvokeTimeoutError as exc:
ainvoke_outcome = "timeout"
sp.set_attribute("subagent.outcome", ainvoke_outcome)
ot_metrics.record_subagent_invoke_duration(
(time.perf_counter() - ainvoke_start) * 1000,
subagent_type=subagent_type,
path=invoke_path,
outcome=ainvoke_outcome,
)
ot_metrics.record_subagent_invoke_outcome(
subagent_type=subagent_type,
path=invoke_path,
outcome=ainvoke_outcome,
)
logger.warning(
"Subagent %r ainvoke (fresh) timed out after %.1fs",
subagent_type,
exc.elapsed_seconds,
)
return _synthesize_timeout_command(
exc, tool_call_id=runtime.tool_call_id
)
except GraphInterrupt as gi:
ainvoke_outcome = "interrupted"
sp.set_attribute("subagent.outcome", ainvoke_outcome)
@ -481,7 +1103,7 @@ def build_task_tool_with_parent_config(
path=invoke_path,
outcome=ainvoke_outcome,
)
return cmd
return _attach_billable(cmd, subagent_type, runtime)
return StructuredTool.from_function(
name="task",

View file

@ -52,9 +52,7 @@ class KbContextProjectionMiddleware(AgentMiddleware): # type: ignore[type-arg]
messages.insert(insert_at, SystemMessage(content=tree_text))
priority_count = 0
if priority:
priority_count = (
len(priority) if hasattr(priority, "__len__") else 1
)
priority_count = len(priority) if hasattr(priority, "__len__") else 1
messages.insert(insert_at, _render_priority_message(priority))
_perf_log.info(
"[kb_context_projection] tree_chars=%d priority_items=%d elapsed=%.3fs",

View file

@ -17,8 +17,7 @@ from langchain_core.tools import BaseTool
from langgraph.types import interrupt
from app.agents.new_chat.permissions import Rule
from app.observability import metrics as ot_metrics
from app.observability import otel as ot
from app.observability import metrics as ot_metrics, otel as ot
from .decision import normalize_permission_decision
from .payload import PERMISSION_ASK_INTERRUPT_TYPE, build_permission_ask_payload

View file

@ -173,6 +173,7 @@ def build_main_agent_deepagent_middleware(
subagents=subagents,
system_prompt=None,
task_description=TASK_TOOL_DESCRIPTION,
search_space_id=search_space_id,
),
resilience.model_call_limit,
resilience.tool_call_limit,