refactor(agents): introduce chat/ category; dissolve top-level agents/shared

Recursive shared-folder rule: a shared/ must be shared by ALL siblings at its
level. The kernel (context, compaction, retry_after, web_search) was shared by
only 2 of the agents -- anonymous_chat + multi_agent_chat -- never by podcaster
or video_presentation. Those 2 are the "chat" category, so their shared code
belongs in that category's shared/, not the top-level one.

  app/agents/anonymous_chat/   -> app/agents/chat/anonymous_chat/
  app/agents/multi_agent_chat/ -> app/agents/chat/multi_agent_chat/
  app/agents/shared/           -> app/agents/chat/shared/   (anon<->mac kernel)

Top-level app/agents/shared/ is gone: nothing was shared across all three
categories (chat / podcaster / video_presentation).

~289 import sites rewritten (app.agents.{anonymous_chat,multi_agent_chat,shared}
-> app.agents.chat.*); all moves are git renames (history preserved).
app/agents/ now: chat/, podcaster/, video_presentation/, runtime/.
This commit is contained in:
CREDO23 2026-06-05 12:54:02 +02:00
parent d59bb2b5aa
commit 24b62a63b4
570 changed files with 712 additions and 613 deletions

View file

@ -0,0 +1,7 @@
"""Deepagents-backed routes: ``subagents/``; main-agent graph under ``main_agent/`` (SRP subpackages)."""
from __future__ import annotations
from .main_agent import create_multi_agent_chat_deep_agent
__all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -0,0 +1,44 @@
"""Connector-type to subagent name; subagent name to availability tokens for build_subagents."""
from __future__ import annotations
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS: dict[str, str] = {
"GOOGLE_GMAIL_CONNECTOR": "gmail",
"COMPOSIO_GMAIL_CONNECTOR": "gmail",
"GOOGLE_CALENDAR_CONNECTOR": "calendar",
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "calendar",
"DISCORD_CONNECTOR": "discord",
"TEAMS_CONNECTOR": "teams",
"LUMA_CONNECTOR": "luma",
"LINEAR_CONNECTOR": "linear",
"JIRA_CONNECTOR": "jira",
"CLICKUP_CONNECTOR": "clickup",
"SLACK_CONNECTOR": "slack",
"AIRTABLE_CONNECTOR": "airtable",
"NOTION_CONNECTOR": "notion",
"CONFLUENCE_CONNECTOR": "confluence",
"GOOGLE_DRIVE_CONNECTOR": "google_drive",
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "google_drive",
"DROPBOX_CONNECTOR": "dropbox",
"ONEDRIVE_CONNECTOR": "onedrive",
}
SUBAGENT_TO_REQUIRED_CONNECTOR_MAP: dict[str, frozenset[str]] = {
"deliverables": frozenset(),
"knowledge_base": frozenset(),
"airtable": frozenset({"AIRTABLE_CONNECTOR"}),
"calendar": frozenset({"GOOGLE_CALENDAR_CONNECTOR"}),
"clickup": frozenset({"CLICKUP_CONNECTOR"}),
"confluence": frozenset({"CONFLUENCE_CONNECTOR"}),
"discord": frozenset({"DISCORD_CONNECTOR"}),
"dropbox": frozenset({"DROPBOX_FILE"}),
"gmail": frozenset({"GOOGLE_GMAIL_CONNECTOR"}),
"google_drive": frozenset({"GOOGLE_DRIVE_FILE"}),
"jira": frozenset({"JIRA_CONNECTOR"}),
"linear": frozenset({"LINEAR_CONNECTOR"}),
"luma": frozenset({"LUMA_CONNECTOR"}),
"notion": frozenset({"NOTION_CONNECTOR"}),
"onedrive": frozenset({"ONEDRIVE_FILE"}),
"slack": frozenset({"SLACK_CONNECTOR"}),
"teams": frozenset({"TEAMS_CONNECTOR"}),
}

View file

@ -0,0 +1,7 @@
"""Main-agent deep agent: ``runtime/`` (factory), ``graph/`` (compile), ``system_prompt/``, etc."""
from __future__ import annotations
from .runtime import create_multi_agent_chat_deep_agent
__all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -0,0 +1,7 @@
"""Tool-name pruning for context editing (exclude lists without dropping protected tools)."""
from __future__ import annotations
from .prune_tool_names import PRUNE_PROTECTED_TOOL_NAMES, safe_exclude_tools
__all__ = ["PRUNE_PROTECTED_TOOL_NAMES", "safe_exclude_tools"]

View file

@ -0,0 +1,26 @@
"""Tool names excluded from context-editing prune when bound."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset(
{
"generate_report",
"generate_resume",
"generate_podcast",
"generate_video_presentation",
"generate_image",
"read_email",
"search_emails",
"invalid",
},
)
def safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]:
"""Names from ``PRUNE_PROTECTED_TOOL_NAMES`` that appear in ``tools``."""
enabled = {t.name for t in tools}
return tuple(n for n in PRUNE_PROTECTED_TOOL_NAMES if n in enabled)

View file

@ -0,0 +1,7 @@
"""Sync compile of the main-agent LangGraph graph (middleware + ``create_agent``)."""
from __future__ import annotations
from .compile_graph_sync import build_compiled_agent_graph_sync
__all__ = ["build_compiled_agent_graph_sync"]

View file

@ -0,0 +1,83 @@
"""Synchronous graph compile (middleware + ``create_agent``)."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from deepagents import __version__ as deepagents_version
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.chat.multi_agent_chat.main_agent.middleware.stack import (
build_main_agent_deepagent_middleware,
)
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.shared.context import SurfSenseContextSchema
from app.db import ChatVisibility
def build_compiled_agent_graph_sync(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
final_system_prompt: str,
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
checkpointer: Checkpointer,
subagent_dependencies: dict[str, Any],
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
disabled_tools: list[str] | None = None,
):
"""Sync compile: middleware + ``create_agent`` (run via ``asyncio.to_thread``)."""
main_agent_middleware = build_main_agent_deepagent_middleware(
llm=llm,
tools=tools,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=max_input_tokens,
flags=flags,
subagent_dependencies=subagent_dependencies,
checkpointer=checkpointer,
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
)
agent = create_agent(
llm,
system_prompt=final_system_prompt,
tools=list(tools),
middleware=main_agent_middleware,
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer,
)
return agent.with_config(
{
"recursion_limit": 10_000,
"metadata": {
"ls_integration": "deepagents",
"versions": {"deepagents": deepagents_version},
},
}
)

View file

@ -0,0 +1,10 @@
"""Action-log middleware: audit row per tool call (impl + builder)."""
from .builder import build_action_log_mw
from .middleware import ActionLogMiddleware, ToolDefinition
__all__ = [
"ActionLogMiddleware",
"ToolDefinition",
"build_action_log_mw",
]

View file

@ -0,0 +1,36 @@
"""Audit row per tool call (reversibility metadata)."""
from __future__ import annotations
import logging
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from .middleware import ActionLogMiddleware
def build_action_log_mw(
*,
flags: AgentFeatureFlags,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
) -> ActionLogMiddleware | None:
if not enabled(flags, "enable_action_log") or thread_id is None:
return None
try:
# No built-in tool declares a ``reverse`` callable yet, so the action
# log runs without a tool_definitions map. Reversibility is opt-in per
# tool via ``ToolDefinition.reverse`` and can be wired here when used.
return ActionLogMiddleware(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
return None

View file

@ -0,0 +1,388 @@
"""Append-only action-log middleware for the SurfSense agent.
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
into reversibility by declaring a ``reverse`` callable on their
:class:`ToolDefinition`; the rendered descriptor is persisted in
``reverse_descriptor`` for use by
``/api/threads/{thread_id}/revert/{action_id}``.
Design points:
* **Defensive.** Logging never blocks the agent. We catch every exception
on the DB write path and emit a warning; the tool's ``ToolMessage``
result is always returned untouched.
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
remains in the LangGraph checkpoint / spilled tool-output files.
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
with the parsed JSON result when the tool's content is a JSON object;
otherwise the raw text is passed. Exceptions in the reverse callable
are swallowed and logged a failed descriptor render simply means the
action is NOT marked reversible.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.callbacks import adispatch_custom_event
from langchain_core.messages import ToolMessage
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
if TYPE_CHECKING: # pragma: no cover - type-only
from langchain.agents.middleware.types import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
@dataclass
class ToolDefinition:
"""Reversibility descriptor consumed by :class:`ActionLogMiddleware`.
Only ``name`` and ``reverse`` are read by the middleware; the remaining
fields let callers and tests describe a tool declaratively. A tool is
marked reversible in the action log when ``reverse`` is set and renders a
descriptor without raising.
Attributes:
name: Unique identifier for the tool.
description: Human-readable description of what the tool does.
factory: Optional callable that builds the tool (unused by the
middleware; retained for declarative call sites/tests).
reverse: Optional callable that, given the tool's ``(args, result)``,
returns a ``ReverseDescriptor`` describing the inverse invocation.
"""
name: str
description: str = ""
factory: Callable[[dict[str, Any]], Any] | None = None
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
# accidentally-huge inputs. Values are truncated and a flag is set in the
# stored payload so consumers can detect truncation.
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
class ActionLogMiddleware(AgentMiddleware):
"""Persist a row in :class:`AgentActionLog` after every tool call.
Should be placed near the OUTERMOST end of the tool-call wrapping stack
so that it sees the *final* :class:`ToolMessage` after all retries,
permission checks, and dedup logic have run. In practice that means
placing it just inside :class:`PermissionMiddleware` and outside
:class:`DedupHITLToolCallsMiddleware`.
The middleware is fully a no-op when:
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
(checked via :func:`get_flags`),
* the per-feature flag ``enable_action_log`` is off, or
* persistence raises (defensive: tool-call dispatch always succeeds).
Args:
thread_id: The current chat thread's primary-key id. Required to
persist a row; if ``None`` the middleware silently no-ops.
search_space_id: Search-space id for cascade-on-delete safety.
user_id: UUID string of the user driving this turn (nullable in
anonymous mode).
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
so the middleware can look up the tool's ``reverse`` callable.
When omitted, no actions are marked reversible.
"""
tools = ()
def __init__(
self,
*,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
tool_definitions: dict[str, ToolDefinition] | None = None,
) -> None:
super().__init__()
self._thread_id = thread_id
self._search_space_id = search_space_id
self._user_id = user_id
self._tool_definitions = dict(tool_definitions or {})
def _enabled(self) -> bool:
flags = get_flags()
if flags.disable_new_agent_stack:
return False
return bool(flags.enable_action_log) and self._thread_id is not None
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not self._enabled():
return await handler(request)
result: ToolMessage | Command[Any]
error_payload: dict[str, Any] | None = None
try:
result = await handler(request)
except Exception as exc:
# Persist the failure too so revert/audit can see it, then
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
error_payload = {"type": type(exc).__name__, "message": str(exc)}
await self._record(
request=request,
result=None,
error_payload=error_payload,
)
raise
await self._record(request=request, result=result, error_payload=None)
return result
async def _record(
self,
*,
request: ToolCallRequest,
result: ToolMessage | Command[Any] | None,
error_payload: dict[str, Any] | None,
) -> None:
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
try:
from app.db import AgentActionLog, shielded_async_session
tool_name = _resolve_tool_name(request)
args_payload = _resolve_args_payload(request)
result_id = _resolve_result_id(result)
reverse_descriptor, reversible = self._render_reverse(
tool_name=tool_name,
args=_resolve_args_dict(request),
result=result,
)
tool_call_id = _resolve_tool_call_id(request)
chat_turn_id = _resolve_chat_turn_id(request)
row = AgentActionLog(
thread_id=self._thread_id,
user_id=self._user_id,
search_space_id=self._search_space_id,
# ``turn_id`` is the deprecated alias of ``tool_call_id``
# kept for one release for safe rollback. New consumers
# should read ``tool_call_id`` directly.
turn_id=tool_call_id,
tool_call_id=tool_call_id,
chat_turn_id=chat_turn_id,
message_id=_resolve_message_id(request),
tool_name=tool_name,
args=args_payload,
result_id=result_id,
reversible=reversible,
reverse_descriptor=reverse_descriptor,
error=error_payload,
)
async with shielded_async_session() as session:
session.add(row)
await session.commit()
row_id = int(row.id) if row.id is not None else None
row_created_at = row.created_at
except Exception:
logger.warning(
"ActionLogMiddleware failed to persist action log row",
exc_info=True,
)
return
# Surface a side-channel SSE event so the chat tool card can
# render a Revert button immediately after the row is durable.
# ``stream_new_chat`` translates this into a
# ``data-action-log`` SSE event. We DO NOT include the
# ``reverse_descriptor`` payload here; only a presence flag.
try:
await adispatch_custom_event(
"action_log",
{
"id": row_id,
"lc_tool_call_id": tool_call_id,
"chat_turn_id": chat_turn_id,
"tool_name": tool_name,
"reversible": bool(reversible),
"reverse_descriptor_present": reverse_descriptor is not None,
"created_at": row_created_at.isoformat()
if row_created_at
else None,
"error": error_payload is not None,
},
)
except Exception:
logger.debug(
"ActionLogMiddleware failed to dispatch action_log event",
exc_info=True,
)
def _render_reverse(
self,
*,
tool_name: str,
args: dict[str, Any] | None,
result: ToolMessage | Command[Any] | None,
) -> tuple[dict[str, Any] | None, bool]:
"""Run the tool's ``reverse`` callable and return its descriptor.
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
the tool has no ``reverse`` callable, or when the callable raises,
the action is marked non-reversible.
"""
if not result or not isinstance(result, ToolMessage):
return None, False
if args is None:
return None, False
tool_def = self._tool_definitions.get(tool_name)
if tool_def is None or tool_def.reverse is None:
return None, False
try:
parsed_result = _parse_tool_result_content(result)
descriptor = tool_def.reverse(args, parsed_result)
except Exception:
logger.warning(
"Reverse descriptor render failed for tool %s",
tool_name,
exc_info=True,
)
return None, False
if not isinstance(descriptor, dict):
return None, False
return descriptor, True
# ---------------------------------------------------------------------------
# Resolution helpers — defensive against tool_call request shape variation.
# ---------------------------------------------------------------------------
def _resolve_tool_name(request: Any) -> str:
try:
tool = getattr(request, "tool", None)
if tool is not None:
name = getattr(tool, "name", None)
if isinstance(name, str) and name:
return name
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
name = call.get("name")
if isinstance(name, str) and name:
return name
except Exception: # pragma: no cover - defensive
pass
return "unknown"
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
try:
call = getattr(request, "tool_call", None)
if not isinstance(call, dict):
return None
args = call.get("args")
if isinstance(args, dict):
return args
return None
except Exception: # pragma: no cover - defensive
return None
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
"""Return a JSON-serializable args dict, truncated if too big."""
args = _resolve_args_dict(request)
if args is None:
return None
try:
encoded = json.dumps(args, default=str)
except Exception:
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
return args
return {
"_truncated": True,
"_size": len(encoded),
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
}
def _resolve_tool_call_id(request: Any) -> str | None:
"""Return the LangChain ``tool_call.id`` for this request, if any."""
try:
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
tid = call.get("id")
if isinstance(tid, str):
return tid
except Exception: # pragma: no cover
pass
return None
# Deprecated alias kept for one release. Old callers and tests treated
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
# lives under ``tool_call_id``. Both resolve to the same value today.
_resolve_turn_id = _resolve_tool_call_id
def _resolve_chat_turn_id(request: Any) -> str | None:
"""Return ``configurable.turn_id`` for this request, if accessible.
``ToolRuntime.config`` is exposed by LangGraph (see
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
lives at ``runtime.config["configurable"]["turn_id"]``.
"""
try:
runtime = getattr(request, "runtime", None)
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
configurable = config.get("configurable")
if not isinstance(configurable, dict):
return None
value = configurable.get("turn_id")
if isinstance(value, str) and value:
return value
except Exception: # pragma: no cover - defensive
pass
return None
def _resolve_message_id(request: Any) -> str | None:
"""Tool-call IDs serve as best-available message correlator at this layer."""
return _resolve_tool_call_id(request)
def _resolve_result_id(result: Any) -> str | None:
if isinstance(result, ToolMessage):
msg_id = getattr(result, "id", None)
if isinstance(msg_id, str):
return msg_id
return None
def _parse_tool_result_content(result: ToolMessage) -> Any:
content = result.content
if isinstance(content, str):
try:
return json.loads(content)
except (json.JSONDecodeError, ValueError):
return content
return content
__all__ = ["ActionLogMiddleware"]

View file

@ -0,0 +1,9 @@
"""Anonymous-document middleware: Redis hydration, cloud only (impl + builder)."""
from .builder import build_anonymous_doc_mw
from .middleware import AnonymousDocumentMiddleware
__all__ = [
"AnonymousDocumentMiddleware",
"build_anonymous_doc_mw",
]

View file

@ -0,0 +1,17 @@
"""Anonymous document hydration from Redis (cloud only)."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from .middleware import AnonymousDocumentMiddleware
def build_anonymous_doc_mw(
*,
filesystem_mode: FilesystemMode,
anon_session_id: str | None,
) -> AnonymousDocumentMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)

View file

@ -0,0 +1,96 @@
"""Lightweight middleware that loads the anonymous-session document into state.
Anonymous chats receive a single uploaded document via Redis (no DB row,
read-only). This middleware loads it once on the first turn into
``state['kb_anon_doc']`` so:
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
view without touching the DB.
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
degenerate priority list.
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
recognises the synthetic path.
The middleware is a no-op when ``anon_session_id`` is not provided or when
the document is already cached in state.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
from app.agents.chat.multi_agent_chat.shared.path_resolver import (
DOCUMENTS_ROOT,
safe_filename,
)
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
logger = logging.getLogger(__name__)
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Load the anonymous user's uploaded document from Redis into state."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(self, *, anon_session_id: str | None) -> None:
self.anon_session_id = anon_session_id
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if not self.anon_session_id:
return None
if state.get("kb_anon_doc"):
return None
anon_doc = await self._load_anon_document()
if anon_doc is None:
return None
return {"kb_anon_doc": anon_doc}
async def _load_anon_document(self) -> dict[str, Any] | None:
"""Read ``anon:doc:<session_id>`` from Redis."""
try:
import redis.asyncio as aioredis # local import to keep cold paths cheap
from app.config import config
redis_client = aioredis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
try:
redis_key = f"anon:doc:{self.anon_session_id}"
data = await redis_client.get(redis_key)
if not data:
return None
payload = json.loads(data)
finally:
await redis_client.aclose()
except Exception as exc:
logger.warning("Failed to load anonymous document from Redis: %s", exc)
return None
title = str(payload.get("filename") or "uploaded_document")
content = str(payload.get("content") or "")
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
return {
"path": path,
"title": title,
"content": content,
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
}
__all__ = ["AnonymousDocumentMiddleware"]

View file

@ -0,0 +1,13 @@
"""Per-thread cooperative lock around the whole turn."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.busy_mutex import (
BusyMutexMiddleware,
)
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None

View file

@ -0,0 +1,32 @@
"""SubAgent ``task`` tool wiring required for HITL inside subagents.
Replaces upstream ``SubAgentMiddleware`` to:
- share the parent's checkpointer with each subagent,
- forward ``runtime.config`` (thread_id, recursion_limit, ) into nested invokes,
- isolate each parallel ``task`` call in its own checkpoint slot via
per-call ``thread_id`` namespacing,
- bridge ``Command(resume=...)`` from the parent into the subagent via the
``config["configurable"]["surfsense_resume_value"]`` side-channel, keyed by
``tool_call_id`` so parallel siblings never race on a shared scalar,
- target the resume at the captured interrupt id so a follow-up
``HumanInTheLoopMiddleware.after_model`` does not consume the same payload,
- stamp each subagent's pending interrupt with the parent's ``tool_call_id``
so ``stream_resume_chat`` can route a flat ``decisions`` list back to the
right paused subagent.
Module layout
-------------
- ``constants`` shared keys / limits.
- ``config`` RunnableConfig + side-channel resume read + per-call ``thread_id``.
- ``resume`` pending-interrupt detection, fan-out, ``Command(resume=...)`` builder.
- ``propagation`` ``wrap_with_tool_call_id`` helper for stamping interrupt values.
- ``resume_routing`` slice a flat decisions list to per-``tool_call_id`` payloads.
- ``task_tool`` the ``task`` tool factory (sync + async), and the catch-and-stamp chokepoint.
- ``middleware`` :class:`SurfSenseCheckpointedSubAgentMiddleware` itself.
"""
from .middleware import SurfSenseCheckpointedSubAgentMiddleware
__all__ = ["SurfSenseCheckpointedSubAgentMiddleware"]

