refactor(agents): colocate main-agent middleware under main_agent/ slice

Vertical-slice colocation: all main-agent code should live under
main_agent/ instead of being split across a parallel middleware/main_agent
tree. Move multi_agent_chat/middleware/main_agent/ -> main_agent/middleware/
and its assembler middleware/stack.py -> main_agent/middleware/stack.py, so
the main-agent slice is self-contained (graph, runtime, system_prompt, tools,
middleware).

Genuinely cross-slice middleware (middleware/shared/, middleware/subagent/)
stays under multi_agent_chat/middleware/ for a later slice; the moved builders
now reference it via absolute imports.

Pure move + import rewrite (git-tracked renames). Verified: full unit suite
green (2430 passed, 1 skipped), including test_import_all and the
checkpointed-subagent middleware suite.
This commit is contained in:
CREDO23 2026-06-04 18:03:49 +02:00
parent 1acde6a470
commit 9c845d562e
42 changed files with 60 additions and 58 deletions

View file

@ -1,36 +0,0 @@
"""Audit row per tool call (reversibility metadata)."""
from __future__ import annotations
import logging
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import ActionLogMiddleware
from app.agents.shared.tools.registry import BUILTIN_TOOLS
from ..shared.flags import enabled
def build_action_log_mw(
*,
flags: AgentFeatureFlags,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
) -> ActionLogMiddleware | None:
if not enabled(flags, "enable_action_log") or thread_id is None:
return None
try:
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
return ActionLogMiddleware(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
tool_definitions=tool_defs_by_name,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
return None

View file

@ -1,16 +0,0 @@
"""Anonymous document hydration from Redis (cloud only)."""
from __future__ import annotations
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.middleware import AnonymousDocumentMiddleware
def build_anonymous_doc_mw(
*,
filesystem_mode: FilesystemMode,
anon_session_id: str | None,
) -> AnonymousDocumentMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)

View file

@ -1,12 +0,0 @@
"""Per-thread cooperative lock around the whole turn."""
from __future__ import annotations
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import BusyMutexMiddleware
from ..shared.flags import enabled
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None

View file

@ -1,32 +0,0 @@
"""SubAgent ``task`` tool wiring required for HITL inside subagents.
Replaces upstream ``SubAgentMiddleware`` to:
- share the parent's checkpointer with each subagent,
- forward ``runtime.config`` (thread_id, recursion_limit, ) into nested invokes,
- isolate each parallel ``task`` call in its own checkpoint slot via
per-call ``thread_id`` namespacing,
- bridge ``Command(resume=...)`` from the parent into the subagent via the
``config["configurable"]["surfsense_resume_value"]`` side-channel, keyed by
``tool_call_id`` so parallel siblings never race on a shared scalar,
- target the resume at the captured interrupt id so a follow-up
``HumanInTheLoopMiddleware.after_model`` does not consume the same payload,
- stamp each subagent's pending interrupt with the parent's ``tool_call_id``
so ``stream_resume_chat`` can route a flat ``decisions`` list back to the
right paused subagent.
Module layout
-------------
- ``constants`` shared keys / limits.
- ``config`` RunnableConfig + side-channel resume read + per-call ``thread_id``.
- ``resume`` pending-interrupt detection, fan-out, ``Command(resume=...)`` builder.
- ``propagation`` ``wrap_with_tool_call_id`` helper for stamping interrupt values.
- ``resume_routing`` slice a flat decisions list to per-``tool_call_id`` payloads.
- ``task_tool`` the ``task`` tool factory (sync + async), and the catch-and-stamp chokepoint.
- ``middleware`` :class:`SurfSenseCheckpointedSubAgentMiddleware` itself.
"""
from .middleware import SurfSenseCheckpointedSubAgentMiddleware
__all__ = ["SurfSenseCheckpointedSubAgentMiddleware"]

View file

