mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
refactor(agents): colocate main-agent middleware under main_agent/ slice
Vertical-slice colocation: all main-agent code should live under main_agent/ instead of being split across a parallel middleware/main_agent tree. Move multi_agent_chat/middleware/main_agent/ -> main_agent/middleware/ and its assembler middleware/stack.py -> main_agent/middleware/stack.py, so the main-agent slice is self-contained (graph, runtime, system_prompt, tools, middleware). Genuinely cross-slice middleware (middleware/shared/, middleware/subagent/) stays under multi_agent_chat/middleware/ for a later slice; the moved builders now reference it via absolute imports. Pure move + import rewrite (git-tracked renames). Verified: full unit suite green (2430 passed, 1 skipped), including test_import_all and the checkpointed-subagent middleware suite.
This commit is contained in:
parent
1acde6a470
commit
9c845d562e
42 changed files with 60 additions and 58 deletions
|
|
@ -1,36 +0,0 @@
|
|||
"""Audit row per tool call (reversibility metadata)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import ActionLogMiddleware
|
||||
from app.agents.shared.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_action_log_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
) -> ActionLogMiddleware | None:
|
||||
if not enabled(flags, "enable_action_log") or thread_id is None:
|
||||
return None
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
return ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
"""Anonymous document hydration from Redis (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.middleware import AnonymousDocumentMiddleware
|
||||
|
||||
|
||||
def build_anonymous_doc_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
anon_session_id: str | None,
|
||||
) -> AnonymousDocumentMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""Per-thread cooperative lock around the whole turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import BusyMutexMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
|
||||
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
"""SubAgent ``task`` tool wiring required for HITL inside subagents.
|
||||
|
||||
Replaces upstream ``SubAgentMiddleware`` to:
|
||||
|
||||
- share the parent's checkpointer with each subagent,
|
||||
- forward ``runtime.config`` (thread_id, recursion_limit, …) into nested invokes,
|
||||
- isolate each parallel ``task`` call in its own checkpoint slot via
|
||||
per-call ``thread_id`` namespacing,
|
||||
- bridge ``Command(resume=...)`` from the parent into the subagent via the
|
||||
``config["configurable"]["surfsense_resume_value"]`` side-channel, keyed by
|
||||
``tool_call_id`` so parallel siblings never race on a shared scalar,
|
||||
- target the resume at the captured interrupt id so a follow-up
|
||||
``HumanInTheLoopMiddleware.after_model`` does not consume the same payload,
|
||||
- stamp each subagent's pending interrupt with the parent's ``tool_call_id``
|
||||
so ``stream_resume_chat`` can route a flat ``decisions`` list back to the
|
||||
right paused subagent.
|
||||
|
||||
Module layout
|
||||
-------------
|
||||
|
||||
- ``constants`` — shared keys / limits.
|
||||
- ``config`` — RunnableConfig + side-channel resume read + per-call ``thread_id``.
|
||||
- ``resume`` — pending-interrupt detection, fan-out, ``Command(resume=...)`` builder.
|
||||
- ``propagation`` — ``wrap_with_tool_call_id`` helper for stamping interrupt values.
|
||||
- ``resume_routing``— slice a flat decisions list to per-``tool_call_id`` payloads.
|
||||
- ``task_tool`` — the ``task`` tool factory (sync + async), and the catch-and-stamp chokepoint.
|
||||
- ``middleware`` — :class:`SurfSenseCheckpointedSubAgentMiddleware` itself.
|
||||
"""
|
||||
|
||||
from .middleware import SurfSenseCheckpointedSubAgentMiddleware
|
||||
|
||||
__all__ = ["SurfSenseCheckpointedSubAgentMiddleware"]
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
"""RunnableConfig wiring for nested subagent invocations.
|
||||
|
||||
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
|
||||
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# langgraph stores the parent task's scratchpad under this configurable key;
|
||||
# subagents inherit the chain via ``parent_scratchpad`` fallback.
|
||||
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
|
||||
|
||||
|
||||
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
|
||||
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
|
||||
|
||||
Each parallel subagent invocation lands in its own checkpoint slot keyed
|
||||
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
|
||||
The same call across the resume cycle keeps reading from the same snapshot
|
||||
(``tool_call_id`` is stable per LLM-emitted call).
|
||||
|
||||
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
|
||||
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
|
||||
subgraph path and raises ``ValueError("Subgraph X not found")``.
|
||||
"""
|
||||
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
|
||||
current_limit = merged.get("recursion_limit")
|
||||
try:
|
||||
current_int = int(current_limit) if current_limit is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
current_int = 0
|
||||
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
|
||||
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
|
||||
|
||||
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
|
||||
parent_thread_id = configurable.get("thread_id")
|
||||
per_call_suffix = f"task:{runtime.tool_call_id}"
|
||||
configurable["thread_id"] = (
|
||||
f"{parent_thread_id}::{per_call_suffix}"
|
||||
if parent_thread_id
|
||||
else per_call_suffix
|
||||
)
|
||||
merged["configurable"] = configurable
|
||||
return merged
|
||||
|
||||
|
||||
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
|
||||
"""Pop the resume payload for *this* call's ``tool_call_id``.
|
||||
|
||||
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
|
||||
so parallel sibling subagents (each with their own ``tool_call_id``) read
|
||||
only their own decision and never race on a shared scalar.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
by_tcid = configurable.get("surfsense_resume_value")
|
||||
if not isinstance(by_tcid, dict):
|
||||
return None
|
||||
payload = by_tcid.pop(runtime.tool_call_id, None)
|
||||
if not by_tcid:
|
||||
configurable.pop("surfsense_resume_value", None)
|
||||
return payload
|
||||
|
||||
|
||||
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
|
||||
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return False
|
||||
by_tcid = configurable.get("surfsense_resume_value")
|
||||
if not isinstance(by_tcid, dict):
|
||||
return False
|
||||
return runtime.tool_call_id in by_tcid
|
||||
|
||||
|
||||
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
|
||||
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
|
||||
|
||||
``stream_resume_chat`` wakes the main agent with
|
||||
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
|
||||
propagated parent-level interrupt can return. langgraph stores that
|
||||
payload as the parent task's ``null_resume`` pending write. The ``task``
|
||||
tool then forwards this turn's slice into the subagent via its own
|
||||
``Command(resume=...)``. While the subagent is mid-execution, any *new*
|
||||
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
|
||||
approve/reject) walks ``subagent_scratchpad → parent_scratchpad.get_null_resume``
|
||||
and picks up the parent's still-live decisions — mismatching against a
|
||||
different number of hanging tool calls and crashing
|
||||
``HumanInTheLoopMiddleware``.
|
||||
|
||||
Draining the write here closes that cross-graph leak so subagent
|
||||
interrupts pause cleanly and bubble back up as a fresh approval card.
|
||||
"""
|
||||
cfg = runtime.config or {}
|
||||
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
|
||||
if not isinstance(configurable, dict):
|
||||
return
|
||||
scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY)
|
||||
if scratchpad is None:
|
||||
return
|
||||
consume = getattr(scratchpad, "get_null_resume", None)
|
||||
if not callable(consume):
|
||||
return
|
||||
try:
|
||||
consume(True)
|
||||
except Exception:
|
||||
# Defensive: if langgraph's internal scratchpad shape changes we don't
|
||||
# want to break the resume path. Worst case the original ValueError
|
||||
# still surfaces — same behavior as before this fix.
|
||||
logger.debug(
|
||||
"drain_parent_null_resume: scratchpad.get_null_resume raised",
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""Constants shared by the checkpointed subagent middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
|
||||
EXCLUDED_STATE_KEYS = frozenset(
|
||||
{
|
||||
"messages",
|
||||
"todos",
|
||||
"structured_response",
|
||||
"skills_metadata",
|
||||
"memory_contents",
|
||||
}
|
||||
)
|
||||
|
||||
# Match the parent graph's budget; the LangGraph default of 25 trips on
|
||||
# multi-step subagent runs.
|
||||
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
|
||||
|
||||
|
||||
def _read_timeout_env(name: str, default: float) -> float:
|
||||
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
|
||||
|
||||
Kept as a free function so the module-level constants stay constants
|
||||
after import; tests can monkeypatch this and re-evaluate via
|
||||
``importlib.reload`` if they need a different value mid-process.
|
||||
"""
|
||||
raw = os.environ.get(name)
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
value = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return value if value > 0 else default
|
||||
|
||||
|
||||
# Wall-clock budget for a single ``task(subagent, ...)`` invocation.
|
||||
# Subagents that run hot (image generation with slow vendors, KB writes
|
||||
# behind a sluggish embedder) can otherwise wedge the orchestrator until
|
||||
# the next checkpoint heartbeat. ``0`` disables the timeout entirely.
|
||||
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS: float = _read_timeout_env(
|
||||
"SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS",
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
|
||||
def _read_int_env(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
value = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return value if value > 0 else default
|
||||
|
||||
|
||||
# Maximum number of children that ``task(..., tasks=[...])`` runs in
|
||||
# parallel via ``asyncio.gather`` + ``Semaphore``. Bounded so a runaway
|
||||
# fanout cannot starve unrelated subagents (each child still owns an
|
||||
# LLM call + DB session). Set ``SURFSENSE_TASK_BATCH_CONCURRENCY=1`` to
|
||||
# effectively serialise batches without changing the schema.
|
||||
DEFAULT_SUBAGENT_BATCH_CONCURRENCY: int = _read_int_env(
|
||||
"SURFSENSE_TASK_BATCH_CONCURRENCY",
|
||||
default=3,
|
||||
)
|
||||
|
||||
# Max number of children in a single batched ``task`` call. Hard upper
|
||||
# bound is a safety net for prompt-injection / runaway loops; the orchestrator
|
||||
# rarely needs more than a handful of concurrent specialists.
|
||||
MAX_SUBAGENT_BATCH_SIZE: int = _read_int_env(
|
||||
"SURFSENSE_TASK_BATCH_MAX_SIZE",
|
||||
default=8,
|
||||
)
|
||||
|
||||
|
||||
# Soft threshold for per-turn cumulative ``task(...)`` invocations across
|
||||
# **all** subagents. Once the sum of ``state['billable_calls']`` values
|
||||
# crosses this number, the runtime appends a one-shot warning ToolMessage
|
||||
# instructing the orchestrator to wrap up the turn. Tunable so heavy-research
|
||||
# turns (which legitimately need 15+ specialist calls) don't trip the alarm
|
||||
# in production. Set to ``0`` to disable the warning entirely.
|
||||
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD: int = _read_int_env(
|
||||
"SURFSENSE_SUBAGENT_BILLABLE_THRESHOLD",
|
||||
default=15,
|
||||
)
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
"""SubAgent middleware that compiles each subagent against the parent checkpointer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
||||
from deepagents.middleware.subagents import (
|
||||
TASK_SYSTEM_PROMPT,
|
||||
CompiledSubAgent,
|
||||
SubAgent,
|
||||
SubAgentMiddleware,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.spec import (
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from .task_tool import build_task_tool_with_parent_config
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
||||
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
checkpointer: Checkpointer,
|
||||
backend: BackendProtocol | BackendFactory,
|
||||
subagents: list[SubAgent | CompiledSubAgent],
|
||||
system_prompt: str | None = TASK_SYSTEM_PROMPT,
|
||||
task_description: str | None = None,
|
||||
search_space_id: int | None = None,
|
||||
) -> None:
|
||||
self._surf_checkpointer = checkpointer
|
||||
super(SubAgentMiddleware, self).__init__()
|
||||
if not subagents:
|
||||
raise ValueError(
|
||||
"At least one subagent must be specified when using the new API"
|
||||
)
|
||||
self._backend = backend
|
||||
self._subagents = subagents
|
||||
# Search-space id is captured at build time (the orchestrator runs in
|
||||
# exactly one search space for its lifetime). The spawn-paused kill
|
||||
# switch keys on it so an operator can quarantine one workspace
|
||||
# without affecting the rest of the deployment.
|
||||
self._search_space_id = search_space_id
|
||||
subagent_specs = self._surf_compile_subagent_graphs()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
subagent_specs,
|
||||
task_description,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if system_prompt and subagent_specs:
|
||||
agents_desc = "\n".join(
|
||||
f"- {s['name']}: {s['description']}" for s in subagent_specs
|
||||
)
|
||||
self.system_prompt = (
|
||||
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
|
||||
)
|
||||
else:
|
||||
self.system_prompt = system_prompt
|
||||
self.tools = [task_tool]
|
||||
|
||||
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
||||
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
||||
specs: list[dict[str, Any]] = []
|
||||
loop_start = time.perf_counter()
|
||||
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
|
||||
|
||||
for spec in self._subagents:
|
||||
spec_start = time.perf_counter()
|
||||
# Provider may be ``None`` (no hint), in which case task_tool
|
||||
# skips the prepend step. We forward the key unconditionally so
|
||||
# the registry shape is uniform.
|
||||
hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY)
|
||||
if "runnable" in spec:
|
||||
compiled = cast(CompiledSubAgent, spec)
|
||||
specs.append(
|
||||
{
|
||||
"name": compiled["name"],
|
||||
"description": compiled["description"],
|
||||
"runnable": compiled["runnable"],
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
|
||||
}
|
||||
)
|
||||
timings.append(
|
||||
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
|
||||
)
|
||||
continue
|
||||
|
||||
if "model" not in spec:
|
||||
msg = f"SubAgent '{spec['name']}' must specify 'model'"
|
||||
raise ValueError(msg)
|
||||
if "tools" not in spec:
|
||||
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
|
||||
raise ValueError(msg)
|
||||
|
||||
model = spec["model"]
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
middleware: list[Any] = list(spec.get("middleware", []))
|
||||
tools_count = len(spec.get("tools") or [])
|
||||
mw_count = len(middleware)
|
||||
|
||||
compile_start = time.perf_counter()
|
||||
runnable = create_agent(
|
||||
model,
|
||||
system_prompt=spec["system_prompt"],
|
||||
tools=spec["tools"],
|
||||
middleware=middleware,
|
||||
name=spec["name"],
|
||||
checkpointer=self._surf_checkpointer,
|
||||
)
|
||||
compile_elapsed = time.perf_counter() - compile_start
|
||||
specs.append(
|
||||
{
|
||||
"name": spec["name"],
|
||||
"description": spec["description"],
|
||||
"runnable": runnable,
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
|
||||
}
|
||||
)
|
||||
timings.append(
|
||||
(
|
||||
spec["name"],
|
||||
compile_elapsed,
|
||||
f"compiled tools={tools_count} mw={mw_count}",
|
||||
)
|
||||
)
|
||||
|
||||
total_elapsed = time.perf_counter() - loop_start
|
||||
per_subagent = ", ".join(
|
||||
f"{name}={elapsed * 1000:.0f}ms[{source}]"
|
||||
for name, elapsed, source in timings
|
||||
)
|
||||
_perf_log.info(
|
||||
"[subagent_compile] total=%.3fs count=%d details=[%s]",
|
||||
total_elapsed,
|
||||
len(timings),
|
||||
per_subagent,
|
||||
)
|
||||
|
||||
return specs
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value.
|
||||
|
||||
When a subagent (compiled as a langgraph subgraph and invoked from a parent
|
||||
tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph
|
||||
raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's
|
||||
``task`` tool catches that exception, stamps ``tool_call_id`` onto each
|
||||
``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a
|
||||
fresh ``GraphInterrupt`` whose values carry that stamp.
|
||||
|
||||
``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]``
|
||||
to route a flat ``decisions`` list back to the right paused subagent — without
|
||||
the stamp, parallel HITL across siblings would collapse into an ambiguous
|
||||
bucket and resume would fail.
|
||||
|
||||
This module hosts only the stamping helper; the catch/re-raise lives in
|
||||
``task_tool.py`` since that's the single chokepoint where the raw exception
|
||||
is in our hands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]:
|
||||
"""Return a value dict that always carries the parent's ``tool_call_id``.
|
||||
|
||||
Dict values are shallow-copied with ``tool_call_id`` stamped on top, so
|
||||
any value the subagent may already carry under that key (from a deeper
|
||||
HITL level) is overwritten — the parent's call id is the only one
|
||||
``stream_resume_chat`` correlates against.
|
||||
|
||||
Non-dict values are wrapped as ``{"value": <original>, "tool_call_id": ...}``
|
||||
so simple ``interrupt("approve?")`` patterns still propagate cleanly.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {**value, "tool_call_id": tool_call_id}
|
||||
return {"value": value, "tool_call_id": tool_call_id}
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
"""Resume-payload shaping and pending-interrupt detection for subagents.
|
||||
|
||||
Splits the work of "given a state snapshot and a parent-stashed resume value,
|
||||
produce the right ``Command(resume=...)`` for the subagent" into pure helpers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
def hitlrequest_action_count(pending_value: Any) -> int:
|
||||
"""Bundle size for a LangChain ``HITLRequest`` payload; ``0`` for non-bundle interrupts."""
|
||||
if not isinstance(pending_value, dict):
|
||||
return 0
|
||||
actions = pending_value.get("action_requests")
|
||||
if isinstance(actions, list):
|
||||
return len(actions)
|
||||
return 0
|
||||
|
||||
|
||||
def fan_out_decisions_to_match(resume_value: Any, expected_count: int) -> Any:
|
||||
"""Legacy fallback: pad a 1-decision resume to N for an ``action_requests=N`` bundle.
|
||||
|
||||
Modern frontend submits N decisions per bundle (one per action_request) so
|
||||
this is a no-op; kept for backwards compatibility with old in-flight
|
||||
threads or non-bundle clients that send a single decision.
|
||||
"""
|
||||
if expected_count <= 1:
|
||||
return resume_value
|
||||
if not isinstance(resume_value, dict):
|
||||
return resume_value
|
||||
decisions = resume_value.get("decisions")
|
||||
if not isinstance(decisions, list) or len(decisions) >= expected_count:
|
||||
return resume_value
|
||||
if not decisions:
|
||||
return resume_value
|
||||
padded = list(decisions) + [decisions[-1]] * (expected_count - len(decisions))
|
||||
return {**resume_value, "decisions": padded}
|
||||
|
||||
|
||||
def get_first_pending_subagent_interrupt(state: Any) -> tuple[str | None, Any]:
|
||||
"""First pending ``(interrupt_id, value)``; ``(None, None)`` if no interrupt.
|
||||
|
||||
Assumes at most one pending interrupt per snapshot (sequential tool nodes).
|
||||
Parallel tool nodes would need an id-aware lookup instead of first-wins.
|
||||
"""
|
||||
if state is None:
|
||||
return None, None
|
||||
for it in getattr(state, "interrupts", None) or ():
|
||||
value = getattr(it, "value", None)
|
||||
interrupt_id = getattr(it, "id", None)
|
||||
if value is not None:
|
||||
return (
|
||||
interrupt_id if isinstance(interrupt_id, str) else None,
|
||||
value,
|
||||
)
|
||||
for sub_task in getattr(state, "tasks", None) or ():
|
||||
for it in getattr(sub_task, "interrupts", None) or ():
|
||||
value = getattr(it, "value", None)
|
||||
interrupt_id = getattr(it, "id", None)
|
||||
if value is not None:
|
||||
return (
|
||||
interrupt_id if isinstance(interrupt_id, str) else None,
|
||||
value,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
def build_resume_command(resume_value: Any, pending_id: str | None) -> Command:
|
||||
"""``Command(resume={id: value})`` when ``id`` is known, else fall back to scalar."""
|
||||
if pending_id is None:
|
||||
return Command(resume=resume_value)
|
||||
return Command(resume={pending_id: resume_value})
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
|
||||
|
||||
The frontend submits decisions in the same order the SSE stream emitted
|
||||
approval cards. When multiple parallel subagents are paused, the backend uses
|
||||
this module to:
|
||||
|
||||
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
|
||||
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
|
||||
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
|
||||
inside ``task_tool``'s catch-and-stamp block when a subagent's
|
||||
``GraphInterrupt`` bubbles up through ``[a]task``.
|
||||
2. Slice the flat ``decisions`` list against that ordered pending list to
|
||||
produce the dict shape expected by ``consume_surfsense_resume``.
|
||||
3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as
|
||||
the parent-level ``Command(resume={interrupt_id: payload})`` input — the
|
||||
only shape langgraph accepts when multiple interrupts are pending.
|
||||
|
||||
All helpers are pure: callers own the state and the input decisions; we
|
||||
return new structures and never mutate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def slice_decisions_by_tool_call(
|
||||
decisions: list[dict[str, Any]],
|
||||
pending: Iterable[tuple[str, int]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
|
||||
|
||||
Args:
|
||||
decisions: Flat list of decisions in the order the SSE stream rendered
|
||||
them.
|
||||
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
|
||||
order. The slicer consumes ``decisions`` left-to-right.
|
||||
|
||||
Returns:
|
||||
Per-``tool_call_id`` payload dict ready to be written to
|
||||
``configurable["surfsense_resume_value"]``.
|
||||
|
||||
Raises:
|
||||
ValueError: When the total expected action count differs from the
|
||||
number of decisions provided. We fail loud rather than silently
|
||||
dropping or padding so a frontend/backend contract drift surfaces
|
||||
immediately.
|
||||
"""
|
||||
pending_list = list(pending)
|
||||
expected = sum(count for _, count in pending_list)
|
||||
if expected != len(decisions):
|
||||
raise ValueError(
|
||||
f"Decision count mismatch: pending tool calls expect "
|
||||
f"{expected} actions but received {len(decisions)} decisions."
|
||||
)
|
||||
|
||||
routed: dict[str, dict[str, Any]] = {}
|
||||
cursor = 0
|
||||
for tool_call_id, action_count in pending_list:
|
||||
routed[tool_call_id] = {"decisions": decisions[cursor : cursor + action_count]}
|
||||
cursor += action_count
|
||||
return routed
|
||||
|
||||
|
||||
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
|
||||
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
|
||||
|
||||
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
|
||||
paused subagent's propagated interrupt). Each interrupt value carries the
|
||||
``tool_call_id`` that the parent's ``task`` tool was processing — see
|
||||
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
|
||||
``except GraphInterrupt`` chokepoint.
|
||||
|
||||
Order is preserved from ``state.interrupts``, which is the order the SSE
|
||||
stream emitted approval cards. The frontend submits decisions in that
|
||||
same order, so the slicer can consume them left-to-right.
|
||||
|
||||
Interrupts without a ``tool_call_id`` are skipped — they were not
|
||||
produced by our task-routing layer (e.g. parent-side HITL middleware on
|
||||
a different tool); ``stream_resume_chat`` is not responsible for routing
|
||||
those.
|
||||
|
||||
Args:
|
||||
state: A langgraph ``StateSnapshot`` (or any object with an
|
||||
``interrupts`` attribute).
|
||||
|
||||
Returns:
|
||||
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
|
||||
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
|
||||
scalar-style ``interrupt("...")`` values that were wrapped as
|
||||
``{"value": ..., "tool_call_id": ...}``.
|
||||
|
||||
Raises:
|
||||
ValueError: When an interrupt value carries a ``tool_call_id`` but
|
||||
the action count cannot be determined (contract bug — every
|
||||
propagated value should be either a HITL bundle or a wrapped
|
||||
scalar).
|
||||
"""
|
||||
pending: list[tuple[str, int]] = []
|
||||
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
|
||||
value = getattr(interrupt_obj, "value", None)
|
||||
if not isinstance(value, dict):
|
||||
logger.warning(
|
||||
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
|
||||
idx,
|
||||
type(value).__name__,
|
||||
)
|
||||
continue
|
||||
tool_call_id = value.get("tool_call_id")
|
||||
if not isinstance(tool_call_id, str):
|
||||
# Should not happen post-stamping; flag loudly if a regression
|
||||
# ever lets an unstamped value reach the parent state.
|
||||
logger.warning(
|
||||
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
|
||||
idx,
|
||||
sorted(value.keys()),
|
||||
)
|
||||
continue
|
||||
|
||||
action_requests = value.get("action_requests")
|
||||
if isinstance(action_requests, list):
|
||||
pending.append((tool_call_id, len(action_requests)))
|
||||
continue
|
||||
if "value" in value:
|
||||
pending.append((tool_call_id, 1))
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Interrupt for tool_call_id={tool_call_id!r} has no "
|
||||
"``action_requests`` list and is not a wrapped scalar value; "
|
||||
"cannot determine action count for resume routing."
|
||||
)
|
||||
|
||||
return pending
|
||||
|
||||
|
||||
def build_lg_resume_map(
|
||||
state: Any, by_tool_call_id: dict[str, dict[str, Any]]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume.
|
||||
|
||||
``stream_resume_chat`` builds ``by_tool_call_id`` via
|
||||
:func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)``
|
||||
requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the
|
||||
parent state has multiple pending interrupts. This pure helper re-keys the
|
||||
slice without mutating it, and skips entries that can't be paired (no
|
||||
stamp, no slice) so contract drift surfaces as a count mismatch at the
|
||||
call site instead of a silent mis-route.
|
||||
|
||||
The two key spaces serve two different consumers:
|
||||
- ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the
|
||||
subagent bridge inside ``task_tool``.
|
||||
- ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's
|
||||
pregel to wake each pending interrupt site.
|
||||
|
||||
Args:
|
||||
state: A langgraph ``StateSnapshot`` (or any object with an
|
||||
``interrupts`` iterable).
|
||||
by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`.
|
||||
|
||||
Returns:
|
||||
Dict ready to be passed as ``Command(resume=<this>)``.
|
||||
"""
|
||||
out: dict[str, dict[str, Any]] = {}
|
||||
for interrupt_obj in getattr(state, "interrupts", ()) or ():
|
||||
value = getattr(interrupt_obj, "value", None)
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
tool_call_id = value.get("tool_call_id")
|
||||
if not isinstance(tool_call_id, str):
|
||||
continue
|
||||
interrupt_id = getattr(interrupt_obj, "id", None)
|
||||
if not isinstance(interrupt_id, str):
|
||||
continue
|
||||
payload = by_tool_call_id.get(tool_call_id)
|
||||
if payload is None:
|
||||
continue
|
||||
out[interrupt_id] = payload
|
||||
return out
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
"""Per-search-space spawn-paused kill switch for the ``task`` boundary.
|
||||
|
||||
When operators see a runaway loop, a vendor outage, or a billing event
|
||||
that requires immediate cessation of subagent traffic for a specific
|
||||
workspace, they flip a Redis flag and the ``task`` tool short-circuits
|
||||
without touching downstream services. The flag is **per-search-space**
|
||||
so one tenant's incident never silences the rest of the deployment.
|
||||
|
||||
Flag key: ``surfsense:spawn_paused:{search_space_id}``
|
||||
Flag value: any string-truthy value (we read presence, not contents).
|
||||
TTL: set by whoever toggles the flag — this module never expires
|
||||
keys on its own, since "the flag is on" is itself the signal
|
||||
that a human (or alert) needs to investigate.
|
||||
|
||||
The check is best-effort: Redis errors are logged but do not block the
|
||||
``task`` invocation. Failing closed (block-on-redis-error) would let a
|
||||
single Redis blip take the whole orchestrator offline; failing open
|
||||
preserves availability and the alarm bells (rate-limits, cost spikes)
|
||||
will surface the underlying outage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Operators can disable the check entirely (e.g. local dev without Redis)
|
||||
# by setting ``SURFSENSE_TASK_SPAWN_PAUSED_DISABLED=1``. Default is
|
||||
# enabled so production never relies on flipping an opt-out flag.
|
||||
_DISABLED = os.environ.get(
|
||||
"SURFSENSE_TASK_SPAWN_PAUSED_DISABLED", ""
|
||||
).strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
|
||||
|
||||
def _flag_key(search_space_id: int) -> str:
|
||||
return f"surfsense:spawn_paused:{search_space_id}"
|
||||
|
||||
|
||||
async def is_spawn_paused(search_space_id: int | None) -> bool:
|
||||
"""Return ``True`` iff the workspace's spawn-paused flag is set in Redis.
|
||||
|
||||
A ``None`` search-space (e.g. dev paths that did not plumb the id
|
||||
through yet) bypasses the check. So does a Redis outage — see module
|
||||
docstring for the fail-open rationale.
|
||||
"""
|
||||
if _DISABLED or search_space_id is None:
|
||||
return False
|
||||
try:
|
||||
# Local import keeps the cold-path import cheap and lets routes
|
||||
# that never call ``task`` skip the redis dependency entirely.
|
||||
import redis.asyncio as aioredis # type: ignore[import-not-found]
|
||||
|
||||
client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
try:
|
||||
raw = await client.get(_flag_key(search_space_id))
|
||||
finally:
|
||||
# ``aclose()`` is the async-safe variant on redis-py >=5; fall back
|
||||
# to ``close()`` for older clients pinned in tests.
|
||||
close = getattr(client, "aclose", None) or getattr(client, "close", None)
|
||||
if callable(close):
|
||||
with contextlib.suppress(Exception):
|
||||
await close() # type: ignore[misc]
|
||||
return bool(raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"spawn_paused check failed for search_space_id=%s; failing open.",
|
||||
search_space_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
__all__ = ["is_spawn_paused"]
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
"""Schema-level description for the ``task`` tool.
|
||||
|
||||
Loaded from ``prompts/tools/task/description.md`` so the tool-schema text
|
||||
and the ``<tools>`` block render from the same source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.system_prompt.builder.load_md import (
|
||||
read_prompt_md,
|
||||
)
|
||||
|
||||
TASK_TOOL_DESCRIPTION: str = read_prompt_md("tools/task/description.md")
|
||||
|
||||
__all__ = ["TASK_TOOL_DESCRIPTION"]
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,50 +0,0 @@
|
|||
"""Spill + clear-tool-uses passes to keep payloads under budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_context_editing_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
max_input_tokens: int | None,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
) -> SpillingContextEditingMiddleware | None:
|
||||
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
|
||||
return None
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
return SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""Drop duplicate HITL tool calls before execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.shared.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
"""Stop N identical tool calls in a row via interrupt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import DoomLoopMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||
return (
|
||||
DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
|
||||
)
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.middleware import KnowledgeBasePersistenceMiddleware
|
||||
|
||||
|
||||
def build_kb_persistence_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> KnowledgeBasePersistenceMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""KB priority planner: <priority_documents> injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.middleware import KnowledgePriorityMiddleware
|
||||
from app.services.llm_service import get_planner_llm
|
||||
|
||||
|
||||
def build_knowledge_priority_mw(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
) -> KnowledgePriorityMiddleware:
|
||||
return KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
planner_llm=get_planner_llm(),
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
inject_system_message=False,
|
||||
)
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
"""<workspace_tree> injection (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.middleware import KnowledgeTreeMiddleware
|
||||
|
||||
|
||||
def build_knowledge_tree_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
llm: BaseChatModel,
|
||||
) -> KnowledgeTreeMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
inject_system_message=False,
|
||||
)
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import NoopInjectionMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""OTel spans on model and tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import OtelSpanMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
|
||||
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
"""Tail-of-stack plugin slot driven by env allowlist."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_plugin_middlewares(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
visibility: ChatVisibility,
|
||||
llm: BaseChatModel,
|
||||
) -> list[Any]:
|
||||
if not enabled(flags, "enable_plugin_loader"):
|
||||
return []
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if not allowed_names:
|
||||
return []
|
||||
return load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
# deepagents-built-in tool names the repair pass treats as known.
|
||||
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_repair_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> ToolCallNameRepairMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_repair"):
|
||||
return None
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
|
||||
return ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
"""Skill discovery + injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.middleware import (
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_skills_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
) -> SkillsMiddleware | None:
|
||||
if not enabled(flags, "enable_skills"):
|
||||
return None
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
return SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
return None
|
||||
|
|
@ -1,210 +0,0 @@
|
|||
"""Main-agent middleware list assembly: one line per slot.
|
||||
|
||||
The main agent is a pure router — filesystem reads/writes are owned by the
|
||||
``knowledge_base`` subagent and delegated via the ``task`` tool. The stack
|
||||
here only renders KB context (workspace tree + priority docs), projects it
|
||||
into system messages, and commits any subagent-side staged writes at end of
|
||||
turn (cloud mode).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
|
||||
READONLY_NAME as KB_READONLY_NAME,
|
||||
build_readonly_subagent as build_kb_readonly_subagent,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
|
||||
build_ask_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .main_agent.action_log import build_action_log_mw
|
||||
from .main_agent.anonymous_doc import build_anonymous_doc_mw
|
||||
from .main_agent.busy_mutex import build_busy_mutex_mw
|
||||
from .main_agent.checkpointed_subagent_middleware import (
|
||||
SurfSenseCheckpointedSubAgentMiddleware,
|
||||
)
|
||||
from .main_agent.checkpointed_subagent_middleware.task_description import (
|
||||
TASK_TOOL_DESCRIPTION,
|
||||
)
|
||||
from .main_agent.context_editing import build_context_editing_mw
|
||||
from .main_agent.dedup_hitl import build_dedup_hitl_mw
|
||||
from .main_agent.doom_loop import build_doom_loop_mw
|
||||
from .main_agent.kb_persistence import build_kb_persistence_mw
|
||||
from .main_agent.knowledge_priority import build_knowledge_priority_mw
|
||||
from .main_agent.knowledge_tree import build_knowledge_tree_mw
|
||||
from .main_agent.noop_injection import build_noop_injection_mw
|
||||
from .main_agent.otel import build_otel_mw
|
||||
from .main_agent.plugins import build_plugin_middlewares
|
||||
from .main_agent.repair import build_repair_mw
|
||||
from .main_agent.skills import build_skills_mw
|
||||
from .shared.anthropic_cache import build_anthropic_cache_mw
|
||||
from .shared.compaction import build_compaction_mw
|
||||
from .shared.kb_context_projection import build_kb_context_projection_mw
|
||||
from .shared.memory import build_memory_mw
|
||||
from .shared.patch_tool_calls import build_patch_tool_calls_mw
|
||||
from .shared.permissions import build_permission_mw
|
||||
from .shared.resilience import build_resilience_middlewares
|
||||
from .shared.todos import build_todos_mw
|
||||
from .subagent.middleware_stack import build_subagent_middleware_stack
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||
resilience = build_resilience_middlewares(flags)
|
||||
|
||||
memory_mw = build_memory_mw(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
visibility=visibility,
|
||||
)
|
||||
|
||||
subagent_dependencies = {
|
||||
**subagent_dependencies,
|
||||
"backend_resolver": backend_resolver,
|
||||
"filesystem_mode": filesystem_mode,
|
||||
"flags": flags,
|
||||
}
|
||||
shared_subagent_middleware = build_subagent_middleware_stack(
|
||||
resilience=resilience,
|
||||
flags=flags,
|
||||
)
|
||||
|
||||
kb_readonly = build_kb_readonly_subagent(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
middleware_stack=shared_subagent_middleware,
|
||||
)
|
||||
kb_readonly_spec = kb_readonly.spec
|
||||
kb_readonly_runnable = create_agent(
|
||||
llm,
|
||||
system_prompt=kb_readonly_spec["system_prompt"],
|
||||
tools=kb_readonly_spec["tools"],
|
||||
middleware=kb_readonly_spec["middleware"],
|
||||
name=KB_READONLY_NAME,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
ask_kb_tool = build_ask_knowledge_base_tool(kb_readonly_runnable)
|
||||
|
||||
subagents: list[SubAgent] = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
middleware_stack=shared_subagent_middleware,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
ask_kb_tool=ask_kb_tool,
|
||||
)
|
||||
logging.debug("Subagents registry: %s", [s["name"] for s in subagents])
|
||||
|
||||
stack: list[Any] = [
|
||||
build_busy_mutex_mw(flags),
|
||||
build_otel_mw(flags),
|
||||
build_todos_mw(system_prompt=""),
|
||||
memory_mw,
|
||||
build_anonymous_doc_mw(
|
||||
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
|
||||
),
|
||||
build_knowledge_tree_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
),
|
||||
build_knowledge_priority_mw(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
build_kb_context_projection_mw(),
|
||||
build_kb_persistence_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_skills_mw(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
),
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagents,
|
||||
system_prompt=None,
|
||||
task_description=TASK_TOOL_DESCRIPTION,
|
||||
search_space_id=search_space_id,
|
||||
),
|
||||
resilience.model_call_limit,
|
||||
resilience.tool_call_limit,
|
||||
build_context_editing_mw(
|
||||
flags=flags,
|
||||
max_input_tokens=max_input_tokens,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_noop_injection_mw(flags),
|
||||
resilience.retry,
|
||||
resilience.fallback,
|
||||
build_repair_mw(flags=flags, tools=tools),
|
||||
build_permission_mw(flags=flags),
|
||||
build_doom_loop_mw(flags),
|
||||
build_action_log_mw(
|
||||
flags=flags,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_dedup_hitl_mw(tools),
|
||||
*build_plugin_middlewares(
|
||||
flags=flags,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
return [m for m in stack if m is not None]
|
||||
Loading…
Add table
Add a link
Reference in a new issue