View file

@ -0,0 +1,125 @@
"""RunnableConfig wiring for nested subagent invocations.
Forwards the parent's ``runtime.config`` (thread_id, …) into the subagent and
exposes the side-channel ``stream_resume_chat`` uses to ferry resume payloads.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain.tools import ToolRuntime
from .constants import DEFAULT_SUBAGENT_RECURSION_LIMIT
logger = logging.getLogger(__name__)
# langgraph stores the parent task's scratchpad under this configurable key;
# subagents inherit the chain via ``parent_scratchpad`` fallback.
_LANGGRAPH_SCRATCHPAD_KEY = "__pregel_scratchpad"
def subagent_invoke_config(runtime: ToolRuntime) -> dict[str, Any]:
"""RunnableConfig for the nested invoke; raises ``recursion_limit`` and isolates ``thread_id``.
Each parallel subagent invocation lands in its own checkpoint slot keyed
by an extended ``thread_id`` of the form ``{parent_thread}::task:{tool_call_id}``.
The same call across the resume cycle keeps reading from the same snapshot
(``tool_call_id`` is stable per LLM-emitted call).
We namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``.
"""
merged: dict[str, Any] = dict(runtime.config) if runtime.config else {}
current_limit = merged.get("recursion_limit")
try:
current_int = int(current_limit) if current_limit is not None else 0
except (TypeError, ValueError):
current_int = 0
if current_int < DEFAULT_SUBAGENT_RECURSION_LIMIT:
merged["recursion_limit"] = DEFAULT_SUBAGENT_RECURSION_LIMIT
configurable: dict[str, Any] = dict(merged.get("configurable") or {})
parent_thread_id = configurable.get("thread_id")
per_call_suffix = f"task:{runtime.tool_call_id}"
configurable["thread_id"] = (
f"{parent_thread_id}::{per_call_suffix}"
if parent_thread_id
else per_call_suffix
)
merged["configurable"] = configurable
return merged
def consume_surfsense_resume(runtime: ToolRuntime) -> Any:
"""Pop the resume payload for *this* call's ``tool_call_id``.
The configurable holds ``surfsense_resume_value: dict[tool_call_id, payload]``
so parallel sibling subagents (each with their own ``tool_call_id``) read
only their own decision and never race on a shared scalar.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return None
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return None
payload = by_tcid.pop(runtime.tool_call_id, None)
if not by_tcid:
configurable.pop("surfsense_resume_value", None)
return payload
def has_surfsense_resume(runtime: ToolRuntime) -> bool:
"""True iff a resume payload for this call's ``tool_call_id`` is queued (non-destructive)."""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return False
by_tcid = configurable.get("surfsense_resume_value")
if not isinstance(by_tcid, dict):
return False
return runtime.tool_call_id in by_tcid
def drain_parent_null_resume(runtime: ToolRuntime) -> None:
"""Consume the parent's lingering ``NULL_TASK_ID/RESUME`` write before delegating.
``stream_resume_chat`` wakes the main agent with
``Command(resume={tool_call_id: {"decisions": [...]}})`` so the previously
propagated parent-level interrupt can return. langgraph stores that
payload as the parent task's ``null_resume`` pending write. The ``task``
tool then forwards this turn's slice into the subagent via its own
``Command(resume=...)``. While the subagent is mid-execution, any *new*
``interrupt()`` inside it (e.g. a follow-up tool call after a mixed
approve/reject) walks ``subagent_scratchpad parent_scratchpad.get_null_resume``
and picks up the parent's still-live decisions — mismatching against a
different number of hanging tool calls and crashing
``HumanInTheLoopMiddleware``.
Draining the write here closes that cross-graph leak so subagent
interrupts pause cleanly and bubble back up as a fresh approval card.
"""
cfg = runtime.config or {}
configurable = cfg.get("configurable") if isinstance(cfg, dict) else None
if not isinstance(configurable, dict):
return
scratchpad = configurable.get(_LANGGRAPH_SCRATCHPAD_KEY)
if scratchpad is None:
return
consume = getattr(scratchpad, "get_null_resume", None)
if not callable(consume):
return
try:
consume(True)
except Exception:
# Defensive: if langgraph's internal scratchpad shape changes we don't
# want to break the resume path. Worst case the original ValueError
# still surfaces — same behavior as before this fix.
logger.debug(
"drain_parent_null_resume: scratchpad.get_null_resume raised",
exc_info=True,
)

View file

@ -0,0 +1,89 @@
"""Constants shared by the checkpointed subagent middleware."""
from __future__ import annotations
import os
# Mirror of deepagents.middleware.subagents._EXCLUDED_STATE_KEYS.
EXCLUDED_STATE_KEYS = frozenset(
{
"messages",
"todos",
"structured_response",
"skills_metadata",
"memory_contents",
}
)
# Match the parent graph's budget; the LangGraph default of 25 trips on
# multi-step subagent runs.
DEFAULT_SUBAGENT_RECURSION_LIMIT = 10_000
def _read_timeout_env(name: str, default: float) -> float:
"""Parse ``name`` from the environment; fall back to ``default`` on bad values.
Kept as a free function so the module-level constants stay constants
after import; tests can monkeypatch this and re-evaluate via
``importlib.reload`` if they need a different value mid-process.
"""
raw = os.environ.get(name)
if not raw:
return default
try:
value = float(raw)
except (TypeError, ValueError):
return default
return value if value > 0 else default
# Wall-clock budget for a single ``task(subagent, ...)`` invocation.
# Subagents that run hot (image generation with slow vendors, KB writes
# behind a sluggish embedder) can otherwise wedge the orchestrator until
# the next checkpoint heartbeat. ``0`` disables the timeout entirely.
DEFAULT_SUBAGENT_INVOKE_TIMEOUT_SECONDS: float = _read_timeout_env(
"SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS",
default=300.0,
)
def _read_int_env(name: str, default: int) -> int:
raw = os.environ.get(name)
if not raw:
return default
try:
value = int(raw)
except (TypeError, ValueError):
return default
return value if value > 0 else default
# Maximum number of children that ``task(..., tasks=[...])`` runs in
# parallel via ``asyncio.gather`` + ``Semaphore``. Bounded so a runaway
# fanout cannot starve unrelated subagents (each child still owns an
# LLM call + DB session). Set ``SURFSENSE_TASK_BATCH_CONCURRENCY=1`` to
# effectively serialise batches without changing the schema.
DEFAULT_SUBAGENT_BATCH_CONCURRENCY: int = _read_int_env(
"SURFSENSE_TASK_BATCH_CONCURRENCY",
default=3,
)
# Max number of children in a single batched ``task`` call. Hard upper
# bound is a safety net for prompt-injection / runaway loops; the orchestrator
# rarely needs more than a handful of concurrent specialists.
MAX_SUBAGENT_BATCH_SIZE: int = _read_int_env(
"SURFSENSE_TASK_BATCH_MAX_SIZE",
default=8,
)
# Soft threshold for per-turn cumulative ``task(...)`` invocations across
# **all** subagents. Once the sum of ``state['billable_calls']`` values
# crosses this number, the runtime appends a one-shot warning ToolMessage
# instructing the orchestrator to wrap up the turn. Tunable so heavy-research
# turns (which legitimately need 15+ specialist calls) don't trip the alarm
# in production. Set to ``0`` to disable the warning entirely.
DEFAULT_SUBAGENT_BILLABLE_THRESHOLD: int = _read_int_env(
"SURFSENSE_SUBAGENT_BILLABLE_THRESHOLD",
default=15,
)

View file

