diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 5fcb8236d..241c4f343 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -6,6 +6,9 @@ with configurable tools via the tools registry and configurable prompts via NewLLMConfig. """ +import asyncio +import logging +import time from collections.abc import Sequence from typing import Any @@ -26,6 +29,8 @@ from app.agents.new_chat.tools.registry import build_tools_async from app.db import ChatVisibility from app.services.connector_service import ConnectorService +_perf_log = logging.getLogger("surfsense.perf") + # ============================================================================= # Connector Type Mapping # ============================================================================= @@ -210,29 +215,29 @@ async def create_surfsense_deep_agent( additional_tools=[my_custom_tool] ) """ + _t_agent_total = time.perf_counter() + # Discover available connectors and document types for this search space - # This enables dynamic tool docstrings that inform the LLM about what's actually available available_connectors: list[str] | None = None available_document_types: list[str] | None = None + _t0 = time.perf_counter() try: - # Get enabled search source connectors for this search space connector_types = await connector_service.get_available_connectors( search_space_id ) if connector_types: - # Convert enum values to strings and also include mapped document types available_connectors = _map_connectors_to_searchable_types(connector_types) - # Get document types that have at least one document indexed available_document_types = await connector_service.get_available_document_types( search_space_id ) except Exception as e: - # Log but don't fail - fall back to all connectors if discovery fails - import logging - logging.warning(f"Failed to discover available connectors/document types: {e}") + _perf_log.info( + "[create_agent] Connector/doc-type discovery in %.3fs", + time.perf_counter() - _t0, + ) # Build dependencies dict for the tools registry visibility = thread_visibility or ChatVisibility.PRIVATE @@ -274,14 +279,21 @@ async def create_surfsense_deep_agent( modified_disabled_tools.extend(linear_tools) # Build tools using the async registry (includes MCP tools) + _t0 = time.perf_counter() tools = await build_tools_async( dependencies=dependencies, enabled_tools=enabled_tools, disabled_tools=modified_disabled_tools, additional_tools=list(additional_tools) if additional_tools else None, ) + _perf_log.info( + "[create_agent] build_tools_async in %.3fs (%d tools)", + time.perf_counter() - _t0, + len(tools), + ) # Build system prompt based on agent_config + _t0 = time.perf_counter() _sandbox_enabled = sandbox_backend is not None if agent_config is not None: system_prompt = build_configurable_system_prompt( @@ -296,15 +308,18 @@ async def create_surfsense_deep_agent( thread_visibility=thread_visibility, sandbox_enabled=_sandbox_enabled, ) + _perf_log.info( + "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 + ) # Build optional kwargs for the deep agent deep_agent_kwargs: dict[str, Any] = {} if sandbox_backend is not None: deep_agent_kwargs["backend"] = sandbox_backend - # Create the deep agent with system prompt and checkpointer - # Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent - agent = create_deep_agent( + _t0 = time.perf_counter() + agent = await asyncio.to_thread( + create_deep_agent, model=llm, tools=tools, system_prompt=system_prompt, @@ -312,5 +327,13 @@ async def create_surfsense_deep_agent( checkpointer=checkpointer, **deep_agent_kwargs, ) + _perf_log.info( + "[create_agent] Graph compiled (create_deep_agent) in %.3fs", + time.perf_counter() - _t0, + ) + _perf_log.info( + "[create_agent] Total agent creation in %.3fs", + time.perf_counter() - _t_agent_total, + ) return agent diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/new_chat/sandbox.py index 24b380b0b..7696f67f2 100644 --- a/surfsense_backend/app/agents/new_chat/sandbox.py +++ b/surfsense_backend/app/agents/new_chat/sandbox.py @@ -12,6 +12,7 @@ the sandbox is deleted so they remain downloadable after cleanup. from __future__ import annotations import asyncio +import contextlib import logging import os import shutil @@ -56,6 +57,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox): _daytona_client: Daytona | None = None +_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {} THREAD_LABEL_KEY = "surfsense_thread" @@ -126,8 +128,8 @@ def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox: async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox: """Get or create a sandbox for a conversation thread. - Uses the thread_id as a label so the same sandbox persists - across multiple messages within the same conversation. + Uses an in-process cache keyed by thread_id so subsequent messages + in the same conversation reuse the sandbox object without an API call. Args: thread_id: The conversation thread identifier. @@ -135,11 +137,19 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox: Returns: DaytonaSandbox connected to the sandbox. """ - return await asyncio.to_thread(_find_or_create, str(thread_id)) + key = str(thread_id) + cached = _sandbox_cache.get(key) + if cached is not None: + logger.info("Reusing cached sandbox for thread %s", key) + return cached + sandbox = await asyncio.to_thread(_find_or_create, key) + _sandbox_cache[key] = sandbox + return sandbox async def delete_sandbox(thread_id: int | str) -> None: """Delete the sandbox for a conversation thread.""" + _sandbox_cache.pop(str(thread_id), None) def _delete() -> None: client = _get_client() @@ -147,7 +157,9 @@ async def delete_sandbox(thread_id: int | str) -> None: try: sandbox = client.find_one(labels=labels) except DaytonaError: - logger.debug("No sandbox to delete for thread %s (already removed)", thread_id) + logger.debug( + "No sandbox to delete for thread %s (already removed)", thread_id + ) return try: client.delete(sandbox) @@ -166,6 +178,7 @@ async def delete_sandbox(thread_id: int | str) -> None: # Local file persistence # --------------------------------------------------------------------------- + def _get_sandbox_files_dir() -> Path: return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files")) @@ -206,6 +219,7 @@ async def persist_and_delete_sandbox( Per-file errors are logged but do **not** prevent the sandbox from being deleted — freeing Daytona storage is the priority. """ + _sandbox_cache.pop(str(thread_id), None) def _persist_and_delete() -> None: client = _get_client() @@ -229,10 +243,8 @@ async def persist_and_delete_sandbox( sandbox.id, exc_info=True, ) - try: + with contextlib.suppress(Exception): client.delete(sandbox) - except Exception: - pass return for path in sandbox_file_paths: diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 5ccd2e749..20cf3ec33 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation. """ import logging +import time from typing import Any from langchain_core.tools import StructuredTool @@ -25,6 +26,9 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType logger = logging.getLogger(__name__) +_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes +_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} + def _create_dynamic_input_model_from_schema( tool_name: str, @@ -355,6 +359,19 @@ async def _load_http_mcp_tools( return tools +def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: + """Invalidate cached MCP tools. + + Args: + search_space_id: If provided, only invalidate for this search space. + If None, invalidate all cached MCP tools. + """ + if search_space_id is not None: + _mcp_tools_cache.pop(search_space_id, None) + else: + _mcp_tools_cache.clear() + + async def load_mcp_tools( session: AsyncSession, search_space_id: int, @@ -364,6 +381,9 @@ async def load_mcp_tools( This discovers tools dynamically from MCP servers using the protocol. Supports both stdio (local process) and HTTP (remote server) transports. + Results are cached per search space for up to 5 minutes to avoid + re-spawning MCP server processes on every chat message. + Args: session: Database session search_space_id: User's search space ID @@ -372,8 +392,20 @@ async def load_mcp_tools( List of LangChain StructuredTool instances """ + now = time.monotonic() + cached = _mcp_tools_cache.get(search_space_id) + if cached is not None: + cached_at, cached_tools = cached + if now - cached_at < _MCP_CACHE_TTL_SECONDS: + logger.info( + "Using cached MCP tools for search space %s (%d tools, age=%.0fs)", + search_space_id, + len(cached_tools), + now - cached_at, + ) + return list(cached_tools) + try: - # Fetch all MCP connectors for this search space result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.connector_type @@ -385,27 +417,22 @@ async def load_mcp_tools( tools: list[StructuredTool] = [] for connector in result.scalars(): try: - # Early validation: Extract and validate connector config config = connector.config or {} server_config = config.get("server_config", {}) - # Validate server_config exists and is a dict if not server_config or not isinstance(server_config, dict): logger.warning( f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping" ) continue - # Determine transport type transport = server_config.get("transport", "stdio") if transport in ("streamable-http", "http", "sse"): - # HTTP-based MCP server connector_tools = await _load_http_mcp_tools( connector.id, connector.name, server_config ) else: - # stdio-based MCP server (default) connector_tools = await _load_stdio_mcp_tools( connector.id, connector.name, server_config ) @@ -417,6 +444,7 @@ async def load_mcp_tools( f"Failed to load tools from MCP connector {connector.id}: {e!s}" ) + _mcp_tools_cache[search_space_id] = (now, tools) logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") return tools diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index dffed5e86..59efc2efb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -444,8 +444,18 @@ async def build_tools_async( List of configured tool instances ready for the agent, including MCP tools. """ - # Build standard tools + import time + + _perf_log = logging.getLogger("surfsense.perf") + _perf_log.setLevel(logging.DEBUG) + + _t0 = time.perf_counter() tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + _perf_log.info( + "[build_tools_async] Built-in tools in %.3fs (%d tools)", + time.perf_counter() - _t0, + len(tools), + ) # Load MCP tools if requested and dependencies are available if ( @@ -454,10 +464,16 @@ async def build_tools_async( and "search_space_id" in dependencies ): try: + _t0 = time.perf_counter() mcp_tools = await load_mcp_tools( dependencies["db_session"], dependencies["search_space_id"], ) + _perf_log.info( + "[build_tools_async] MCP tools loaded in %.3fs (%d tools)", + time.perf_counter() - _t0, + len(mcp_tools), + ) tools.extend(mcp_tools) logging.info( f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}", diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 72475e3a7..0a549abe5 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -175,8 +175,39 @@ def rate_limit_password_reset(request: Request): ) +def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None: + """Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*. + + This helps pinpoint synchronous code that freezes the entire FastAPI server. + Only active when the PERF_DEBUG env var is set (to avoid overhead in production). + """ + import os + + if not os.environ.get("PERF_DEBUG"): + return + + _slow_log = logging.getLogger("surfsense.perf.slow") + _slow_log.setLevel(logging.WARNING) + if not _slow_log.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s")) + _slow_log.addHandler(_h) + _slow_log.propagate = False + + loop = asyncio.get_running_loop() + loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined] + loop.set_debug(True) + _slow_log.warning( + "Event-loop slow-callback detector ENABLED (threshold=%.1fs). " + "Set PERF_DEBUG='' to disable.", + threshold_sec, + ) + + @asynccontextmanager async def lifespan(app: FastAPI): + # Enable slow-callback detection (set PERF_DEBUG=1 env var to activate) + _enable_slow_callback_logging(threshold_sec=0.5) # Not needed if you setup a migration system like Alembic await create_db_and_tables() # Setup LangGraph checkpointer tables for conversation persistence diff --git a/surfsense_backend/app/indexing_pipeline/connector_document.py b/surfsense_backend/app/indexing_pipeline/connector_document.py index ecd47bab2..019efe287 100644 --- a/surfsense_backend/app/indexing_pipeline/connector_document.py +++ b/surfsense_backend/app/indexing_pipeline/connector_document.py @@ -5,6 +5,7 @@ from app.db import DocumentType class ConnectorDocument(BaseModel): """Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline.""" + title: str source_markdown: str unique_id: str diff --git a/surfsense_backend/app/indexing_pipeline/document_chunker.py b/surfsense_backend/app/indexing_pipeline/document_chunker.py index 719c9f4bb..4f3c698ef 100644 --- a/surfsense_backend/app/indexing_pipeline/document_chunker.py +++ b/surfsense_backend/app/indexing_pipeline/document_chunker.py @@ -3,5 +3,7 @@ from app.config import config def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]: """Chunk a text string using the configured chunker and return the chunk texts.""" - chunker = config.code_chunker_instance if use_code_chunker else config.chunker_instance + chunker = ( + config.code_chunker_instance if use_code_chunker else config.chunker_instance + ) return [c.text for c in chunker.chunk(text)] diff --git a/surfsense_backend/app/indexing_pipeline/document_summarizer.py b/surfsense_backend/app/indexing_pipeline/document_summarizer.py index 1e708075e..76cc77377 100644 --- a/surfsense_backend/app/indexing_pipeline/document_summarizer.py +++ b/surfsense_backend/app/indexing_pipeline/document_summarizer.py @@ -2,7 +2,9 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE from app.utils.document_converters import optimize_content_for_context_window -async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str: +async def summarize_document( + source_markdown: str, llm, metadata: dict | None = None +) -> str: """Generate a text summary of a document using an LLM, prefixed with metadata when provided.""" model_name = getattr(llm, "model", "gpt-3.5-turbo") optimized_content = optimize_content_for_context_window( diff --git a/surfsense_backend/app/indexing_pipeline/exceptions.py b/surfsense_backend/app/indexing_pipeline/exceptions.py index 8c9c6f2d5..9366bbc3a 100644 --- a/surfsense_backend/app/indexing_pipeline/exceptions.py +++ b/surfsense_backend/app/indexing_pipeline/exceptions.py @@ -12,7 +12,7 @@ from litellm.exceptions import ( Timeout, UnprocessableEntityError, ) -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError as IntegrityError # Tuples for use directly in except clauses. RETRYABLE_LLM_ERRORS = ( @@ -36,29 +36,33 @@ PERMANENT_LLM_ERRORS = ( # (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError). EMBEDDING_ERRORS = ( RuntimeError, # local device failure or API backend normalization - OSError, # model files missing or corrupted (local backends) - MemoryError, # document too large for available RAM + OSError, # model files missing or corrupted (local backends) + MemoryError, # document too large for available RAM ) class PipelineMessages: - RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync." - LLM_TIMEOUT = "LLM request timed out. Will retry on next sync." - LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync." - LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync." - LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync." - LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity." + RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync." + LLM_TIMEOUT = "LLM request timed out. Will retry on next sync." + LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync." + LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync." + LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync." + LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity." - LLM_AUTH = "LLM authentication failed. Check your API key." - LLM_PERMISSION = "LLM request denied. Check your account permissions." - LLM_NOT_FOUND = "LLM model not found. Check your model configuration." - LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid." - LLM_UNPROCESSABLE = "Document exceeds the LLM context window even after optimization." - LLM_RESPONSE = "LLM returned an invalid response." + LLM_AUTH = "LLM authentication failed. Check your API key." + LLM_PERMISSION = "LLM request denied. Check your account permissions." + LLM_NOT_FOUND = "LLM model not found. Check your model configuration." + LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid." + LLM_UNPROCESSABLE = ( + "Document exceeds the LLM context window even after optimization." + ) + LLM_RESPONSE = "LLM returned an invalid response." - EMBEDDING_FAILED = "Embedding failed. Check your embedding model configuration or service." - EMBEDDING_MODEL = "Embedding model files are missing or corrupted." - EMBEDDING_MEMORY = "Not enough memory to embed this document." + EMBEDDING_FAILED = ( + "Embedding failed. Check your embedding model configuration or service." + ) + EMBEDDING_MODEL = "Embedding model files are missing or corrupted." + EMBEDDING_MEMORY = "Not enough memory to embed this document." CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk." diff --git a/surfsense_backend/app/indexing_pipeline/pipeline_logger.py b/surfsense_backend/app/indexing_pipeline/pipeline_logger.py index 6571920cf..281a92c52 100644 --- a/surfsense_backend/app/indexing_pipeline/pipeline_logger.py +++ b/surfsense_backend/app/indexing_pipeline/pipeline_logger.py @@ -8,27 +8,29 @@ logger = logging.getLogger(__name__) class PipelineLogContext: connector_id: int | None search_space_id: int - unique_id: str # always available from ConnectorDocument - doc_id: int | None = None # set once the DB row exists (index phase only) + unique_id: str # always available from ConnectorDocument + doc_id: int | None = None # set once the DB row exists (index phase only) class LogMessages: # prepare_for_indexing - DOCUMENT_QUEUED = "New document queued for indexing." - DOCUMENT_UPDATED = "Document content changed, re-queued for indexing." - DOCUMENT_REQUEUED = "Stuck document re-queued for indexing." + DOCUMENT_QUEUED = "New document queued for indexing." + DOCUMENT_UPDATED = "Document content changed, re-queued for indexing." + DOCUMENT_REQUEUED = "Stuck document re-queued for indexing." DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped." - BATCH_ABORTED = "Fatal DB error — aborting prepare batch." - RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch." + BATCH_ABORTED = "Fatal DB error — aborting prepare batch." + RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch." # index - INDEX_STARTED = "Document indexing started." - INDEX_SUCCESS = "Document indexed successfully." - LLM_RETRYABLE = "Retryable LLM error — document marked failed, will retry on next sync." - LLM_PERMANENT = "Permanent LLM error — document marked failed." - EMBEDDING_FAILED = "Embedding error — document marked failed." - CHUNKING_OVERFLOW = "Chunking overflow — document marked failed." - UNEXPECTED = "Unexpected error — document marked failed." + INDEX_STARTED = "Document indexing started." + INDEX_SUCCESS = "Document indexed successfully." + LLM_RETRYABLE = ( + "Retryable LLM error — document marked failed, will retry on next sync." + ) + LLM_PERMANENT = "Permanent LLM error — document marked failed." + EMBEDDING_FAILED = "Embedding error — document marked failed." + CHUNKING_OVERFLOW = "Chunking overflow — document marked failed." + UNEXPECTED = "Unexpected error — document marked failed." def _format_context(ctx: PipelineLogContext) -> str: @@ -52,7 +54,9 @@ def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str: return msg -def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra) -> None: +def _safe_log( + level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra +) -> None: # Logging must never raise — a broken log call inside an except block would # chain with the original exception and mask it entirely. try: @@ -64,6 +68,7 @@ def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extr # ── prepare_for_indexing ────────────────────────────────────────────────────── + def log_document_queued(ctx: PipelineLogContext) -> None: _safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx) @@ -77,7 +82,9 @@ def log_document_requeued(ctx: PipelineLogContext) -> None: def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None: - _safe_log(logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc) + _safe_log( + logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc + ) def log_race_condition(ctx: PipelineLogContext) -> None: @@ -90,6 +97,7 @@ def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None: # ── index ───────────────────────────────────────────────────────────────────── + def log_index_started(ctx: PipelineLogContext) -> None: _safe_log(logger.info, LogMessages.INDEX_STARTED, ctx) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 7856a2c17..c997cba68 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui: - POST /threads/{thread_id}/messages - Append message """ +import asyncio +import logging from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException, Request @@ -52,10 +54,8 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user from app.utils.rbac import check_permission -import asyncio -import logging - _logger = logging.getLogger(__name__) +_background_tasks: set[asyncio.Task] = set() router = APIRouter() @@ -75,15 +75,25 @@ def _try_delete_sandbox(thread_id: int) -> None: try: await delete_sandbox(thread_id) except Exception: - _logger.warning("Background sandbox delete failed for thread %s", thread_id, exc_info=True) + _logger.warning( + "Background sandbox delete failed for thread %s", + thread_id, + exc_info=True, + ) try: delete_local_sandbox_files(thread_id) except Exception: - _logger.warning("Local sandbox file cleanup failed for thread %s", thread_id, exc_info=True) + _logger.warning( + "Local sandbox file cleanup failed for thread %s", + thread_id, + exc_info=True, + ) try: loop = asyncio.get_running_loop() - loop.create_task(_bg()) + task = loop.create_task(_bg()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) except RuntimeError: pass diff --git a/surfsense_backend/app/routes/sandbox_routes.py b/surfsense_backend/app/routes/sandbox_routes.py index e5b737371..2c12c3a1e 100644 --- a/surfsense_backend/app/routes/sandbox_routes.py +++ b/surfsense_backend/app/routes/sandbox_routes.py @@ -87,7 +87,7 @@ async def download_sandbox_file( # Fall back to live sandbox download try: sandbox = await get_or_create_sandbox(thread_id) - raw_sandbox = sandbox._sandbox # noqa: SLF001 + raw_sandbox = sandbox._sandbox content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path) except Exception as exc: logger.warning("Sandbox file download failed for %s: %s", path, exc) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index b69238837..e808635e6 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -2735,7 +2735,10 @@ async def create_mcp_connector( f"for user {user.id} in search space {search_space_id}" ) - # Convert to read schema + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(search_space_id) + connector_read = SearchSourceConnectorRead.model_validate(db_connector) return MCPConnectorRead.from_connector(connector_read) @@ -2910,6 +2913,10 @@ async def update_mcp_connector( logger.info(f"Updated MCP connector {connector_id}") + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(connector.search_space_id) + connector_read = SearchSourceConnectorRead.model_validate(connector) return MCPConnectorRead.from_connector(connector_read) @@ -2960,9 +2967,14 @@ async def delete_mcp_connector( "You don't have permission to delete this connector", ) + search_space_id = connector.search_space_id await session.delete(connector) await session.commit() + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(search_space_id) + logger.info(f"Deleted MCP connector {connector_id}") except HTTPException: diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index bf942f548..ddadbc48b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -13,14 +13,17 @@ import asyncio import json import logging import re +import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field from typing import Any from uuid import UUID from langchain_core.messages import HumanMessage +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer @@ -31,10 +34,17 @@ from app.agents.new_chat.llm_config import ( load_agent_config, load_llm_config_from_yaml, ) +from app.agents.new_chat.sandbox import ( + get_or_create_sandbox, + is_sandbox_enabled, +) from app.db import ( ChatVisibility, Document, + NewChatMessage, + NewChatThread, Report, + SearchSourceConnectorType, SurfsenseDocsDocument, async_session_maker, ) @@ -47,6 +57,16 @@ from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db +_perf_log = logging.getLogger("surfsense.perf") +_perf_log.setLevel(logging.DEBUG) +if not _perf_log.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s")) + _perf_log.addHandler(_h) + _perf_log.propagate = False + +_background_tasks: set[asyncio.Task] = set() + def format_mentioned_documents_as_context(documents: list[Document]) -> str: """ @@ -877,7 +897,9 @@ async def _stream_agent_events( output_text = om.group(1) if om else "" thread_id_str = config.get("configurable", {}).get("thread_id", "") - for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE): + for sf_match in re.finditer( + r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE + ): fpath = sf_match.group(1).strip() if fpath and fpath not in result.sandbox_files: result.sandbox_files.append(fpath) @@ -963,7 +985,10 @@ def _try_persist_and_delete_sandbox( sandbox_files: list[str], ) -> None: """Fire-and-forget: persist sandbox files locally then delete the sandbox.""" - from app.agents.new_chat.sandbox import is_sandbox_enabled, persist_and_delete_sandbox + from app.agents.new_chat.sandbox import ( + is_sandbox_enabled, + persist_and_delete_sandbox, + ) if not is_sandbox_enabled(): return @@ -980,7 +1005,9 @@ def _try_persist_and_delete_sandbox( try: loop = asyncio.get_running_loop() - loop.create_task(_run()) + task = loop.create_task(_run()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) except RuntimeError: pass @@ -1022,6 +1049,7 @@ async def stream_new_chat( """ streaming_service = VercelStreamingService() stream_result = StreamResult() + _t_total = time.perf_counter() try: # Mark AI as responding to this user for live collaboration @@ -1030,6 +1058,7 @@ async def stream_new_chat( # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None + _t0 = time.perf_counter() if llm_config_id >= 0: # Positive ID: Load from NewLLMConfig database table agent_config = await load_agent_config( @@ -1060,6 +1089,11 @@ async def stream_new_chat( llm = create_chat_litellm_from_config(llm_config) # Create AgentConfig from YAML for consistency (uses defaults for prompt settings) agent_config = AgentConfig.from_yaml_config(llm_config) + _perf_log.info( + "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", + time.perf_counter() - _t0, + llm_config_id, + ) if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -1067,28 +1101,29 @@ async def stream_new_chat( return # Create connector service + _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) - # Get Firecrawl API key from webcrawler connector if configured - from app.db import SearchSourceConnectorType - firecrawl_api_key = None webcrawler_connector = await connector_service.get_connector_by_type( SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id ) if webcrawler_connector and webcrawler_connector.config: firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - - # Get the PostgreSQL checkpointer for persistent conversation memory - checkpointer = await get_checkpointer() - - # Optionally provision a sandboxed code execution environment - sandbox_backend = None - from app.agents.new_chat.sandbox import ( - get_or_create_sandbox, - is_sandbox_enabled, + _perf_log.info( + "[stream_new_chat] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, ) + # Get the PostgreSQL checkpointer for persistent conversation memory + _t0 = time.perf_counter() + checkpointer = await get_checkpointer() + _perf_log.info( + "[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + sandbox_backend = None + _t0 = time.perf_counter() if is_sandbox_enabled(): try: sandbox_backend = await get_or_create_sandbox(chat_id) @@ -1097,8 +1132,14 @@ async def stream_new_chat( "Sandbox creation failed, continuing without execute tool: %s", sandbox_err, ) + _perf_log.info( + "[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)", + time.perf_counter() - _t0, + sandbox_backend is not None, + ) visibility = thread_visibility or ChatVisibility.PRIVATE + _t0 = time.perf_counter() agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, @@ -1112,19 +1153,20 @@ async def stream_new_chat( thread_visibility=visibility, sandbox_backend=sandbox_backend, ) + _perf_log.info( + "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 + ) # Build input with message history langchain_messages = [] + _t0 = time.perf_counter() # Bootstrap history for cloned chats (no LangGraph checkpoint exists yet) if needs_history_bootstrap: langchain_messages = await bootstrap_history_from_db( session, chat_id, thread_visibility=visibility ) - # Clear the flag so we don't bootstrap again on next message - from app.db import NewChatThread - thread_result = await session.execute( select(NewChatThread).filter(NewChatThread.id == chat_id) ) @@ -1136,11 +1178,9 @@ async def stream_new_chat( # Fetch mentioned documents if any (with chunks for proper citations) mentioned_documents: list[Document] = [] if mentioned_document_ids: - from sqlalchemy.orm import selectinload as doc_selectinload - result = await session.execute( select(Document) - .options(doc_selectinload(Document.chunks)) + .options(selectinload(Document.chunks)) .filter( Document.id.in_(mentioned_document_ids), Document.search_space_id == search_space_id, @@ -1151,8 +1191,6 @@ async def stream_new_chat( # Fetch mentioned SurfSense docs if any mentioned_surfsense_docs: list[SurfsenseDocsDocument] = [] if mentioned_surfsense_doc_ids: - from sqlalchemy.orm import selectinload - result = await session.execute( select(SurfsenseDocsDocument) .options(selectinload(SurfsenseDocsDocument.chunks)) @@ -1236,6 +1274,11 @@ async def stream_new_chat( "search_space_id": search_space_id, } + _perf_log.info( + "[stream_new_chat] History bootstrap + doc/report queries in %.3fs", + time.perf_counter() - _t0, + ) + # All pre-streaming DB reads are done. Commit to release the # transaction and its ACCESS SHARE locks so we don't block DDL # (e.g. migrations) for the entire duration of LLM streaming. @@ -1243,6 +1286,12 @@ async def stream_new_chat( # short-lived transactions (or use isolated sessions). await session.commit() + _perf_log.info( + "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", + time.perf_counter() - _t_total, + chat_id, + ) + # Configure LangGraph with thread_id for memory # If checkpoint_id is provided, fork from that checkpoint (for edit/reload) configurable = {"thread_id": str(chat_id)} @@ -1304,6 +1353,8 @@ async def stream_new_chat( items=initial_items, ) + _t_stream_start = time.perf_counter() + _first_event_logged = False async for sse in _stream_agent_events( agent=agent, config=config, @@ -1315,8 +1366,23 @@ async def stream_new_chat( initial_step_title=initial_title, initial_step_items=initial_items, ): + if not _first_event_logged: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True yield sse + _perf_log.info( + "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", + time.perf_counter() - _t_stream_start, + chat_id, + ) + if stream_result.is_interrupted: yield streaming_service.format_finish_step() yield streaming_service.format_finish() @@ -1325,12 +1391,6 @@ async def stream_new_chat( accumulated_text = stream_result.accumulated_text - # Generate LLM title for new chats after first response - # Check if this is the first assistant response by counting existing assistant messages - from sqlalchemy import func - - from app.db import NewChatMessage, NewChatThread - assistant_count_result = await session.execute( select(func.count(NewChatMessage.id)).filter( NewChatMessage.thread_id == chat_id, @@ -1431,12 +1491,14 @@ async def stream_resume_chat( ) -> AsyncGenerator[str, None]: streaming_service = VercelStreamingService() stream_result = StreamResult() + _t_total = time.perf_counter() try: if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None + _t0 = time.perf_counter() if llm_config_id >= 0: agent_config = await load_agent_config( session=session, @@ -1460,31 +1522,37 @@ async def stream_resume_chat( return llm = create_chat_litellm_from_config(llm_config) agent_config = AgentConfig.from_yaml_config(llm_config) + _perf_log.info( + "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 + ) if not llm: yield streaming_service.format_error("Failed to create LLM instance") yield streaming_service.format_done() return + _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) - from app.db import SearchSourceConnectorType - firecrawl_api_key = None webcrawler_connector = await connector_service.get_connector_by_type( SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id ) if webcrawler_connector and webcrawler_connector.config: firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY") - - checkpointer = await get_checkpointer() - - sandbox_backend = None - from app.agents.new_chat.sandbox import ( - get_or_create_sandbox, - is_sandbox_enabled, + _perf_log.info( + "[stream_resume] Connector service + firecrawl key in %.3fs", + time.perf_counter() - _t0, ) + _t0 = time.perf_counter() + checkpointer = await get_checkpointer() + _perf_log.info( + "[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0 + ) + + sandbox_backend = None + _t0 = time.perf_counter() if is_sandbox_enabled(): try: sandbox_backend = await get_or_create_sandbox(chat_id) @@ -1493,9 +1561,15 @@ async def stream_resume_chat( "Sandbox creation failed, continuing without execute tool: %s", sandbox_err, ) + _perf_log.info( + "[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)", + time.perf_counter() - _t0, + sandbox_backend is not None, + ) visibility = thread_visibility or ChatVisibility.PRIVATE + _t0 = time.perf_counter() agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, @@ -1509,10 +1583,19 @@ async def stream_resume_chat( thread_visibility=visibility, sandbox_backend=sandbox_backend, ) + _perf_log.info( + "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 + ) # Release the transaction before streaming (same rationale as stream_new_chat). await session.commit() + _perf_log.info( + "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", + time.perf_counter() - _t_total, + chat_id, + ) + from langgraph.types import Command config = { @@ -1523,6 +1606,8 @@ async def stream_resume_chat( yield streaming_service.format_message_start() yield streaming_service.format_start_step() + _t_stream_start = time.perf_counter() + _first_event_logged = False async for sse in _stream_agent_events( agent=agent, config=config, @@ -1531,7 +1616,20 @@ async def stream_resume_chat( result=stream_result, step_prefix="thinking-resume", ): + if not _first_event_logged: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True yield sse + _perf_log.info( + "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", + time.perf_counter() - _t_stream_start, + chat_id, + ) if stream_result.is_interrupted: yield streaming_service.format_finish_step() yield streaming_service.format_finish() diff --git a/surfsense_backend/tests/conftest.py b/surfsense_backend/tests/conftest.py index c5d3b191b..14e158c03 100644 --- a/surfsense_backend/tests/conftest.py +++ b/surfsense_backend/tests/conftest.py @@ -33,4 +33,5 @@ def make_connector_document(): } defaults.update(overrides) return ConnectorDocument(**defaults) + return _make diff --git a/surfsense_backend/tests/integration/conftest.py b/surfsense_backend/tests/integration/conftest.py index 99e182c6b..119045d29 100644 --- a/surfsense_backend/tests/integration/conftest.py +++ b/surfsense_backend/tests/integration/conftest.py @@ -1,4 +1,3 @@ - import os import uuid from unittest.mock import AsyncMock, MagicMock @@ -9,14 +8,21 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.pool import NullPool -from app.db import Base, SearchSpace, SearchSourceConnector, SearchSourceConnectorType -from app.db import User -from app.db import DocumentType +from app.db import ( + Base, + DocumentType, + SearchSourceConnector, + SearchSourceConnectorType, + SearchSpace, + User, +) from app.indexing_pipeline.connector_document import ConnectorDocument _EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation -_DEFAULT_TEST_DB = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test" +_DEFAULT_TEST_DB = ( + "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test" +) TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB) @@ -80,7 +86,9 @@ async def db_user(db_session: AsyncSession) -> User: @pytest_asyncio.fixture -async def db_connector(db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace") -> SearchSourceConnector: +async def db_connector( + db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace" +) -> SearchSourceConnector: connector = SearchSourceConnector( name="Test Connector", connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR, @@ -147,6 +155,7 @@ def patched_chunk_text(monkeypatch) -> MagicMock: @pytest.fixture def make_connector_document(db_connector, db_user): """Integration-scoped override: uses real DB connector and user IDs.""" + def _make(**overrides): defaults = { "title": "Test Document", @@ -159,6 +168,5 @@ def make_connector_document(db_connector, db_user): } defaults.update(overrides) return ConnectorDocument(**defaults) + return _make - - diff --git a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py index 723c0e13b..193e4bd80 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py @@ -7,7 +7,9 @@ from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_fi pytestmark = pytest.mark.integration -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_sets_status_ready(db_session, db_search_space, db_user, mocker): """Document status is READY after successful indexing.""" await index_uploaded_file( @@ -28,7 +30,9 @@ async def test_sets_status_ready(db_session, db_search_space, db_user, mocker): assert DocumentStatus.is_state(document.status, DocumentStatus.READY) -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_content_is_summary(db_session, db_search_space, db_user, mocker): """Document content is set to the LLM-generated summary.""" await index_uploaded_file( @@ -49,7 +53,9 @@ async def test_content_is_summary(db_session, db_search_space, db_user, mocker): assert document.content == "Mocked summary." -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker): """Chunks derived from the source markdown are persisted in the DB.""" await index_uploaded_file( @@ -76,7 +82,9 @@ async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker assert chunks[0].content == "Test chunk content." -@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize_raises", "patched_embed_text", "patched_chunk_text" +) async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker): """RuntimeError is raised when the indexing step fails so the caller can fire a failure notification.""" with pytest.raises(RuntimeError): diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py b/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py index 7c5e1e4f4..0065a03e1 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py @@ -7,9 +7,14 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ pytestmark = pytest.mark.integration -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_sets_status_ready( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """Document status is READY after successful indexing.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -21,15 +26,22 @@ async def test_sets_status_ready( await service.index(document, connector_doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY) -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_content_is_summary_when_should_summarize_true( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """Document content is set to the LLM-generated summary when should_summarize=True.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -41,15 +53,21 @@ async def test_content_is_summary_when_should_summarize_true( await service.index(document, connector_doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded.content == "Mocked summary." -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_content_is_source_markdown_when_should_summarize_false( - db_session, db_search_space, make_connector_document, + db_session, + db_search_space, + make_connector_document, ): """Document content is set to source_markdown verbatim when should_summarize=False.""" connector_doc = make_connector_document( @@ -65,15 +83,22 @@ async def test_content_is_source_markdown_when_should_summarize_false( await service.index(document, connector_doc, llm=None) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded.content == "## Raw content" -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_chunks_written_to_db( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """Chunks derived from source_markdown are persisted in the DB.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -94,9 +119,14 @@ async def test_chunks_written_to_db( assert chunks[0].content == "Test chunk content." -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_embedding_written_to_db( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """Document embedding vector is persisted in the DB after indexing.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -108,16 +138,23 @@ async def test_embedding_written_to_db( await service.index(document, connector_doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded.embedding is not None assert len(reloaded.embedding) == 1024 -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_updated_at_advances_after_indexing( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """updated_at timestamp is later after indexing than it was at prepare time.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -127,20 +164,28 @@ async def test_updated_at_advances_after_indexing( document = prepared[0] document_id = document.id - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) updated_at_pending = result.scalars().first().updated_at await service.index(document, connector_doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) updated_at_ready = result.scalars().first().updated_at assert updated_at_ready > updated_at_pending -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_no_llm_falls_back_to_source_markdown( - db_session, db_search_space, make_connector_document, + db_session, + db_search_space, + make_connector_document, ): """When llm=None and no fallback_summary, content falls back to source_markdown.""" connector_doc = make_connector_document( @@ -156,16 +201,22 @@ async def test_no_llm_falls_back_to_source_markdown( await service.index(document, connector_doc, llm=None) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY) assert reloaded.content == "## Fallback content" -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_fallback_summary_used_when_llm_unavailable( - db_session, db_search_space, make_connector_document, + db_session, + db_search_space, + make_connector_document, ): """fallback_summary is used as content when llm=None and should_summarize=True.""" connector_doc = make_connector_document( @@ -181,16 +232,23 @@ async def test_fallback_summary_used_when_llm_unavailable( await service.index(prepared[0], connector_doc, llm=None) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY) assert reloaded.content == "Short pre-built summary." -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_reindex_replaces_old_chunks( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """Re-indexing a document replaces its old chunks rather than appending.""" connector_doc = make_connector_document( @@ -220,9 +278,14 @@ async def test_reindex_replaces_old_chunks( assert len(chunks) == 1 -@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize_raises", "patched_embed_text", "patched_chunk_text" +) async def test_llm_error_sets_status_failed( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """Document status is FAILED when the LLM raises during indexing.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -234,15 +297,22 @@ async def test_llm_error_sets_status_failed( await service.index(document, connector_doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED) -@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize_raises", "patched_embed_text", "patched_chunk_text" +) async def test_llm_error_leaves_no_partial_data( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """A failed indexing attempt leaves no partial embedding or chunks in the DB.""" connector_doc = make_connector_document(search_space_id=db_search_space.id) @@ -254,7 +324,9 @@ async def test_llm_error_leaves_no_partial_data( await service.index(document, connector_doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded.embedding is None diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py b/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py index b6d257f7a..837b02c9f 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py @@ -2,7 +2,9 @@ import pytest from sqlalchemy import select from app.db import Document, DocumentStatus -from app.indexing_pipeline.document_hashing import compute_content_hash as real_compute_content_hash +from app.indexing_pipeline.document_hashing import ( + compute_content_hash as real_compute_content_hash, +) from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService pytestmark = pytest.mark.integration @@ -20,7 +22,9 @@ async def test_new_document_is_persisted_with_pending_status( assert len(results) == 1 document_id = results[0].id - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded is not None @@ -28,9 +32,14 @@ async def test_new_document_is_persisted_with_pending_status( assert reloaded.source_markdown == doc.source_markdown -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_unchanged_ready_document_is_skipped( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """A READY document with unchanged content is not returned for re-indexing.""" doc = make_connector_document(search_space_id=db_search_space.id) @@ -46,24 +55,35 @@ async def test_unchanged_ready_document_is_skipped( assert results == [] -@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_text", "patched_chunk_text" +) async def test_title_only_change_updates_title_in_db( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """A title-only change updates the DB title without re-queuing the document.""" - original = make_connector_document(search_space_id=db_search_space.id, title="Original Title") + original = make_connector_document( + search_space_id=db_search_space.id, title="Original Title" + ) service = IndexingPipelineService(session=db_session) prepared = await service.prepare_for_indexing([original]) document_id = prepared[0].id await service.index(prepared[0], original, llm=mocker.Mock()) - renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title") + renamed = make_connector_document( + search_space_id=db_search_space.id, title="Updated Title" + ) results = await service.prepare_for_indexing([renamed]) assert results == [] - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded.title == "Updated Title" @@ -73,19 +93,25 @@ async def test_changed_content_is_returned_for_reprocessing( db_session, db_search_space, make_connector_document ): """A document with changed content is returned for re-indexing with updated markdown.""" - original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1") + original = make_connector_document( + search_space_id=db_search_space.id, source_markdown="## v1" + ) service = IndexingPipelineService(session=db_session) first = await service.prepare_for_indexing([original]) original_id = first[0].id - updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2") + updated = make_connector_document( + search_space_id=db_search_space.id, source_markdown="## v2" + ) results = await service.prepare_for_indexing([updated]) assert len(results) == 1 assert results[0].id == original_id - result = await db_session.execute(select(Document).filter(Document.id == original_id)) + result = await db_session.execute( + select(Document).filter(Document.id == original_id) + ) reloaded = result.scalars().first() assert reloaded.source_markdown == "## v2" @@ -97,9 +123,24 @@ async def test_all_documents_in_batch_are_persisted( ): """All documents in a batch are persisted and returned.""" docs = [ - make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"), - make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"), - make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"), + make_connector_document( + search_space_id=db_search_space.id, + unique_id="id-1", + title="Doc 1", + source_markdown="## Content 1", + ), + make_connector_document( + search_space_id=db_search_space.id, + unique_id="id-2", + title="Doc 2", + source_markdown="## Content 2", + ), + make_connector_document( + search_space_id=db_search_space.id, + unique_id="id-3", + title="Doc 3", + source_markdown="## Content 3", + ), ] service = IndexingPipelineService(session=db_session) @@ -107,7 +148,9 @@ async def test_all_documents_in_batch_are_persisted( assert len(results) == 3 - result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id)) + result = await db_session.execute( + select(Document).filter(Document.search_space_id == db_search_space.id) + ) rows = result.scalars().all() assert len(rows) == 3 @@ -124,7 +167,9 @@ async def test_duplicate_in_batch_is_persisted_once( assert len(results) == 1 - result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id)) + result = await db_session.execute( + select(Document).filter(Document.search_space_id == db_search_space.id) + ) rows = result.scalars().all() assert len(rows) == 1 @@ -143,7 +188,9 @@ async def test_created_by_id_is_persisted( results = await service.prepare_for_indexing([doc]) document_id = results[0].id - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert str(reloaded.created_by_id) == str(db_user.id) @@ -170,7 +217,9 @@ async def test_metadata_is_updated_when_content_changes( ) await service.prepare_for_indexing([updated]) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) reloaded = result.scalars().first() assert reloaded.document_metadata == {"status": "done"} @@ -180,19 +229,27 @@ async def test_updated_at_advances_when_title_only_changes( db_session, db_search_space, make_connector_document ): """updated_at advances even when only the title changes.""" - original = make_connector_document(search_space_id=db_search_space.id, title="Old Title") + original = make_connector_document( + search_space_id=db_search_space.id, title="Old Title" + ) service = IndexingPipelineService(session=db_session) first = await service.prepare_for_indexing([original]) document_id = first[0].id - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) updated_at_v1 = result.scalars().first().updated_at - renamed = make_connector_document(search_space_id=db_search_space.id, title="New Title") + renamed = make_connector_document( + search_space_id=db_search_space.id, title="New Title" + ) await service.prepare_for_indexing([renamed]) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) updated_at_v2 = result.scalars().first().updated_at assert updated_at_v2 > updated_at_v1 @@ -202,19 +259,27 @@ async def test_updated_at_advances_when_content_changes( db_session, db_search_space, make_connector_document ): """updated_at advances when document content changes.""" - original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1") + original = make_connector_document( + search_space_id=db_search_space.id, source_markdown="## v1" + ) service = IndexingPipelineService(session=db_session) first = await service.prepare_for_indexing([original]) document_id = first[0].id - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) updated_at_v1 = result.scalars().first().updated_at - updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2") + updated = make_connector_document( + search_space_id=db_search_space.id, source_markdown="## v2" + ) await service.prepare_for_indexing([updated]) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) updated_at_v2 = result.scalars().first().updated_at assert updated_at_v2 > updated_at_v1 @@ -273,9 +338,14 @@ async def test_same_content_from_different_source_is_skipped( assert len(result.scalars().all()) == 1 -@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize_raises", "patched_embed_text", "patched_chunk_text" +) async def test_failed_document_with_unchanged_content_is_requeued( - db_session, db_search_space, make_connector_document, mocker, + db_session, + db_search_space, + make_connector_document, + mocker, ): """A FAILED document with unchanged content is re-queued as PENDING on the next run.""" doc = make_connector_document(search_space_id=db_search_space.id) @@ -286,8 +356,12 @@ async def test_failed_document_with_unchanged_content_is_requeued( document_id = prepared[0].id await service.index(prepared[0], doc, llm=mocker.Mock()) - result = await db_session.execute(select(Document).filter(Document.id == document_id)) - assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.FAILED) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + assert DocumentStatus.is_state( + result.scalars().first().status, DocumentStatus.FAILED + ) # Next run: same content, pipeline must re-queue the failed document results = await service.prepare_for_indexing([doc]) @@ -295,8 +369,12 @@ async def test_failed_document_with_unchanged_content_is_requeued( assert len(results) == 1 assert results[0].id == document_id - result = await db_session.execute(select(Document).filter(Document.id == document_id)) - assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.PENDING) + result = await db_session.execute( + select(Document).filter(Document.id == document_id) + ) + assert DocumentStatus.is_state( + result.scalars().first().status, DocumentStatus.PENDING + ) async def test_title_and_content_change_updates_both_and_returns_document( @@ -323,16 +401,20 @@ async def test_title_and_content_change_updates_both_and_returns_document( assert len(results) == 1 assert results[0].id == original_id - result = await db_session.execute(select(Document).filter(Document.id == original_id)) + result = await db_session.execute( + select(Document).filter(Document.id == original_id) + ) reloaded = result.scalars().first() assert reloaded.title == "Updated Title" assert reloaded.source_markdown == "## v2" - async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted( - db_session, db_search_space, make_connector_document, monkeypatch, + db_session, + db_search_space, + make_connector_document, + monkeypatch, ): """ A per-document error during prepare_for_indexing must be isolated. @@ -374,4 +456,4 @@ async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_pers result = await db_session.execute( select(Document).filter(Document.search_space_id == db_search_space.id) ) - assert len(result.scalars().all()) == 2 \ No newline at end of file + assert len(result.scalars().all()) == 2 diff --git a/surfsense_backend/tests/unit/indexing_pipeline/conftest.py b/surfsense_backend/tests/unit/indexing_pipeline/conftest.py index 2147cfa3f..11f84dce5 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/conftest.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/conftest.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import AsyncMock, MagicMock +import pytest + @pytest.fixture def patched_summarizer_chain(monkeypatch): @@ -21,7 +22,9 @@ def patched_summarizer_chain(monkeypatch): def patched_chunker_instance(monkeypatch): mock = MagicMock() mock.chunk.return_value = [MagicMock(text="prose chunk")] - monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.chunker_instance", mock) + monkeypatch.setattr( + "app.indexing_pipeline.document_chunker.config.chunker_instance", mock + ) return mock @@ -29,5 +32,7 @@ def patched_chunker_instance(monkeypatch): def patched_code_chunker_instance(monkeypatch): mock = MagicMock() mock.chunk.return_value = [MagicMock(text="code chunk")] - monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock) + monkeypatch.setattr( + "app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock + ) return mock diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py b/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py index 6b7a47f51..fe536b066 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py @@ -1,7 +1,10 @@ import pytest from app.db import DocumentType -from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash +from app.indexing_pipeline.document_hashing import ( + compute_content_hash, + compute_unique_identifier_hash, +) pytestmark = pytest.mark.unit @@ -10,21 +13,31 @@ def test_different_unique_id_produces_different_hash(make_connector_document): """Two documents with different unique_ids produce different identifier hashes.""" doc_a = make_connector_document(unique_id="id-001") doc_b = make_connector_document(unique_id="id-002") - assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b) + assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash( + doc_b + ) -def test_different_search_space_produces_different_identifier_hash(make_connector_document): +def test_different_search_space_produces_different_identifier_hash( + make_connector_document, +): """Same document in different search spaces produces different identifier hashes.""" doc_a = make_connector_document(search_space_id=1) doc_b = make_connector_document(search_space_id=2) - assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b) + assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash( + doc_b + ) -def test_different_document_type_produces_different_identifier_hash(make_connector_document): +def test_different_document_type_produces_different_identifier_hash( + make_connector_document, +): """Same unique_id with different document types produces different identifier hashes.""" doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR) doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR) - assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b) + assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash( + doc_b + ) def test_same_content_same_space_produces_same_content_hash(make_connector_document): @@ -34,7 +47,9 @@ def test_same_content_same_space_produces_same_content_hash(make_connector_docum assert compute_content_hash(doc_a) == compute_content_hash(doc_b) -def test_same_content_different_space_produces_different_content_hash(make_connector_document): +def test_same_content_different_space_produces_different_content_hash( + make_connector_document, +): """Identical content in different search spaces produces different content hashes.""" doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1) doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py b/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py index a3a8ecfc2..eee32357f 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import MagicMock +import pytest + from app.indexing_pipeline.document_summarizer import summarize_document pytestmark = pytest.mark.unit @@ -38,5 +39,3 @@ async def test_with_metadata_omits_empty_fields_from_output(): assert "Alice" in result assert "description" not in result.lower() - -