@ -1,125 +0,0 @@
"""RunnableConfig wiring for nested subagent invocations.
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain.tools import ToolRuntime
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
logger = logging.getLogger(__name__)
# langgraph stores the parent task's scratchpad under this configurable key;
# subagents inherit the chain via ``parent_scratchpad`` fallback.
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
Each parallel subagent invocation lands in its own checkpoint slot keyed
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
The same call across the resume cycle keeps reading from the same snapshot
(``tool_call_id`` is stable per LLM-emitted call).
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``.
"""
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
current_limit = merged.get("recursion_limit")
try:
current_int = int(current_limit) if current_limit is not None else 0
except (TypeError, ValueError):
current_int = 0
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
parent_thread_id = configurable.get("thread_id")
per_call_suffix = f"task:{runtime.tool_call_id}"
configurable["thread_id"] = (
f"{parent_thread_id}::{per_call_suffix}"
if parent_thread_id
else per_call_suffix
)
merged["configurable"] = configurable
return merged
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
"""Pop the resume payload for *this* call's ``tool_call_id``.
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
so parallel sibling subagents (each with their own ``tool_call_id``) read
only their own decision and never race on a shared scalar.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return None
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return None
payload = by_tcid.pop(runtime.tool_call_id, None)
if not by_tcid:
configurable.pop("surfsense_resume_value", None)
return payload
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return False
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return False
return runtime.tool_call_id in by_tcid
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
``stream_resume_chat`` wakes the main agent with
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
propagated parent-level interrupt can return. langgraph stores that
payload as the parent task's ``null_resume`` pending write. The ``task``
tool then forwards this turn's slice into the subagent via its own
``Command(resume=...)``. While the subagent is mid-execution, any *new*
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
approve/reject) walks ``subagent_scratchpad parent_scratchpad.get_null_resume``
and picks up the parent's still-live decisions — mismatching against a
different number of hanging tool calls and crashing
``HumanInTheLoopMiddleware``.
Draining the write here closes that cross-graph leak so subagent
interrupts pause cleanly and bubble back up as a fresh approval card.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return
scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY)
if scratchpad is None:
return
consume = getattr(scratchpad, "get_null_resume", None)
if not callable(consume):
return
try:
consume(True)
except Exception:
# Defensive: if langgraph's internal scratchpad shape changes we don't
# want to break the resume path. Worst case the original ValueError
# still surfaces — same behavior as before this fix.
logger.debug(
"drain_parent_null_resume: scratchpad.get_null_resume raised",
exc_info=True,
)

View file

@ -1,89 +0,0 @@
"""Constants shared by the checkpointed subagent middleware."""
from __future__ import annotations
import os
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
EXCLUDED_STATE_KEYS = frozenset(
{
"messages",
"todos",
"structured_response",
"skills_metadata",
"memory_contents",
}
)
# Match the parent graph's budget; the LangGraph default of 25 trips on
# multi-step subagent runs.
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
def _read_timeout_env(name: str, default: float) -> float:
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
Kept as a free function so the module-level constants stay constants
after import; tests can monkeypatch this and re-evaluate via
``importlib.reload`` if they need a different value mid-process.
"""
raw = os.environ.get(name)
if not raw:
return default
try:
value = float(raw)
except (TypeError, ValueError):
return default
return value if value > 0 else default
# Wall-clock budget for a single ``task(subagent, ...)`` invocation.
# Subagents that run hot (image generation with slow vendors, KB writes
# behind a sluggish embedder) can otherwise wedge the orchestrator until
# the next checkpoint heartbeat. ``0`` disables the timeout entirely.
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS: float = _read_timeout_env(
"SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS",
default=300.0,
)
def _read_int_env(name: str, default: int) -> int:
raw = os.environ.get(name)
if not raw:
return default
try:
value = int(raw)
except (TypeError, ValueError):
return default
return value if value > 0 else default
# Maximum number of children that ``task(..., tasks=[...])`` runs in
# parallel via ``asyncio.gather`` + ``Semaphore``. Bounded so a runaway
# fanout cannot starve unrelated subagents (each child still owns an
# LLM call + DB session). Set ``SURFSENSE_TASK_BATCH_CONCURRENCY=1`` to
# effectively serialise batches without changing the schema.
DEFAULT_SUBAGENT_BATCH_CONCURRENCY: int = _read_int_env(
"SURFSENSE_TASK_BATCH_CONCURRENCY",
default=3,
)
# Max number of children in a single batched ``task`` call. Hard upper
# bound is a safety net for prompt-injection / runaway loops; the orchestrator
# rarely needs more than a handful of concurrent specialists.
MAX_SUBAGENT_BATCH_SIZE: int = _read_int_env(
"SURFSENSE_TASK_BATCH_MAX_SIZE",
default=8,
)
# Soft threshold for per-turn cumulative ``task(...)`` invocations across
# **all** subagents. Once the sum of ``state['billable_calls']`` values
# crosses this number, the runtime appends a one-shot warning ToolMessage
# instructing the orchestrator to wrap up the turn. Tunable so heavy-research
# turns (which legitimately need 15+ specialist calls) don't trip the alarm
# in production. Set to ``0`` to disable the warning entirely.
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD: int = _read_int_env(
"SURFSENSE_SUBAGENT_BILLABLE_THRESHOLD",
default=15,
)