@ -0,0 +1,152 @@
"""SubAgent middleware that compiles each subagent against the parent checkpointer."""
from __future__ import annotations
import time
from typing import Any, cast
from deepagents.backends.protocol import BackendFactory, BackendProtocol
from deepagents.middleware.subagents import (
TASK_SYSTEM_PROMPT,
CompiledSubAgent,
SubAgent,
SubAgentMiddleware,
)
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langgraph.types import Checkpointer
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
SURF_CONTEXT_HINT_PROVIDER_KEY,
)
from app.utils.perf import get_perf_logger
from .task_tool import build_task_tool_with_parent_config
_perf_log = get_perf_logger()
class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
"""``SubAgentMiddleware`` variant that compiles each subagent against the parent checkpointer."""
def __init__(
self,
*,
checkpointer: Checkpointer,
backend: BackendProtocol | BackendFactory,
subagents: list[SubAgent | CompiledSubAgent],
system_prompt: str | None = TASK_SYSTEM_PROMPT,
task_description: str | None = None,
search_space_id: int | None = None,
) -> None:
self._surf_checkpointer = checkpointer
super(SubAgentMiddleware, self).__init__()
if not subagents:
raise ValueError(
"At least one subagent must be specified when using the new API"
)
self._backend = backend
self._subagents = subagents
# Search-space id is captured at build time (the orchestrator runs in
# exactly one search space for its lifetime). The spawn-paused kill
# switch keys on it so an operator can quarantine one workspace
# without affecting the rest of the deployment.
self._search_space_id = search_space_id
subagent_specs = self._surf_compile_subagent_graphs()
task_tool = build_task_tool_with_parent_config(
subagent_specs,
task_description,
search_space_id=search_space_id,
)
if system_prompt and subagent_specs:
agents_desc = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagent_specs
)
self.system_prompt = (
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
)
else:
self.system_prompt = system_prompt
self.tools = [task_tool]
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
specs: list[dict[str, Any]] = []
loop_start = time.perf_counter()
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
for spec in self._subagents:
spec_start = time.perf_counter()
# Provider may be ``None`` (no hint), in which case task_tool
# skips the prepend step. We forward the key unconditionally so
# the registry shape is uniform.
hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY)
if "runnable" in spec:
compiled = cast(CompiledSubAgent, spec)
specs.append(
{
"name": compiled["name"],
"description": compiled["description"],
"runnable": compiled["runnable"],
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
)
continue
if "model" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'model'"
raise ValueError(msg)
if "tools" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
raise ValueError(msg)
model = spec["model"]
if isinstance(model, str):
model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", []))
tools_count = len(spec.get("tools") or [])
mw_count = len(middleware)
compile_start = time.perf_counter()
runnable = create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
)
compile_elapsed = time.perf_counter() - compile_start
specs.append(
{
"name": spec["name"],
"description": spec["description"],
"runnable": runnable,
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
(
spec["name"],
compile_elapsed,
f"compiled tools={tools_count} mw={mw_count}",
)
)
total_elapsed = time.perf_counter() - loop_start
per_subagent = ", ".join(
f"{name}={elapsed * 1000:.0f}ms[{source}]"
for name, elapsed, source in timings
)
_perf_log.info(
"[subagent_compile] total=%.3fs count=%d details=[%s]",
total_elapsed,
len(timings),
per_subagent,
)
return specs

View file

@ -0,0 +1,38 @@
"""Stamp the parent's ``tool_call_id`` onto a subagent's pending interrupt value.
When a subagent (compiled as a langgraph subgraph and invoked from a parent
tool node) hits an ``interrupt(...)`` from its HITL middleware, langgraph
raises ``GraphInterrupt`` out of ``subagent.[a]invoke(...)``. The parent's
``task`` tool catches that exception, stamps ``tool_call_id`` onto each
``Interrupt.value`` using :func:`wrap_with_tool_call_id`, and re-raises a
fresh ``GraphInterrupt`` whose values carry that stamp.
``stream_resume_chat`` then reads ``parent.state.interrupts[*].value["tool_call_id"]``
to route a flat ``decisions`` list back to the right paused subagent without
the stamp, parallel HITL across siblings would collapse into an ambiguous
bucket and resume would fail.
This module hosts only the stamping helper; the catch/re-raise lives in
``task_tool.py`` since that's the single chokepoint where the raw exception
is in our hands.
"""
from __future__ import annotations
from typing import Any
def wrap_with_tool_call_id(value: Any, tool_call_id: str) -> dict[str, Any]:
"""Return a value dict that always carries the parent's ``tool_call_id``.
Dict values are shallow-copied with ``tool_call_id`` stamped on top, so
any value the subagent may already carry under that key (from a deeper
HITL level) is overwritten the parent's call id is the only one
``stream_resume_chat`` correlates against.
Non-dict values are wrapped as ``{"value": <original>, "tool_call_id": ...}``
so simple ``interrupt("approve?")`` patterns still propagate cleanly.
"""
if isinstance(value, dict):
return {**value, "tool_call_id": tool_call_id}
return {"value": value, "tool_call_id": tool_call_id}

View file

@ -0,0 +1,76 @@
"""Resume-payload shaping and pending-interrupt detection for subagents.
Splits the work of "given a state snapshot and a parent-stashed resume value,
produce the right ``Command(resume=...)`` for the subagent" into pure helpers.
"""
from __future__ import annotations
from typing import Any
from langgraph.types import Command
def hitlrequest_action_count(pending_value: Any) -> int:
"""Bundle size for a LangChain ``HITLRequest`` payload; ``0`` for non-bundle interrupts."""
if not isinstance(pending_value, dict):
return 0
actions = pending_value.get("action_requests")
if isinstance(actions, list):
return len(actions)
return 0
def fan_out_decisions_to_match(resume_value: Any, expected_count: int) -> Any:
"""Legacy fallback: pad a 1-decision resume to N for an ``action_requests=N`` bundle.
Modern frontend submits N decisions per bundle (one per action_request) so
this is a no-op; kept for backwards compatibility with old in-flight
threads or non-bundle clients that send a single decision.
"""
if expected_count <= 1:
return resume_value
if not isinstance(resume_value, dict):
return resume_value
decisions = resume_value.get("decisions")
if not isinstance(decisions, list) or len(decisions) >= expected_count:
return resume_value
if not decisions:
return resume_value
padded = list(decisions) + [decisions[-1]] * (expected_count - len(decisions))
return {**resume_value, "decisions": padded}
def get_first_pending_subagent_interrupt(state: Any) -> tuple[str | None, Any]:
"""First pending ``(interrupt_id, value)``; ``(None, None)`` if no interrupt.
Assumes at most one pending interrupt per snapshot (sequential tool nodes).
Parallel tool nodes would need an id-aware lookup instead of first-wins.
"""
if state is None:
return None, None
for it in getattr(state, "interrupts", None) or ():
value = getattr(it, "value", None)
interrupt_id = getattr(it, "id", None)
if value is not None:
return (
interrupt_id if isinstance(interrupt_id, str) else None,
value,
)
for sub_task in getattr(state, "tasks", None) or ():
for it in getattr(sub_task, "interrupts", None) or ():
value = getattr(it, "value", None)
interrupt_id = getattr(it, "id", None)
if value is not None:
return (
interrupt_id if isinstance(interrupt_id, str) else None,
value,
)
return None, None
def build_resume_command(resume_value: Any, pending_id: str | None) -> Command:
"""``Command(resume={id: value})`` when ``id`` is known, else fall back to scalar."""
if pending_id is None:
return Command(resume=resume_value)
return Command(resume={pending_id: resume_value})

View file

@ -0,0 +1,183 @@
"""Route a flat ``decisions`` list to per-``tool_call_id`` resume payloads.
The frontend submits decisions in the same order the SSE stream emitted
approval cards. When multiple parallel subagents are paused, the backend uses
this module to:
1. Read ``state.interrupts`` from the parent's paused snapshot, extracting
``[(tool_call_id, action_count), ...]`` from each interrupt's value.
The ``tool_call_id`` is stamped on by ``propagation.wrap_with_tool_call_id``
inside ``task_tool``'s catch-and-stamp block when a subagent's
``GraphInterrupt`` bubbles up through ``[a]task``.
2. Slice the flat ``decisions`` list against that ordered pending list to
produce the dict shape expected by ``consume_surfsense_resume``.
3. Re-key those slices by ``Interrupt.id`` (langgraph's primitive) for use as
the parent-level ``Command(resume={interrupt_id: payload})`` input the
only shape langgraph accepts when multiple interrupts are pending.
All helpers are pure: callers own the state and the input decisions; we
return new structures and never mutate.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
from typing import Any
logger = logging.getLogger(__name__)
def slice_decisions_by_tool_call(
decisions: list[dict[str, Any]],
pending: Iterable[tuple[str, int]],
) -> dict[str, dict[str, Any]]:
"""Slice ``decisions`` into ``{tool_call_id: {"decisions": <slice>}}``.
Args:
decisions: Flat list of decisions in the order the SSE stream rendered
them.
pending: Ordered ``(tool_call_id, action_count)`` pairs in the same
order. The slicer consumes ``decisions`` left-to-right.
Returns:
Per-``tool_call_id`` payload dict ready to be written to
``configurable["surfsense_resume_value"]``.
Raises:
ValueError: When the total expected action count differs from the
number of decisions provided. We fail loud rather than silently
dropping or padding so a frontend/backend contract drift surfaces
immediately.
"""
pending_list = list(pending)
expected = sum(count for _, count in pending_list)
if expected != len(decisions):
raise ValueError(
f"Decision count mismatch: pending tool calls expect "
f"{expected} actions but received {len(decisions)} decisions."
)
routed: dict[str, dict[str, Any]] = {}
cursor = 0
for tool_call_id, action_count in pending_list:
routed[tool_call_id] = {"decisions": decisions[cursor : cursor + action_count]}
cursor += action_count
return routed
def collect_pending_tool_calls(state: Any) -> list[tuple[str, int]]:
"""Extract ``[(tool_call_id, action_count), ...]`` from a paused parent state.
Reads ``state.interrupts`` (the bundle langgraph aggregated from each
paused subagent's propagated interrupt). Each interrupt value carries the
``tool_call_id`` that the parent's ``task`` tool was processing — see
``propagation.wrap_with_tool_call_id`` and ``task_tool``'s
``except GraphInterrupt`` chokepoint.
Order is preserved from ``state.interrupts``, which is the order the SSE
stream emitted approval cards. The frontend submits decisions in that
same order, so the slicer can consume them left-to-right.
Interrupts without a ``tool_call_id`` are skipped they were not
produced by our task-routing layer (e.g. parent-side HITL middleware on
a different tool); ``stream_resume_chat`` is not responsible for routing
those.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` attribute).
Returns:
Ordered list of ``(tool_call_id, action_count)``. ``action_count`` is
``len(value["action_requests"])`` for HITL-bundle values, or ``1`` for
scalar-style ``interrupt("...")`` values that were wrapped as
``{"value": ..., "tool_call_id": ...}``.
Raises:
ValueError: When an interrupt value carries a ``tool_call_id`` but
the action count cannot be determined (contract bug every
propagated value should be either a HITL bundle or a wrapped
scalar).
"""
pending: list[tuple[str, int]] = []
for idx, interrupt_obj in enumerate(getattr(state, "interrupts", ()) or ()):
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
logger.warning(
"[hitl_route] interrupt[%d] skipped: value not a dict (type=%s)",
idx,
type(value).__name__,
)
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
# Should not happen post-stamping; flag loudly if a regression
# ever lets an unstamped value reach the parent state.
logger.warning(
"[hitl_route] interrupt[%d] skipped: no tool_call_id stamp (keys=%s)",
idx,
sorted(value.keys()),
)
continue
action_requests = value.get("action_requests")
if isinstance(action_requests, list):
pending.append((tool_call_id, len(action_requests)))
continue
if "value" in value:
pending.append((tool_call_id, 1))
continue
raise ValueError(
f"Interrupt for tool_call_id={tool_call_id!r} has no "
"``action_requests`` list and is not a wrapped scalar value; "
"cannot determine action count for resume routing."
)
return pending
def build_lg_resume_map(
state: Any, by_tool_call_id: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
"""Map ``Interrupt.id → resume_payload`` for langgraph's multi-interrupt resume.
``stream_resume_chat`` builds ``by_tool_call_id`` via
:func:`slice_decisions_by_tool_call`. Langgraph's ``Command(resume=...)``
requires ``Interrupt.id`` keys (not our ``tool_call_id`` stamps) when the
parent state has multiple pending interrupts. This pure helper re-keys the
slice without mutating it, and skips entries that can't be paired (no
stamp, no slice) so contract drift surfaces as a count mismatch at the
call site instead of a silent mis-route.
The two key spaces serve two different consumers:
- ``surfsense_resume_value`` (keyed by ``tool_call_id``): read by the
subagent bridge inside ``task_tool``.
- ``Command(resume=...)`` (keyed by ``Interrupt.id``): read by langgraph's
pregel to wake each pending interrupt site.
Args:
state: A langgraph ``StateSnapshot`` (or any object with an
``interrupts`` iterable).
by_tool_call_id: Output of :func:`slice_decisions_by_tool_call`.
Returns:
Dict ready to be passed as ``Command(resume=<this>)``.
"""
out: dict[str, dict[str, Any]] = {}
for interrupt_obj in getattr(state, "interrupts", ()) or ():
value = getattr(interrupt_obj, "value", None)
if not isinstance(value, dict):
continue
tool_call_id = value.get("tool_call_id")
if not isinstance(tool_call_id, str):
continue
interrupt_id = getattr(interrupt_obj, "id", None)
if not isinstance(interrupt_id, str):
continue
payload = by_tool_call_id.get(tool_call_id)
if payload is None:
continue
out[interrupt_id] = payload
return out

View file

@ -0,0 +1,84 @@
"""Per-search-space spawn-paused kill switch for the ``task`` boundary.
When operators see a runaway loop, a vendor outage, or a billing event
that requires immediate cessation of subagent traffic for a specific
workspace, they flip a Redis flag and the ``task`` tool short-circuits
without touching downstream services. The flag is **per-search-space**
so one tenant's incident never silences the rest of the deployment.
Flag key: ``surfsense:spawn_paused:{search_space_id}``
Flag value: any string-truthy value (we read presence, not contents).
TTL: set by whoever toggles the flag this module never expires
keys on its own, since "the flag is on" is itself the signal
that a human (or alert) needs to investigate.
The check is best-effort: Redis errors are logged but do not block the
``task`` invocation. Failing closed (block-on-redis-error) would let a
single Redis blip take the whole orchestrator offline; failing open
preserves availability and the alarm bells (rate-limits, cost spikes)
will surface the underlying outage.
"""
from __future__ import annotations
import contextlib
import logging
import os
from app.config import config
logger = logging.getLogger(__name__)
# Operators can disable the check entirely (e.g. local dev without Redis)
# by setting ``SURFSENSE_TASK_SPAWN_PAUSED_DISABLED=1``. Default is
# enabled so production never relies on flipping an opt-out flag.
_DISABLED = os.environ.get(
"SURFSENSE_TASK_SPAWN_PAUSED_DISABLED", ""
).strip().lower() in {
"1",
"true",
"yes",
"on",
}
def _flag_key(search_space_id: int) -> str:
return f"surfsense:spawn_paused:{search_space_id}"
async def is_spawn_paused(search_space_id: int | None) -> bool:
"""Return ``True`` iff the workspace's spawn-paused flag is set in Redis.
A ``None`` search-space (e.g. dev paths that did not plumb the id
through yet) bypasses the check. So does a Redis outage see module
docstring for the fail-open rationale.
"""
if _DISABLED or search_space_id is None:
return False
try:
# Local import keeps the cold-path import cheap and lets routes
# that never call ``task`` skip the redis dependency entirely.
import redis.asyncio as aioredis # type: ignore[import-not-found]
client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
try:
raw = await client.get(_flag_key(search_space_id))
finally:
# ``aclose()`` is the async-safe variant on redis-py >=5; fall back
# to ``close()`` for older clients pinned in tests.
close = getattr(client, "aclose", None) or getattr(client, "close", None)
if callable(close):
with contextlib.suppress(Exception):
await close() # type: ignore[misc]
return bool(raw)
except Exception:
logger.warning(
"spawn_paused check failed for search_space_id=%s; failing open.",
search_space_id,
exc_info=True,
)
return False
__all__ = ["is_spawn_paused"]

View file

@ -0,0 +1,15 @@
"""Schema-level description for the ``task`` tool.
Loaded from ``prompts/tools/task/description.md`` so the tool-schema text
and the ``<tools>`` block render from the same source.
"""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.main_agent.system_prompt.builder.load_md import (
read_prompt_md,
)
TASK_TOOL_DESCRIPTION: str = read_prompt_md("tools/task/description.md")
__all__ = ["TASK_TOOL_DESCRIPTION"]

View file

@ -0,0 +1,15 @@
"""Context-editing middleware: spill + clear-tool-uses passes (impl + builder)."""
from .builder import build_context_editing_mw
from .middleware import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
__all__ = [
"ClearToolUsesEdit",
"SpillToBackendEdit",
"SpillingContextEditingMiddleware",
"build_context_editing_mw",
]

View file

@ -0,0 +1,50 @@
"""Spill + clear-tool-uses passes to keep payloads under budget."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from langchain_core.tools import BaseTool
from app.agents.chat.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
safe_exclude_tools,
)
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from .middleware import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
def build_context_editing_mw(
*,
flags: AgentFeatureFlags,
max_input_tokens: int | None,
tools: Sequence[BaseTool],
backend_resolver: Any,
) -> SpillingContextEditingMiddleware | None:
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
return None
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
return SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)

View file

@ -0,0 +1,350 @@
"""
SpillToBackendEdit + SpillingContextEditingMiddleware.
LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content``
when the context-editing budget triggers, replacing the body with a fixed
placeholder. That's lossy: anything the agent might want to revisit is
gone. The spill-to-disk pattern (originally from OpenCode's
``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune
behavior but writes the full original payload to the runtime backend
under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The
placeholder is then upgraded to point at the spill path so the agent
(or a subagent) can read it back on demand.
Why this is a middleware subclass instead of a plain ``ContextEdit``:
``ContextEdit.apply`` is sync, but writing to the backend is async. We
capture the spill payloads inside ``apply`` and flush them via
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
delegating to the handler, so the explore subagent can always read what
the placeholder advertises.
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Awaitable, Callable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware.context_editing import (
ClearToolUsesEdit,
ContextEdit,
ContextEditingMiddleware,
TokenCounter,
)
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from langgraph.config import get_config
if TYPE_CHECKING:
from deepagents.backends.protocol import BackendProtocol
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
)
logger = logging.getLogger(__name__)
DEFAULT_SPILL_PREFIX = "/tool_outputs"
def _build_spill_placeholder(spill_path: str) -> str:
"""Build the user-facing placeholder text shown to the model."""
return (
f"[cleared — full output at {spill_path}; ask the explore subagent to read it]"
)
def _get_thread_id_or_session() -> str:
"""Best-effort thread_id discovery for the spill path.
Falls back to a process-stable string if no LangGraph config is
available (e.g. unit tests). The exact value doesn't matter as long
as it's stable within one stream so the placeholder paths line up
with the actual upload path.
"""
try:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is not None:
return str(thread_id)
except RuntimeError:
pass
return "no_thread"
@dataclass(slots=True)
class SpillToBackendEdit(ContextEdit):
"""Capture-and-replace context edit that spills full tool output to the backend.
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
semantics) **and** records the original ``ToolMessage.content`` in
:attr:`pending_spills` so the wrapping middleware can flush them
before the model call.
Args:
trigger: Token threshold above which the edit fires.
clear_at_least: Minimum number of tokens to reclaim (best effort).
keep: Number of most-recent ``ToolMessage`` instances to leave
untouched.
exclude_tools: Names of tools whose output is NOT spilled.
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
args when their pair is cleared.
path_prefix: Path under the backend where spills are written.
Default ``"/tool_outputs"``.
"""
trigger: int = 100_000
clear_at_least: int = 0
keep: int = 3
clear_tool_inputs: bool = False
exclude_tools: Sequence[str] = ()
path_prefix: str = DEFAULT_SPILL_PREFIX
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
_lock: threading.Lock = field(default_factory=threading.Lock)
def drain_pending(self) -> list[tuple[str, bytes]]:
"""Return and clear the pending-spill list atomically."""
with self._lock:
out = list(self.pending_spills)
self.pending_spills.clear()
return out
def apply(
self,
messages: list[AnyMessage],
*,
count_tokens: TokenCounter,
) -> None:
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
tokens = count_tokens(messages)
if tokens <= self.trigger:
return
candidates = [
(idx, msg)
for idx, msg in enumerate(messages)
if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
return
if self.keep:
candidates = candidates[: -self.keep]
thread_id = _get_thread_id_or_session()
excluded_tools = set(self.exclude_tools)
for idx, tool_message in candidates:
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
continue
ai_message = next(
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
None,
)
if ai_message is None:
continue
tool_call = next(
(
call
for call in ai_message.tool_calls
if call.get("id") == tool_message.tool_call_id
),
None,
)
if tool_call is None:
continue
tool_name = tool_message.name or tool_call["name"]
if tool_name in excluded_tools:
continue
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
original = tool_message.content
payload = self._encode_payload(original)
with self._lock:
self.pending_spills.append((spill_path, payload))
messages[idx] = tool_message.model_copy(
update={
"artifact": None,
"content": _build_spill_placeholder(spill_path),
"response_metadata": {
**tool_message.response_metadata,
"context_editing": {
"cleared": True,
"strategy": "spill_to_backend",
"spill_path": spill_path,
},
},
}
)
if self.clear_tool_inputs:
ai_idx = messages.index(ai_message)
messages[ai_idx] = self._clear_input_args(
ai_message, tool_message.tool_call_id or ""
)
if self.clear_at_least > 0:
new_token_count = count_tokens(messages)
cleared_tokens = max(0, tokens - new_token_count)
if cleared_tokens >= self.clear_at_least:
break
@staticmethod
def _encode_payload(content: Any) -> bytes:
"""Serialize ``ToolMessage.content`` to bytes for upload."""
if isinstance(content, bytes):
return content
if isinstance(content, str):
return content.encode("utf-8")
try:
import json
return json.dumps(content, default=str).encode("utf-8")
except Exception:
return str(content).encode("utf-8")
@staticmethod
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
updated_tool_calls: list[dict[str, Any]] = []
cleared_any = False
for tool_call in message.tool_calls:
updated = dict(tool_call)
if updated.get("id") == tool_call_id:
updated["args"] = {}
cleared_any = True
updated_tool_calls.append(updated)
metadata = dict(getattr(message, "response_metadata", {}))
if cleared_any:
ctx = dict(metadata.get("context_editing", {}))
ids = set(ctx.get("cleared_tool_inputs", []))
ids.add(tool_call_id)
ctx["cleared_tool_inputs"] = sorted(ids)
metadata["context_editing"] = ctx
return message.model_copy(
update={
"tool_calls": updated_tool_calls,
"response_metadata": metadata,
}
)
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
Runs the configured edits as the parent does, then flushes any
pending spills via the supplied backend resolver before delegating
to the model handler. Spill failures are logged but never abort the
model call the placeholder text is already in the message, so the
worst case is the agent gets a placeholder it cannot follow up on.
"""
def __init__(
self,
*,
edits: Sequence[ContextEdit],
backend_resolver: BackendResolver | None = None,
token_count_method: str = "approximate",
) -> None:
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
self._backend_resolver = backend_resolver
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
if self._backend_resolver is None:
return None
if callable(self._backend_resolver):
try:
from langchain.tools import ToolRuntime
tool_runtime = ToolRuntime(
state=getattr(request, "state", {}),
context=getattr(request.runtime, "context", None),
stream_writer=getattr(request.runtime, "stream_writer", None),
store=getattr(request.runtime, "store", None),
config=getattr(request.runtime, "config", None) or {},
tool_call_id=None,
)
return self._backend_resolver(tool_runtime)
except Exception:
logger.exception("Failed to resolve spill backend")
return None
return self._backend_resolver # type: ignore[return-value]
def _collect_pending(self) -> list[tuple[str, bytes]]:
out: list[tuple[str, bytes]] = []
for edit in self.edits:
if isinstance(edit, SpillToBackendEdit):
out.extend(edit.drain_pending())
return out
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> Any:
if not request.messages:
return await handler(request)
if self.token_count_method == "approximate":
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = [request.system_message] if request.system_message else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(edited_messages, count_tokens=count_tokens)
pending = self._collect_pending()
if pending:
backend = self._resolve_backend(request)
if backend is not None:
try:
await backend.aupload_files(pending)
except Exception:
logger.exception(
"Spill-to-backend upload failed (%d files); placeholders "
"remain in messages but content is unrecoverable",
len(pending),
)
else:
logger.warning(
"SpillToBackendEdit produced %d pending spills but no backend "
"resolver was configured; content is unrecoverable",
len(pending),
)
return await handler(request.override(messages=edited_messages))
__all__ = [
"DEFAULT_SPILL_PREFIX",
"ClearToolUsesEdit",
"SpillToBackendEdit",
"SpillingContextEditingMiddleware",
"_build_spill_placeholder",
]

View file

@ -0,0 +1,128 @@
"""Drop duplicate HITL tool calls before execution.
When the LLM emits multiple calls to the same HITL tool with the same
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
only the first call is kept. Non-HITL tools are never touched.
This runs in the ``after_model`` hook **before** any tool executes so
the duplicate call is stripped from the AIMessage that gets checkpointed.
That means it is also safe across LangGraph ``interrupt()`` boundaries:
the removed call will never appear on graph resume.
Dedup-key resolution order (read from each tool's own ``metadata``):
1. ``tool.metadata["dedup_key"]`` callable mapping the args dict to a
stable signature string. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string naming a primary arg;
used by MCP / Composio tools that only expose a single key field.
A tool with no resolver from either path simply opts out of dedup.
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from app.agents.chat.multi_agent_chat.shared.middleware.dedup_tool_calls import (
DedupResolver,
wrap_dedup_key_by_arg_name,
)
logger = logging.getLogger(__name__)
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Remove duplicate HITL tool calls from a single LLM response.
Only the **first** occurrence of each ``(tool-name, dedup_key)``
pair is kept; subsequent duplicates are silently dropped.
The dedup-resolver map is built from two sources, in priority order:
1. ``tool.metadata["dedup_key"]`` callable that receives the args dict
and returns a string signature. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string with a primary arg
name; primarily used by MCP / Composio tools.
"""
tools = ()
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
self._resolvers: dict[str, DedupResolver] = {}
for t in agent_tools or []:
meta = getattr(t, "metadata", None) or {}
callable_key = meta.get("dedup_key")
if callable(callable_key):
self._resolvers[t.name] = callable_key
continue
if meta.get("hitl") and meta.get("hitl_dedup_key"):
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
meta["hitl_dedup_key"]
)
def after_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
async def aafter_model(
self, state: AgentState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
return self._dedup(state, self._resolvers)
@staticmethod
def _dedup(
state: AgentState,
resolvers: dict[str, DedupResolver],
) -> dict[str, Any] | None:
messages = state.get("messages")
if not messages:
return None
last_msg = messages[-1]
if last_msg.type != "ai" or not getattr(last_msg, "tool_calls", None):
return None
tool_calls: list[dict[str, Any]] = last_msg.tool_calls
seen: set[tuple[str, str]] = set()
deduped: list[dict[str, Any]] = []
for tc in tool_calls:
name = tc.get("name", "")
resolver = resolvers.get(name)
if resolver is not None:
try:
arg_val = resolver(tc.get("args", {}) or {})
except Exception:
logger.exception(
"Dedup resolver for tool %s raised; keeping call", name
)
deduped.append(tc)
continue
key = (name, arg_val)
if key in seen:
logger.info(
"Dedup: dropped duplicate HITL tool call %s(%s)",
name,
arg_val,
)
continue
seen.add(key)
deduped.append(tc)
if len(deduped) == len(tool_calls):
return None
updated_msg = last_msg.model_copy(update={"tool_calls": deduped})
return {"messages": [updated_msg]}
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))

View file

@ -0,0 +1,9 @@
"""Doom-loop middleware: detect repeated identical tool calls (impl + builder)."""
from .builder import build_doom_loop_mw
from .middleware import DoomLoopMiddleware
__all__ = [
"DoomLoopMiddleware",
"build_doom_loop_mw",
]

View file

@ -0,0 +1,14 @@
"""Stop N identical tool calls in a row via interrupt."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from .middleware import DoomLoopMiddleware
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
return (
DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
)

View file

@ -0,0 +1,238 @@
"""
DoomLoopMiddleware pattern-based detector for repeated identical tool calls.
LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number
of tool calls per turn but it can't tell apart "10 distinct, useful
calls" from "the same call 10 times in a row". This middleware fills that
gap with a sliding-window check on tool-call signatures, ported from
OpenCode's ``packages/opencode/src/session/processor.ts``.
When the same tool with the same arguments is called N times in a row,
the agent has likely entered an infinite loop. We surface this to the
user as an interrupt with ``permission="doom_loop"`` so the UI can
render an "Are you stuck? Continue / cancel?" affordance.
This ships **OFF by default** until the frontend explicitly handles
``context.permission == "doom_loop"`` interrupts.
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
(see ``app/agents/shared/tools/hitl.py``):
{
"type": "permission_ask",
"action": {"tool": <name>, "params": <args>},
"context": {"permission": "doom_loop", "recent_signatures": [...]},
}
so the frontend that already handles HITL prompts can render this with
no changes beyond a string check.
"""
from __future__ import annotations
import hashlib
import json
import logging
from collections import deque
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from app.observability import metrics as ot_metrics, otel as ot
logger = logging.getLogger(__name__)
def _signature(name: str, args: Any) -> str:
"""Hash a tool call ``(name, args)`` to a short signature."""
try:
canonical = json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
canonical = repr(args)
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
return digest[:16]
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Detect repeated identical tool calls and prompt the user.
Tracks a sliding window of the most-recent ``threshold`` tool-call
signatures across the live request. When all entries match, raise
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
Args:
threshold: How many consecutive identical signatures count as a
doom loop. Default 3 (matches OpenCode's processor.ts).
"""
def __init__(self, *, threshold: int = 3) -> None:
super().__init__()
if threshold < 2:
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
self._threshold = threshold
self.tools = []
# Per-thread sliding windows. We can't put this in graph state
# without state-schema gymnastics; for one process-lifetime it's
# fine to keep an in-memory map keyed by thread_id.
self._windows: dict[str, deque[str]] = {}
@staticmethod
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
"""Resolve the thread id for sliding-window keying.
Prefer LangGraph's ``get_config()`` (the only way to read
``RunnableConfig`` inside a node :class:`Runtime` does NOT carry
a ``config`` attribute). Fall back to ``runtime.config`` for unit
tests that synthesize a config-bearing stub. Default
``"no_thread"`` is intentionally only used when both lookups fail
it would collapse all threads into one window so we keep the
debug log loud.
"""
def _from_dict(cfg: Any) -> str | None:
if not isinstance(cfg, dict):
return None
tid = (cfg.get("configurable") or {}).get("thread_id")
return str(tid) if tid is not None else None
try:
tid = _from_dict(get_config())
except Exception:
tid = None
if tid is not None:
return tid
tid = _from_dict(getattr(runtime, "config", None))
if tid is not None:
return tid
logger.debug(
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
"falling back to shared 'no_thread' window."
)
return "no_thread"
def _window(self, thread_id: str) -> deque[str]:
win = self._windows.get(thread_id)
if win is None:
win = deque(maxlen=self._threshold)
self._windows[thread_id] = win
return win
def _detect(
self, message: AIMessage, runtime: Runtime[ContextT]
) -> tuple[bool, list[str], dict[str, Any] | None]:
if not message.tool_calls:
return False, [], None
thread_id = self._thread_id_from_runtime(runtime)
window = self._window(thread_id)
triggered_call: dict[str, Any] | None = None
for call in message.tool_calls:
name = (
call.get("name")
if isinstance(call, dict)
else getattr(call, "name", None)
)
args = (
call.get("args")
if isinstance(call, dict)
else getattr(call, "args", {})
)
if not isinstance(name, str):
continue
sig = _signature(name, args)
window.append(sig)
if len(window) >= self._threshold and len(set(window)) == 1:
triggered_call = {"name": name, "params": args or {}}
break
if triggered_call is None:
return False, list(window), None
return True, list(window), triggered_call
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
triggered, signatures, action = self._detect(last, runtime)
if not triggered:
return None
logger.warning(
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
action["name"] if action else "<unknown>",
self._threshold,
signatures[-1] if signatures else "<empty>",
)
# Open an interrupt.raised span with permission=doom_loop attribute
# so dashboards can break out doom-loop interrupts from regular
# permission asks via the ``interrupt.permission`` attribute.
with ot.interrupt_span(
interrupt_type="permission_ask",
extra={
"interrupt.permission": "doom_loop",
"interrupt.threshold": self._threshold,
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
},
):
ot_metrics.record_interrupt(interrupt_type="permission_ask")
decision = interrupt(
{
"type": "permission_ask",
"action": action or {"tool": "<unknown>", "params": {}},
"context": {
"permission": "doom_loop",
"recent_signatures": signatures,
"threshold": self._threshold,
},
}
)
# Reset window so the next decision (continue/cancel) starts fresh.
thread_id = self._thread_id_from_runtime(runtime)
self._windows.pop(thread_id, None)
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
# If the user cancelled, jump to end. Otherwise return ``None`` so the
# tool call proceeds. The frontend's exact reply names may differ —
# we tolerate any shape that contains a string with "reject"/"cancel".
if isinstance(decision, dict):
kind = str(
decision.get("decision_type") or decision.get("type") or ""
).lower()
if "reject" in kind or "cancel" in kind:
return {"jump_to": "end"}
return None
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"DoomLoopMiddleware",
"_signature",
]

View file

@ -0,0 +1,25 @@
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.multi_agent_chat.shared.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
)
def build_kb_persistence_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
) -> KnowledgeBasePersistenceMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)

View file

@ -0,0 +1,32 @@
"""KB priority planner: <priority_documents> injection."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
KnowledgePriorityMiddleware,
)
from app.services.llm_service import get_planner_llm
def build_knowledge_priority_mw(
*,
llm: BaseChatModel,
search_space_id: int,
filesystem_mode: FilesystemMode,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
) -> KnowledgePriorityMiddleware:
return KnowledgePriorityMiddleware(
llm=llm,
planner_llm=get_planner_llm(),
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
inject_system_message=False,
)

View file

@ -0,0 +1,9 @@
"""Knowledge-tree middleware: <workspace_tree> injection, cloud only (impl + builder)."""
from .builder import build_knowledge_tree_mw
from .middleware import KnowledgeTreeMiddleware
__all__ = [
"KnowledgeTreeMiddleware",
"build_knowledge_tree_mw",
]

View file

@ -0,0 +1,25 @@
"""<workspace_tree> injection (cloud only)."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from .middleware import KnowledgeTreeMiddleware
def build_knowledge_tree_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
llm: BaseChatModel,
) -> KnowledgeTreeMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
inject_system_message=False,
)

View file

@ -0,0 +1,334 @@
"""Workspace-tree middleware for the SurfSense agent.
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
injects the result as a ``<workspace_tree>`` system message immediately
before the latest human turn.
The render is bounded by two truncation layers:
1. **Entry cap** at most ``MAX_TREE_ENTRIES`` lines. The remainder is
replaced with a "use ls" hint.
2. **Token cap** at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
token-count profile when available). If the entry-truncated tree still
exceeds the token cap we fall back to a root-only summary.
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
This middleware also performs a one-time initialization of ``state['cwd']``
to ``"/documents"`` so subsequent middlewares and tools always see a valid
cwd in cloud mode.
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.multi_agent_chat.shared.path_resolver import (
DOCUMENTS_ROOT,
PathIndex,
build_path_index,
doc_to_virtual_path,
)
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
from app.db import Document, shielded_async_session
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
try:
from litellm import token_counter
except Exception: # pragma: no cover - optional dep
token_counter = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
MAX_TREE_ENTRIES = 500
MAX_TREE_TOKENS = 4000
def _approx_tokens(text: str) -> int:
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
return max(1, (len(text) + 3) // 4)
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
if llm is None:
return _approx_tokens(text)
count_fn = getattr(llm, "_count_tokens", None)
if callable(count_fn):
try:
return int(count_fn([{"role": "user", "content": text}]))
except Exception:
pass
profile = getattr(llm, "profile", None)
model_names: list[str] = []
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
model_names.extend(name for name in tcms if isinstance(name, str) and name)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in model_names:
model_names.append(tcm)
model_name = model_names[0] if model_names else getattr(llm, "model", None)
if not isinstance(model_name, str) or not model_name or token_counter is None:
return _approx_tokens(text)
try:
return int(
token_counter(
messages=[{"role": "user", "content": text}],
model=model_name,
)
)
except Exception:
return _approx_tokens(text)
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Inject the workspace folder/document tree into the agent's context."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(
self,
*,
search_space_id: int,
filesystem_mode: FilesystemMode,
llm: BaseChatModel | None = None,
max_entries: int = MAX_TREE_ENTRIES,
max_tokens: int = MAX_TREE_TOKENS,
inject_system_message: bool = True, # For backwards compatibility
) -> None:
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
self.llm = llm
self.max_entries = max_entries
self.max_tokens = max_tokens
self.inject_system_message = inject_system_message
self._cache: dict[tuple[int, int, bool], str] = {}
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
start = time.perf_counter()
update: dict[str, Any] = {}
if not state.get("cwd"):
update["cwd"] = DOCUMENTS_ROOT
anon_doc = state.get("kb_anon_doc")
if anon_doc:
tree_msg = self._render_anon_tree(anon_doc)
cache_outcome = "anon"
else:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cache_outcome = "hit" if cache_key in self._cache else "miss"
tree_msg = await self._render_kb_tree(state)
update["workspace_tree_text"] = tree_msg
if self.inject_system_message:
messages = list(state.get("messages") or [])
insert_at = max(len(messages) - 1, 0)
messages.insert(insert_at, SystemMessage(content=tree_msg))
update["messages"] = messages
_perf_log.info(
"[knowledge_tree] cache=%s chars=%d elapsed=%.3fs space=%d",
cache_outcome,
len(tree_msg),
time.perf_counter() - start,
self.search_space_id,
)
return update
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
# ------------------------------------------------------------------ render
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
path = str(anon_doc.get("path") or "")
title = str(anon_doc.get("title") or "uploaded_document")
return (
"<workspace_tree>\n"
"Anonymous session — only one read-only document is available.\n"
f"{DOCUMENTS_ROOT}/\n"
f" {path}{title}\n"
"</workspace_tree>"
)
async def _render_kb_tree(self, state: AgentState) -> str:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cached = self._cache.get(cache_key)
if cached is not None:
return cached
try:
async with shielded_async_session() as session:
index = await build_path_index(session, self.search_space_id)
doc_rows = await session.execute(
select(Document.id, Document.title, Document.folder_id).where(
Document.search_space_id == self.search_space_id
)
)
docs = list(doc_rows.all())
except Exception as exc: # pragma: no cover - defensive
logger.warning("knowledge_tree: DB error %s", exc)
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
rendered = self._format_tree(index, docs)
self._cache[cache_key] = rendered
return rendered
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
folder_paths = sorted(set(index.folder_paths.values()))
doc_paths = sorted(
doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
for row in docs
)
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
# Pre-compute which folders have at least one descendant (folder or doc).
# A folder is "empty" iff no path in `all_paths` is strictly under it.
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
# infer emptiness from indentation alone.
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
lines: list[str] = []
for path in all_paths:
depth = (
0
if path == DOCUMENTS_ROOT
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
)
indent = " " * depth
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
display = (
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
)
if is_dir:
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
lines.append(f"{indent}{display}/ (empty)")
else:
lines.append(f"{indent}{display}/")
else:
lines.append(f"{indent}{display}")
if len(lines) >= self.max_entries:
remaining = len(all_paths) - len(lines)
if remaining > 0:
lines.append(
f"... {remaining} more entries — use "
"ls('/documents/<folder>', offset, limit) to expand"
)
break
body = "\n".join(lines)
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
token_count = _count_tokens(rendered, llm=self.llm)
if token_count <= self.max_tokens:
return rendered
return self._format_root_summary(folder_paths, doc_paths)
@staticmethod
def _compute_non_empty_folders(
folder_paths: list[str], doc_paths: list[str]
) -> set[str]:
"""Return the set of folder paths that contain at least one descendant.
A folder is "non-empty" if any document path or any other folder path
is strictly under it. Documents propagate emptiness up to every
ancestor folder, while a sub-folder only marks its direct ancestors
non-empty (so a chain of empty folders all read ``(empty)``).
"""
non_empty: set[str] = set()
folder_set = set(folder_paths)
for doc_path in doc_paths:
parent = doc_path.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT:
if parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
for child in folder_paths:
parent = child.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
return non_empty
def _format_root_summary(
self, folder_paths: list[str], doc_paths: list[str]
) -> str:
top_level: dict[str, int] = {}
loose_docs = 0
for path in doc_paths:
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
if "/" in rel:
top = rel.split("/", 1)[0]
top_level[top] = top_level.get(top, 0) + 1
else:
loose_docs += 1
for path in folder_paths:
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
if not rel:
continue
top = rel.split("/", 1)[0]
top_level.setdefault(top, 0)
lines = [DOCUMENTS_ROOT + "/"]
for name in sorted(top_level):
count = top_level[name]
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
if loose_docs:
lines.append(
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
)
lines.append(
"Tree is large; use list_tree('/documents/<folder>') to drill in "
"or ls('/documents/<folder>', offset, limit) for paginated listings."
)
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
__all__ = ["KnowledgeTreeMiddleware"]

View file

@ -0,0 +1,9 @@
"""Noop-injection middleware: provider-compat _noop tool (impl + builder)."""
from .builder import build_noop_injection_mw
from .middleware import NoopInjectionMiddleware
__all__ = [
"NoopInjectionMiddleware",
"build_noop_injection_mw",
]

View file

@ -0,0 +1,12 @@
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from .middleware import NoopInjectionMiddleware
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None

View file

@ -0,0 +1,141 @@
"""
``_noop`` provider-compatibility tool + injection middleware.
Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has
empty ``tools`` but the message history includes prior ``tool_calls``
they treat that shape as malformed even though it's perfectly valid
LangChain. SurfSense hits this on the compaction summarize call (no
tools, history full of tool calls).
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``,
which discovered and codified the workaround: inject a no-op tool *only*
on those provider shapes so the request validates without ever being
called.
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
if the request has zero tools but the last AI message in history includes
``tool_calls``. If yes, it injects the ``_noop`` tool only never
globally mirroring OpenCode's gating exactly. The :func:`noop_tool`
returns empty content when called (which it should never be in
practice).
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import AIMessage
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
NOOP_TOOL_NAME = "_noop"
NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility."
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
def noop_tool() -> str:
"""Return empty content. Never expected to be called."""
return ""
# Provider markers that benefit from ``_noop`` injection. These match
# OpenCode's gating list (``llm.ts:209-228``). We also accept any string
# containing one of these substrings so e.g. ``litellm`` matches
# ``ChatLiteLLM``.
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
"litellm",
"bedrock",
"copilot",
)
def _provider_needs_noop(model: Any) -> bool:
"""Heuristic: does this model's provider need the _noop injection?"""
try:
ls_params = model._get_ls_params()
provider = str(ls_params.get("ls_provider", "")).lower()
except Exception:
provider = ""
if not provider:
cls_name = type(model).__name__.lower()
provider = cls_name
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return bool(msg.tool_calls)
return False
class NoopInjectionMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
The check fires per model call, not at agent build time, because the
summarization path generates a no-tool subcall at runtime. The
extra tool is appended to ``request.tools`` as an instance the
actual ``langchain_core.tools.BaseTool`` is bound on every call site
that creates the agent.
"""
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
super().__init__()
self._noop_tool = noop_tool_instance or noop_tool
self.tools = []
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
if request.tools:
return False
if not _last_ai_has_tool_calls(request.messages):
return False
return _provider_needs_noop(request.model)
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
return request.override(tools=[self._noop_tool])
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")
return handler(self._augmented(request))
return handler(request)
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> Any:
if self._should_inject(request):
logger.debug("Injecting _noop tool for provider compatibility")
return await handler(self._augmented(request))
return await handler(request)
__all__ = [
"NOOP_TOOL_DESCRIPTION",
"NOOP_TOOL_NAME",
"NoopInjectionMiddleware",
"_provider_needs_noop",
"noop_tool",
]

View file

@ -0,0 +1,9 @@
"""OTel-span middleware: spans on model and tool calls (impl + builder)."""
from .builder import build_otel_mw
from .middleware import OtelSpanMiddleware
__all__ = [
"OtelSpanMiddleware",
"build_otel_mw",
]

View file

@ -0,0 +1,12 @@
"""OTel spans on model and tool calls."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from .middleware import OtelSpanMiddleware
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None

View file

@ -0,0 +1,285 @@
"""
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
executions) with OTel spans, attaching low-cardinality span names and
high-cardinality identifiers as attributes.
This middleware is intentionally a thin adapter over
:mod:`app.observability.otel`; when OTel is not configured all spans
collapse to no-ops and the wrapper adds <1µs overhead per call. When
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
model and tool call gets a span with the standard attributes our
dashboards expect.
"""
from __future__ import annotations
import logging
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, ToolMessage
from app.observability import metrics as ot_metrics, otel as ot
if TYPE_CHECKING: # pragma: no cover — type-only
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
ToolCallRequest,
)
from langgraph.types import Command
logger = logging.getLogger(__name__)
class OtelSpanMiddleware(AgentMiddleware):
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
Should be placed near the **outer** end of the middleware list so
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
model.call spans for ``N`` retry attempts) but inside any concurrency/
auth gate. Empirically this means **between** ``BusyMutex`` and
``RetryAfter``.
"""
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
super().__init__()
self._instrumentation_name = instrumentation_name
# ------------------------------------------------------------------
# Model call spans
# ------------------------------------------------------------------
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
) -> ModelResponse | AIMessage | Any:
if not ot.is_enabled():
return await handler(request)
model_id, provider = _resolve_model_attrs(request)
t0 = time.perf_counter()
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
_annotate_model_request(sp, model_id=model_id, provider=provider)
try:
result = await handler(request)
except Exception:
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
# span context manager records + re-raises
raise
else:
input_tokens, output_tokens = _annotate_model_response(
sp,
result,
model_id=model_id,
provider=provider,
)
ot_metrics.record_model_call_duration(
(time.perf_counter() - t0) * 1000,
model=model_id,
provider=provider,
)
ot_metrics.record_model_token_usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
model=model_id,
provider=provider,
)
return result
# ------------------------------------------------------------------
# Tool call spans
# ------------------------------------------------------------------
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not ot.is_enabled():
return await handler(request)
tool_name = _resolve_tool_name(request)
input_size = _resolve_input_size(request)
t0 = time.perf_counter()
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
try:
result = await handler(request)
except Exception:
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
ot_metrics.record_tool_call_error(tool_name=tool_name)
raise
errored = _annotate_tool_result(sp, result)
ot_metrics.record_tool_call_duration(
(time.perf_counter() - t0) * 1000,
tool_name=tool_name,
)
if errored:
ot_metrics.record_tool_call_error(tool_name=tool_name)
return result
# ---------------------------------------------------------------------------
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
# a real model/tool call).
# ---------------------------------------------------------------------------
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
model_id: str | None = None
provider: str | None = None
try:
model = getattr(request, "model", None)
if model is None:
return None, None
# langchain BaseChatModel exposes a few different identifiers
for attr in ("model_name", "model", "model_id"):
value = getattr(model, attr, None)
if value:
model_id = str(value)
break
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
for attr in ("provider", "_llm_type"):
value = getattr(model, attr, None)
if value:
provider = str(value)
break
except Exception: # pragma: no cover — defensive
pass
return model_id, provider
def _resolve_tool_name(request: Any) -> str:
try:
tool = getattr(request, "tool", None)
if tool is not None:
name = getattr(tool, "name", None)
if isinstance(name, str) and name:
return name
# Fall back to the tool_call dict
call = getattr(request, "tool_call", None) or {}
name = call.get("name") if isinstance(call, dict) else None
if isinstance(name, str) and name:
return name
except Exception: # pragma: no cover — defensive
pass
return "unknown"
def _resolve_input_size(request: Any) -> int | None:
try:
call = getattr(request, "tool_call", None)
if not isinstance(call, dict) or not call:
return None
args = call.get("args")
if args is None:
return None
return len(repr(args))
except Exception: # pragma: no cover — defensive
return None
def _annotate_model_request(
span: Any, *, model_id: str | None, provider: str | None
) -> None:
try:
span.set_attribute("gen_ai.operation.name", "chat")
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
if provider:
span.set_attribute("gen_ai.provider.name", provider)
except Exception: # pragma: no cover — defensive
pass
def _annotate_model_response(
span: Any,
result: Any,
*,
model_id: str | None = None,
provider: str | None = None,
) -> tuple[int | None, int | None]:
"""Best-effort: attach prompt/completion token counts when available."""
input_tokens: int | None = None
output_tokens: int | None = None
try:
# ModelResponse may be a dataclass with .result containing AIMessage
msg: Any
if isinstance(result, AIMessage):
msg = result
else:
inner = getattr(result, "result", None)
msg = inner[-1] if isinstance(inner, list) and inner else inner
if msg is None:
return None, None
if provider:
span.set_attribute("gen_ai.provider.name", provider)
if model_id:
span.set_attribute("gen_ai.request.model", model_id)
response_model = getattr(msg, "response_metadata", {}) or {}
if isinstance(response_model, dict):
response_model = (
response_model.get("model_name")
or response_model.get("model")
or response_model.get("model_id")
)
if not response_model:
response_model = model_id
if response_model:
span.set_attribute("gen_ai.response.model", str(response_model))
span.set_attribute("gen_ai.operation.name", "chat")
usage = getattr(msg, "usage_metadata", None) or {}
if isinstance(usage, dict):
if (n := usage.get("input_tokens")) is not None:
input_tokens = int(n)
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
if (n := usage.get("output_tokens")) is not None:
output_tokens = int(n)
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
if (n := usage.get("total_tokens")) is not None:
span.set_attribute("gen_ai.usage.total_tokens", int(n))
tool_calls = getattr(msg, "tool_calls", None) or []
span.set_attribute("model.tool_calls", len(tool_calls))
except Exception: # pragma: no cover — defensive
pass
return input_tokens, output_tokens
def _annotate_tool_result(span: Any, result: Any) -> bool:
errored = False
try:
if isinstance(result, ToolMessage):
content = (
result.content
if isinstance(result.content, str)
else repr(result.content)
)
span.set_attribute("tool.output.size", len(content))
status = getattr(result, "status", None)
if isinstance(status, str):
span.set_attribute("tool.status", status)
errored = status.lower() == "error"
kwargs = getattr(result, "additional_kwargs", None) or {}
if isinstance(kwargs, dict) and kwargs.get("error"):
span.set_attribute("tool.error", True)
errored = True
except Exception: # pragma: no cover — defensive
pass
return errored
__all__ = ["OtelSpanMiddleware"]

View file

@ -0,0 +1,49 @@
"""Tail-of-stack plugin slot driven by env allowlist."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from app.db import ChatVisibility
from ..plugins.loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
def build_plugin_middlewares(
*,
flags: AgentFeatureFlags,
search_space_id: int,
user_id: str | None,
visibility: ChatVisibility,
llm: BaseChatModel,
) -> list[Any]:
if not enabled(flags, "enable_plugin_loader"):
return []
try:
allowed_names = load_allowed_plugin_names_from_env()
if not allowed_names:
return []
return load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
return []

View file

@ -0,0 +1,36 @@
"""Skill discovery + injection."""
from __future__ import annotations
import logging
from deepagents.middleware.skills import SkillsMiddleware
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from ..skills.backends import build_skills_backend_factory, default_skills_sources
def build_skills_mw(
*,
flags: AgentFeatureFlags,
filesystem_mode: FilesystemMode,
search_space_id: int,
) -> SkillsMiddleware | None:
if not enabled(flags, "enable_skills"):
return None
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
return SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
return None

View file

@ -0,0 +1,224 @@
"""Main-agent middleware list assembly: one line per slot.
The main agent is a pure router filesystem reads/writes are owned by the
``knowledge_base`` subagent and delegated via the ``task`` tool. The stack
here only renders KB context (workspace tree + priority docs), projects it
into system messages, and commits any subagent-side staged writes at end of
turn (cloud mode).
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents import SubAgent
from deepagents.backends import StateBackend
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.multi_agent_chat.shared.middleware.anthropic_cache import (
build_anthropic_cache_mw,
)
from app.agents.chat.multi_agent_chat.shared.middleware.compaction import (
build_compaction_mw,
)
from app.agents.chat.multi_agent_chat.shared.middleware.kb_context_projection import (
build_kb_context_projection_mw,
)
from app.agents.chat.multi_agent_chat.shared.middleware.memory import build_memory_mw
from app.agents.chat.multi_agent_chat.shared.middleware.patch_tool_calls import (
build_patch_tool_calls_mw,
)
from app.agents.chat.multi_agent_chat.shared.middleware.permissions import (
build_permission_mw,
)
from app.agents.chat.multi_agent_chat.shared.middleware.resilience import (
build_resilience_middlewares,
)
from app.agents.chat.multi_agent_chat.shared.middleware.todos import build_todos_mw
from app.agents.chat.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
READONLY_NAME as KB_READONLY_NAME,
build_readonly_subagent as build_kb_readonly_subagent,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
build_ask_knowledge_base_tool,
)
from app.agents.chat.multi_agent_chat.subagents.shared.middleware.middleware_stack import (
build_subagent_middleware_stack,
)
from app.db import ChatVisibility
from .action_log import build_action_log_mw
from .anonymous_document import build_anonymous_doc_mw
from .busy_mutex import build_busy_mutex_mw
from .checkpointed_subagent_middleware import (
SurfSenseCheckpointedSubAgentMiddleware,
)
from .checkpointed_subagent_middleware.task_description import (
TASK_TOOL_DESCRIPTION,
)
from .context_editing import build_context_editing_mw
from .dedup_hitl import build_dedup_hitl_mw
from .doom_loop import build_doom_loop_mw
from .kb_persistence import build_kb_persistence_mw
from .knowledge_priority import build_knowledge_priority_mw
from .knowledge_tree import build_knowledge_tree_mw
from .noop_injection import build_noop_injection_mw
from .otel_span import build_otel_mw
from .plugins import build_plugin_middlewares
from .skills import build_skills_mw
from .tool_call_repair import build_repair_mw
def build_main_agent_deepagent_middleware(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
subagent_dependencies: dict[str, Any],
checkpointer: Checkpointer,
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
resilience = build_resilience_middlewares(flags)
memory_mw = build_memory_mw(
user_id=user_id,
search_space_id=search_space_id,
visibility=visibility,
)
subagent_dependencies = {
**subagent_dependencies,
"backend_resolver": backend_resolver,
"filesystem_mode": filesystem_mode,
"flags": flags,
}
shared_subagent_middleware = build_subagent_middleware_stack(
resilience=resilience,
flags=flags,
)
kb_readonly = build_kb_readonly_subagent(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
)
kb_readonly_spec = kb_readonly.spec
kb_readonly_runnable = create_agent(
llm,
system_prompt=kb_readonly_spec["system_prompt"],
tools=kb_readonly_spec["tools"],
middleware=kb_readonly_spec["middleware"],
name=KB_READONLY_NAME,
checkpointer=checkpointer,
)
ask_kb_tool = build_ask_knowledge_base_tool(kb_readonly_runnable)
subagents: list[SubAgent] = build_subagents(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
disabled_tools=disabled_tools,
ask_kb_tool=ask_kb_tool,
)
logging.debug("Subagents registry: %s", [s["name"] for s in subagents])
stack: list[Any] = [
build_busy_mutex_mw(flags),
build_otel_mw(flags),
build_todos_mw(system_prompt=""),
memory_mw,
build_anonymous_doc_mw(
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
),
build_knowledge_tree_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
llm=llm,
),
build_knowledge_priority_mw(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
build_kb_context_projection_mw(),
build_kb_persistence_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_skills_mw(
flags=flags,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
),
SurfSenseCheckpointedSubAgentMiddleware(
checkpointer=checkpointer,
backend=StateBackend,
subagents=subagents,
system_prompt=None,
task_description=TASK_TOOL_DESCRIPTION,
search_space_id=search_space_id,
),
resilience.model_call_limit,
resilience.tool_call_limit,
build_context_editing_mw(
flags=flags,
max_input_tokens=max_input_tokens,
tools=tools,
backend_resolver=backend_resolver,
),
build_compaction_mw(llm),
build_noop_injection_mw(flags),
resilience.retry,
resilience.fallback,
build_repair_mw(flags=flags, tools=tools),
build_permission_mw(flags=flags),
build_doom_loop_mw(flags),
build_action_log_mw(
flags=flags,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
),
build_patch_tool_calls_mw(),
build_dedup_hitl_mw(tools),
*build_plugin_middlewares(
flags=flags,
search_space_id=search_space_id,
user_id=user_id,
visibility=visibility,
llm=llm,
),
build_anthropic_cache_mw(),
]
return [m for m in stack if m is not None]

View file

@ -0,0 +1,9 @@
"""Tool-call-repair middleware: fix miscased/unknown tool names (impl + builder)."""
from .builder import build_repair_mw
from .middleware import ToolCallNameRepairMiddleware
__all__ = [
"ToolCallNameRepairMiddleware",
"build_repair_mw",
]

View file

@ -0,0 +1,50 @@
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.middleware.flags import enabled
from .middleware import ToolCallNameRepairMiddleware
# deepagents-built-in tool names the repair pass treats as known.
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
{
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
)
def build_repair_mw(
*,
flags: AgentFeatureFlags,
tools: Sequence[BaseTool],
) -> ToolCallNameRepairMiddleware | None:
if not enabled(flags, "enable_tool_call_repair"):
return None
registered_names: set[str] = {t.name for t in tools}
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
return ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)

View file

@ -0,0 +1,197 @@
"""
ToolCallNameRepairMiddleware two-stage tool-name repair.
Operation:
1. **Stage 1 lowercase repair:** if a tool call's ``name`` is not in
the registry but ``name.lower()`` is, rewrite in place. Catches
models that emit ``Search`` instead of ``search``.
2. **Stage 2 invalid fallback:** if still unmatched, rewrite the call
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
so the registered :func:`invalid_tool` returns the error to the model
for self-correction.
Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358``
+ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent:
:class:`deepagents.middleware.PatchToolCallsMiddleware` patches
*dangling* tool calls (no matching ToolMessage) but does nothing about
wrong names, and the model framework's default behavior on an unknown
name is to crash the turn rather than route to a self-correction
fallback.
"""
from __future__ import annotations
import difflib
import logging
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
)
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
"""Normalize a tool call entry to a mutable dict."""
if isinstance(call, dict):
return dict(call)
return {
"name": getattr(call, "name", None),
"args": getattr(call, "args", {}),
"id": getattr(call, "id", None),
"type": "tool_call",
}
class ToolCallNameRepairMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""Two-stage tool-name repair on the most recent ``AIMessage``.
Args:
registered_tool_names: Set of canonically-registered tool names.
``invalid`` should be in this set so the fallback dispatches.
fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the
fuzzy-match step that runs *between* lowercase and invalid.
Set to ``None`` to disable fuzzy matching (default in
OpenCode; we mirror that to avoid silent rewrites).
"""
def __init__(
self,
*,
registered_tool_names: set[str],
fuzzy_match_threshold: float | None = 0.85,
) -> None:
super().__init__()
self._registered = set(registered_tool_names)
self._registered_lower = {name.lower(): name for name in self._registered}
self._fuzzy_threshold = fuzzy_match_threshold
self.tools = []
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
if isinstance(ctx_tools, set | frozenset):
return self._registered | set(ctx_tools)
if isinstance(ctx_tools, list | tuple):
return self._registered | set(ctx_tools)
return self._registered
def _repair_one(
self,
call: dict[str, Any],
registered: set[str],
) -> dict[str, Any]:
name = call.get("name")
if not isinstance(name, str):
return call
if name in registered:
return call
# Stage 1 — lowercase
lowered = name.lower()
if lowered in registered:
call["name"] = lowered
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", "lowercase")
call["response_metadata"] = metadata
return call
# Optional fuzzy step (off by default — see class docstring)
if self._fuzzy_threshold is not None:
close = difflib.get_close_matches(
name, registered, n=1, cutoff=self._fuzzy_threshold
)
if close:
call["name"] = close[0]
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
call["response_metadata"] = metadata
return call
# Stage 2 — invalid fallback
# Local import keeps the middleware module import-light and avoids any
# tools <-> middleware import-order coupling at module scope.
from app.agents.chat.multi_agent_chat.main_agent.tools.invalid_tool import (
INVALID_TOOL_NAME,
)
if INVALID_TOOL_NAME in registered:
original_args = call.get("args") or {}
error_msg = (
f"Tool name '{name}' is not registered. "
f"Original arguments were: {original_args!r}."
)
call["name"] = INVALID_TOOL_NAME
call["args"] = {"tool": name, "error": error_msg}
metadata = dict(call.get("response_metadata") or {})
metadata.setdefault("repair", f"invalid_fallback:{name}")
call["response_metadata"] = metadata
else:
logger.warning(
"Could not repair unknown tool call %r; 'invalid' tool not registered",
name,
)
return call
def _maybe_repair(
self,
message: AIMessage,
registered: set[str],
) -> AIMessage | None:
if not message.tool_calls:
return None
new_calls: list[dict[str, Any]] = []
any_changed = False
for raw in message.tool_calls:
call = _coerce_existing_tool_call(raw)
before = (call.get("name"), call.get("args"))
repaired = self._repair_one(call, registered)
after = (repaired.get("name"), repaired.get("args"))
if before != after:
any_changed = True
new_calls.append(repaired)
if not any_changed:
return None
return message.model_copy(update={"tool_calls": new_calls})
def after_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
messages = state.get("messages") or []
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
registered = self._registered_for_runtime(runtime)
repaired = self._maybe_repair(last, registered)
if repaired is None:
return None
return {"messages": [repaired]}
async def aafter_model( # type: ignore[override]
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
return self.after_model(state, runtime)
__all__ = [
"ToolCallNameRepairMiddleware",
]

View file

@ -0,0 +1,6 @@
"""Reference plugins bundled with SurfSense.
These plugins are intentionally small and demonstrative. They are NOT
auto-loaded they ship as examples that a deployment can opt into via
``global_llm_config.yaml`` or ``SURFSENSE_ALLOWED_PLUGINS``.
"""

View file

@ -0,0 +1,158 @@
"""Entry-point based plugin loader for SurfSense agent middleware.
LangChain's :class:`AgentMiddleware` ABC already covers the practical
surface most plugins need (``before_agent`` / ``before_model`` /
``wrap_tool_call`` / their async counterparts), so a SurfSense-specific
plugin protocol would be redundant. We just need a way to discover and
admit third-party middleware safely.
A plugin is therefore just an installable Python package that registers a
factory callable under the ``surfsense.plugins`` entry-point group:
.. code-block:: toml
# in a plugin package's pyproject.toml
[project.entry-points."surfsense.plugins"]
year_substituter = "my_plugin:make_middleware"
The factory has the signature ``Callable[[PluginContext], AgentMiddleware]``.
It receives a small, sanitized :class:`PluginContext` with the IDs and the
LLM the plugin is allowed to talk to and **never** raw secrets, DB
sessions, or other connectors.
## Trust model
Plugins are loaded **only if** their entry-point ``name`` appears in
``allowed_plugins`` (admin-controlled, sourced from
``global_llm_config.yaml`` or :func:`load_allowed_plugin_names_from_env`).
There is **no env-driven auto-load**. A plugin failure is logged and
isolated; it does not break agent construction.
"""
from __future__ import annotations
import logging
import os
from collections.abc import Iterable
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
from langchain.agents.middleware import AgentMiddleware
if TYPE_CHECKING: # pragma: no cover - type-only
from langchain_core.language_models import BaseChatModel
from app.db import ChatVisibility
logger = logging.getLogger(__name__)
PLUGIN_ENTRY_POINT_GROUP = "surfsense.plugins"
class PluginContext(dict):
"""Sanitized DI bag handed to each plugin factory.
Backed by ``dict`` so plugins can inspect the keys they care about
without coupling to a concrete dataclass shape. Required keys:
* ``search_space_id`` (int)
* ``user_id`` (str | None)
* ``thread_visibility`` (:class:`app.db.ChatVisibility`)
* ``llm`` (:class:`langchain_core.language_models.BaseChatModel`)
The context **never** carries DB sessions, raw secrets, or other
connectors. If a future plugin genuinely needs DB access, that
integration goes through a rate-limited service interface, not
through this bag.
"""
@classmethod
def build(
cls,
*,
search_space_id: int,
user_id: str | None,
thread_visibility: ChatVisibility,
llm: BaseChatModel,
) -> PluginContext:
return cls(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=thread_visibility,
llm=llm,
)
def load_plugin_middlewares(
ctx: PluginContext,
allowed_plugin_names: Iterable[str],
) -> list[AgentMiddleware]:
"""Discover, allowlist-filter, and instantiate plugin middleware.
For each entry-point in :data:`PLUGIN_ENTRY_POINT_GROUP` whose name is
in ``allowed_plugin_names``, load the factory and call it with ``ctx``.
The factory's return value must be an :class:`AgentMiddleware` instance;
anything else is logged and skipped.
Errors are isolated a plugin that raises during ``ep.load()`` or
factory invocation is logged at ``ERROR`` and ignored. Agent
construction continues with whatever plugins did succeed.
"""
allowed = {name for name in allowed_plugin_names if name}
if not allowed:
return []
out: list[AgentMiddleware] = []
try:
eps = entry_points(group=PLUGIN_ENTRY_POINT_GROUP)
except Exception: # pragma: no cover - defensive (entry_points is robust)
logger.exception("Failed to enumerate plugin entry points")
return []
for ep in eps:
if ep.name not in allowed:
logger.info("Skipping non-allowlisted plugin %s", ep.name)
continue
try:
factory = ep.load()
except Exception:
logger.exception("Failed to load plugin %s", ep.name)
continue
try:
mw = factory(ctx)
except Exception:
logger.exception("Plugin %s factory raised", ep.name)
continue
if not isinstance(mw, AgentMiddleware):
logger.warning(
"Plugin %s returned %s, expected AgentMiddleware; skipping",
ep.name,
type(mw).__name__,
)
continue
out.append(mw)
logger.info("Loaded plugin %s as %s", ep.name, type(mw).__name__)
return out
def load_allowed_plugin_names_from_env() -> set[str]:
"""Read ``SURFSENSE_ALLOWED_PLUGINS`` (comma-separated) into a set.
Provided as a thin convenience for deployments that don't surface plugins
through ``global_llm_config.yaml`` yet. Whitespace is stripped and empty
entries are dropped.
"""
raw = os.environ.get("SURFSENSE_ALLOWED_PLUGINS", "").strip()
if not raw:
return set()
return {token.strip() for token in raw.split(",") if token.strip()}
__all__ = [
"PLUGIN_ENTRY_POINT_GROUP",
"PluginContext",
"load_allowed_plugin_names_from_env",
"load_plugin_middlewares",
]

View file

@ -0,0 +1,88 @@
"""Reference plugin: substitute ``{{year}}`` in tool descriptions.
Demonstrates the :meth:`AgentMiddleware.awrap_tool_call` hook -- the
plugin sees every tool invocation and can rewrite the request *or* the
result. This particular plugin is read-only and only transforms the
*description* the user might see in error messages (no request
mutation).
The plugin is built as a factory function so the entry-point loader can
inject :class:`PluginContext` (containing the agent's LLM, search-space
ID, etc.). The factory signature
``Callable[[PluginContext], AgentMiddleware]`` is the only contract --
SurfSense doesn't define a custom plugin protocol on top of LangChain's
:class:`AgentMiddleware`.
Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't
need this -- it's already on the import path)::
[project.entry-points."surfsense.plugins"]
year_substituter = "app.agents.chat.multi_agent_chat.main_agent.plugins.year_substituter:make_middleware"
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
if TYPE_CHECKING: # pragma: no cover - type-only
from langchain.agents.middleware.types import ToolCallRequest
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from .loader import PluginContext
logger = logging.getLogger(__name__)
class _YearSubstituterMiddleware(AgentMiddleware):
"""Replace ``{{year}}`` in the result text with the current UTC year."""
tools = ()
def __init__(self, year: int | None = None) -> None:
super().__init__()
self._year = str(year if year is not None else datetime.now(UTC).year)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
result = await handler(request)
try:
from langchain_core.messages import ToolMessage
if (
isinstance(result, ToolMessage)
and isinstance(result.content, str)
and "{{year}}" in result.content
):
new_text = result.content.replace("{{year}}", self._year)
result = ToolMessage(
content=new_text,
tool_call_id=result.tool_call_id,
id=result.id,
name=result.name,
status=result.status,
artifact=result.artifact,
)
except Exception: # pragma: no cover - defensive
logger.exception("year_substituter plugin failed; passing original result")
return result
def make_middleware(ctx: PluginContext) -> AgentMiddleware:
"""Plugin factory used by :func:`load_plugin_middlewares`."""
# Plugin is intentionally small so it has no state to threading-protect
# and ignores ``ctx`` beyond demonstrating that the loader passes it in.
_ = ctx
return _YearSubstituterMiddleware()
__all__ = ["make_middleware"]

View file

@ -0,0 +1,7 @@
"""Async factory: wiring tools, prompts, MCP buckets, then graph compile."""
from __future__ import annotations
from .factory import create_multi_agent_chat_deep_agent
__all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -0,0 +1,121 @@
"""Compiled agent graph caching for the multi-agent path."""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
from .agent_cache_store import (
flags_signature,
get_cache,
stable_hash,
system_prompt_hash,
tools_signature,
)
def mcp_signature(mcp_tools_by_agent: dict[str, list[BaseTool]]) -> str:
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
rows = []
for agent_name in sorted(mcp_tools_by_agent.keys()):
names = sorted(
getattr(t, "name", "") or "" for t in mcp_tools_by_agent[agent_name]
)
rows.append((agent_name, names))
return stable_hash(rows)
async def build_agent_with_cache(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
final_system_prompt: str,
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str],
available_document_types: list[str],
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
checkpointer: Checkpointer,
subagent_dependencies: dict[str, Any],
mcp_tools_by_agent: dict[str, list[BaseTool]],
disabled_tools: list[str] | None,
config_id: str | None,
image_generation_config_id_override: int | None = None,
) -> Any:
"""Compile the multi-agent graph, serving from cache when key components are stable."""
async def _build() -> Any:
return await asyncio.to_thread(
build_compiled_agent_graph_sync,
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=max_input_tokens,
flags=flags,
checkpointer=checkpointer,
subagent_dependencies=subagent_dependencies,
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
)
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
return await _build()
# Every per-request value any middleware closes over at __init__ must be in
# the key, otherwise a hit will leak state across threads. Bump the schema
# version when the component list changes shape.
cache_key = stable_hash(
"multi-agent-v2",
config_id,
thread_id,
user_id,
search_space_id,
visibility,
filesystem_mode,
anon_session_id,
tools_signature(
tools,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
mcp_signature(mcp_tools_by_agent),
flags_signature(flags),
system_prompt_hash(final_system_prompt),
max_input_tokens,
sorted(disabled_tools) if disabled_tools else None,
# Bound into the generate_image subagent tool at construction time, so it
# must key the compiled-agent cache to avoid leaking one automation's
# image model into another with the same config_id/search_space.
image_generation_config_id_override,
)
return await get_cache().get_or_build(cache_key, builder=_build)
__all__ = ["build_agent_with_cache", "mcp_signature"]

View file

@ -0,0 +1,356 @@
"""TTL-LRU cache for compiled SurfSense deep agents.
Why this exists
---------------
``create_surfsense_deep_agent`` runs a 4-5 second pipeline on EVERY chat
turn:
1. Discover connectors & document types from Postgres (~50-200ms)
2. Build the tool list (built-in + MCP) (~200ms-1.7s)
3. Compose the system prompt
4. Construct ~15 middleware instances (CPU)
5. Eagerly compile the general-purpose subagent
(``SubAgentMiddleware.__init__`` calls ``create_agent`` synchronously,
which builds a second LangGraph + Pydantic schemas ~1.5-2s of pure
CPU work)
6. Compile the outer LangGraph
For a single thread, all six steps produce the SAME object on every turn
unless the user has changed their LLM config, toggled a feature flag,
added a connector, etc. The right answer is to compile ONCE per
"agent shape" and reuse the resulting :class:`CompiledStateGraph` for
every subsequent turn on the same thread.
Why a per-thread key (not a global pool)
----------------------------------------
Most middleware in the SurfSense stack captures per-thread state in
``__init__`` closures (``thread_id``, ``user_id``, ``search_space_id``,
``filesystem_mode``, ``mentioned_document_ids``). Cross-thread reuse
would silently leak state across users and threads. Keying the cache on
``(llm_config_id, thread_id, ...)`` gives us safe reuse for repeated
turns on the same thread without changing any middleware's behavior.
Phase 2 will move those captured fields onto :class:`SurfSenseContextSchema`
(read via ``runtime.context``) so the cache can collapse to a single
``(llm_config_id, search_space_id, ...)`` key shared across threads. Until
then, per-thread keying is the only safe option.
Cache shape
-----------
* TTL-LRU: entries auto-expire after ``ttl_seconds`` (default 1800s, 30
minutes matches a typical chat session). ``maxsize`` (default 256)
caps memory; LRU evicts least-recently-used on overflow.
* In-flight de-duplication: per-key :class:`asyncio.Lock` so concurrent
cold misses on the same key wait for the first build instead of
building N times.
* Process-local: this is an in-memory cache. Multi-replica deployments
pay the build cost once per replica per key. That's fine; the working
set per replica is small (one entry per active thread on that replica).
Telemetry
---------
Every lookup logs ``[agent_cache]`` lines through ``surfsense.perf``:
* ``hit`` cache hit, microseconds-fast
* ``miss`` first build for this key, includes build duration
* ``stale`` entry was found but expired; rebuilt
* ``evict`` LRU eviction (size-limited)
* ``size`` current cache occupancy at lookup time
"""
from __future__ import annotations
import asyncio
import hashlib
import logging
import os
import time
from collections import OrderedDict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# ---------------------------------------------------------------------------
# Public API: signature helpers (cache key components)
# ---------------------------------------------------------------------------
def stable_hash(*parts: Any) -> str:
"""Compute a deterministic SHA1 of the str repr of ``parts``.
Used for cache key components that need a fixed-width representation
(system prompt, tool list, etc.). SHA1 is fine here this is not a
security boundary, just a content fingerprint.
"""
h = hashlib.sha1(usedforsecurity=False)
for p in parts:
h.update(repr(p).encode("utf-8", errors="replace"))
h.update(b"\x1f") # ASCII unit separator between parts
return h.hexdigest()
def tools_signature(
tools: list[Any] | tuple[Any, ...],
*,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> str:
"""Hash the bound-tool surface for cache-key purposes.
The signature changes whenever:
* A tool is added or removed from the bound list (built-in toggles,
MCP tools loaded for the user changes, gating rules flip, etc.).
* The available connectors / document types for the search space
change (new connector added, last connector removed, new document
type indexed). Connector gating derives disabled tools from
``available_connectors``, so the tool surface is technically already
covered but we hash the connector list separately so an empty-list
"no tools changed" situation still rotates the key when, say, the user
re-adds a connector that gates a tool we were already not exposing.
Stays stable across:
* Process restarts (tool names + descriptions are static).
* Different replicas (everyone gets the same hash for the same
inputs).
"""
tool_descriptors = sorted(
(getattr(t, "name", repr(t)), getattr(t, "description", "")) for t in tools
)
connectors = sorted(available_connectors or [])
doc_types = sorted(available_document_types or [])
return stable_hash(tool_descriptors, connectors, doc_types)
def flags_signature(flags: Any) -> str:
"""Hash the resolved :class:`AgentFeatureFlags` dataclass.
Frozen dataclasses are deterministically reprable, so a SHA1 of their
repr is a stable fingerprint. Restart safe (flags are read once at
process boot).
"""
return stable_hash(repr(flags))
def system_prompt_hash(system_prompt: str) -> str:
"""Hash a system prompt string. Cheap, ~30µs for typical prompts."""
return hashlib.sha1(
system_prompt.encode("utf-8", errors="replace"),
usedforsecurity=False,
).hexdigest()
# ---------------------------------------------------------------------------
# Cache implementation
# ---------------------------------------------------------------------------
@dataclass
class _Entry:
value: Any
created_at: float
last_used_at: float
class _AgentCache:
"""In-process TTL-LRU cache with per-key in-flight de-duplication.
NOT THREAD-SAFE in the multithreading sense designed for a single
asyncio event loop. Uvicorn runs one event loop per worker process,
so this is fine; multi-worker deployments simply each maintain their
own cache.
"""
def __init__(self, *, maxsize: int, ttl_seconds: float) -> None:
self._maxsize = maxsize
self._ttl = ttl_seconds
self._entries: OrderedDict[str, _Entry] = OrderedDict()
# One lock per key — guards "build" so concurrent cold misses on
# the same key wait for the first build instead of all racing.
self._locks: dict[str, asyncio.Lock] = {}
def _now(self) -> float:
return time.monotonic()
def _is_fresh(self, entry: _Entry) -> bool:
return (self._now() - entry.created_at) < self._ttl
def _evict_if_full(self) -> None:
while len(self._entries) >= self._maxsize:
evicted_key, _ = self._entries.popitem(last=False)
self._locks.pop(evicted_key, None)
_perf_log.info(
"[agent_cache] evict key=%s reason=lru size=%d",
_short(evicted_key),
len(self._entries),
)
def _touch(self, key: str, entry: _Entry) -> None:
entry.last_used_at = self._now()
self._entries.move_to_end(key, last=True)
async def get_or_build(
self,
key: str,
*,
builder: Callable[[], Awaitable[Any]],
) -> Any:
"""Return the cached value for ``key`` or call ``builder()`` to make it.
``builder`` MUST be idempotent concurrent cold misses on the
same key collapse to a single ``builder()`` call (the others
wait on the in-flight lock and observe the populated entry on
wake).
"""
# Fast path: hot hit.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
# Stale entry — drop it; rebuild below.
if entry is not None and not self._is_fresh(entry):
_perf_log.info(
"[agent_cache] stale key=%s age=%.1fs ttl=%.0fs",
_short(key),
self._now() - entry.created_at,
self._ttl,
)
self._entries.pop(key, None)
# Slow path: serialize concurrent misses for the same key.
lock = self._locks.setdefault(key, asyncio.Lock())
async with lock:
# Double-check after acquiring the lock — another waiter may
# have populated the entry while we slept.
entry = self._entries.get(key)
if entry is not None and self._is_fresh(entry):
self._touch(key, entry)
_perf_log.info(
"[agent_cache] hit key=%s age=%.1fs size=%d coalesced=true",
_short(key),
self._now() - entry.created_at,
len(self._entries),
)
return entry.value
t0 = time.perf_counter()
try:
value = await builder()
except BaseException:
# Don't cache failed builds; let the next caller retry.
_perf_log.warning(
"[agent_cache] build_failed key=%s elapsed=%.3fs",
_short(key),
time.perf_counter() - t0,
)
raise
elapsed = time.perf_counter() - t0
# Insert + evict.
self._evict_if_full()
now = self._now()
self._entries[key] = _Entry(value=value, created_at=now, last_used_at=now)
self._entries.move_to_end(key, last=True)
_perf_log.info(
"[agent_cache] miss key=%s build=%.3fs size=%d",
_short(key),
elapsed,
len(self._entries),
)
return value
def invalidate(self, key: str) -> bool:
"""Drop a single entry; return True if anything was removed."""
removed = self._entries.pop(key, None) is not None
self._locks.pop(key, None)
if removed:
_perf_log.info(
"[agent_cache] invalidate key=%s size=%d",
_short(key),
len(self._entries),
)
return removed
def invalidate_prefix(self, prefix: str) -> int:
"""Drop every entry whose key starts with ``prefix``. Returns count."""
keys = [k for k in self._entries if k.startswith(prefix)]
for k in keys:
self._entries.pop(k, None)
self._locks.pop(k, None)
if keys:
_perf_log.info(
"[agent_cache] invalidate_prefix prefix=%s removed=%d size=%d",
_short(prefix),
len(keys),
len(self._entries),
)
return len(keys)
def clear(self) -> None:
n = len(self._entries)
self._entries.clear()
self._locks.clear()
if n:
_perf_log.info("[agent_cache] clear removed=%d", n)
def stats(self) -> dict[str, Any]:
return {
"size": len(self._entries),
"maxsize": self._maxsize,
"ttl_seconds": self._ttl,
}
def _short(key: str, n: int = 16) -> str:
"""Truncate keys for log lines so they don't blow up log volume."""
return key if len(key) <= n else f"{key[:n]}..."
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
def get_cache() -> _AgentCache:
"""Return the process-wide compiled-agent cache singleton."""
return _cache
def reload_for_tests(*, maxsize: int = 256, ttl_seconds: float = 1800.0) -> _AgentCache:
"""Replace the singleton with a fresh cache. Tests only."""
global _cache
_cache = _AgentCache(maxsize=maxsize, ttl_seconds=ttl_seconds)
return _cache
__all__ = [
"flags_signature",
"get_cache",
"reload_for_tests",
"stable_hash",
"system_prompt_hash",
"tools_signature",
]

View file

@ -0,0 +1,100 @@
"""Map configured connectors to the searchable document/connector types.
This is agent-agnostic infrastructure shared by every agent factory (single-
and multi-agent). It translates the connectors a search space has enabled into
the set of searchable type strings that pre-search middleware and ``web_search``
understand, and always layers in the document types that exist independently of
any connector (uploads, notes, extension captures, YouTube).
It lives in its own module rather than inside a specific agent factory so
that retiring or moving any single agent never disturbs the others' access to
this mapping.
"""
from __future__ import annotations
from typing import Any
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
# used by pre-search middleware and web_search.
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
# the web_search tool; all others are considered local/indexed data.
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
# Live search connectors (handled by web_search tool)
"TAVILY_API": "TAVILY_API",
"LINKUP_API": "LINKUP_API",
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
# Local/indexed connectors (handled by KB pre-search middleware)
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
"GITHUB_CONNECTOR": "GITHUB_CONNECTOR",
"LINEAR_CONNECTOR": "LINEAR_CONNECTOR",
"DISCORD_CONNECTOR": "DISCORD_CONNECTOR",
"JIRA_CONNECTOR": "JIRA_CONNECTOR",
"CONFLUENCE_CONNECTOR": "CONFLUENCE_CONNECTOR",
"CLICKUP_CONNECTOR": "CLICKUP_CONNECTOR",
"GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
"GOOGLE_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
"GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE", # Connector type differs from document type
"AIRTABLE_CONNECTOR": "AIRTABLE_CONNECTOR",
"LUMA_CONNECTOR": "LUMA_CONNECTOR",
"ELASTICSEARCH_CONNECTOR": "ELASTICSEARCH_CONNECTOR",
"WEBCRAWLER_CONNECTOR": "CRAWLED_URL", # Maps to document type
"BOOKSTACK_CONNECTOR": "BOOKSTACK_CONNECTOR",
"CIRCLEBACK_CONNECTOR": "CIRCLEBACK", # Connector type differs from document type
"OBSIDIAN_CONNECTOR": "OBSIDIAN_CONNECTOR",
"DROPBOX_CONNECTOR": "DROPBOX_FILE", # Connector type differs from document type
"ONEDRIVE_CONNECTOR": "ONEDRIVE_FILE", # Connector type differs from document type
# Composio connectors (unified to native document types).
# Reverse of NATIVE_TO_LEGACY_DOCTYPE in app.db.
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": "GOOGLE_DRIVE_FILE",
"COMPOSIO_GMAIL_CONNECTOR": "GOOGLE_GMAIL_CONNECTOR",
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": "GOOGLE_CALENDAR_CONNECTOR",
}
# Document types that don't come from SearchSourceConnector but should always be searchable
_ALWAYS_AVAILABLE_DOC_TYPES: list[str] = [
"EXTENSION", # Browser extension data
"FILE", # Uploaded files
"NOTE", # User notes
"YOUTUBE_VIDEO", # YouTube videos
]
def map_connectors_to_searchable_types(
connector_types: list[Any],
) -> list[str]:
"""
Map SearchSourceConnectorType enums to searchable document/connector types.
This function:
1. Converts connector type enums to their searchable counterparts
2. Includes always-available document types (EXTENSION, FILE, NOTE, YOUTUBE_VIDEO)
3. Deduplicates while preserving order
Args:
connector_types: List of SearchSourceConnectorType enum values
Returns:
List of searchable connector/document type strings
"""
result_set: set[str] = set()
result_list: list[str] = []
# Add always-available document types first
for doc_type in _ALWAYS_AVAILABLE_DOC_TYPES:
if doc_type not in result_set:
result_set.add(doc_type)
result_list.append(doc_type)
# Map each connector type to its searchable equivalent
for ct in connector_types:
# Handle both enum and string types
ct_str = ct.value if hasattr(ct, "value") else str(ct)
searchable = _CONNECTOR_TYPE_TO_SEARCHABLE.get(ct_str)
if searchable and searchable not in result_set:
result_set.add(searchable)
result_list.append(searchable)
return result_list

View file

@ -0,0 +1,320 @@
"""Async factory: tools, system prompt, MCP buckets for subagents, then sync graph compile."""
from __future__ import annotations
import logging
import time
from collections.abc import Sequence
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.multi_agent_chat.shared.feature_flags import (
AgentFeatureFlags,
get_flags,
)
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode,
FilesystemSelection,
)
from app.agents.chat.multi_agent_chat.shared.llm_config import AgentConfig
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.resolver import (
build_backend_resolver,
)
from app.agents.chat.multi_agent_chat.shared.prompt_caching import (
apply_litellm_prompt_caching,
)
from app.agents.chat.multi_agent_chat.subagents import (
get_subagents_to_exclude,
main_prompt_registry_subagent_lines,
)
from app.agents.chat.multi_agent_chat.subagents.mcp_tools.index import (
load_mcp_tools_by_connector,
)
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.services.user_tool_allowlist import (
fetch_user_allowlist_rulesets,
make_trusted_tool_saver,
)
from app.utils.perf import get_perf_logger
from ..system_prompt import build_main_agent_system_prompt
from ..tools import (
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
)
from ..tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
from ..tools.registry import build_main_agent_tools
from .agent_cache import build_agent_with_cache
from .connector_searchable_types import map_connectors_to_searchable_types
_perf_log = get_perf_logger()
async def create_multi_agent_chat_deep_agent(
llm: BaseChatModel,
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
checkpointer: Checkpointer,
user_id: str | None = None,
thread_id: int | None = None,
agent_config: AgentConfig | None = None,
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
filesystem_selection: FilesystemSelection | None = None,
image_generation_config_id: int | None = None,
):
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
``image_generation_config_id`` overrides the search space's image model for
this invocation (used by automations to run on their captured model). When
``None``, the ``generate_image`` tool resolves the live search-space pref.
"""
_t_agent_total = time.perf_counter()
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(
filesystem_selection,
search_space_id=search_space_id
if filesystem_selection.mode == FilesystemMode.CLOUD
else None,
)
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
connector_types = await connector_service.get_available_connectors(
search_space_id
)
available_connectors = map_connectors_to_searchable_types(connector_types)
available_document_types = await connector_service.get_available_document_types(
search_space_id
)
except Exception as e:
logging.warning(
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
e,
)
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
# nothing", which would silently advertise every connector specialist on a flaky
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
if available_connectors is None:
available_connectors = []
if available_document_types is None:
available_document_types = []
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
time.perf_counter() - _t0,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_model_profile = getattr(llm, "profile", None)
_max_input_tokens: int | None = (
_model_profile.get("max_input_tokens")
if isinstance(_model_profile, dict)
else None
)
dependencies: dict[str, Any] = {
"search_space_id": search_space_id,
"db_session": db_session,
"connector_service": connector_service,
"firecrawl_api_key": firecrawl_api_key,
"user_id": user_id,
"thread_id": thread_id,
"thread_visibility": visibility,
"available_connectors": available_connectors,
"available_document_types": available_document_types,
"max_input_tokens": _max_input_tokens,
"llm": llm,
# Per-invocation image model override (automations run on their captured
# model). Reaches the generate_image subagent tool via subagent_dependencies.
"image_generation_config_id_override": image_generation_config_id,
}
_t0 = time.perf_counter()
try:
mcp_tools_by_agent = await load_mcp_tools_by_connector(
db_session, search_space_id
)
except Exception as e:
# Degrade to builtins-only rather than aborting the turn: a transient
# DB or MCP-server hiccup should not deny the user a response.
logging.warning(
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
e,
)
mcp_tools_by_agent = {}
_perf_log.info(
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d agents)",
time.perf_counter() - _t0,
len(mcp_tools_by_agent),
)
# User-scoped allow-list ("Always Allow" persisted to
# ``SearchSourceConnector.config.trusted_tools``). Layered last in each
# subagent's PermissionMiddleware so user ``allow`` overrides coded
# ``ask`` via last-match-wins. Anonymous turns and read failures both
# degrade to "no user rules" rather than blocking the turn.
user_allowlist_by_subagent: dict[str, Any] = {}
trusted_tool_saver = None
if user_id:
try:
import uuid as _uuid
user_uuid = _uuid.UUID(user_id)
except (TypeError, ValueError):
user_uuid = None
if user_uuid is not None:
_t0 = time.perf_counter()
try:
user_allowlist_by_subagent = await fetch_user_allowlist_rulesets(
db_session,
user_id=user_uuid,
search_space_id=search_space_id,
)
except Exception as e:
logging.warning(
"User allow-list fetch failed; subagents will run without user trust rules this turn: %s",
e,
)
user_allowlist_by_subagent = {}
_perf_log.info(
"[create_agent] fetch_user_allowlist_rulesets in %.3fs (%d subagents have rules)",
time.perf_counter() - _t0,
len(user_allowlist_by_subagent),
)
trusted_tool_saver = make_trusted_tool_saver(user_uuid)
dependencies["user_allowlist_by_subagent"] = user_allowlist_by_subagent
dependencies["trusted_tool_saver"] = trusted_tool_saver
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
if "search_knowledge_base" not in modified_disabled_tools:
modified_disabled_tools.append("search_knowledge_base")
if enabled_tools is not None:
main_agent_enabled_tools = [
n for n in enabled_tools if n in MAIN_AGENT_SURFSENSE_TOOL_NAMES
]
else:
main_agent_enabled_tools = list(MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED)
_t0 = time.perf_counter()
# Main agent builds only its own small SurfSense toolset via the SRP
# main-agent registry; connectors/MCP/deliverables are delegated to
# subagents, so no MCP loading or connector construction happens here.
tools = build_main_agent_tools(
dependencies=dependencies,
enabled_tools=main_agent_enabled_tools,
disabled_tools=modified_disabled_tools,
additional_tools=list(additional_tools) if additional_tools else None,
)
_flags: AgentFeatureFlags = get_flags()
if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in {
t.name for t in tools
}:
tools = [*list(tools), invalid_tool]
_perf_log.info(
"[create_agent] build_tools_async in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
_t0 = time.perf_counter()
_enabled_tool_names = {t.name for t in tools}
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
_model_name: str | None = None
prof = getattr(llm, "model_name", None) or getattr(llm, "model", None)
if isinstance(prof, str):
_model_name = prof
_connector_exclude = get_subagents_to_exclude(available_connectors)
_registry_subagent_prompt_lines = main_prompt_registry_subagent_lines(
_connector_exclude
)
if agent_config is not None:
system_prompt = build_main_agent_system_prompt(
today=None,
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
custom_system_instructions=agent_config.system_instructions,
use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled,
model_name=_model_name or getattr(agent_config, "model_name", None),
registry_subagent_prompt_lines=_registry_subagent_prompt_lines,
)
else:
system_prompt = build_main_agent_system_prompt(
thread_visibility=thread_visibility,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
citations_enabled=True,
model_name=_model_name,
registry_subagent_prompt_lines=_registry_subagent_prompt_lines,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
)
final_system_prompt = system_prompt
config_id = agent_config.config_id if agent_config is not None else None
_t0 = time.perf_counter()
agent = await build_agent_with_cache(
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_selection.mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=_max_input_tokens,
flags=_flags,
checkpointer=checkpointer,
subagent_dependencies=dependencies,
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
config_id=config_id,
image_generation_config_id_override=image_generation_config_id,
)
_perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs",
time.perf_counter() - _t0,
)
_perf_log.info(
"[create_agent] Total agent creation in %.3fs",
time.perf_counter() - _t_agent_total,
)
return agent

View file

@ -0,0 +1,7 @@
"""SurfSense built-in agent skills (Anthropic Skills format).
Each subdirectory corresponds to one skill and contains a ``SKILL.md`` file
with YAML frontmatter (name, description, allowed_tools) plus markdown
instructions. The :class:`BuiltinSkillsBackend` exposes them to the
deepagents :class:`SkillsMiddleware`.
"""

View file

@ -0,0 +1,342 @@
"""Skills backends for SurfSense.
Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol`
subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`.
The middleware only needs four methods to load skills from a backend:
* ``ls_info`` / ``als_info`` list directories under a source path.
* ``download_files`` / ``adownload_files`` fetch ``SKILL.md`` bytes.
Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` )
default to ``NotImplementedError`` from the base class. They are never reached
by the skills middleware because skill content is rendered into the system
prompt at agent build time, not edited at runtime.
Two backends are provided:
* :class:`BuiltinSkillsBackend` disk-backed read of bundled skills from
``app/agents/shared/skills/builtin/``.
* :class:`SearchSpaceSkillsBackend` a thin read-only wrapper over
:class:`KBPostgresBackend` that filters notes under the privileged folder
``/documents/_skills/``.
Both backends are intentionally read-only: skill authoring happens out of band
(via filesystem or a search-space-admin route), so we never expose
``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError``
gives a clean failure mode if anything tries.
"""
from __future__ import annotations
import contextlib
import logging
from collections.abc import Callable
from dataclasses import replace
from pathlib import Path
from typing import TYPE_CHECKING
from deepagents.backends.composite import CompositeBackend
from deepagents.backends.protocol import (
BackendProtocol,
FileDownloadResponse,
FileInfo,
)
from deepagents.backends.state import StateBackend
if TYPE_CHECKING:
from langchain.tools import ToolRuntime
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
KBPostgresBackend,
)
logger = logging.getLogger(__name__)
# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE.
_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
def _default_builtin_root() -> Path:
"""Return the absolute path to the bundled builtin skills directory.
Located at ``builtin/`` next to this module (this module lives at
``app/agents/multi_agent_chat/main_agent/skills/backends.py``).
"""
return (Path(__file__).resolve().parent / "builtin").resolve()
class BuiltinSkillsBackend(BackendProtocol):
"""Read-only disk-backed skills source.
Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk,
where each skill is its own subdirectory containing a ``SKILL.md`` file::
<root>/<skill-name>/SKILL.md
The middleware calls :meth:`als_info` with the source path and expects a
``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it
calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and
parses YAML frontmatter from the returned ``content`` bytes.
Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at
prefix ``/skills/builtin/`` means the middleware can issue paths like
``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to
``/kb-research/SKILL.md`` before forwarding here. We treat any leading
slash as anchoring at :attr:`root`.
"""
def __init__(self, root: Path | str | None = None) -> None:
self.root: Path = Path(root).resolve() if root else _default_builtin_root()
if not self.root.exists():
logger.info(
"BuiltinSkillsBackend root %s does not exist; skills will be empty.",
self.root,
)
def _resolve(self, path: str) -> Path:
"""Resolve a virtual posix path under :attr:`root`, refusing escapes."""
bare = path.lstrip("/")
candidate = (self.root / bare).resolve() if bare else self.root
# Refuse symlink/.. traversal that escapes the root.
try:
candidate.relative_to(self.root)
except ValueError as exc:
raise ValueError(f"path {path!r} escapes builtin skills root") from exc
return candidate
def ls_info(self, path: str) -> list[FileInfo]:
try:
target = self._resolve(path)
except ValueError as exc:
logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc)
return []
if not target.exists() or not target.is_dir():
return []
infos: list[FileInfo] = []
# Build virtual paths anchored at "/" because CompositeBackend already
# stripped the route prefix before calling us.
target_virtual = (
"/"
if target == self.root
else ("/" + str(target.relative_to(self.root)).replace("\\", "/"))
)
for child in sorted(target.iterdir()):
if child.name == "__pycache__" or child.name.startswith("."):
continue
child_virtual = (
target_virtual.rstrip("/") + "/" + child.name
if target_virtual != "/"
else "/" + child.name
)
info: FileInfo = {
"path": child_virtual,
"is_dir": child.is_dir(),
}
if child.is_file():
with contextlib.suppress(OSError): # pragma: no cover - defensive
info["size"] = child.stat().st_size
infos.append(info)
return infos
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
responses: list[FileDownloadResponse] = []
for p in paths:
try:
target = self._resolve(p)
except ValueError:
responses.append(FileDownloadResponse(path=p, error="invalid_path"))
continue
if not target.exists():
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
continue
if target.is_dir():
responses.append(FileDownloadResponse(path=p, error="is_directory"))
continue
try:
# Hard cap to avoid loading rogue mega-files into memory.
size = target.stat().st_size
if size > _MAX_SKILL_FILE_SIZE:
logger.warning(
"Builtin skill file %s exceeds %d bytes; truncating.",
target,
_MAX_SKILL_FILE_SIZE,
)
with target.open("rb") as fh:
content = fh.read(_MAX_SKILL_FILE_SIZE)
else:
content = target.read_bytes()
except PermissionError:
responses.append(
FileDownloadResponse(path=p, error="permission_denied")
)
continue
except OSError as exc: # pragma: no cover - defensive
logger.warning("Builtin skill read failed %s: %s", target, exc)
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
continue
responses.append(FileDownloadResponse(path=p, content=content, error=None))
return responses
class SearchSpaceSkillsBackend(BackendProtocol):
"""Read-only view of search-space-authored skills.
Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged
folder ``/documents/_skills/`` (configurable). The folder is intended to be
writable only by search-space admins; this backend never writes.
The skills middleware expects a layout like::
/<source_root>/<skill-name>/SKILL.md
But the KB stores documents like ``/documents/_skills/<name>/SKILL.md``.
We expose the inner namespace by remapping each path. When mounted under
:class:`CompositeBackend` at prefix ``/skills/space/`` the paths the
middleware sees become ``/skills/space/<name>/SKILL.md``; the composite
strips ``/skills/space/`` and hands us ``/<name>/SKILL.md``, which we
rewrite to ``/documents/_skills/<name>/SKILL.md`` before forwarding to the
KB.
No new database table is needed: the privileged folder convention is
enforced server-side outside of this class. We intentionally swallow any
write/edit attempts (the base class raises ``NotImplementedError``).
"""
DEFAULT_KB_ROOT: str = "/documents/_skills"
def __init__(
self,
kb_backend: KBPostgresBackend,
*,
kb_root: str = DEFAULT_KB_ROOT,
) -> None:
self._kb = kb_backend
# Normalize trailing slash off so we can join cleanly.
self._kb_root = kb_root.rstrip("/") or "/"
def _to_kb(self, path: str) -> str:
"""Rewrite a virtual path into the underlying KB namespace."""
bare = path.lstrip("/")
if not bare:
return self._kb_root
return f"{self._kb_root}/{bare}"
def _from_kb(self, kb_path: str) -> str:
"""Rewrite a KB path back into our virtual namespace."""
if not kb_path.startswith(self._kb_root):
return kb_path # pragma: no cover - defensive
rel = kb_path[len(self._kb_root) :]
return rel if rel.startswith("/") else "/" + rel
def ls_info(self, path: str) -> list[FileInfo]:
# KBPostgresBackend exposes only the async API meaningfully; the sync
# path falls back to ``asyncio.to_thread(...)`` in the base class. We
# keep this stub to satisfy abstract resolution; the middleware calls
# ``als_info``.
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
async def als_info(self, path: str) -> list[FileInfo]:
kb_path = self._to_kb(path)
try:
infos = await self._kb.als_info(kb_path)
except Exception as exc: # pragma: no cover - defensive
logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc)
return []
remapped: list[FileInfo] = []
for info in infos:
kb_p = info.get("path", "")
if not kb_p.startswith(self._kb_root):
continue
remapped.append({**info, "path": self._from_kb(kb_p)})
return remapped
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
kb_paths = [self._to_kb(p) for p in paths]
responses = await self._kb.adownload_files(kb_paths)
# Re-map response paths back to the virtual namespace so the middleware
# correlates them to the input list correctly.
remapped: list[FileDownloadResponse] = []
for original, resp in zip(paths, responses, strict=True):
remapped.append(replace(resp, path=original))
return remapped
SKILLS_BUILTIN_PREFIX = "/skills/builtin/"
SKILLS_SPACE_PREFIX = "/skills/space/"
def build_skills_backend_factory(
*,
builtin_root: Path | str | None = None,
search_space_id: int | None = None,
) -> Callable[[ToolRuntime], BackendProtocol]:
"""Return a runtime-aware factory for the skills :class:`CompositeBackend`.
When ``search_space_id`` is provided the composite includes a
:class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh
per-runtime :class:`KBPostgresBackend`, mirroring how
:func:`build_backend_resolver` constructs the main filesystem backend.
When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit
tests) only the bundled :class:`BuiltinSkillsBackend` is exposed.
Returning a factory rather than a fixed instance is intentional: the
underlying KB backend depends on per-call ``ToolRuntime`` state
(``staged_dirs``, ``files`` cache, runtime config), so a single shared
instance cannot serve multiple concurrent agent runs.
"""
builtin = BuiltinSkillsBackend(builtin_root)
if search_space_id is None:
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
# Default StateBackend is intentionally inert: any path outside the
# ``/skills/builtin/`` route resolves to an empty per-runtime state
# so the SkillsMiddleware can iterate sources without raising.
return CompositeBackend(
default=StateBackend(runtime),
routes={SKILLS_BUILTIN_PREFIX: builtin},
)
return _factory_builtin_only
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
# Imported lazily to avoid a hard dependency at module import time:
# ``KBPostgresBackend`` pulls in DB models, which are unnecessary for
# the unit-tested builtin path.
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
KBPostgresBackend,
)
kb = KBPostgresBackend(search_space_id, runtime)
space = SearchSpaceSkillsBackend(kb)
return CompositeBackend(
default=StateBackend(runtime),
routes={
SKILLS_BUILTIN_PREFIX: builtin,
SKILLS_SPACE_PREFIX: space,
},
)
return _factory_with_space
def default_skills_sources() -> list[str]:
"""Return the canonical source list for SkillsMiddleware (built-in then space)."""
return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]
__all__ = [
"SKILLS_BUILTIN_PREFIX",
"SKILLS_SPACE_PREFIX",
"BuiltinSkillsBackend",
"SearchSpaceSkillsBackend",
"build_skills_backend_factory",
"default_skills_sources",
]

View file

@ -0,0 +1,24 @@
---
name: email-drafting
description: Draft an email matching the user's voice, with structured intent and CTA
---
# Email drafting
## When to use this skill
"Draft an email to ...", "reply to this thread", "write a follow-up to X". Plain "summarize the email" is **not** in scope — that's a comprehension task.
## Voice
Search the KB for prior emails from the user to similar audiences (same recipient, same topic class). Mirror tone, opening style, sign-off, and length distribution. If there is no precedent, default to: warm, direct, no filler, short paragraphs, one clear ask.
## Required structure
Every draft includes, in this order:
1. **Subject line** — concrete, ≤ 8 words, no clickbait, no `Re:` unless replying.
2. **Opening (1 sentence)** — context the recipient already shares; never restate what they wrote unless the thread is long.
3. **Body** — the actual point in one short paragraph. Bullets only if there are >3 discrete items.
4. **Single explicit CTA** — what you want the recipient to do, with a soft deadline if relevant.
5. **Sign-off** — match the user's prior closing style.
## Always offer alternatives
End your message with: "Want me to make it shorter, more formal, or add a different angle?" — give the user one obvious next step.

View file

@ -0,0 +1,23 @@
---
name: kb-research
description: Structured approach to finding and synthesizing information from the user's knowledge base
allowed-tools: scrape_webpage, read_file, ls_tree, grep, web_search
---
# Knowledge-base research
## When to use this skill
- The user asks "find/look up/research" something specifically inside their knowledge base.
- The user references documents, notes, repos, or connector data they expect to exist already.
- A multi-document synthesis is required (e.g., "summarize what we've discussed about X across all my notes").
## Plan
1. Decompose the user's question into 2-4 specific, citation-worthy sub-questions.
2. For each sub-question, run **one** targeted KB search (focused on terms the user would have written, not synonyms). Open the most relevant 2-3 documents fully via `read_file` if their excerpts are too short.
3. Use `grep` to find supporting passages in long files instead of re-reading them end to end.
4. Cite every claim with `[citation:chunk_id]` exactly as the chunk tag specifies.
## What good output looks like
- Short paragraphs with inline citations.
- Quoted phrases when wording matters.
- An explicit "Not found in your knowledge base" callout when a sub-question has no support — never fabricate.

View file

@ -0,0 +1,22 @@
---
name: meeting-prep
description: Pull together briefing materials before a scheduled meeting
allowed-tools: web_search, scrape_webpage, read_file
---
# Meeting preparation
## When to use this skill
The user mentions an upcoming meeting, call, or interview and asks you to "prep", "brief me", "pull background", or "what do I need to know about X before tomorrow".
## Output structure
Always produce these sections (omit any with no signal — don't pad):
1. **Attendees & context** — who's in the room, their roles, what they care about. Pull from KB notes about prior interactions; supplement with public profile facts via `web_search` when names or companies are unfamiliar.
2. **Open threads** — outstanding action items, unresolved decisions, last-mentioned blockers from prior conversation history.
3. **Recent moves** — within the last 30 days: relevant launches, hires, news. Cite KB chunks when present, otherwise external sources.
4. **Suggested questions** — 3-5 questions the user could ask, tailored to the open threads and the attendees' likely priorities.
## Source ordering
- Always check the user's KB **first** for prior meeting notes, internal docs, or Slack threads about these attendees.
- Only fall back to `web_search` for *publicly verifiable* facts — never to fabricate a participant's preferences or relationships.

View file

@ -0,0 +1,23 @@
---
name: report-writing
description: How to scope, draft, and revise a Markdown report artifact via generate_report
allowed-tools: generate_report, read_file
---
# Report writing
## When to use this skill
The user explicitly requests a deliverable: "write a report on …", "draft a memo", "produce a brief", "expand the previous report". A creation or modification verb pointed at an artifact is required (see `generate_report`'s when-to-call rules).
## Decision flow
1. **Source strategy.** Decide which `source_strategy` fits:
- `conversation` — substantive Q&A on the topic already in chat.
- `kb_search` — fresh topic; supply 15 precise `search_queries`.
- `auto` — partial conversation context; let the tool fall back.
- `provided` — verbatim source text only.
2. **Style.** Default to `report_style="detailed"` unless the user explicitly asks for "brief", "one page", "500 words".
3. **Revisions.** When modifying an existing report from this conversation, set `parent_report_id` and put the change list in `user_instructions` ("add carbon-capture section", "tighten conclusion").
4. **Never paste the report back into chat** after `generate_report` returns — confirm and let the artifact card render itself.
## Hooks for KB-only mode
If `kb_search`/`auto` returns no results, do **not** silently switch to general knowledge. Surface the gap in your confirmation message.

View file

@ -0,0 +1,25 @@
---
name: slack-summary
description: Distill a Slack channel or thread into actionable summary
---
# Slack summarization
## When to use this skill
The user asks to summarize Slack ("what happened in #eng-platform this week", "what did Alice say about the launch", "catch me up on the design channel").
## Required inputs
Confirm before searching:
- **Which channel(s) or thread(s)?** Don't guess if ambiguous.
- **What time window?** Default to the last 7 days when not specified, but say so.
## Output shape
Produce three concise sections:
1. **Key decisions** — explicit choices that were made, with the deciding message cited.
2. **Open questions** — things asked but not answered, with the asking message cited.
3. **Action items**`@mention` who owes what by when, *only if explicitly stated*. Don't invent assignees.
## What not to do
- Never produce a chronological play-by-play of every message — distill.
- Never quote private messages without flagging them as such.
- If the channel was empty in the time window, say so — don't fabricate filler.

View file

@ -0,0 +1,7 @@
"""Main-agent system prompt — not shared verbatim with single-agent ``new_chat``."""
from __future__ import annotations
from .builder import build_main_agent_system_prompt
__all__ = ["build_main_agent_system_prompt"]

View file

@ -0,0 +1,7 @@
"""Assemble the main-agent system prompt from ``prompts/`` fragments."""
from __future__ import annotations
from .compose import build_main_agent_system_prompt
__all__ = ["build_main_agent_system_prompt"]

View file

@ -0,0 +1,100 @@
"""Assemble the main-agent system prompt from ``prompts/``.
Section order (default flow)::
<agent_identity>
[user's custom_system_instructions, if any]
<core_behavior> # default body
<knowledge_base_first> # default body
<dynamic_context> # always
<routing> # default body
<specialists> # always (dynamic roster)
<tools> # always (vertical-slice)
<memory_protocol> # default body
<citations> # always
<output_format> # always
<refusal_and_limits> # always
<reminder> # always
``custom_system_instructions`` is **additive**, not a replacement: it slots
between identity and the default body so platform safety nets (KB-first,
routing, citations, output formatting, refusal rules) always apply.
``use_default_system_instructions=False`` skips the four "default body"
sections but keeps all the always-on platform sections.
"""
from __future__ import annotations
from datetime import UTC, datetime
from app.db import ChatVisibility
from .load_md import read_prompt_md
from .sections.citations import build_citations_section
from .sections.dynamic_context import build_dynamic_context_section
from .sections.identity import build_identity_section
from .sections.memory_protocol import build_memory_protocol_section
from .sections.specialists import build_specialists_section
from .sections.tools import build_tools_section
def build_main_agent_system_prompt(
*,
registry_subagent_prompt_lines: list[tuple[str, str]],
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
custom_system_instructions: str | None = None,
use_default_system_instructions: bool = True,
citations_enabled: bool = True,
model_name: str | None = None,
) -> str:
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
visibility = thread_visibility or ChatVisibility.PRIVATE
parts: list[str] = []
parts.append(
build_identity_section(visibility=visibility, resolved_today=resolved_today)
)
if custom_system_instructions and custom_system_instructions.strip():
parts.append(
"\n"
+ custom_system_instructions.format(resolved_today=resolved_today)
+ "\n"
)
if use_default_system_instructions:
parts.append(_wrap(read_prompt_md("core_behavior.md")))
parts.append(_wrap(read_prompt_md("kb_first.md")))
parts.append(build_dynamic_context_section(visibility=visibility))
if use_default_system_instructions:
parts.append(_wrap(read_prompt_md("routing.md")))
parts.append(build_specialists_section(registry_subagent_prompt_lines))
parts.append(
build_tools_section(
visibility=visibility,
enabled_tool_names=enabled_tool_names,
disabled_tool_names=disabled_tool_names,
)
)
if use_default_system_instructions:
parts.append(build_memory_protocol_section(visibility=visibility))
parts.append(build_citations_section(citations_enabled=citations_enabled))
parts.append(_wrap(read_prompt_md("output_format.md")))
parts.append(_wrap(read_prompt_md("refusal_and_limits.md")))
parts.append(_wrap(read_prompt_md("reminder.md")))
return "".join(p for p in parts if p)
def _wrap(fragment: str) -> str:
return f"\n{fragment}\n" if fragment else ""

View file

@ -0,0 +1,16 @@
"""Load main-agent prompt fragments from ``system_prompt/prompts/``."""
from __future__ import annotations
from importlib import resources
_PROMPTS_PACKAGE = "app.agents.chat.multi_agent_chat.main_agent.system_prompt.prompts"
def read_prompt_md(filename: str) -> str:
"""Load ``prompts/{filename}`` (e.g. ``core_behavior.md`` or ``tools/web_search/description.md``)."""
ref = resources.files(_PROMPTS_PACKAGE).joinpath(filename)
if not ref.is_file():
return ""
text = ref.read_text(encoding="utf-8")
return text[:-1] if text.endswith("\n") else text

View file

@ -0,0 +1 @@
"""Rendered slices of the main-agent system prompt."""

View file

@ -0,0 +1,11 @@
"""``<citations>`` section — on/off variant based on workspace configuration."""
from __future__ import annotations
from ..load_md import read_prompt_md
def build_citations_section(*, citations_enabled: bool) -> str:
variant = "on" if citations_enabled else "off"
fragment = read_prompt_md(f"citations/{variant}.md")
return f"\n{fragment}\n" if fragment else ""

View file

@ -0,0 +1,13 @@
"""``<dynamic_context>`` section — visibility-aware (private vs team thread)."""
from __future__ import annotations
from app.db import ChatVisibility
from ..load_md import read_prompt_md
def build_dynamic_context_section(*, visibility: ChatVisibility) -> str:
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
fragment = read_prompt_md(f"dynamic_context/{variant}.md")
return f"\n{fragment}\n" if fragment else ""

View file

@ -0,0 +1,19 @@
"""``<agent_identity>`` section — visibility-aware, with ``{resolved_today}`` injection."""
from __future__ import annotations
from app.db import ChatVisibility
from ..load_md import read_prompt_md
def build_identity_section(
*,
visibility: ChatVisibility,
resolved_today: str,
) -> str:
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
fragment = read_prompt_md(f"identity/{variant}.md")
if not fragment:
return ""
return "\n" + fragment.format(resolved_today=resolved_today) + "\n"

View file

@ -0,0 +1,13 @@
"""``<memory_protocol>`` section — visibility-aware (user vs team memory)."""
from __future__ import annotations
from app.db import ChatVisibility
from ..load_md import read_prompt_md
def build_memory_protocol_section(*, visibility: ChatVisibility) -> str:
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
fragment = read_prompt_md(f"memory_protocol/{variant}.md")
return f"\n{fragment}\n" if fragment else ""

View file

@ -0,0 +1,15 @@
"""``<specialists>`` section — live ``task`` roster for this workspace.
The roster is non-empty by contract: ``deliverables`` and ``knowledge_base``
both declare ``frozenset()`` in ``SUBAGENT_TO_REQUIRED_CONNECTOR_MAP``, so
they survive every connector-based exclusion pass.
"""
from __future__ import annotations
def build_specialists_section(
specialist_lines: list[tuple[str, str]],
) -> str:
bullets = "\n".join(f"- **{name}** — {desc}" for name, desc in specialist_lines)
return f"\n<specialists>\n{bullets}\n</specialists>\n"

View file

@ -0,0 +1,20 @@
"""Main-agent ``<tools>`` block (memory + research builtins + ``task``)."""
from __future__ import annotations
from app.db import ChatVisibility
from ..tool_instruction_block import build_tools_instruction_block
def build_tools_section(
*,
visibility: ChatVisibility,
enabled_tool_names: set[str] | None,
disabled_tool_names: set[str] | None,
) -> str:
return build_tools_instruction_block(
visibility=visibility,
enabled_tool_names=enabled_tool_names,
disabled_tool_names=disabled_tool_names,
)

View file

@ -0,0 +1,84 @@
"""Compose the ``<tools>`` block from per-tool vertical-slice folders.
Each tool lives in ``prompts/tools/<name>/`` with ``description.md`` and an
``example.md``. Visibility variants live in ``{private,team}/`` subfolders.
"""
from __future__ import annotations
from app.db import ChatVisibility
from ...tools import MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED
from .load_md import read_prompt_md
_MEMORY_VARIANT_TOOLS: frozenset[str] = frozenset({"update_memory"})
def _tool_fragment(tool_name: str, variant: str, leaf: str) -> str:
if tool_name in _MEMORY_VARIANT_TOOLS:
return read_prompt_md(f"tools/{tool_name}/{variant}/{leaf}")
return read_prompt_md(f"tools/{tool_name}/{leaf}")
def _format_tool_label(tool_name: str) -> str:
return tool_name.replace("_", " ").title()
def build_tools_instruction_block(
*,
visibility: ChatVisibility,
enabled_tool_names: set[str] | None,
disabled_tool_names: set[str] | None,
) -> str:
"""Render ``<tools>``. ``task`` is always included: at least ``deliverables``
and ``knowledge_base`` are always in ``<specialists>`` (see constants)."""
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
parts: list[str] = ["\n<tools>\n"]
for tool_name in MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED:
if enabled_tool_names is not None and tool_name not in enabled_tool_names:
continue
description = _tool_fragment(tool_name, variant, "description.md")
example = _tool_fragment(tool_name, variant, "example.md")
if not description and not example:
continue
if description:
parts.append(description + "\n")
if example:
parts.append("\n" + example + "\n")
parts.append("\n")
task_description = read_prompt_md("tools/task/description.md")
task_example = read_prompt_md("tools/task/example.md")
if task_description:
parts.append(task_description + "\n")
if task_example:
parts.append("\n" + task_example + "\n")
parts.append("\n")
known_disabled = (
set(disabled_tool_names) & set(MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED)
if disabled_tool_names
else set()
)
if known_disabled:
disabled_list = ", ".join(
_format_tool_label(n)
for n in MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED
if n in known_disabled
)
parts.append(
"<disabled_tools>\n"
f"Disabled for this session: {disabled_list}.\n"
"Don't claim you can use them. If the user needs that capability,\n"
"delegate with `task` when a specialist covers it; otherwise say\n"
"the tool is disabled.\n"
"</disabled_tools>\n"
)
parts.append("</tools>\n")
return "".join(parts)

View file

@ -0,0 +1 @@
"""Main-agent prompt fragments loaded by :mod:`...system_prompt.builder.load_md`."""

View file

@ -0,0 +1 @@
"""``<citations>`` block — ``on`` (cite chunk ids) and ``off`` (hard suppression)."""

View file

@ -0,0 +1,12 @@
<citations>
Citation markers are **disabled** in this configuration.
Do NOT include `[citation:…]` markers anywhere, even if tool descriptions or
examples reference them. Ignore citation-format reminders elsewhere in this
prompt when they conflict with this block.
1. Answer in plain prose. Optional markdown links to public URLs when
sources are URLs.
2. Do not expose raw chunk ids, document ids, or internal ids to the user.
3. Present KB or docs facts naturally without attribution markers.
</citations>

View file

@ -0,0 +1,42 @@
<citations>
Citations reach the answer through two channels. Use whichever applies — and
never invent ids you didn't see. Citation ids are resolved by exact-match
lookup; a wrong id silently breaks the link, so when in doubt, omit.
### Channel A — chunk blocks injected this turn
When `web_search` returns `<document>` / `<chunk id='…'>` blocks in this
turn:
1. For each factual statement taken from those chunks, add
`[citation:chunk_id]` using the **exact** id from a visible
`<chunk id='…'>` tag. Copy digit-for-digit (or the URL verbatim);
do not retype from memory.
2. `<document_id>` is the parent doc id, **not** a citation source —
only ids inside `<chunk id='…'>` count.
3. Multiple chunks → `[citation:id1], [citation:id2]` (comma-separated,
each id copied individually).
4. Never invent, normalise, or guess at adjacent ids; if unsure, omit.
5. Plain brackets only — no markdown links, no footnote numbering.
### Channel B — citations relayed by a `task` specialist
A `task(...)` tool message may contain `[citation:<chunk_id>]` markers
the specialist already attached to its prose. The specialist saw the
underlying `<chunk id='…'>` blocks; you didn't. So:
1. **Preserve those markers verbatim** in your final answer — do not
reformat, renumber, drop, or wrap them in markdown links. When you
paraphrase a specialist sentence, copy the marker character-for-
character; do not regenerate the id from memory (LLMs reliably
corrupt nearby digits).
2. Keep each marker attached to the sentence the specialist attached
it to.
3. Do **not** add new `[citation:…]` markers of your own to a
specialist's prose; if a fact has no marker, the specialist
couldn't tie it to a chunk and neither can you.
4. When a specialist returns JSON, the citation markers live inside
the prose-bearing fields (e.g. a summary or excerpt). Pull them
along with the surrounding sentence when you quote.
If neither channel surfaces citation markers this turn, do not fabricate
them.
</citations>

View file

@ -0,0 +1,13 @@
<core_behavior>
- Be concise and direct. No preamble ("Sure!", "Great question!", "I'll now…").
- Don't narrate intent — just act. State the outcome, not the plan.
- If the request is ambiguous, ask before acting. If asked *how* to do
something, explain first, then act.
- Prioritise accuracy over agreement. Disagree respectfully when the user is
wrong; avoid unnecessary superlatives or emotional validation.
- Persist until the task is done or you are genuinely blocked. Don't stop
partway and describe what you *would* do.
- For longer work, give brief progress updates only when they add new
information (a discovery, a tradeoff, a blocker, the start of a non-trivial
step). Don't narrate routine reads.
</core_behavior>

View file

@ -0,0 +1 @@
"""``<dynamic_context>`` block — private and team variants."""

View file

@ -0,0 +1,27 @@
<dynamic_context>
The runtime inserts these system messages each turn. They are authoritative
for *this* turn only.
`<user_memory>` carries the durable personal context the user has accumulated
across sessions — role, interests, preferences, projects, background,
standing instructions. It also reports current character usage versus the
hard limit so you can manage the budget. Treat it as background colour for
your answer, not as the task itself.
`<priority_documents>` lists the workspace documents most relevant to the
latest user message, ranked by relevance score, with `[USER-MENTIONED]`
flagged on anything the user explicitly referenced. When the task is about
workspace content, read these first; matched passages inside each document
are flagged via `<chunk_index>` so you can jump straight to them.
`<workspace_tree>` shows the full `/documents/` folder and file layout. Use
it to resolve paths the user describes in natural language ("my Q2 roadmap",
"last week's meeting notes") into concrete document references before
delegating to a specialist.
`<document>` and `<chunk id='…'>` blocks are chunked indexed content returned
by KB search (backing `<priority_documents>`). Each chunk carries a stable
`id` attribute.
If a block doesn't appear this turn, work from the conversation alone.
</dynamic_context>

View file

@ -0,0 +1,27 @@
<dynamic_context>
The runtime inserts these system messages each turn. They are authoritative
for *this* turn only.
`<team_memory>` carries the durable shared context this team has built up —
decisions, conventions, architecture notes, processes, key facts. It also
reports current character usage versus the hard limit so you can manage the
budget. Treat it as background colour for your answer, not as the task itself.
`<priority_documents>` lists the workspace documents most relevant to the
latest user message, ranked by relevance score, with `[USER-MENTIONED]`
flagged on anything someone in the thread explicitly referenced. When the
task is about workspace content, read these first; matched passages inside
each document are flagged via `<chunk_index>` so you can jump straight to
them.
`<workspace_tree>` shows the full `/documents/` folder and file layout. Use
it to resolve paths described in natural language ("the Q2 roadmap", "last
week's planning notes") into concrete document references before delegating
to a specialist.
`<document>` and `<chunk id='…'>` blocks are chunked indexed content returned
by KB search (backing `<priority_documents>`). Each chunk carries a stable
`id` attribute.
If a block doesn't appear this turn, work from the conversation alone.
</dynamic_context>

View file

@ -0,0 +1 @@
"""``<agent_identity>`` block — private and team variants."""

View file

@ -0,0 +1,8 @@
<agent_identity>
You are **SurfSense's main agent**. Your job is to answer the user using their
knowledge base, lightweight web research, persistent memory, and **specialist
subagents** invoked via the `task` tool. You are an orchestrator — most
non-trivial work belongs on a specialist.
Today (UTC): {resolved_today}
</agent_identity>

View file

@ -0,0 +1,11 @@
<agent_identity>
You are **SurfSense's main agent**. Your job is to answer the user using their
shared team knowledge base, lightweight web research, persistent memory, and
**specialist subagents** invoked via the `task` tool. You are an orchestrator
— most non-trivial work belongs on a specialist.
Today (UTC): {resolved_today}
You are in a **team thread**. Each message is prefixed with `[DisplayName]`.
Attribute quotes and decisions to the named author when relevant.
</agent_identity>

View file

@ -0,0 +1,21 @@
<knowledge_base_first>
CRITICAL — ground factual answers in what you actually receive this turn:
- injected workspace context (see `<dynamic_context>`),
- results from your own tool calls (`web_search`, `scrape_webpage`),
- or substantive summaries returned by a `task` specialist you invoked.
Do **not** answer factual or informational questions from general knowledge
unless the user explicitly authorises it after you say you couldn't find
enough in those sources. The flow when nothing is found:
1. Say you couldn't find enough in their workspace or tool output.
2. Ask: *"Would you like me to answer from my general knowledge instead?"*
3. Only answer from general knowledge after a clear yes.
This rule does NOT apply to: casual conversation · meta-questions about
SurfSense ("what can you do?") · formatting or analysis of content already
in chat · clear rewrite/edit instructions · lightweight web research.
For "how do I use SurfSense" / product-documentation questions, point the
user to https://www.surfsense.com/docs.
</knowledge_base_first>

View file

@ -0,0 +1 @@
"""``<memory_protocol>`` block — private and team variants."""

View file

@ -0,0 +1,15 @@
<memory_protocol>
After understanding each user message, check: does it reveal durable facts
about the user — role, interests, preferences, projects, background, or
standing instructions?
If yes, call `update_memory` **alongside** your normal response — don't
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
session logistics). Stay within the budget shown in `<user_memory>`.
Memory is heading-based markdown. New entries should be under `##` headings
such as `## Facts`, `## Preferences`, or `## Instructions`, with bullets like
`- YYYY-MM-DD: text`. If existing memory contains legacy
`(YYYY-MM-DD) [fact|pref|instr]` markers, preserve the information but write
new saves in the heading-based format.
</memory_protocol>

View file

@ -0,0 +1,17 @@
<memory_protocol>
After understanding each user message, check: does it reveal durable facts
about the team — decisions, conventions, architecture notes, processes, or
key facts?
If yes, call `update_memory` **alongside** your normal response — don't
defer it to a later turn. Skip ephemeral chat noise (one-off Q/A, greetings,
session logistics). Stay within the budget shown in `<team_memory>`.
Team memory is heading-based markdown. New entries should be under `##`
headings such as `## Product Decisions`, `## Engineering Conventions`,
`## Project Facts`, or `## Open Questions`, with bullets like
`- YYYY-MM-DD: text`. If existing memory contains legacy `(YYYY-MM-DD) [fact]`
markers, preserve the information but write new saves in the heading-based
format. Do not create personal headings such as `## Preferences` or
`## Instructions`.
</memory_protocol>

View file

@ -0,0 +1,7 @@
<output_format>
- Mathematical formulas: **always** LaTeX. Never backtick code spans or
Unicode symbols for math.
- Never expose internal tool parameter names, backend IDs, or
implementation details. Use natural, user-friendly language.
- External sources: markdown links `[label](url)`, never bare URLs.
</output_format>

View file

@ -0,0 +1,16 @@
<provider_hints>
You are running on an Anthropic Claude model (SurfSense **main agent**).
Structured reasoning:
- For non-trivial work, `<thinking>` / short `<plan>` before tool calls is fine.
Professional objectivity:
- Accuracy over flattery; verify with **web_search**, **scrape_webpage**, or **task** when unsure — dont invent connector access.
Task management:
- For 3+ steps, use todo tooling; update statuses promptly.
Tool calls:
- Parallelise independent calls; sequence only when outputs chain.
- Never pretend you can run connector-specific tools directly — route through **task** when needed.
</provider_hints>

View file

@ -0,0 +1,18 @@
<provider_hints>
You are running on a DeepSeek model (SurfSense **main agent**).
Reasoning hygiene (R1-aware):
- Keep internal scratch separate from the user-facing answer; dont leak chain-of-thought into tool arguments.
Output style:
- Concise; lead with the answer or the next action; avoid sycophantic openers.
Attribution:
- When citations are **enabled** and facts come from chunk-tagged context, follow the citation block above.
- When citations are **disabled**, do not use `[citation:…]`.
Tool calls:
- Parallelise independent calls.
- For SurfSense docs/product questions, point the user to https://www.surfsense.com/docs.
- Dont invent paths, chunk ids, or URLs — only values from tools or the user.
</provider_hints>

View file

@ -0,0 +1,18 @@
<provider_hints>
You are running on a Google Gemini model (SurfSense **main agent**).
Output style:
- Concise & direct. Fewer than ~3 lines of prose when the task allows (excluding tool output and code).
- No filler openers/closers — move straight to the answer or the tool call.
- GitHub-flavoured Markdown; monospace-friendly.
Workflow (Understand → Plan → Act → Verify):
1. **Understand:** parse the ask; use injected workspace context before guessing.
2. **Plan:** for multi-step work, a short plan first.
3. **Act:** only with tools you actually have on this agent (see `<tools>` and `<tool_routing>`). Connector work → **task**.
4. **Verify:** re-read or re-search only when it materially reduces risk.
Discipline:
- Do not imply access to connectors, MCP tools, or deliverable generators except via **task**.
- Pass paths to **task(knowledge_base, …)** only when you saw them in `<workspace_tree>` or `<priority_documents>`. Otherwise describe the document in natural language and let the subagent resolve it.
</provider_hints>

View file

@ -0,0 +1,16 @@
<provider_hints>
You are running on an xAI Grok model (SurfSense **main agent**).
Maximum terseness:
- Fewer than 4 lines unless detail is requested; skip preamble/postamble.
Tool discipline:
- Typically one investigative tool per turn unless several independent read-only queries are clearly needed; dont repeat identical calls.
Attribution:
- When citations are **enabled** (see citation block above) and you answer from chunk-tagged documents, use `[citation:chunk_id]` exactly as specified there.
- When citations are **disabled**, never emit `[citation:…]` — plain prose and links per tool guidance.
Style:
- No emojis unless asked; flat lists for short answers.
</provider_hints>

View file

@ -0,0 +1,21 @@
<provider_hints>
You are running on a Moonshot Kimi model (Kimi-K1.5 / Kimi-K2 / Kimi-K2.5+), SurfSense **main agent**.
Action bias:
- Default to taking action with tools rather than describing solutions in prose. If a tool can answer the question, call the tool.
- Don't narrate routine reads, searches, or obvious next steps. Combine related progress into one short status line.
- Be thorough in actions (test what you build, verify what you change). Be brief in explanations.
Tool calls:
- Output multiple non-interfering tool calls in a SINGLE response — parallelism is a major efficiency win on this model.
- When the `task` tool is available, delegate focused subtasks to a subagent with full context (subagents don't inherit yours).
- Don't apologise or pre-announce tool calls. The tool call itself is self-explanatory.
Language:
- Respond in the SAME language as the user's most recent turn unless explicitly instructed otherwise.
Discipline:
- Stay on track. Never give the user more than what they asked for.
- Fact-check with tools; dont fabricate chunk ids or connector outcomes.
- Keep it stupidly simple. Don't overcomplicate.
</provider_hints>

Some files were not shown because too many files have changed in this diff Show more