feat: implement agent caches and fix invalid prompt cache configs
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

- Added a new function `_warm_agent_jit_caches` to pre-warm agent caches at startup, reducing cold invocation costs.
- Updated the `SurfSenseContextSchema` to include per-invocation fields for better state management during agent execution.
- Introduced caching mechanisms in various tools to ensure fresh database sessions are used, improving performance and reliability.
- Enhanced middleware to support new context features and improve error handling during connector and document type discovery.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-03 06:03:40 -07:00
parent 90a653c8c7
commit a34f1fb25c
60 changed files with 8477 additions and 5381 deletions

View file

@ -31,6 +31,7 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
@ -559,6 +560,29 @@ async def _preflight_llm(llm: Any) -> None:
)
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
"""Wait for a discarded speculative agent build to release shared state.
Used by the parallel preflight + agent-build path. The speculative build
closes over the request-scoped ``AsyncSession`` (for the brief connector
discovery / tool-factory window before its CPU work moves into a worker
thread). If preflight reports a 429 we want to fall back to the original
repin reload rebuild path, but we MUST NOT touch ``session`` again
until any in-flight session work owned by the speculative build has
fully settled :class:`sqlalchemy.ext.asyncio.AsyncSession` is not
concurrency-safe and the same hazard cost us a hard ``InvalidRequestError``
earlier in this PR (see ``connector_service`` parallel-gather revert).
We simply ``await`` the task and swallow any exception: in this path the
build's outcome is irrelevant — success populates the agent cache (a free
side effect), failure is discarded. The wasted CPU is acceptable since
429 fallbacks are rare and the original sequential code also paid the
full build cost on the same path.
"""
with contextlib.suppress(BaseException):
await task
def _classify_stream_exception(
exc: Exception,
*,
@ -696,6 +720,7 @@ async def _stream_agent_events(
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
runtime_context: Any = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@ -801,7 +826,18 @@ async def _stream_agent_events(
return event
return None
async for event in agent.astream_events(input_data, config=config, version="v2"):
# Per-invocation runtime context (Phase 1.5). When supplied,
# ``KnowledgePriorityMiddleware`` reads ``mentioned_document_ids``
# from ``runtime.context`` instead of its constructor closure — the
# prerequisite that lets the compiled-agent cache (Phase 1) reuse a
# single graph across turns. Astream_events_kwargs stays empty when
# callers leave ``runtime_context`` as ``None`` to preserve the
# legacy code path bit-for-bit.
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
if runtime_context is not None:
astream_kwargs["context"] = runtime_context
async for event in agent.astream_events(input_data, **astream_kwargs):
event_type = event.get("event", "")
if event_type == "on_chat_model_stream":
@ -2560,23 +2596,102 @@ async def stream_new_chat(
# Detecting a 429 here lets us repin BEFORE the planner/classifier/
# title-generation LLM calls fan out and each independently hit the
# same upstream rate limit.
if (
#
# PERF: preflight is a network round-trip to the LLM provider (~1-5s)
# and is independent of the agent build (CPU-bound, ~5-7s). They used
# to run sequentially → ``preflight + build`` on cold cache = 11.5s.
# We now kick off preflight as a background task FIRST, then run the
# synchronous setup work and the agent build in parallel. In the
# success path (the common case) total wall time drops to roughly
# ``max(preflight, build)`` — the preflight finishes during the
# agent compile and we just consume its result. In the rare 429
# path the speculative build is awaited to completion (so its
# session usage is fully released) via
# :func:`_settle_speculative_agent_build`, then discarded, and
# we fall back to the original repin-and-rebuild flow.
preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
):
)
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight:{llm_config_id}",
)
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
# Speculative agent build — runs in parallel with the preflight
# task (if any). Built with the *current* ``llm`` / ``agent_config``;
# if preflight reports 429 we will discard this future and rebuild
# against the freshly pinned config below.
agent_build_task = asyncio.create_task(
create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
),
name="agent_build:stream_new_chat",
)
agent: Any = None
if preflight_task is not None:
try:
await _preflight_llm(llm)
await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs",
"[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
# Both branches below need the session: the non-429 path
# may unwind via cleanup that uses ``session``, and the
# 429 path explicitly calls ``resolve_or_get_pinned_llm_config_id``
# against it. Wait for the speculative build to release its
# session usage before we proceed.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
# 429: speculative agent is discarded; run the original
# repin → reload → rebuild path against the freshly
# pinned config.
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id, reason="preflight_rate_limited"
@ -2639,46 +2754,28 @@ async def stream_new_chat(
"fallback_config_id": llm_config_id,
},
)
# Rebuild against the new llm/agent_config. Sequential
# here because we no longer have anything to overlap with.
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
if agent is None:
# Either no preflight was needed, or preflight succeeded —
# in both cases the speculative build is the agent we want.
agent = await agent_build_task
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
@ -3005,6 +3102,18 @@ async def stream_new_chat(
title_emitted = False
# Build the per-invocation runtime context (Phase 1.5).
# ``mentioned_document_ids`` is read by ``KnowledgePriorityMiddleware``
# via ``runtime.context.mentioned_document_ids`` instead of its
# ``__init__`` closure — that way the same compiled-agent instance
# can serve multiple turns with different mention lists.
runtime_context = SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
request_id=request_id,
turn_id=stream_result.turn_id,
)
_t_stream_start = time.perf_counter()
_first_event_logged = False
runtime_rate_limit_recovered = False
@ -3028,6 +3137,7 @@ async def stream_new_chat(
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
):
if not _first_event_logged:
_perf_log.info(
@ -3643,21 +3753,75 @@ async def stream_resume_chat(
# Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``:
# one cheap probe before the agent is rebuilt so a 429'd pin gets
# repinned without burning planner/classifier/title calls first.
if (
# See ``stream_new_chat`` for the full rationale on the speculative
# parallel build pattern below.
preflight_needed = (
requested_llm_config_id == 0
and llm_config_id < 0
and not is_recently_healthy(llm_config_id)
):
)
preflight_task: asyncio.Task[None] | None = None
_t_preflight = 0.0
if preflight_needed:
_t_preflight = time.perf_counter()
preflight_task = asyncio.create_task(
_preflight_llm(llm),
name=f"auto_pin_preflight_resume:{llm_config_id}",
)
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent_build_task = asyncio.create_task(
create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
),
name="agent_build:stream_resume",
)
agent: Any = None
if preflight_task is not None:
try:
await _preflight_llm(llm)
await preflight_task
mark_healthy(llm_config_id)
_perf_log.info(
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs",
"[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs (parallel)",
llm_config_id,
time.perf_counter() - _t_preflight,
)
except Exception as preflight_exc:
# Same session-safety rationale as ``stream_new_chat``.
await _settle_speculative_agent_build(agent_build_task)
if not _is_provider_rate_limited(preflight_exc):
raise
previous_config_id = llm_config_id
@ -3717,43 +3881,22 @@ async def stream_resume_chat(
"fallback_config_id": llm_config_id,
},
)
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
if agent is None:
agent = await agent_build_task
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
@ -3794,6 +3937,16 @@ async def stream_resume_chat(
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Resume path doesn't carry new ``mentioned_document_ids`` —
# those are seeded in the original turn. We still pass a
# context so future middleware extensions (Phase 2) can rely on
# ``runtime.context`` always being populated.
runtime_context = SurfSenseContextSchema(
search_space_id=search_space_id,
request_id=request_id,
turn_id=stream_result.turn_id,
)
_t_stream_start = time.perf_counter()
_first_event_logged = False
runtime_rate_limit_recovered = False
@ -3814,6 +3967,7 @@ async def stream_resume_chat(
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
):
if not _first_event_logged:
_perf_log.info(