View file

@ -1,152 +0,0 @@
"""SubAgent middleware that compiles each subagent against the parent checkpointer."""
from __future__ import annotations
import time
from typing import Any, cast
from deepagents.backends.protocol import BackendFactory, BackendProtocol
from deepagents.middleware.subagents import (
TASK_SYSTEM_PROMPT,
CompiledSubAgent,
SubAgent,
SubAgentMiddleware,
)
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents.shared.spec import (
SURF_CONTEXT_HINT_PROVIDER_KEY,
)
from app.utils.perf import get_perf_logger
from .task_tool import build_task_tool_with_parent_config
_perf_log = get_perf_logger()
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
def __init__(
self,
*,
checkpointer: Checkpointer,
backend: BackendProtocol | BackendFactory,
subagents: list[SubAgent | CompiledSubAgent],
system_prompt: str | None = TASK_SYSTEM_PROMPT,
task_description: str | None = None,
search_space_id: int | None = None,
) -> None:
self._surf_checkpointer = checkpointer
super(SubAgentMiddleware, self).__init__()
if not subagents:
raise ValueError(
"At least one subagent must be specified when using the new API"
)
self._backend = backend
self._subagents = subagents
# Search-space id is captured at build time (the orchestrator runs in
# exactly one search space for its lifetime). The spawn-paused kill
# switch keys on it so an operator can quarantine one workspace
# without affecting the rest of the deployment.
self._search_space_id = search_space_id
subagent_specs = self._surf_compile_subagent_graphs()
task_tool = build_task_tool_with_parent_config(
subagent_specs,
task_description,
search_space_id=search_space_id,
)
if system_prompt and subagent_specs:
agents_desc = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagent_specs
)
self.system_prompt = (
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
)
else:
self.system_prompt = system_prompt
self.tools = [task_tool]
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
specs: list[dict[str, Any]] = []
loop_start = time.perf_counter()
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
for spec in self._subagents:
spec_start = time.perf_counter()
# Provider may be ``None`` (no hint), in which case task_tool
# skips the prepend step. We forward the key unconditionally so
# the registry shape is uniform.
hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY)
if "runnable" in spec:
compiled = cast(CompiledSubAgent, spec)
specs.append(
{
"name": compiled["name"],
"description": compiled["description"],
"runnable": compiled["runnable"],
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
)
continue
if "model" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'model'"
raise ValueError(msg)
if "tools" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
raise ValueError(msg)
model = spec["model"]
if isinstance(model, str):
model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", []))
tools_count = len(spec.get("tools") or [])
mw_count = len(middleware)
compile_start = time.perf_counter()
runnable = create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
)
compile_elapsed = time.perf_counter() - compile_start
specs.append(
{
"name": spec["name"],
"description": spec["description"],
"runnable": runnable,
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
(
spec["name"],
compile_elapsed,
f"compiled tools={tools_count} mw={mw_count}",
)
)
total_elapsed = time.perf_counter() - loop_start
per_subagent = ", ".join(
f"{name}={elapsed * 1000:.0f}ms[{source}]"
for name, elapsed, source in timings
)
_perf_log.info(
"[subagent_compile] total=%.3fs count=%d details=[%s]",
total_elapsed,
len(timings),
per_subagent,
)
return specs

View file

@ -1,38 +0,0 @@
"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value.
When a subagent (compiled as a langgraph subgraph and invoked from a parent
tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph
raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's
``task`` tool catches that exception, stamps ``tool_call_id`` onto each
``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a
fresh ``GraphInterrupt`` whose values carry that stamp.
``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]``
to route a flat ``decisions`` list back to the right paused subagent without
the stamp, parallel HITL across siblings would collapse into an ambiguous
bucket and resume would fail.
This module hosts only the stamping helper; the catch/re-raise lives in
``task_tool.py`` since that's the single chokepoint where the raw exception
is in our hands.
"""
from __future__ import annotations
from typing import Any
def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]:
"""Return a value dict that always carries the parent's ``tool_call_id``.
Dict values are shallow-copied with ``tool_call_id`` stamped on top, so
any value the subagent may already carry under that key (from a deeper
HITL level) is overwritten the parent's call id is the only one
``stream_resume_chat`` correlates against.
Non-dict values are wrapped as ``{"value": <original>, "tool_call_id": ...}``
so simple ``interrupt("approve?")`` patterns still propagate cleanly.
"""
if isinstance(value, dict):
return {**value, "tool_call_id": tool_call_id}
return {"value": value, "tool_call_id": tool_call_id}

View file

@ -1,76 +0,0 @@
"""Resume-payload shaping and pending-interrupt detection for subagents.
Splits the work of "given a state snapshot and a parent-stashed resume value,
produce the right ``Command(resume=...)`` for the subagent" into pure helpers.
"""
from __future__ import annotations
from typing import Any
from langgraph.types import Command
def hitlrequest_action_count(pending_value: Any) -> int:
"""Bundle size for a LangChain ``HITLRequest`` payload; ``0`` for non-bundle interrupts."""
if not isinstance(pending_value, dict):
return 0
actions = pending_value.get("action_requests")
if isinstance(actions, list):
return len(actions)
return 0
def fan_out_decisions_to_match(resume_value: Any, expected_count: int) -> Any:
"""Legacy fallback: pad a 1-decision resume to N for an ``action_requests=N`` bundle.
Modern frontend submits N decisions per bundle (one per action_request) so
this is a no-op; kept for backwards compatibility with old in-flight
threads or non-bundle clients that send a single decision.
"""
if expected_count <= 1:
return resume_value
if not isinstance(resume_value, dict):
return resume_value
decisions = resume_value.get("decisions")
if not isinstance(decisions, list) or len(decisions) >= expected_count:
return resume_value
if not decisions:
return resume_value
padded = list(decisions) + [decisions[-1]] * (expected_count - len(decisions))
return {**resume_value, "decisions": padded}
def get_first_pending_subagent_interrupt(state: Any) -> tuple[str | None, Any]:
"""First pending ``(interrupt_id, value)``; ``(None, None)`` if no interrupt.
Assumes at most one pending interrupt per snapshot (sequential tool nodes).
Parallel tool nodes would need an id-aware lookup instead of first-wins.
"""
if state is None:
return None, None
for it in getattr(state, "interrupts", None) or ():
value = getattr(it, "value", None)
interrupt_id = getattr(it, "id", None)
if value is not None:
return (
interrupt_id if isinstance(interrupt_id, str) else None,
value,
)
for sub_task in getattr(state, "tasks", None) or ():
for it in getattr(sub_task, "interrupts", None) or ():
value = getattr(it, "value", None)
interrupt_id = getattr(it, "id", None)
if value is not None:
return (
interrupt_id if isinstance(interrupt_id, str) else None,
value,
)
return None, None
def build_resume_command(resume_value: Any, pending_id: str | None) -> Command:
"""``Command(resume={id: value})`` when ``id`` is known, else fall back to scalar."""
if pending_id is None:
return Command(resume=resume_value)
return Command(resume={pending_id: resume_value})

View file

@ -1,183 +0,0 @@
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
The frontend submits decisions in the same order the SSE stream emitted
approval cards. When multiple parallel subagents are paused, the backend uses
this module to:
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
inside ``task_tool``'s catch-and-stamp block when a subagent's
``GraphInterrupt`` bubbles up through ``[a]task``.
2. Slice the flat ``decisions`` list against that ordered pending list to
produce the dict shape expected by ``consume_surfsense_resume``.
3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as
the parent-level ``Command(resume={interrupt_id: payload})`` input the
only shape langgraph accepts when multiple interrupts are pending.
All helpers are pure: callers own the state and the input decisions; we
return new structures and never mutate.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
from typing import Any
logger = logging.getLogger(__name__)
def slice_decisions_by_tool_call(
decisions: list[dict[str, Any]],
pending: Iterable[tuple[str, int]],
) -> dict[str, dict[str, Any]]:
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
Args:
decisions: Flat list of decisions in the order the SSE stream rendered
them.
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
order. The slicer consumes ``decisions`` left-to-right.
Returns:
Per-``tool_call_id`` payload dict ready to be written to
``configurable["surfsense_resume_value"]``.
Raises:
ValueError: When the total expected action count differs from the
number of decisions provided. We fail loud rather than silently
dropping or padding so a frontend/backend contract drift surfaces
immediately.
"""
pending_list = list(pending)
expected = sum(count for _, count in pending_list)
if expected != len(decisions):
raise ValueError(
f"Decision count mismatch: pending tool calls expect "
f"{expected} actions but received {len(decisions)} decisions."
)
routed: dict[str, dict[str, Any]] = {}
cursor = 0
for tool_call_id, action_count in pending_list:
routed[tool_call_id] = {"decisions": decisions[cursor : cursor + action_count]}
cursor += action_count
return routed
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
paused subagent's propagated interrupt). Each interrupt value carries the
``tool_call_id`` that the parent's ``task`` tool was processing — see
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
``except GraphInterrupt`` chokepoint.
Order is preserved from ``state.interrupts``, which is the order the SSE
stream emitted approval cards. The frontend submits decisions in that
same order, so the slicer can consume them left-to-right.
Interrupts without a ``tool_call_id`` are skipped they were not
produced by our task-routing layer (e.g. parent-side HITL middleware on
a different tool); ``stream_resume_chat`` is not responsible for routing
those.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` attribute).
Returns:
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
scalar-style ``interrupt("...")`` values that were wrapped as
``{"value": ..., "tool_call_id": ...}``.
Raises:
ValueError: When an interrupt value carries a ``tool_call_id`` but
the action count cannot be determined (contract bug every
propagated value should be either a HITL bundle or a wrapped
scalar).
"""
pending: list[tuple[str, int]] = []
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
logger.warning(
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
idx,
type(value).__name__,
)
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
# Should not happen post-stamping; flag loudly if a regression
# ever lets an unstamped value reach the parent state.
logger.warning(
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
idx,
sorted(value.keys()),
)
continue
action_requests = value.get("action_requests")
if isinstance(action_requests, list):
pending.append((tool_call_id, len(action_requests)))
continue
if "value" in value:
pending.append((tool_call_id, 1))
continue
raise ValueError(
f"Interrupt for tool_call_id={tool_call_id!r} has no "
"``action_requests`` list and is not a wrapped scalar value; "
"cannot determine action count for resume routing."
)
return pending
def build_lg_resume_map(
state: Any, by_tool_call_id: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
"""Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume.
``stream_resume_chat`` builds ``by_tool_call_id`` via
:func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)``
requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the
parent state has multiple pending interrupts. This pure helper re-keys the
slice without mutating it, and skips entries that can't be paired (no
stamp, no slice) so contract drift surfaces as a count mismatch at the
call site instead of a silent mis-route.
The two key spaces serve two different consumers:
- ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the
subagent bridge inside ``task_tool``.
- ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's
pregel to wake each pending interrupt site.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` iterable).
by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`.
Returns:
Dict ready to be passed as ``Command(resume=<this>)``.
"""
out: dict[str, dict[str, Any]] = {}
for interrupt_obj in getattr(state, "interrupts", ()) or ():
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
continue
interrupt_id = getattr(interrupt_obj, "id", None)
if not isinstance(interrupt_id, str):
continue
payload = by_tool_call_id.get(tool_call_id)
if payload is None:
continue
out[interrupt_id] = payload
return out

View file

@ -1,84 +0,0 @@
"""Per-search-space spawn-paused kill switch for the ``task`` boundary.
When operators see a runaway loop, a vendor outage, or a billing event
that requires immediate cessation of subagent traffic for a specific
workspace, they flip a Redis flag and the ``task`` tool short-circuits
without touching downstream services. The flag is **per-search-space**
so one tenant's incident never silences the rest of the deployment.
Flag key: ``surfsense:spawn_paused:{search_space_id}``
Flag value: any string-truthy value (we read presence, not contents).
TTL: set by whoever toggles the flag this module never expires
keys on its own, since "the flag is on" is itself the signal
that a human (or alert) needs to investigate.
The check is best-effort: Redis errors are logged but do not block the
``task`` invocation. Failing closed (block-on-redis-error) would let a
single Redis blip take the whole orchestrator offline; failing open
preserves availability and the alarm bells (rate-limits, cost spikes)
will surface the underlying outage.
"""
from __future__ import annotations
import contextlib
import logging
import os
from app.config import config
logger = logging.getLogger(__name__)
# Operators can disable the check entirely (e.g. local dev without Redis)
# by setting ``SURFSENSE_TASK_SPAWN_PAUSED_DISABLED=1``. Default is
# enabled so production never relies on flipping an opt-out flag.
_DISABLED = os.environ.get(
"SURFSENSE_TASK_SPAWN_PAUSED_DISABLED", ""
).strip().lower() in {
"1",
"true",
"yes",
"on",
}
def _flag_key(search_space_id: int) -> str:
return f"surfsense:spawn_paused:{search_space_id}"
async def is_spawn_paused(search_space_id: int | None) -> bool:
"""Return ``True`` iff the workspace's spawn-paused flag is set in Redis.
A ``None`` search-space (e.g. dev paths that did not plumb the id
through yet) bypasses the check. So does a Redis outage see module
docstring for the fail-open rationale.
"""
if _DISABLED or search_space_id is None:
return False
try:
# Local import keeps the cold-path import cheap and lets routes
# that never call ``task`` skip the redis dependency entirely.
import redis.asyncio as aioredis # type: ignore[import-not-found]
client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
try:
raw = await client.get(_flag_key(search_space_id))
finally:
# ``aclose()`` is the async-safe variant on redis-py >=5; fall back
# to ``close()`` for older clients pinned in tests.
close = getattr(client, "aclose", None) or getattr(client, "close", None)
if callable(close):
with contextlib.suppress(Exception):
await close() # type: ignore[misc]
return bool(raw)
except Exception:
logger.warning(
"spawn_paused check failed for search_space_id=%s; failing open.",
search_space_id,
exc_info=True,
)
return False
__all__ = ["is_spawn_paused"]

View file

@ -1,15 +0,0 @@
"""Schema-level description for the ``task`` tool.
Loaded from ``prompts/tools/task/description.md`` so the tool-schema text
and the ``<tools>`` block render from the same source.
"""
from __future__ import annotations
from app.agents.multi_agent_chat.main_agent.system_prompt.builder.load_md import (
read_prompt_md,
)
TASK_TOOL_DESCRIPTION: str = read_prompt_md("tools/task/description.md")
__all__ = ["TASK_TOOL_DESCRIPTION"]

View file

@ -1,50 +0,0 @@
"""Spill + clear-tool-uses passes to keep payloads under budget."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from langchain_core.tools import BaseTool
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
safe_exclude_tools,
)
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from ..shared.flags import enabled
def build_context_editing_mw(
*,
flags: AgentFeatureFlags,
max_input_tokens: int | None,
tools: Sequence[BaseTool],
backend_resolver: Any,
) -> SpillingContextEditingMiddleware | None:
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
return None
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
return SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)

View file

@ -1,13 +0,0 @@
"""Drop duplicate HITL tool calls before execution."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.shared.middleware import DedupHITLToolCallsMiddleware
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))

View file

@ -1,14 +0,0 @@
"""Stop N identical tool calls in a row via interrupt."""
from __future__ import annotations
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import DoomLoopMiddleware
from ..shared.flags import enabled
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
return (
DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
)

View file

@ -1,23 +0,0 @@
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
from __future__ import annotations
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.middleware import KnowledgeBasePersistenceMiddleware
def build_kb_persistence_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
) -> KnowledgeBasePersistenceMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)

View file

@ -1,30 +0,0 @@
"""KB priority planner: <priority_documents> injection."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.middleware import KnowledgePriorityMiddleware
from app.services.llm_service import get_planner_llm
def build_knowledge_priority_mw(
*,
llm: BaseChatModel,
search_space_id: int,
filesystem_mode: FilesystemMode,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
) -> KnowledgePriorityMiddleware:
return KnowledgePriorityMiddleware(
llm=llm,
planner_llm=get_planner_llm(),
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
inject_system_message=False,
)

View file

@ -1,24 +0,0 @@
"""<workspace_tree> injection (cloud only)."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.middleware import KnowledgeTreeMiddleware
def build_knowledge_tree_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
llm: BaseChatModel,
) -> KnowledgeTreeMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
inject_system_message=False,
)

View file

@ -1,12 +0,0 @@
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
from __future__ import annotations
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import NoopInjectionMiddleware
from ..shared.flags import enabled
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None

View file

@ -1,12 +0,0 @@
"""OTel spans on model and tool calls."""
from __future__ import annotations
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import OtelSpanMiddleware
from ..shared.flags import enabled
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None

View file

@ -1,49 +0,0 @@
"""Tail-of-stack plugin slot driven by env allowlist."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.plugin_loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.db import ChatVisibility
from ..shared.flags import enabled
def build_plugin_middlewares(
*,
flags: AgentFeatureFlags,
search_space_id: int,
user_id: str | None,
visibility: ChatVisibility,
llm: BaseChatModel,
) -> list[Any]:
if not enabled(flags, "enable_plugin_loader"):
return []
try:
allowed_names = load_allowed_plugin_names_from_env()
if not allowed_names:
return []
return load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
return []

View file

@ -1,50 +0,0 @@
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.middleware import ToolCallNameRepairMiddleware
from ..shared.flags import enabled
# deepagents-built-in tool names the repair pass treats as known.
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
{
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
)
def build_repair_mw(
*,
flags: AgentFeatureFlags,
tools: Sequence[BaseTool],
) -> ToolCallNameRepairMiddleware | None:
if not enabled(flags, "enable_tool_call_repair"):
return None
registered_names: set[str] = {t.name for t in tools}
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
return ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)

View file

@ -1,39 +0,0 @@
"""Skill discovery + injection."""
from __future__ import annotations
import logging
from deepagents.middleware.skills import SkillsMiddleware
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.filesystem_selection import FilesystemMode
from app.agents.shared.middleware import (
build_skills_backend_factory,
default_skills_sources,
)
from ..shared.flags import enabled
def build_skills_mw(
*,
flags: AgentFeatureFlags,
filesystem_mode: FilesystemMode,
search_space_id: int,
) -> SkillsMiddleware | None:
if not enabled(flags, "enable_skills"):
return None
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
return SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
return None

View file

@ -1,210 +0,0 @@
"""Main-agent middleware list assembly: one line per slot.
The main agent is a pure router filesystem reads/writes are owned by the
``knowledge_base`` subagent and delegated via the ``task`` tool. The stack
here only renders KB context (workspace tree + priority docs), projects it
into system messages, and commits any subagent-side staged writes at end of
turn (cloud mode).
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents import SubAgent
from deepagents.backends import StateBackend
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
)
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
READONLY_NAME as KB_READONLY_NAME,
build_readonly_subagent as build_kb_readonly_subagent,
)
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
build_ask_knowledge_base_tool,
)
from app.agents.shared.feature_flags import AgentFeatureFlags
from app.agents.shared.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from .main_agent.action_log import build_action_log_mw
from .main_agent.anonymous_doc import build_anonymous_doc_mw
from .main_agent.busy_mutex import build_busy_mutex_mw
from .main_agent.checkpointed_subagent_middleware import (
SurfSenseCheckpointedSubAgentMiddleware,
)
from .main_agent.checkpointed_subagent_middleware.task_description import (
TASK_TOOL_DESCRIPTION,
)
from .main_agent.context_editing import build_context_editing_mw
from .main_agent.dedup_hitl import build_dedup_hitl_mw
from .main_agent.doom_loop import build_doom_loop_mw
from .main_agent.kb_persistence import build_kb_persistence_mw
from .main_agent.knowledge_priority import build_knowledge_priority_mw
from .main_agent.knowledge_tree import build_knowledge_tree_mw
from .main_agent.noop_injection import build_noop_injection_mw
from .main_agent.otel import build_otel_mw
from .main_agent.plugins import build_plugin_middlewares
from .main_agent.repair import build_repair_mw
from .main_agent.skills import build_skills_mw
from .shared.anthropic_cache import build_anthropic_cache_mw
from .shared.compaction import build_compaction_mw
from .shared.kb_context_projection import build_kb_context_projection_mw
from .shared.memory import build_memory_mw
from .shared.patch_tool_calls import build_patch_tool_calls_mw
from .shared.permissions import build_permission_mw
from .shared.resilience import build_resilience_middlewares
from .shared.todos import build_todos_mw
from .subagent.middleware_stack import build_subagent_middleware_stack
def build_main_agent_deepagent_middleware(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
subagent_dependencies: dict[str, Any],
checkpointer: Checkpointer,
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
resilience = build_resilience_middlewares(flags)
memory_mw = build_memory_mw(
user_id=user_id,
search_space_id=search_space_id,
visibility=visibility,
)
subagent_dependencies = {
**subagent_dependencies,
"backend_resolver": backend_resolver,
"filesystem_mode": filesystem_mode,
"flags": flags,
}
shared_subagent_middleware = build_subagent_middleware_stack(
resilience=resilience,
flags=flags,
)
kb_readonly = build_kb_readonly_subagent(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
)
kb_readonly_spec = kb_readonly.spec
kb_readonly_runnable = create_agent(
llm,
system_prompt=kb_readonly_spec["system_prompt"],
tools=kb_readonly_spec["tools"],
middleware=kb_readonly_spec["middleware"],
name=KB_READONLY_NAME,
checkpointer=checkpointer,
)
ask_kb_tool = build_ask_knowledge_base_tool(kb_readonly_runnable)
subagents: list[SubAgent] = build_subagents(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
disabled_tools=disabled_tools,
ask_kb_tool=ask_kb_tool,
)
logging.debug("Subagents registry: %s", [s["name"] for s in subagents])
stack: list[Any] = [
build_busy_mutex_mw(flags),
build_otel_mw(flags),
build_todos_mw(system_prompt=""),
memory_mw,
build_anonymous_doc_mw(
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
),
build_knowledge_tree_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
llm=llm,
),
build_knowledge_priority_mw(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
build_kb_context_projection_mw(),
build_kb_persistence_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_skills_mw(
flags=flags,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
),
SurfSenseCheckpointedSubAgentMiddleware(
checkpointer=checkpointer,
backend=StateBackend,
subagents=subagents,
system_prompt=None,
task_description=TASK_TOOL_DESCRIPTION,
search_space_id=search_space_id,
),
resilience.model_call_limit,
resilience.tool_call_limit,
build_context_editing_mw(
flags=flags,
max_input_tokens=max_input_tokens,
tools=tools,
backend_resolver=backend_resolver,
),
build_compaction_mw(llm),
build_noop_injection_mw(flags),
resilience.retry,
resilience.fallback,
build_repair_mw(flags=flags, tools=tools),
build_permission_mw(flags=flags),
build_doom_loop_mw(flags),
build_action_log_mw(
flags=flags,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
),
build_patch_tool_calls_mw(),
build_dedup_hitl_mw(tools),
*build_plugin_middlewares(
flags=flags,
search_space_id=search_space_id,
user_id=user_id,
visibility=visibility,
llm=llm,
),
build_anthropic_cache_mw(),
]
return [m for m in stack if m is not None]