feat: made chat fast

- Introduced lazy knowledge base retrieval mode, allowing the main agent to fetch KB content on demand via the `search_knowledge_base` tool, improving performance by skipping expensive pre-injection processes.
- Added cross-thread caching capability, enabling reuse of compiled graphs across different user chats, reducing latency for returning users.
- Updated middleware to support new lazy loading and caching features, ensuring efficient resource utilization and improved response times.
- Enhanced logging for performance tracking during knowledge retrieval and agent interactions.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-06-09 04:45:17 -07:00
parent ce952d2ad1
commit 41ff57101c
32 changed files with 979 additions and 169 deletions

View file

@ -362,6 +362,13 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false
# SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false
# KB retrieval mode (default OFF = lazy). When OFF, the main agent retrieves
# KB content on demand via the `search_knowledge_base` tool and skips the
# expensive per-turn pre-injection (planner LLM + embed + hybrid search,
# ~2.3s); explicit @-mentions are still surfaced cheaply. Set to true to
# restore the original eager `<priority_documents>` pre-injection.
# SURFSENSE_ENABLE_KB_PRIORITY_PREINJECTION=false
# Snapshot / revert
# SURFSENSE_ENABLE_ACTION_LOG=false
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
@ -382,6 +389,15 @@ LANGSMITH_PROJECT=surfsense
# rollback if you suspect cache-related staleness.
# SURFSENSE_ENABLE_AGENT_CACHE=true
# Cross-thread reuse (default ON). Drops thread_id from the cache key so a
# returning user's NEW chats (same user + search space + config + visibility)
# hit the already-compiled graph instead of paying a fresh ~4-5s compile —
# turning a cold first turn into a warm one. Safe because ActionLog,
# KB-persistence, and the deliverables tools now resolve the chat thread from
# the live RunnableConfig at call time rather than a build-time closure. Flip
# OFF to fall back to a per-thread cache key (instant rollback).
# SURFSENSE_ENABLE_CROSS_THREAD_AGENT_CACHE=true
# Cache capacity (max number of compiled-agent entries kept in memory)
# and TTL per entry (seconds). Working set is typically one entry per
# active thread on this replica; tune up for very large deployments.

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import time
from collections.abc import Sequence
from typing import Any
@ -18,6 +19,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFl
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
from app.agents.chat.shared.context import SurfSenseContextSchema
from app.db import ChatVisibility
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
def build_compiled_agent_graph_sync(
@ -43,6 +47,7 @@ def build_compiled_agent_graph_sync(
disabled_tools: list[str] | None = None,
):
"""Sync compile: middleware + ``create_agent`` (run via ``asyncio.to_thread``)."""
mw_start = time.perf_counter()
main_agent_middleware = build_main_agent_deepagent_middleware(
llm=llm,
tools=tools,
@ -63,7 +68,9 @@ def build_compiled_agent_graph_sync(
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
)
mw_elapsed = time.perf_counter() - mw_start
create_start = time.perf_counter()
agent = create_agent(
llm,
system_prompt=final_system_prompt,
@ -72,6 +79,15 @@ def build_compiled_agent_graph_sync(
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer,
)
create_elapsed = time.perf_counter() - create_start
_perf_log.info(
"[graph_compile] middleware_build=%.3fs main_create_agent=%.3fs "
"total=%.3fs mw_count=%d",
mw_elapsed,
create_elapsed,
mw_elapsed + create_elapsed,
len(main_agent_middleware),
)
return agent.with_config(
{
"recursion_limit": 10_000,

View file

@ -108,18 +108,32 @@ class ActionLogMiddleware(AgentMiddleware):
self._user_id = user_id
self._tool_definitions = dict(tool_definitions or {})
def _enabled(self) -> bool:
def _enabled(self, thread_id: int | None) -> bool:
flags = get_flags()
if flags.disable_new_agent_stack:
return False
return bool(flags.enable_action_log) and self._thread_id is not None
return bool(flags.enable_action_log) and thread_id is not None
def _resolve_thread_id(self, request: ToolCallRequest) -> int | None:
"""Resolve the live thread id, preferring the runtime config.
Reading ``configurable.thread_id`` from the active ``RunnableConfig``
(rather than the value captured at ``__init__``) lets a single cached
compiled graph safely serve many threads without it, a cache hit
would attribute action-log rows to whichever thread first built the
graph. Falls back to the constructor value for legacy/test runtimes
that don't surface a config.
"""
resolved = _resolve_thread_id(request)
return resolved if resolved is not None else self._thread_id
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
if not self._enabled():
thread_id = self._resolve_thread_id(request)
if not self._enabled(thread_id):
return await handler(request)
result: ToolMessage | Command[Any]
@ -134,10 +148,16 @@ class ActionLogMiddleware(AgentMiddleware):
request=request,
result=None,
error_payload=error_payload,
thread_id=thread_id,
)
raise
await self._record(request=request, result=result, error_payload=None)
await self._record(
request=request,
result=result,
error_payload=None,
thread_id=thread_id,
)
return result
async def _record(
@ -146,6 +166,7 @@ class ActionLogMiddleware(AgentMiddleware):
request: ToolCallRequest,
result: ToolMessage | Command[Any] | None,
error_payload: dict[str, Any] | None,
thread_id: int | None,
) -> None:
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
try:
@ -164,7 +185,7 @@ class ActionLogMiddleware(AgentMiddleware):
chat_turn_id = _resolve_chat_turn_id(request)
row = AgentActionLog(
thread_id=self._thread_id,
thread_id=thread_id,
user_id=self._user_id,
search_space_id=self._search_space_id,
# ``turn_id`` is the deprecated alias of ``tool_call_id``
@ -350,6 +371,36 @@ def _resolve_chat_turn_id(request: Any) -> str | None:
return None
def _resolve_thread_id(request: Any) -> int | None:
"""Return ``configurable.thread_id`` (as int) for this request, if accessible.
Mirrors :func:`_resolve_chat_turn_id`: ``ToolRuntime.config`` is exposed by
LangGraph at ``request.runtime.config``, and the chat thread id lives at
``configurable.thread_id`` (a stringified ``chat_id`` at the main-graph
level). Returns ``None`` when absent or unparseable so the caller can fall
back to the constructor value.
"""
try:
runtime = getattr(request, "runtime", None)
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
configurable = config.get("configurable")
if not isinstance(configurable, dict):
return None
value = configurable.get("thread_id")
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
except Exception: # pragma: no cover - defensive
return None
def _resolve_message_id(request: Any) -> str | None:
"""Tool-call IDs serve as best-available message correlator at this layer."""
return _resolve_tool_call_id(request)

View file

@ -3,6 +3,7 @@
from __future__ import annotations
import time
from collections.abc import Callable
from typing import Any, cast
from deepagents.backends.protocol import BackendFactory, BackendProtocol
@ -14,10 +15,12 @@ from deepagents.middleware.subagents import (
)
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain_core.runnables import Runnable
from langgraph.types import Checkpointer
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
SURF_CONTEXT_HINT_PROVIDER_KEY,
SURF_LAZY_SPEC_FACTORY_KEY,
)
from app.utils.perf import get_perf_logger
@ -52,15 +55,32 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
# switch keys on it so an operator can quarantine one workspace
# without affecting the rest of the deployment.
self._search_space_id = search_space_id
subagent_specs = self._surf_compile_subagent_graphs()
# Lazy subagent compilation. Compiling a subagent graph via
# ``create_agent`` is expensive (~250-400ms each) and there can be up
# to ~17 of them. Doing it all in ``__init__`` put the full cost on
# every cold ``agent_cache`` miss (i.e. on time-to-first-token), even
# though a turn usually invokes zero or one subagent. We instead index
# the raw specs here and compile each graph on first ``task(name)``
# use, memoizing the result for the life of this (cached) instance.
self._compiled: dict[str, Runnable] = {}
self._lazy_specs: dict[str, dict[str, Any]] = {}
# Subagents whose *spec itself* is built lazily (not just compiled).
# Keyed by name → zero-arg factory returning the full spec dict. Used
# for the write knowledge_base subagent, whose filesystem middleware
# builds ~13 tool schemas (~2s) that almost never matter on turn 1.
self._lazy_spec_factories: dict[str, Callable[[], dict[str, Any]]] = {}
descriptors = self._build_subagent_registry()
task_tool = build_task_tool_with_parent_config(
subagent_specs,
descriptors,
task_description,
search_space_id=search_space_id,
resolve_subagent=self._resolve_subagent,
)
if system_prompt and subagent_specs:
if system_prompt and descriptors:
agents_desc = "\n".join(
f"- {s['name']}: {s['description']}" for s in subagent_specs
f"- {s['name']}: {s['description']}" for s in descriptors
)
self.system_prompt = (
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
@ -69,84 +89,100 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
self.system_prompt = system_prompt
self.tools = [task_tool]
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
specs: list[dict[str, Any]] = []
loop_start = time.perf_counter()
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
def _build_subagent_registry(self) -> list[dict[str, Any]]:
"""Index subagents for lazy compilation; return lightweight descriptors.
Pre-compiled specs (those carrying a ``runnable``) are seeded directly
into the memo. Lazy specs are stashed by name and compiled on first
``task(...)`` use via :meth:`_resolve_subagent`. The returned
descriptors carry only ``name``/``description`` plus the optional
context-hint provider everything the ``task`` tool needs to validate
names, render its catalog, and run hints, without paying the
``create_agent`` cost up front.
"""
descriptors: list[dict[str, Any]] = []
for spec in self._subagents:
spec_start = time.perf_counter()
# Provider may be ``None`` (no hint), in which case task_tool
# skips the prepend step. We forward the key unconditionally so
# the registry shape is uniform.
# Provider may be ``None`` (no hint), in which case task_tool skips
# the prepend step. We forward the key unconditionally so the
# descriptor shape is uniform.
hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY)
if "runnable" in spec:
name = spec["name"]
spec_factory = cast(dict, spec).get(SURF_LAZY_SPEC_FACTORY_KEY)
if spec_factory is not None:
# Descriptor-only entry: the spec dict is built on first use.
self._lazy_spec_factories[name] = spec_factory
elif "runnable" in spec:
compiled = cast(CompiledSubAgent, spec)
specs.append(
{
"name": compiled["name"],
"description": compiled["description"],
"runnable": compiled["runnable"],
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
)
continue
if "model" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'model'"
raise ValueError(msg)
if "tools" not in spec:
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
raise ValueError(msg)
model = spec["model"]
if isinstance(model, str):
model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", []))
tools_count = len(spec.get("tools") or [])
mw_count = len(middleware)
compile_start = time.perf_counter()
runnable = create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
)
compile_elapsed = time.perf_counter() - compile_start
specs.append(
self._compiled[name] = compiled["runnable"]
else:
if "model" not in spec:
msg = f"SubAgent '{name}' must specify 'model'"
raise ValueError(msg)
if "tools" not in spec:
msg = f"SubAgent '{name}' must specify 'tools'"
raise ValueError(msg)
self._lazy_specs[name] = cast(dict, spec)
descriptors.append(
{
"name": spec["name"],
"name": name,
"description": spec["description"],
"runnable": runnable,
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
}
)
timings.append(
(
spec["name"],
compile_elapsed,
f"compiled tools={tools_count} mw={mw_count}",
)
)
return descriptors
total_elapsed = time.perf_counter() - loop_start
per_subagent = ", ".join(
f"{name}={elapsed * 1000:.0f}ms[{source}]"
for name, elapsed, source in timings
def _resolve_subagent(self, name: str) -> Runnable:
"""Return the compiled subagent graph for ``name``, compiling on first use.
Memoized: the ``create_agent`` cost is paid once per subagent per
cached middleware instance. Raises ``KeyError`` for unknown names
(callers in the ``task`` tool validate membership before resolving).
"""
cached = self._compiled.get(name)
if cached is not None:
return cached
spec = self._lazy_specs.get(name)
if spec is None:
factory = self._lazy_spec_factories.get(name)
if factory is None:
raise KeyError(name)
# Build the spec on first use (pays the deferred construction cost
# here, off the cold agent-build path), then compile and memoize.
build_start = time.perf_counter()
spec = factory()
_perf_log.info(
"[subagent_spec_lazy] name=%s (deferred spec build) in %.3fs",
name,
time.perf_counter() - build_start,
)
runnable = self._compile_one(spec)
self._compiled[name] = runnable
return runnable
def _compile_one(self, spec: dict[str, Any]) -> Runnable:
"""Compile a single subagent graph against the parent checkpointer."""
model = spec["model"]
if isinstance(model, str):
model = init_chat_model(model)
middleware: list[Any] = list(spec.get("middleware", []))
tools_count = len(spec.get("tools") or [])
mw_count = len(middleware)
compile_start = time.perf_counter()
runnable = create_agent(
model,
system_prompt=spec["system_prompt"],
tools=spec["tools"],
middleware=middleware,
name=spec["name"],
checkpointer=self._surf_checkpointer,
)
_perf_log.info(
"[subagent_compile] total=%.3fs count=%d details=[%s]",
total_elapsed,
len(timings),
per_subagent,
"[subagent_compile_lazy] name=%s in %.3fs tools=%d mw=%d",
spec["name"],
time.perf_counter() - compile_start,
tools_count,
mw_count,
)
return specs
return runnable

View file

@ -12,7 +12,7 @@ import asyncio
import json
import logging
import time
from collections.abc import Awaitable
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, NoReturn, TypeVar
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
@ -143,11 +143,28 @@ def build_task_tool_with_parent_config(
task_description: str | None = None,
*,
search_space_id: int | None = None,
resolve_subagent: Callable[[str], Runnable] | None = None,
) -> BaseTool:
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging."""
subagent_graphs: dict[str, Runnable] = {
spec["name"]: spec["runnable"] for spec in subagents
}
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging.
``subagents`` are lightweight descriptors (``name``/``description`` + the
optional context-hint provider); the actual compiled graph is fetched
lazily via ``resolve_subagent(name)`` so subagent ``create_agent`` cost is
paid on first ``task(name)`` use rather than at graph-build time.
For backward compatibility (and tests), ``resolve_subagent`` may be omitted
when every descriptor already carries a pre-compiled ``runnable``; in that
case a trivial dict-backed resolver is used.
"""
subagent_names: set[str] = {spec["name"] for spec in subagents}
if resolve_subagent is None:
_eager_graphs: dict[str, Runnable] = {
spec["name"]: spec["runnable"] for spec in subagents if "runnable" in spec
}
def resolve_subagent(name: str) -> Runnable:
return _eager_graphs[name]
# Sparse map of opt-in context-hint providers; each runs once per task()
# call to prepend a string to the subagent's first HumanMessage. Failures
# are swallowed so a broken hint never blocks the task.
@ -329,7 +346,7 @@ def build_task_tool_with_parent_config(
def _validate_and_prepare_state(
subagent_type: str, description: str, runtime: ToolRuntime
) -> tuple[Runnable, dict]:
subagent = subagent_graphs[subagent_type]
subagent = resolve_subagent(subagent_type)
subagent_state = {
k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS
}
@ -442,8 +459,8 @@ def build_task_tool_with_parent_config(
batched HITL is intentionally out of scope.
"""
async with semaphore:
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
if subagent_type not in subagent_names:
allowed_types = ", ".join([f"`{k}`" for k in subagent_names])
return (
task_index,
subagent_type,
@ -618,8 +635,8 @@ def build_task_tool_with_parent_config(
"task: must provide either single-mode (`description`+`subagent_type`) "
"or batch-mode (`tasks`)."
)
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
if subagent_type not in subagent_names:
allowed_types = ", ".join([f"`{k}`" for k in subagent_names])
return (
f"We cannot invoke subagent {subagent_type} because it does not exist, "
f"the only allowed types are {allowed_types}"
@ -827,8 +844,8 @@ def build_task_tool_with_parent_config(
subagent_type,
runtime.tool_call_id,
)
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
if subagent_type not in subagent_names:
allowed_types = ", ".join([f"`{k}`" for k in subagent_names])
return (
f"We cannot invoke subagent {subagent_type} because it does not exist, "
f"the only allowed types are {allowed_types}"

View file

@ -26,6 +26,7 @@ from typing import Any
from fractional_indexing import generate_key_between
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
from langgraph.config import get_config
from langgraph.runtime import Runtime
from sqlalchemy import delete, select, update
from sqlalchemy.exc import IntegrityError
@ -1436,9 +1437,33 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-
search_space_id=self.search_space_id,
created_by_id=self.created_by_id,
filesystem_mode=self.filesystem_mode,
thread_id=self.thread_id,
thread_id=self._resolve_thread_id(),
)
def _resolve_thread_id(self) -> int | None:
"""Resolve the live thread id from the active ``RunnableConfig``.
``aafter_agent`` only receives a ``Runtime`` (which does NOT carry the
config), so we read ``configurable.thread_id`` via
:func:`langgraph.config.get_config` the same node-context pattern used
by ``BusyMutexMiddleware``. Resolving at runtime (rather than using the
value captured at ``__init__``) lets one cached compiled graph commit
staged writes against the correct thread across many chats. Falls back
to the constructor value for legacy/test runtimes.
"""
try:
config = get_config()
except Exception:
config = None
if isinstance(config, dict):
value = (config.get("configurable") or {}).get("thread_id")
if value is not None:
try:
return int(value)
except (TypeError, ValueError):
return None
return self.thread_id
__all__ = [
"KnowledgeBasePersistenceMiddleware",

View file

@ -19,7 +19,16 @@ def build_knowledge_priority_mw(
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
preinjection_enabled: bool = True,
) -> KnowledgePriorityMiddleware:
"""Build the KB priority middleware.
When ``preinjection_enabled`` is False (the lazy default), the middleware
runs in mentions-only mode: it skips the expensive planner LLM + embedding
+ hybrid search and only surfaces explicit @-mentions. The main agent is
expected to pull relevant KB content on demand via the
``search_knowledge_base`` tool instead.
"""
return KnowledgePriorityMiddleware(
llm=llm,
planner_llm=get_planner_llm(),
@ -29,4 +38,5 @@ def build_knowledge_priority_mw(
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
inject_system_message=False,
mentions_only=not preinjection_enabled,
)

View file

@ -10,13 +10,15 @@ turn (cloud mode).
from __future__ import annotations
import logging
import time
from collections.abc import Sequence
from typing import Any
from typing import Any, cast
from deepagents import SubAgent
from deepagents.backends import StateBackend
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -49,16 +51,25 @@ from app.agents.chat.multi_agent_chat.subagents import (
get_subagents_to_exclude,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
NAME as KB_WRITE_NAME,
READONLY_NAME as KB_READONLY_NAME,
build_readonly_subagent as build_kb_readonly_subagent,
build_subagent as build_kb_write_subagent,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
build_ask_knowledge_base_tool,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.prompts import (
load_description as load_kb_write_description,
)
from app.agents.chat.multi_agent_chat.subagents.middleware_stack import (
build_subagent_middleware_stack,
)
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
SURF_LAZY_SPEC_FACTORY_KEY,
)
from app.db import ChatVisibility
from app.utils.perf import get_perf_logger
from .action_log import build_action_log_mw
from .anonymous_document import build_anonymous_doc_mw
@ -81,6 +92,8 @@ from .plugins import build_plugin_middlewares
from .skills import build_skills_mw
from .tool_call_repair import build_repair_mw
_perf_log = get_perf_logger()
def build_main_agent_deepagent_middleware(
*,
@ -104,6 +117,7 @@ def build_main_agent_deepagent_middleware(
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
stack_build_start = time.perf_counter()
resilience = build_resilience_middlewares(flags)
memory_mw = build_memory_mw(
@ -118,38 +132,98 @@ def build_main_agent_deepagent_middleware(
"filesystem_mode": filesystem_mode,
"flags": flags,
}
shared_mw_start = time.perf_counter()
shared_subagent_middleware = build_subagent_middleware_stack(
resilience=resilience,
flags=flags,
)
shared_mw_elapsed = time.perf_counter() - shared_mw_start
kb_readonly = build_kb_readonly_subagent(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
)
kb_readonly_spec = kb_readonly.spec
kb_readonly_runnable = create_agent(
llm,
system_prompt=kb_readonly_spec["system_prompt"],
tools=kb_readonly_spec["tools"],
middleware=kb_readonly_spec["middleware"],
name=KB_READONLY_NAME,
checkpointer=checkpointer,
)
ask_kb_tool = build_ask_knowledge_base_tool(kb_readonly_runnable)
def _compile_kb_readonly() -> Runnable:
"""Build *and* compile the read-only KB graph on first ``ask_knowledge_base`` use.
Both the spec build (``build_kb_readonly_subagent`` middleware +
tool-schema construction, ~the same cost as one regular subagent) and
the ``create_agent`` compile are deferred here (memoized by
``build_ask_knowledge_base_tool``) so neither is paid on the cold
agent-build / TTFT path; most first turns never call a subagent.
"""
build_start = time.perf_counter()
kb_readonly_spec = build_kb_readonly_subagent(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
).spec
runnable = create_agent(
llm,
system_prompt=kb_readonly_spec["system_prompt"],
tools=kb_readonly_spec["tools"],
middleware=kb_readonly_spec["middleware"],
name=KB_READONLY_NAME,
checkpointer=checkpointer,
)
_perf_log.info(
"[subagent_compile_lazy] name=%s (spec+compile) in %.3fs",
KB_READONLY_NAME,
time.perf_counter() - build_start,
)
return runnable
ask_kb_tool = build_ask_knowledge_base_tool(_compile_kb_readonly)
def _build_kb_write_spec() -> dict[str, Any]:
"""Build the *write* knowledge_base subagent spec on first ``task`` use.
The KB filesystem middleware builds ~13 tool schemas at ~150ms each
(~2s total), all of which used to land on the cold agent-build / TTFT
path even though ``task("knowledge_base")`` is essentially never the
first thing a turn does. Deferring the whole spec build here (memoized
by the checkpointed subagent middleware) moves that cost to the first
actual KB-write delegation. Captures the same ``subagent_dependencies``
the eager build would have used, so cross-thread cache behaviour is
unchanged.
"""
spec = build_kb_write_subagent(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
).spec
if disabled_tools:
disabled = frozenset(disabled_tools)
tools = spec.get("tools") # type: ignore[typeddict-item]
if isinstance(tools, list):
spec["tools"] = [ # type: ignore[typeddict-unknown-key]
t for t in tools if getattr(t, "name", None) not in disabled
]
return cast(dict[str, Any], spec)
subagents_start = time.perf_counter()
# The write knowledge_base subagent is excluded from the eager build and
# registered as a lazy descriptor (name + description cheap; spec built on
# first ``task("knowledge_base")`` use) — see ``_build_kb_write_spec``.
exclude_names = [*get_subagents_to_exclude(available_connectors), KB_WRITE_NAME]
subagents: list[SubAgent] = build_subagents(
dependencies=subagent_dependencies,
model=llm,
middleware_stack=shared_subagent_middleware,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
exclude=exclude_names,
disabled_tools=disabled_tools,
ask_kb_tool=ask_kb_tool,
)
kb_write_descriptor = cast(
SubAgent,
{
"name": KB_WRITE_NAME,
"description": load_kb_write_description(),
SURF_LAZY_SPEC_FACTORY_KEY: _build_kb_write_spec,
},
)
subagents.append(kb_write_descriptor)
subagents_elapsed = time.perf_counter() - subagents_start
logging.debug("Subagents registry: %s", [s["name"] for s in subagents])
assembly_start = time.perf_counter()
stack: list[Any] = [
build_busy_mutex_mw(flags),
build_otel_mw(flags),
@ -170,6 +244,7 @@ def build_main_agent_deepagent_middleware(
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
preinjection_enabled=flags.enable_kb_priority_preinjection,
),
build_kb_context_projection_mw(),
build_kb_persistence_mw(
@ -223,4 +298,17 @@ def build_main_agent_deepagent_middleware(
),
build_anthropic_cache_mw(),
]
return [m for m in stack if m is not None]
result = [m for m in stack if m is not None]
assembly_elapsed = time.perf_counter() - assembly_start
_perf_log.info(
"[stack_build] total=%.3fs shared_subagent_mw=%.3fs "
"build_subagents=%.3fs stack_assembly=%.3fs subagents=%d mw=%d "
"(kb_readonly deferred to first ask_knowledge_base)",
time.perf_counter() - stack_build_start,
shared_mw_elapsed,
subagents_elapsed,
assembly_elapsed,
len(subagents),
len(result),
)
return result

View file

@ -91,10 +91,18 @@ async def build_agent_with_cache(
# Every per-request value any middleware closes over at __init__ must be in
# the key, otherwise a hit will leak state across threads. Bump the schema
# version when the component list changes shape.
#
# Cross-thread reuse: when enabled, ``thread_id`` is dropped from the key so
# one compiled graph serves all of a user's (same space/config/visibility)
# chats. This is only safe because ActionLog, KB-persistence, and the
# deliverables tools now resolve the chat thread from the live
# RunnableConfig instead of a constructor closure; the schema tag is bumped
# so v2 (per-thread) entries are never confused with v3 (shared) ones.
cross_thread = flags.enable_cross_thread_agent_cache
cache_key = stable_hash(
"multi-agent-v2",
"multi-agent-v3" if cross_thread else "multi-agent-v2",
config_id,
thread_id,
None if cross_thread else thread_id,
user_id,
search_space_id,
visibility,

View file

@ -209,9 +209,6 @@ async def create_multi_agent_chat_deep_agent(
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
if "search_knowledge_base" not in modified_disabled_tools:
modified_disabled_tools.append("search_knowledge_base")
if enabled_tools is not None:
main_agent_enabled_tools = [
n for n in enabled_tools if n in MAIN_AGENT_SURFSENSE_TOOL_NAMES

View file

@ -1,9 +1,17 @@
<knowledge_base_first>
CRITICAL — ground factual answers in what you actually receive this turn:
- the user's knowledge base via `search_knowledge_base` (your PRIMARY source
for anything about their documents, notes, or connected data — the
`<workspace_tree>` only lists what exists, so call the tool to read the
actual content before answering),
- injected workspace context (see `<dynamic_context>`),
- results from your own tool calls (`web_search`, `scrape_webpage`),
- results from your other tool calls (`web_search`, `scrape_webpage`),
- or substantive summaries returned by a `task` specialist you invoked.
For questions about the user's own workspace, call `search_knowledge_base`
first rather than answering from the tree or from memory. Use
`task(knowledge_base)` when you need a document's full text or deeper reads.
Do **not** answer factual or informational questions from general knowledge
unless the user explicitly authorises it after you say you couldn't find
enough in those sources. The flow when nothing is found:

View file

@ -0,0 +1,19 @@
- `search_knowledge_base` — Search the user's own knowledge base (their
indexed documents, notes, files, and connected sources) with hybrid
semantic + keyword retrieval.
- This is your PRIMARY way to ground factual answers about the user's
workspace. The `<workspace_tree>` shows what files exist; this tool pulls
the actual relevant content. Call it BEFORE answering any question about
the user's documents, notes, or connected data — don't answer from the
tree alone or from memory.
- Each hit returns the document's virtual path, a relevance score, and the
matched snippets. The snippets are often enough to answer directly with a
citation.
- When you need a document's full text (not just snippets), delegate a read
to the `knowledge_base` specialist via `task`, passing the path from the
results.
- Args: `query` (focused; include concrete entities, acronyms, people,
projects, or terms), `top_k` (default 5, max 20).
- If nothing relevant comes back, tell the user you couldn't find it in
their workspace before offering to search the web or answer from general
knowledge.

View file

@ -0,0 +1,13 @@
<example>
user: "What did our Q3 planning doc say about hiring?"
→ search_knowledge_base(query="Q3 planning hiring headcount plan")
(Answer from the returned snippets with a citation; if you need the full
document, task the knowledge_base specialist with the returned path.)
</example>
<example>
user: "Summarize my notes on the Acme migration."
→ search_knowledge_base(query="Acme migration notes")
→ task(subagent_type="knowledge_base", description="Read <path> and return a
detailed summary of the Acme migration plan, risks, and timeline.")
</example>

View file

@ -6,6 +6,7 @@ Connector integrations, MCP, deliverables, etc. are delegated via ``task`` subag
from __future__ import annotations
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED: tuple[str, ...] = (
"search_knowledge_base",
"web_search",
"scrape_webpage",
"update_memory",

View file

@ -25,6 +25,7 @@ from app.agents.chat.shared.tools.web_search import create_web_search_tool
from app.db import ChatVisibility
from .scrape_webpage import create_scrape_webpage_tool
from .search_knowledge_base import create_search_knowledge_base_tool
from .update_memory import (
create_update_memory_tool,
create_update_team_memory_tool,
@ -35,6 +36,14 @@ def _build_scrape_webpage_tool(deps: dict[str, Any]) -> BaseTool:
return create_scrape_webpage_tool(firecrawl_api_key=deps.get("firecrawl_api_key"))
def _build_search_knowledge_base_tool(deps: dict[str, Any]) -> BaseTool:
return create_search_knowledge_base_tool(
search_space_id=deps["search_space_id"],
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
)
def _build_web_search_tool(deps: dict[str, Any]) -> BaseTool:
return create_web_search_tool(
search_space_id=deps.get("search_space_id"),
@ -75,6 +84,10 @@ def _build_update_memory_tool(deps: dict[str, Any]) -> BaseTool:
_MAIN_AGENT_TOOL_FACTORIES: dict[
str, tuple[Callable[[dict[str, Any]], BaseTool], tuple[str, ...]]
] = {
"search_knowledge_base": (
_build_search_knowledge_base_tool,
("search_space_id",),
),
"scrape_webpage": (_build_scrape_webpage_tool, ()),
"web_search": (_build_web_search_tool, ()),
"create_automation": (

View file

@ -0,0 +1,232 @@
"""On-demand ``search_knowledge_base`` main-agent tool (OpenCode-style lazy RAG).
The main agent no longer receives eagerly pre-injected KB context on every
turn (see :class:`KnowledgePriorityMiddleware`, now gated off by default).
Instead it calls this tool only when it decides it needs knowledge-base
content. The tool runs a single hybrid search (embed + DB search, ~0.5s),
formats the top matches for the model, and writes ``kb_matched_chunk_ids``
into graph state so matched-section highlighting is preserved when the agent
later reads a document via ``task(knowledge_base)``.
"""
from __future__ import annotations
import time
from typing import Annotated, Any
from langchain.tools import ToolRuntime
from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from sqlalchemy import select
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
search_knowledge_base as _hybrid_search_kb,
)
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
from app.agents.chat.runtime.path_resolver import (
PathIndex,
build_path_index,
doc_to_virtual_path,
)
from app.db import Document, shielded_async_session
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
_DEFAULT_TOP_K = 5
_MAX_TOP_K = 20
_PER_DOC_SNIPPET_CHARS = 1200
_MAX_TOTAL_CHARS = 16_000
_TOOL_DESCRIPTION = (
"Search the user's knowledge base (their indexed documents, files, and "
"connector content) for passages relevant to a query, using hybrid "
"semantic + keyword retrieval.\n\n"
"Use this FIRST to ground any factual or informational answer about the "
"user's own documents, notes, or connected sources. The workspace tree "
"shows which files exist; this tool pulls the actual relevant content. "
"Each hit returns the document's virtual path, a relevance score, and the "
"matched snippets. If you need a document's full text, delegate a read to "
"the knowledge_base specialist via `task` using the returned path.\n\n"
"Write a focused, specific query containing the concrete entities, "
"acronyms, people, projects, or terms you are looking for."
)
async def _resolve_virtual_paths(
results: list[dict[str, Any]],
*,
search_space_id: int,
) -> dict[int, str]:
"""Resolve ``Document.id`` -> canonical virtual path for the search hits."""
doc_ids = [
doc_id
for doc_id in (
(doc.get("document") or {}).get("id")
for doc in results
if isinstance(doc, dict)
)
if isinstance(doc_id, int)
]
if not doc_ids:
return {}
async with shielded_async_session() as session:
index: PathIndex = await build_path_index(session, search_space_id)
folder_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
paths: dict[int, str] = {}
for doc in results:
doc_meta = doc.get("document") or {}
doc_id = doc_meta.get("id")
if not isinstance(doc_id, int):
continue
folder_id = folder_by_doc_id.get(doc_id, doc_meta.get("folder_id"))
paths[doc_id] = doc_to_virtual_path(
doc_id=doc_id,
title=str(doc_meta.get("title") or "untitled"),
folder_id=folder_id if isinstance(folder_id, int) else None,
index=index,
)
return paths
def _format_hits(
results: list[dict[str, Any]],
*,
paths: dict[int, str],
query: str,
) -> str:
"""Render search hits as a compact, model-readable block."""
if not results:
return (
f"No knowledge-base matches found for query: {query!r}.\n"
"Tell the user nothing relevant was found in their workspace, or "
"try a different query."
)
lines: list[str] = [f"<knowledge_base_results query={query!r}>"]
total = len(lines[0])
for rank, doc in enumerate(results, start=1):
doc_meta = doc.get("document") or {}
doc_id = doc_meta.get("id")
title = str(doc_meta.get("title") or "untitled")
doc_type = doc_meta.get("document_type") or doc.get("source") or "document"
score = doc.get("score")
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
path = paths.get(doc_id) if isinstance(doc_id, int) else None
header = f"\n{rank}. {title} (type={doc_type}, score={score_str})" + (
f"\n path: {path}" if path else ""
)
content = (doc.get("content") or "").strip()
if content:
snippet = content[:_PER_DOC_SNIPPET_CHARS].strip()
if len(content) > _PER_DOC_SNIPPET_CHARS:
snippet += " ..."
body = "\n " + snippet.replace("\n", "\n ")
else:
body = "\n (no preview available; read the document for details)"
entry = header + body
if total + len(entry) > _MAX_TOTAL_CHARS:
lines.append("\n<!-- additional matches truncated to fit context -->")
break
lines.append(entry)
total += len(entry)
lines.append(
"\n\nTo read a full document, delegate to the knowledge_base specialist "
"with `task`, referencing the path above."
)
lines.append("\n</knowledge_base_results>")
return "".join(lines)
def _matched_chunk_ids(results: list[dict[str, Any]]) -> dict[int, list[int]]:
"""Extract ``Document.id`` -> matched chunk ids for state hand-off."""
matched: dict[int, list[int]] = {}
for doc in results:
doc_id = (doc.get("document") or {}).get("id")
if not isinstance(doc_id, int):
continue
chunk_ids = doc.get("matched_chunk_ids") or []
normalized = [int(cid) for cid in chunk_ids if isinstance(cid, int | str)]
if normalized:
matched[doc_id] = normalized
return matched
def create_search_knowledge_base_tool(
*,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
) -> BaseTool:
"""Factory for the on-demand ``search_knowledge_base`` tool."""
_space_id = search_space_id
_connectors = available_connectors
_doc_types = available_document_types
async def _impl(
query: Annotated[
str,
"Focused search query with the concrete entities/terms to look for.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
top_k: Annotated[
int,
"Maximum number of documents to return (default 5).",
] = _DEFAULT_TOP_K,
) -> Command | str:
cleaned_query = (query or "").strip()
if not cleaned_query:
return "Error: provide a non-empty search query."
clamped_top_k = min(max(1, top_k), _MAX_TOP_K)
t0 = time.perf_counter()
results = await _hybrid_search_kb(
query=cleaned_query,
search_space_id=_space_id,
available_connectors=_connectors,
available_document_types=_doc_types,
top_k=clamped_top_k,
)
paths = await _resolve_virtual_paths(results, search_space_id=_space_id)
rendered = _format_hits(results, paths=paths, query=cleaned_query)
matched = _matched_chunk_ids(results)
_perf_log.info(
"[search_knowledge_base] tool query=%r results=%d chars=%d in %.3fs",
cleaned_query[:60],
len(results),
len(rendered),
time.perf_counter() - t0,
)
update: dict[str, Any] = {
"messages": [
ToolMessage(content=rendered, tool_call_id=runtime.tool_call_id)
],
}
if matched:
update["kb_matched_chunk_ids"] = matched
return Command(update=update)
return StructuredTool.from_function(
name="search_knowledge_base",
description=_TOOL_DESCRIPTION,
coroutine=_impl,
)

View file

@ -55,6 +55,13 @@ class AgentFeatureFlags:
enable_specialized_subagents: bool = True
enable_kb_planner_runnable: bool = True
# KB retrieval mode — when False (default), the main agent retrieves KB
# content lazily via the on-demand ``search_knowledge_base`` tool and the
# expensive per-turn pre-injection (planner LLM + embed + hybrid search,
# ~2.3s) is skipped; explicit @-mentions are still surfaced cheaply. Set
# True to restore the original eager ``<priority_documents>`` pre-injection.
enable_kb_priority_preinjection: bool = False
# Snapshot / revert
enable_action_log: bool = True
enable_revert_route: bool = True
@ -71,6 +78,14 @@ class AgentFeatureFlags:
# is read from runtime.context, not the constructor closure. Rollback via
# SURFSENSE_ENABLE_AGENT_CACHE=false.
enable_agent_cache: bool = True
# Reuse one compiled graph across a returning user's *new* chats by dropping
# ``thread_id`` from the agent_cache key. Safe because every middleware/tool
# that needs the chat thread now resolves it from the live RunnableConfig
# (ActionLog, KB-persistence, deliverables) rather than a constructor
# closure, and mutation tools open fresh per-call sessions. Turns a
# returning user's cold first turn into a cache hit (cold == warm).
# Rollback via SURFSENSE_ENABLE_CROSS_THREAD_AGENT_CACHE=false.
enable_cross_thread_agent_cache: bool = True
# Deferred: only helps on outer-cache MISSES, so off until data shows cold
# misses are frequent enough to justify the extra global state.
enable_agent_cache_share_gp_subagent: bool = False
@ -104,11 +119,14 @@ class AgentFeatureFlags:
enable_skills=False,
enable_specialized_subagents=False,
enable_kb_planner_runnable=False,
# Full rollback restores the original eager KB pre-injection.
enable_kb_priority_preinjection=True,
enable_action_log=False,
enable_revert_route=False,
enable_plugin_loader=False,
enable_otel=False,
enable_agent_cache=False,
enable_cross_thread_agent_cache=False,
enable_agent_cache_share_gp_subagent=False,
)
@ -141,6 +159,9 @@ class AgentFeatureFlags:
enable_kb_planner_runnable=_env_bool(
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
),
enable_kb_priority_preinjection=_env_bool(
"SURFSENSE_ENABLE_KB_PRIORITY_PREINJECTION", False
),
# Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
@ -150,6 +171,9 @@ class AgentFeatureFlags:
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
# Performance
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
enable_cross_thread_agent_cache=_env_bool(
"SURFSENSE_ENABLE_CROSS_THREAD_AGENT_CACHE", True
),
enable_agent_cache_share_gp_subagent=_env_bool(
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
),

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import time as _perf_time
from typing import Any
from deepagents import FilesystemMiddleware
@ -14,6 +15,7 @@ from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox impor
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
SurfSenseFilesystemState,
)
from app.utils.perf import get_perf_logger
from ..system_prompt import build_system_prompt
from ..tools import (
@ -34,6 +36,8 @@ from ..tools.glob.description import select_description as glob_description
from ..tools.grep.description import select_description as grep_description
from .read_only_policy import READ_ONLY_TOOL_NAMES
_perf_log = get_perf_logger()
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
@ -60,16 +64,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
is_sandbox_enabled() and thread_id is not None and not read_only
)
_t0 = _perf_time.perf_counter()
system_prompt = build_system_prompt(
filesystem_mode,
sandbox_available=self._sandbox_available,
)
_t_prompt = _perf_time.perf_counter() - _t0
_t0 = _perf_time.perf_counter()
super().__init__(
backend=backend,
system_prompt=system_prompt,
tool_token_limit_before_evict=tool_token_limit_before_evict,
)
_t_super = _perf_time.perf_counter() - _t0
_t0 = _perf_time.perf_counter()
self.tools = [t for t in self.tools if t.name != "execute"]
self.tools.append(create_mkdir_tool(self))
self.tools.append(create_cd_tool(self))
@ -83,6 +93,15 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
if read_only:
self.tools = [t for t in self.tools if t.name in READ_ONLY_TOOL_NAMES]
_t_tools = _perf_time.perf_counter() - _t0
_perf_log.info(
"[fs_middleware_init] ro=%s system_prompt=%.3fs super_init=%.3fs "
"surf_tools=%.3fs",
read_only,
_t_prompt,
_t_super,
_t_tools,
)
# ----------------------------------------- base-class tool overrides

View file

@ -624,6 +624,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
top_k: int = 10,
mentioned_document_ids: list[int] | None = None,
inject_system_message: bool = True, # For backwards compatibility
mentions_only: bool = False,
) -> None:
self.llm = llm
# Cheap model for structured internal tasks (query rewrite, date
@ -637,6 +638,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
self.inject_system_message = inject_system_message
# Lazy mode: skip the planner LLM + embedding + hybrid search and only
# surface explicit @-mentions. The agent retrieves topical KB content on
# demand via the ``search_knowledge_base`` tool instead.
self.mentions_only = mentions_only
# Compiled lazily and memoized to avoid the per-turn create_agent cost.
self._planner: Runnable | None = None
self._planner_compile_failed = False
@ -825,15 +830,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
runtime: Runtime[Any] | None = None,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
(
planned_query,
start_date,
end_date,
is_recency,
) = await self._plan_search_inputs(
messages=messages,
user_text=user_text,
)
# Prefer per-turn mentions from runtime.context (lets a cached graph
# serve different turns); fall back to the constructor closure, draining
@ -864,6 +860,52 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
if ctx_folders:
folder_mention_ids = list(ctx_folders)
# Lazy mode: skip the planner LLM + embedding + hybrid search entirely.
# With no explicit mentions there is nothing cheap to surface, so we bail
# out early and let the agent decide to call ``search_knowledge_base``.
if self.mentions_only:
if not mention_ids and not folder_mention_ids:
return None
planned_query = user_text
start_date = end_date = None
is_recency = False
search_results: list[dict[str, Any]] = []
_search_phase_elapsed = 0.0
else:
(
planned_query,
start_date,
end_date,
is_recency,
) = await self._plan_search_inputs(
messages=messages,
user_text=user_text,
)
_t_search_phase = time.perf_counter()
if is_recency:
doc_types = _resolve_search_types(
self.available_connectors, self.available_document_types
)
search_results = await browse_recent_documents(
search_space_id=self.search_space_id,
document_type=doc_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
else:
search_results = await search_knowledge_base(
query=planned_query,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
_search_phase_elapsed = time.perf_counter() - _t_search_phase
mentioned_results: list[dict[str, Any]] = []
if mention_ids:
mentioned_results = await fetch_mentioned_documents(
@ -871,30 +913,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
search_space_id=self.search_space_id,
)
_t_search_phase = time.perf_counter()
if is_recency:
doc_types = _resolve_search_types(
self.available_connectors, self.available_document_types
)
search_results = await browse_recent_documents(
search_space_id=self.search_space_id,
document_type=doc_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
else:
search_results = await search_knowledge_base(
query=planned_query,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
_search_phase_elapsed = time.perf_counter() - _t_search_phase
seen_doc_ids: set[int] = set()
merged: list[dict[str, Any]] = []
for doc in mentioned_results:

View file

@ -60,6 +60,10 @@ TOOL_CATALOG: list[ToolMetadata] = [
name="generate_image",
description="Generate images from text descriptions using AI image models",
),
ToolMetadata(
name="search_knowledge_base",
description="Search the user's knowledge base with hybrid semantic + keyword retrieval",
),
ToolMetadata(
name="scrape_webpage",
description="Scrape and extract the main content from a webpage",

View file

@ -21,6 +21,9 @@ from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receip
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import (
wait_for_deliverable,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
resolve_root_thread_id,
)
from app.db import Podcast, PodcastStatus, shielded_async_session
logger = logging.getLogger(__name__)
@ -71,7 +74,7 @@ def create_generate_podcast_tool(
title=podcast_title,
status=PodcastStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
thread_id=resolve_root_thread_id(runtime, thread_id),
)
session.add(podcast)
await session.commit()

View file

@ -14,6 +14,9 @@ from langgraph.types import Command
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
resolve_root_thread_id,
)
from app.db import Report, shielded_async_session
from app.services.connector_service import ConnectorService
from app.services.llm_service import get_agent_llm
@ -687,7 +690,7 @@ def create_generate_report_tool(
},
report_style=report_style,
search_space_id=search_space_id,
thread_id=thread_id,
thread_id=resolve_root_thread_id(runtime, thread_id),
report_group_id=report_group_id,
)
session.add(failed_report)
@ -991,7 +994,7 @@ def create_generate_report_tool(
report_metadata=metadata,
report_style=report_style,
search_space_id=search_space_id,
thread_id=thread_id,
thread_id=resolve_root_thread_id(runtime, thread_id),
report_group_id=report_group_id,
)
write_session.add(report)

View file

@ -16,6 +16,9 @@ from langgraph.types import Command
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
resolve_root_thread_id,
)
from app.db import Report, shielded_async_session
from app.services.llm_service import get_agent_llm
@ -529,7 +532,7 @@ def create_generate_resume_tool(
},
report_style="resume",
search_space_id=search_space_id,
thread_id=thread_id,
thread_id=resolve_root_thread_id(runtime, thread_id),
report_group_id=report_group_id,
)
session.add(failed)
@ -817,7 +820,7 @@ def create_generate_resume_tool(
report_metadata=metadata,
report_style="resume",
search_space_id=search_space_id,
thread_id=thread_id,
thread_id=resolve_root_thread_id(runtime, thread_id),
report_group_id=report_group_id,
)
write_session.add(report)

View file

@ -0,0 +1,39 @@
"""Resolve the root chat ``thread_id`` from a deliverables tool's runtime.
Deliverables tools run inside the ``deliverables`` subagent, which is invoked
with a *namespaced* ``thread_id`` of the form ``{chat_id}::task:{tool_call_id}``
(see :func:`subagent_invoke_config`). To attribute a generated deliverable
(podcast / report / resume / video) to the correct chat, we parse the leading
segment of that namespaced id rather than trusting a ``thread_id`` captured at
tool-build time the latter would be stale once a single compiled agent graph
is reused across chats (cross-thread ``agent_cache`` reuse).
"""
from __future__ import annotations
from langchain.tools import ToolRuntime
def resolve_root_thread_id(runtime: ToolRuntime, fallback: int | None) -> int | None:
"""Return the root chat id from the live runtime config, else ``fallback``.
The subagent's ``configurable.thread_id`` looks like ``"2099::task:call_x"``;
the chat id is the segment before the first ``"::"``. Returns ``fallback``
when the config is absent or the leading segment is not an integer.
"""
try:
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return fallback
value = (config.get("configurable") or {}).get("thread_id")
if isinstance(value, int):
return value
if isinstance(value, str) and value:
root = value.split("::", 1)[0]
try:
return int(root)
except (TypeError, ValueError):
return fallback
except Exception: # pragma: no cover - defensive
return fallback
return fallback

View file

@ -22,6 +22,9 @@ from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receip
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import (
wait_for_deliverable,
)
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
resolve_root_thread_id,
)
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
logger = logging.getLogger(__name__)
@ -58,7 +61,7 @@ def create_generate_video_presentation_tool(
title=video_title,
status=VideoPresentationStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
thread_id=resolve_root_thread_id(runtime, thread_id),
)
session.add(video_pres)
await session.commit()

View file

@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Annotated
from langchain.tools import BaseTool, ToolRuntime
@ -39,7 +40,28 @@ def _wrap_result(result: dict, tool_call_id: str) -> Command:
)
def build_ask_knowledge_base_tool(kb_readonly_runnable: Runnable) -> BaseTool:
def build_ask_knowledge_base_tool(
kb_readonly: Runnable | Callable[[], Runnable],
) -> BaseTool:
"""Build the ``ask_knowledge_base`` tool backed by the read-only KB graph.
``kb_readonly`` may be a pre-compiled ``Runnable`` or a zero-arg factory
that compiles it on first use. Passing a factory defers the ~0.3-0.8s
``create_agent`` cost of the read-only knowledge_base graph until a subagent
actually calls ``ask_knowledge_base``, keeping it off the cold agent-build
(time-to-first-token) path. The factory result is memoized.
"""
_cache: dict[str, Runnable] = {}
def _resolve() -> Runnable:
if not callable(kb_readonly) or isinstance(kb_readonly, Runnable):
return kb_readonly # type: ignore[return-value]
cached = _cache.get("runnable")
if cached is None:
cached = kb_readonly()
_cache["runnable"] = cached
return cached
def ask_knowledge_base(
query: Annotated[
str,
@ -52,7 +74,7 @@ def build_ask_knowledge_base_tool(kb_readonly_runnable: Runnable) -> BaseTool:
raise ValueError("Tool call ID is required for ask_knowledge_base")
sub_state = _forward_state(runtime, query)
sub_config = subagent_invoke_config(runtime)
result = kb_readonly_runnable.invoke(sub_state, config=sub_config)
result = _resolve().invoke(sub_state, config=sub_config)
return _wrap_result(result, runtime.tool_call_id)
async def aask_knowledge_base(
@ -67,7 +89,7 @@ def build_ask_knowledge_base_tool(kb_readonly_runnable: Runnable) -> BaseTool:
raise ValueError("Tool call ID is required for ask_knowledge_base")
sub_state = _forward_state(runtime, query)
sub_config = subagent_invoke_config(runtime)
result = await kb_readonly_runnable.ainvoke(sub_state, config=sub_config)
result = await _resolve().ainvoke(sub_state, config=sub_config)
return _wrap_result(result, runtime.tool_call_id)
return StructuredTool.from_function(

View file

@ -6,6 +6,7 @@ The KB-owned :class:`PermissionMiddleware` slot is what enforces
from __future__ import annotations
import time as _perf_time
from typing import Any
from langchain_core.language_models import BaseChatModel
@ -31,6 +32,9 @@ from app.agents.chat.multi_agent_chat.shared.permissions import (
Ruleset,
build_permission_mw,
)
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
def _kb_user_allowlist(
@ -93,25 +97,62 @@ def build_kb_middleware(
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
if user_allowlist is not None:
rulesets.append(user_allowlist)
_t0 = _perf_time.perf_counter()
permission_mw = build_permission_mw(
flags=flags,
subagent_rulesets=rulesets,
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
)
_t_perm = _perf_time.perf_counter() - _t0
else:
_t_perm = 0.0
_t0 = _perf_time.perf_counter()
kb_ctx_mw = build_kb_context_projection_mw()
_t_ctx = _perf_time.perf_counter() - _t0
_t0 = _perf_time.perf_counter()
fs_mw = build_filesystem_mw(
backend_resolver=dependencies["backend_resolver"],
filesystem_mode=filesystem_mode,
search_space_id=dependencies["search_space_id"],
user_id=dependencies.get("user_id"),
thread_id=dependencies.get("thread_id"),
read_only=read_only,
)
_t_fs = _perf_time.perf_counter() - _t0
_t0 = _perf_time.perf_counter()
compaction_mw = build_compaction_mw(llm)
_t_comp = _perf_time.perf_counter() - _t0
_t0 = _perf_time.perf_counter()
patch_mw = build_patch_tool_calls_mw()
_t_patch = _perf_time.perf_counter() - _t0
_t0 = _perf_time.perf_counter()
cache_mw = build_anthropic_cache_mw()
_t_cache = _perf_time.perf_counter() - _t0
_perf_log.info(
"[kb_middleware] name=%s ro=%s ctx=%.3fs filesystem=%.3fs "
"compaction=%.3fs patch=%.3fs anthropic_cache=%.3fs permission=%.3fs",
subagent_name,
read_only,
_t_ctx,
_t_fs,
_t_comp,
_t_patch,
_t_cache,
_t_perm,
)
return [
mws["todos"],
build_kb_context_projection_mw(),
build_filesystem_mw(
backend_resolver=dependencies["backend_resolver"],
filesystem_mode=filesystem_mode,
search_space_id=dependencies["search_space_id"],
user_id=dependencies.get("user_id"),
thread_id=dependencies.get("thread_id"),
read_only=read_only,
),
build_compaction_mw(llm),
build_patch_tool_calls_mw(),
kb_ctx_mw,
fs_mw,
compaction_mw,
patch_mw,
*([permission_mw] if permission_mw is not None else []),
*resilience_mws,
build_anthropic_cache_mw(),
cache_mw,
]

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import time as _perf_time
from typing import Any, Protocol
from deepagents import SubAgent
@ -72,6 +73,9 @@ from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import (
read_md_file,
)
from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec
from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
class SubagentBuilder(Protocol):
@ -192,19 +196,25 @@ def build_subagents(
if exclude:
excluded.extend(exclude)
disabled_names = frozenset(disabled_tools or ())
_timings: list[tuple[str, float]] = []
for name in sorted(SUBAGENT_BUILDERS_BY_NAME):
if name in excluded:
continue
builder = SUBAGENT_BUILDERS_BY_NAME[name]
_t0 = _perf_time.perf_counter()
result = builder(
dependencies=dependencies,
model=model,
middleware_stack=middleware_stack,
mcp_tools=mcp.get(name),
)
_timings.append((name, _perf_time.perf_counter() - _t0))
spec = result.spec
_filter_disabled_tools_in_place(spec, disabled_names)
if ask_kb_tool is not None:
_inject_ask_kb_tool_in_place(spec, ask_kb_tool)
specs.append(spec)
if _timings:
_detail = " ".join(f"{n}={dt:.3f}s" for n, dt in _timings)
_perf_log.info("[build_subagents.detail] %s", _detail)
return specs

View file

@ -26,6 +26,16 @@ ContextHintProvider = Callable[[Mapping[str, Any], str], str | None]
# The prefix avoids any collision with future deepagents fields.
SURF_CONTEXT_HINT_PROVIDER_KEY = "surf_context_hint_provider"
# Custom key carrying a zero-arg callable that builds the full deepagents
# ``SubAgent`` spec dict on demand. A descriptor dict carrying only
# ``name`` / ``description`` / this key lets the checkpointed subagent
# middleware register a subagent's catalog entry cheaply while deferring the
# expensive spec construction (e.g. the knowledge_base filesystem middleware,
# which builds ~13 tool schemas at ~150ms each) until the first
# ``task(name)`` call. Most turns never invoke a subagent, so this keeps the
# cost off the cold agent-build / time-to-first-token path.
SURF_LAZY_SPEC_FACTORY_KEY = "surf_lazy_spec_factory"
@dataclass(frozen=True, slots=True)
class SurfSenseSubagentSpec:
@ -54,6 +64,7 @@ class SurfSenseSubagentSpec:
__all__ = [
"SURF_CONTEXT_HINT_PROVIDER_KEY",
"SURF_LAZY_SPEC_FACTORY_KEY",
"ContextHintProvider",
"SurfSenseSubagentSpec",
]

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import re
import time as _perf_time
from typing import Any, cast
from deepagents import SubAgent
@ -23,8 +24,10 @@ from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
ContextHintProvider,
SurfSenseSubagentSpec,
)
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# ``<include snippet="NAME"/>`` directive. Matches an XML-style self-closing
# tag whose ``snippet`` attribute names a file in ``shared/snippets/``.
@ -110,19 +113,31 @@ def pack_subagent(
msg = f"Subagent {name!r}: system_prompt is empty"
raise ValueError(msg)
_t0 = _perf_time.perf_counter()
system_prompt = _resolve_includes(system_prompt, subagent_name=name)
_t_resolve = _perf_time.perf_counter() - _t0
flags = dependencies["flags"]
user_allowlist = _user_allowlist_for(dependencies, name)
subagent_rulesets: list[Ruleset] = [ruleset]
if user_allowlist is not None:
subagent_rulesets.append(user_allowlist)
_t0 = _perf_time.perf_counter()
per_subagent_perm = build_permission_mw(
flags=flags,
subagent_rulesets=subagent_rulesets,
tools=tools,
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
)
_t_perm = _perf_time.perf_counter() - _t0
_perf_log.info(
"[pack_subagent] name=%s tools=%d resolve_includes=%.3fs "
"build_permission_mw=%.3fs",
name,
len(tools),
_t_resolve,
_t_perm,
)
prepended: list[Any] = []
for slot, mw in (middleware_stack or {}).items():

View file

@ -571,6 +571,41 @@ async def _warm_agent_jit_caches() -> None:
)
async def _warm_embedding_model() -> None:
"""Pre-load/JIT the embedding model so the first KB search is fast.
With lazy KB retrieval (OpenCode-style), the main agent no longer embeds
on every turn it calls the on-demand ``search_knowledge_base`` tool only
when it needs KB content, and that tool's first ``embed_texts`` call in a
fresh process pays the model's one-time load/JIT (local sentence-transformer
warm or API client init). Doing one throwaway embed at startup moves that
cost off the first real search.
Safety: behind the embedding global lock (run in a worker thread), bounded
by the caller's ``asyncio.wait_for``, and non-fatal — on any failure we log
and swallow so the worst case is the first real search pays the cold cost.
"""
import time as _time
logger = logging.getLogger(__name__)
t0 = _time.perf_counter()
try:
from app.utils.document_converters import embed_texts
await asyncio.to_thread(embed_texts, ["warmup"])
logger.info(
"[startup] Embedding model warmup completed in %.3fs",
_time.perf_counter() - t0,
)
except Exception:
logger.warning(
"[startup] Embedding model warmup failed in %.3fs (non-fatal — first "
"KB search will pay the cold embed cost)",
_time.perf_counter() - t0,
exc_info=True,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
@ -601,6 +636,16 @@ async def lifespan(app: FastAPI):
"first real request will pay the full compile cost."
)
# Phase 2 — embedding warmup so the first lazy ``search_knowledge_base``
# call doesn't pay the cold embed-model load. Bounded + non-fatal.
try:
await asyncio.wait_for(asyncio.shield(_warm_embedding_model()), timeout=20)
except (TimeoutError, Exception): # pragma: no cover - defensive
logging.getLogger(__name__).warning(
"[startup] Embedding warmup hit timeout/error — skipping; "
"first KB search will pay the cold embed cost."
)
register_session_hooks()
log_system_snapshot("startup_complete")
await start_gateway_inbox_worker()

View file

@ -12,7 +12,7 @@ import { schema } from "@/zero/schema";
// container and would make every authenticated Zero query fail with a 503.
const backendURL = (
process.env.FASTAPI_BACKEND_INTERNAL_URL ||
BACKEND_URL ||
process.env.BACKEND_URL ||
"http://localhost:8000"
).replace(/\/$/, "");