Merge remote-tracking branch 'upstream/dev' into fix/docker

This commit is contained in:
Anish Sarkar 2026-02-27 05:00:23 +05:30
commit f419efcde1
35 changed files with 771 additions and 623 deletions

View file

@ -176,12 +176,3 @@ DAYTONA_API_URL=https://app.daytona.io/api
DAYTONA_TARGET=us
# Directory for locally-persisted sandbox files (after sandbox deletion)
SANDBOX_FILES_DIR=sandbox_files
# ============================================================
# Testing (optional — all have sensible defaults)
# ============================================================
# TEST_BACKEND_URL=http://localhost:8000
# TEST_DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
# TEST_USER_EMAIL=testuser@surfsense.com
# TEST_USER_PASSWORD=testpassword123

View file

@ -6,6 +6,9 @@ with configurable tools via the tools registry and configurable prompts
via NewLLMConfig.
"""
import asyncio
import logging
import time
from collections.abc import Sequence
from typing import Any
@ -26,6 +29,8 @@ from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
_perf_log = logging.getLogger("surfsense.perf")
# =============================================================================
# Connector Type Mapping
# =============================================================================
@ -210,29 +215,29 @@ async def create_surfsense_deep_agent(
additional_tools=[my_custom_tool]
)
"""
_t_agent_total = time.perf_counter()
# Discover available connectors and document types for this search space
# This enables dynamic tool docstrings that inform the LLM about what's actually available
available_connectors: list[str] | None = None
available_document_types: list[str] | None = None
_t0 = time.perf_counter()
try:
# Get enabled search source connectors for this search space
connector_types = await connector_service.get_available_connectors(
search_space_id
)
if connector_types:
# Convert enum values to strings and also include mapped document types
available_connectors = _map_connectors_to_searchable_types(connector_types)
# Get document types that have at least one document indexed
available_document_types = await connector_service.get_available_document_types(
search_space_id
)
except Exception as e:
# Log but don't fail - fall back to all connectors if discovery fails
import logging
logging.warning(f"Failed to discover available connectors/document types: {e}")
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
time.perf_counter() - _t0,
)
# Build dependencies dict for the tools registry
visibility = thread_visibility or ChatVisibility.PRIVATE
@ -274,14 +279,21 @@ async def create_surfsense_deep_agent(
modified_disabled_tools.extend(linear_tools)
# Build tools using the async registry (includes MCP tools)
_t0 = time.perf_counter()
tools = await build_tools_async(
dependencies=dependencies,
enabled_tools=enabled_tools,
disabled_tools=modified_disabled_tools,
additional_tools=list(additional_tools) if additional_tools else None,
)
_perf_log.info(
"[create_agent] build_tools_async in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
# Build system prompt based on agent_config
_t0 = time.perf_counter()
_sandbox_enabled = sandbox_backend is not None
if agent_config is not None:
system_prompt = build_configurable_system_prompt(
@ -296,15 +308,18 @@ async def create_surfsense_deep_agent(
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
)
# Build optional kwargs for the deep agent
deep_agent_kwargs: dict[str, Any] = {}
if sandbox_backend is not None:
deep_agent_kwargs["backend"] = sandbox_backend
# Create the deep agent with system prompt and checkpointer
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
agent = create_deep_agent(
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
create_deep_agent,
model=llm,
tools=tools,
system_prompt=system_prompt,
@ -312,5 +327,13 @@ async def create_surfsense_deep_agent(
checkpointer=checkpointer,
**deep_agent_kwargs,
)
_perf_log.info(
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
time.perf_counter() - _t0,
)
_perf_log.info(
"[create_agent] Total agent creation in %.3fs",
time.perf_counter() - _t_agent_total,
)
return agent

View file

@ -12,6 +12,7 @@ the sandbox is deleted so they remain downloadable after cleanup.
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import shutil
@ -56,6 +57,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
_daytona_client: Daytona | None = None
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
THREAD_LABEL_KEY = "surfsense_thread"
@ -126,8 +128,8 @@ def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
"""Get or create a sandbox for a conversation thread.
Uses the thread_id as a label so the same sandbox persists
across multiple messages within the same conversation.
Uses an in-process cache keyed by thread_id so subsequent messages
in the same conversation reuse the sandbox object without an API call.
Args:
thread_id: The conversation thread identifier.
@ -135,11 +137,19 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
Returns:
DaytonaSandbox connected to the sandbox.
"""
return await asyncio.to_thread(_find_or_create, str(thread_id))
key = str(thread_id)
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached
sandbox = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
return sandbox
async def delete_sandbox(thread_id: int | str) -> None:
"""Delete the sandbox for a conversation thread."""
_sandbox_cache.pop(str(thread_id), None)
def _delete() -> None:
client = _get_client()
@ -209,6 +219,7 @@ async def persist_and_delete_sandbox(
Per-file errors are logged but do **not** prevent the sandbox from
being deleted freeing Daytona storage is the priority.
"""
_sandbox_cache.pop(str(thread_id), None)
def _persist_and_delete() -> None:
client = _get_client()
@ -232,10 +243,8 @@ async def persist_and_delete_sandbox(
sandbox.id,
exc_info=True,
)
try:
with contextlib.suppress(Exception):
client.delete(sandbox)
except Exception:
pass
return
for path in sandbox_file_paths:

View file

@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation.
"""
import logging
import time
from typing import Any
from langchain_core.tools import StructuredTool
@ -25,6 +26,9 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
def _create_dynamic_input_model_from_schema(
tool_name: str,
@ -355,6 +359,19 @@ async def _load_http_mcp_tools(
return tools
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
Args:
search_space_id: If provided, only invalidate for this search space.
If None, invalidate all cached MCP tools.
"""
if search_space_id is not None:
_mcp_tools_cache.pop(search_space_id, None)
else:
_mcp_tools_cache.clear()
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
@ -364,6 +381,9 @@ async def load_mcp_tools(
This discovers tools dynamically from MCP servers using the protocol.
Supports both stdio (local process) and HTTP (remote server) transports.
Results are cached per search space for up to 5 minutes to avoid
re-spawning MCP server processes on every chat message.
Args:
session: Database session
search_space_id: User's search space ID
@ -372,8 +392,20 @@ async def load_mcp_tools(
List of LangChain StructuredTool instances
"""
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
if cached is not None:
cached_at, cached_tools = cached
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
logger.info(
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
search_space_id,
len(cached_tools),
now - cached_at,
)
return list(cached_tools)
try:
# Fetch all MCP connectors for this search space
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
@ -385,27 +417,22 @@ async def load_mcp_tools(
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Early validation: Extract and validate connector config
config = connector.config or {}
server_config = config.get("server_config", {})
# Validate server_config exists and is a dict
if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
)
continue
# Determine transport type
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
# HTTP-based MCP server
connector_tools = await _load_http_mcp_tools(
connector.id, connector.name, server_config
)
else:
# stdio-based MCP server (default)
connector_tools = await _load_stdio_mcp_tools(
connector.id, connector.name, server_config
)
@ -417,6 +444,7 @@ async def load_mcp_tools(
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
)
_mcp_tools_cache[search_space_id] = (now, tools)
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools

View file

@ -444,8 +444,18 @@ async def build_tools_async(
List of configured tool instances ready for the agent, including MCP tools.
"""
# Build standard tools
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
_t0 = time.perf_counter()
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
_perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(tools),
)
# Load MCP tools if requested and dependencies are available
if (
@ -454,10 +464,16 @@ async def build_tools_async(
and "search_space_id" in dependencies
):
try:
_t0 = time.perf_counter()
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
_perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
time.perf_counter() - _t0,
len(mcp_tools),
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",

View file

@ -175,8 +175,39 @@ def rate_limit_password_reset(request: Request):
)
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
This helps pinpoint synchronous code that freezes the entire FastAPI server.
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
"""
import os
if not os.environ.get("PERF_DEBUG"):
return
_slow_log = logging.getLogger("surfsense.perf.slow")
_slow_log.setLevel(logging.WARNING)
if not _slow_log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
_slow_log.addHandler(_h)
_slow_log.propagate = False
loop = asyncio.get_running_loop()
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
loop.set_debug(True)
_slow_log.warning(
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
"Set PERF_DEBUG='' to disable.",
threshold_sec,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
_enable_slow_callback_logging(threshold_sec=0.5)
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence

View file

@ -12,6 +12,7 @@ from litellm.exceptions import (
Timeout,
UnprocessableEntityError,
)
from sqlalchemy.exc import IntegrityError as IntegrityError
# Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = (
@ -37,6 +38,8 @@ EMBEDDING_ERRORS = (
RuntimeError, # local device failure or API backend normalization
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
)
@ -47,7 +50,21 @@ class PipelineMessages:
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = (
"Document exceeds the LLM context window even after optimization."
)
LLM_RESPONSE = "LLM returned an invalid response."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
@ -57,6 +74,11 @@ class PipelineMessages:
)
LLM_RESPONSE = "LLM returned an invalid response."
EMBEDDING_FAILED = (
"Embedding failed. Check your embedding model configuration or service."
)
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
EMBEDDING_FAILED = (
"Embedding failed. Check your embedding model configuration or service."
)

View file

@ -28,6 +28,7 @@ from app.schemas import (
DocumentWithChunksRead,
PaginatedResponse,
)
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -120,6 +121,7 @@ async def create_documents_file_upload(
search_space_id: int = Form(...),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
):
"""
Upload files as documents with real-time status tracking.
@ -290,14 +292,10 @@ async def create_documents_file_upload(
for doc in created_documents:
await session.refresh(doc)
# ===== PHASE 2: Dispatch Celery tasks for each file =====
# ===== PHASE 2: Dispatch tasks for each file =====
# Each task will update document status: pending → processing → ready/failed
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
)
for document, temp_path, filename in files_to_process:
process_file_upload_with_document_task.delay(
await dispatcher.dispatch_file_processing(
document_id=document.id,
temp_path=temp_path,
filename=filename,

View file

@ -55,6 +55,7 @@ from app.users import current_active_user
from app.utils.rbac import check_permission
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
router = APIRouter()
@ -90,7 +91,9 @@ def _try_delete_sandbox(thread_id: int) -> None:
try:
loop = asyncio.get_running_loop()
loop.create_task(_bg())
task = loop.create_task(_bg())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass

View file

@ -2735,7 +2735,10 @@ async def create_mcp_connector(
f"for user {user.id} in search space {search_space_id}"
)
# Convert to read schema
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2910,6 +2913,10 @@ async def update_mcp_connector(
logger.info(f"Updated MCP connector {connector_id}")
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(connector.search_space_id)
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
@ -2960,9 +2967,14 @@ async def delete_mcp_connector(
"You don't have permission to delete this connector",
)
search_space_id = connector.search_space_id
await session.delete(connector)
await session.commit()
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
logger.info(f"Deleted MCP connector {connector_id}")
except HTTPException:

View file

@ -0,0 +1,50 @@
"""Task dispatcher abstraction for background document processing.
Decouples the upload endpoint from Celery so tests can swap in a
synchronous (inline) implementation that needs only PostgreSQL.
"""
from __future__ import annotations
from typing import Protocol
class TaskDispatcher(Protocol):
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
) -> None: ...
class CeleryTaskDispatcher:
"""Production dispatcher — fires Celery tasks via Redis broker."""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
)
process_file_upload_with_document_task.delay(
document_id=document_id,
temp_path=temp_path,
filename=filename,
search_space_id=search_space_id,
user_id=user_id,
)
async def get_task_dispatcher() -> TaskDispatcher:
return CeleryTaskDispatcher()

View file

@ -13,14 +13,17 @@ import asyncio
import json
import logging
import re
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
@ -31,10 +34,17 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
)
from app.db import (
ChatVisibility,
Document,
NewChatMessage,
NewChatThread,
Report,
SearchSourceConnectorType,
SurfsenseDocsDocument,
async_session_maker,
)
@ -47,6 +57,16 @@ from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
if not _perf_log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
_perf_log.addHandler(_h)
_perf_log.propagate = False
_background_tasks: set[asyncio.Task] = set()
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
"""
@ -985,7 +1005,9 @@ def _try_persist_and_delete_sandbox(
try:
loop = asyncio.get_running_loop()
loop.create_task(_run())
task = loop.create_task(_run())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
@ -1027,6 +1049,7 @@ async def stream_new_chat(
"""
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
try:
# Mark AI as responding to this user for live collaboration
@ -1035,6 +1058,7 @@ async def stream_new_chat(
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
if llm_config_id >= 0:
# Positive ID: Load from NewLLMConfig database table
agent_config = await load_agent_config(
@ -1065,6 +1089,11 @@ async def stream_new_chat(
llm = create_chat_litellm_from_config(llm_config)
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info(
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
time.perf_counter() - _t0,
llm_config_id,
)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
@ -1072,28 +1101,29 @@ async def stream_new_chat(
return
# Create connector service
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
# Get Firecrawl API key from webcrawler connector if configured
from app.db import SearchSourceConnectorType
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
# Get the PostgreSQL checkpointer for persistent conversation memory
checkpointer = await get_checkpointer()
# Optionally provision a sandboxed code execution environment
sandbox_backend = None
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
_perf_log.info(
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
# Get the PostgreSQL checkpointer for persistent conversation memory
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
@ -1102,8 +1132,14 @@ async def stream_new_chat(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
@ -1117,19 +1153,20 @@ async def stream_new_chat(
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
)
# Build input with message history
langchain_messages = []
_t0 = time.perf_counter()
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(
session, chat_id, thread_visibility=visibility
)
# Clear the flag so we don't bootstrap again on next message
from app.db import NewChatThread
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
@ -1141,11 +1178,9 @@ async def stream_new_chat(
# Fetch mentioned documents if any (with chunks for proper citations)
mentioned_documents: list[Document] = []
if mentioned_document_ids:
from sqlalchemy.orm import selectinload as doc_selectinload
result = await session.execute(
select(Document)
.options(doc_selectinload(Document.chunks))
.options(selectinload(Document.chunks))
.filter(
Document.id.in_(mentioned_document_ids),
Document.search_space_id == search_space_id,
@ -1156,8 +1191,6 @@ async def stream_new_chat(
# Fetch mentioned SurfSense docs if any
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
if mentioned_surfsense_doc_ids:
from sqlalchemy.orm import selectinload
result = await session.execute(
select(SurfsenseDocsDocument)
.options(selectinload(SurfsenseDocsDocument.chunks))
@ -1241,6 +1274,11 @@ async def stream_new_chat(
"search_space_id": search_space_id,
}
_perf_log.info(
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
time.perf_counter() - _t0,
)
# All pre-streaming DB reads are done. Commit to release the
# transaction and its ACCESS SHARE locks so we don't block DDL
# (e.g. migrations) for the entire duration of LLM streaming.
@ -1248,6 +1286,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
# Configure LangGraph with thread_id for memory
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
configurable = {"thread_id": str(chat_id)}
@ -1309,6 +1353,8 @@ async def stream_new_chat(
items=initial_items,
)
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
@ -1320,8 +1366,23 @@ async def stream_new_chat(
initial_step_title=initial_title,
initial_step_items=initial_items,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
@ -1330,12 +1391,6 @@ async def stream_new_chat(
accumulated_text = stream_result.accumulated_text
# Generate LLM title for new chats after first response
# Check if this is the first assistant response by counting existing assistant messages
from sqlalchemy import func
from app.db import NewChatMessage, NewChatThread
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
NewChatMessage.thread_id == chat_id,
@ -1436,12 +1491,14 @@ async def stream_resume_chat(
) -> AsyncGenerator[str, None]:
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
if llm_config_id >= 0:
agent_config = await load_agent_config(
session=session,
@ -1465,31 +1522,37 @@ async def stream_resume_chat(
return
llm = create_chat_litellm_from_config(llm_config)
agent_config = AgentConfig.from_yaml_config(llm_config)
_perf_log.info(
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
yield streaming_service.format_done()
return
_t0 = time.perf_counter()
connector_service = ConnectorService(session, search_space_id=search_space_id)
from app.db import SearchSourceConnectorType
firecrawl_api_key = None
webcrawler_connector = await connector_service.get_connector_by_type(
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
)
if webcrawler_connector and webcrawler_connector.config:
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
checkpointer = await get_checkpointer()
sandbox_backend = None
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
_perf_log.info(
"[stream_resume] Connector service + firecrawl key in %.3fs",
time.perf_counter() - _t0,
)
_t0 = time.perf_counter()
checkpointer = await get_checkpointer()
_perf_log.info(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
@ -1498,9 +1561,15 @@ async def stream_resume_chat(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
@ -1514,10 +1583,19 @@ async def stream_resume_chat(
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
)
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
chat_id,
)
from langgraph.types import Command
config = {
@ -1528,6 +1606,8 @@ async def stream_resume_chat(
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
@ -1536,7 +1616,20 @@ async def stream_resume_chat(
result=stream_result,
step_prefix="thinking-resume",
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,
chat_id,
)
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()

View file

@ -178,8 +178,7 @@ python_functions = ["test_*"]
addopts = "-v --tb=short -x --strict-markers -ra --durations=5"
markers = [
"unit: pure logic tests, no DB or external services",
"integration: tests that require a real PostgreSQL database",
"e2e: tests requiring a running backend and real HTTP calls"
"integration: tests that require a real PostgreSQL database"
]
filterwarnings = [
"ignore::UserWarning:chonkie",

View file

@ -3,23 +3,21 @@
from __future__ import annotations
import os
from pathlib import Path
_DEFAULT_TEST_DB = (
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
)
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
# Force the app to use the test database regardless of any pre-existing
# DATABASE_URL in the environment (e.g. from .env or shell profile).
os.environ["DATABASE_URL"] = TEST_DATABASE_URL
import pytest
from dotenv import load_dotenv
from app.db import DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
# Shared DB URL referenced by both e2e and integration helper functions.
DATABASE_URL = os.environ.get(
"TEST_DATABASE_URL",
os.environ.get("DATABASE_URL", ""),
).replace("postgresql+asyncpg://", "postgresql://")
# ---------------------------------------------------------------------------
# Unit test fixtures
# ---------------------------------------------------------------------------

View file

@ -1,198 +0,0 @@
"""E2e conftest — fixtures that require a running backend + database."""
from __future__ import annotations
from collections.abc import AsyncGenerator
import asyncpg
import httpx
import pytest
from tests.conftest import DATABASE_URL
from tests.utils.helpers import (
BACKEND_URL,
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
# ---------------------------------------------------------------------------
# Backend connectivity fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def backend_url() -> str:
return BACKEND_URL
@pytest.fixture(scope="session")
async def auth_token(backend_url: str) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
return await get_auth_token(client)
@pytest.fixture(scope="session")
async def search_space_id(backend_url: str, auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
return await get_search_space_id(client, auth_token)
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(
search_space_id: int,
):
"""
Delete all documents in the test search space before the session starts.
Uses direct database access to bypass the API's 409 protection on
pending/processing documents. This ensures stuck documents from
previous crashed runs are always cleaned up.
"""
deleted = await _force_delete_documents_db(search_space_id)
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) from search space {search_space_id}"
)
yield
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
"""Authorization headers reused across all tests in the session."""
return auth_headers(auth_token)
@pytest.fixture
async def client(backend_url: str) -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client pointing at the running backend."""
async with httpx.AsyncClient(base_url=backend_url, timeout=180.0) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
"""
Runs after every test. Tries the API first for clean deletes, then
falls back to direct DB access for any stuck documents.
"""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB access)
# ---------------------------------------------------------------------------
async def _force_delete_documents_db(search_space_id: int) -> int:
"""
Bypass the API and delete documents directly from the database.
This handles stuck documents in pending/processing state that the API
refuses to delete (409 Conflict). Chunks are cascade-deleted by the
foreign key constraint.
Returns the number of deleted rows.
"""
conn = await asyncpg.connect(DATABASE_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
return int(result.split()[-1])
finally:
await conn.close()
async def _get_user_page_usage(email: str) -> tuple[int, int]:
"""Return ``(pages_used, pages_limit)`` for the given user."""
conn = await asyncpg.connect(DATABASE_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
"""Overwrite ``pages_used`` and ``pages_limit`` for the given user."""
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""
Fixture that exposes helpers for manipulating the test user's page limits.
Automatically restores the original values after each test.
Usage inside a test::
await page_limits.set(pages_used=0, pages_limit=100)
used, limit = await page_limits.get()
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
async def get(self) -> tuple[int, int]:
return await _get_user_page_usage(TEST_EMAIL)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)

View file

@ -1,4 +1,3 @@
import os
import uuid
from unittest.mock import AsyncMock, MagicMock
@ -8,6 +7,7 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config as app_config
from app.db import (
Base,
DocumentType,
@ -17,13 +17,9 @@ from app.db import (
User,
)
from app.indexing_pipeline.connector_document import ConnectorDocument
from tests.conftest import TEST_DATABASE_URL
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
_DEFAULT_TEST_DB = (
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
)
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
@pytest_asyncio.fixture(scope="session")

View file

@ -0,0 +1,283 @@
"""Integration conftest — runs the FastAPI app in-process via ASGITransport.
Prerequisites: PostgreSQL + pgvector only.
External system boundaries are mocked:
- LLM summarization, text embedding, text chunking (external APIs)
- Redis heartbeat (external infrastructure)
- Task dispatch is swapped via DI (InlineTaskDispatcher)
"""
from __future__ import annotations
import contextlib
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import asyncpg
import httpx
import pytest
from httpx import ASGITransport
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
from app.app import app
from app.config import config as app_config
from app.db import Base
from app.services.task_dispatcher import get_task_dispatcher
from tests.integration.conftest import TEST_DATABASE_URL
from tests.utils.helpers import (
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
_ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://")
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Inline task dispatcher (replaces Celery via DI — not a mock)
# ---------------------------------------------------------------------------
class InlineTaskDispatcher:
"""Processes files synchronously in the calling coroutine.
Swapped in via FastAPI dependency_overrides so the upload endpoint
processes documents inline instead of dispatching to Celery.
Exceptions are caught to match Celery's fire-and-forget semantics —
the processing function already marks documents as failed internally.
"""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
_process_file_with_document,
)
with contextlib.suppress(Exception):
await _process_file_with_document(
document_id, temp_path, filename, search_space_id, user_id
)
app.dependency_overrides[get_task_dispatcher] = lambda: InlineTaskDispatcher()
# ---------------------------------------------------------------------------
# Database setup (ASGITransport skips the app lifespan)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def _ensure_tables():
"""Create DB tables and extensions once per session."""
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
await conn.run_sync(Base.metadata.create_all)
await engine.dispose()
# ---------------------------------------------------------------------------
# Auth & search space (session-scoped, via the in-process app)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def auth_token(_ensure_tables) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_auth_token(c)
@pytest.fixture(scope="session")
async def search_space_id(auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_search_space_id(c, auth_token)
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
return auth_headers(auth_token)
# ---------------------------------------------------------------------------
# Per-test HTTP client & cleanup
# ---------------------------------------------------------------------------
@pytest.fixture
async def client() -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client using ASGITransport (no running server)."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=180.0
) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(search_space_id: int):
"""Delete stale documents from previous runs before the session starts."""
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
deleted = int(result.split()[-1])
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) "
f"from search space {search_space_id}"
)
finally:
await conn.close()
yield
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
cleanup_doc_ids: list[int],
):
"""Delete test documents after every test (API first, DB fallback)."""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB for setup, API for verification)
# ---------------------------------------------------------------------------
async def _get_user_page_usage(email: str) -> tuple[int, int]:
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
conn = await asyncpg.connect(_ASYNCPG_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""Manipulate the test user's page limits (direct DB for setup only).
Automatically restores original values after each test.
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)
# ---------------------------------------------------------------------------
# Mock external system boundaries
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _mock_external_apis(monkeypatch):
"""Mock LLM, embedding, and chunking — these are external API boundaries."""
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Mocked summary."),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
MagicMock(return_value=[0.1] * _EMBEDDING_DIM),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
MagicMock(return_value=["Test chunk content."]),
)
@pytest.fixture(autouse=True)
def _mock_redis_heartbeat(monkeypatch):
"""Mock Redis heartbeat — Redis is an external infrastructure boundary."""
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._start_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._stop_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
AsyncMock(),
)

View file

@ -1,14 +1,10 @@
"""
End-to-end tests for manual document upload.
Integration tests for the document upload HTTP API.
These tests exercise the full pipeline:
API upload Celery task ETL extraction chunking embedding DB storage
Covers the API contract, auth, duplicate detection, and error handling.
Pipeline internals are tested in the ``indexing_pipeline`` suite.
Prerequisites (must be running):
- FastAPI backend
- PostgreSQL + pgvector
- Redis
- Celery worker
Requires PostgreSQL + pgvector.
"""
from __future__ import annotations
@ -21,36 +17,21 @@ import pytest
from tests.utils.helpers import (
FIXTURES_DIR,
delete_document,
get_document,
poll_document_status,
upload_file,
upload_multiple_files,
)
pytestmark = pytest.mark.e2e
# ---------------------------------------------------------------------------
# Helpers local to this module
# ---------------------------------------------------------------------------
def _assert_document_ready(doc: dict, *, expected_filename: str) -> None:
"""Common assertions for a successfully processed document."""
assert doc["title"] == expected_filename
assert doc["document_type"] == "FILE"
assert doc["content"], "Document content (summary) should not be empty"
assert doc["content_hash"], "content_hash should be set"
assert doc["document_metadata"].get("FILE_NAME") == expected_filename
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Test A: Upload a .txt file (direct read path — no ETL service needed)
# Upload smoke tests (one per distinct code-path: direct-read & ETL)
# ---------------------------------------------------------------------------
class TestTxtFileUpload:
"""Upload a plain-text file and verify the full pipeline."""
"""Upload a plain-text file (direct-read path) via the HTTP API."""
async def test_upload_txt_returns_document_id(
self,
@ -89,85 +70,9 @@ class TestTxtFileUpload:
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_txt_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.txt")
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
# ---------------------------------------------------------------------------
# Test B: Upload a .md file (markdown direct-read path)
# ---------------------------------------------------------------------------
class TestMarkdownFileUpload:
"""Upload a Markdown file and verify the full pipeline."""
async def test_md_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.md", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_md_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.md", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.md")
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
# ---------------------------------------------------------------------------
# Test C: Upload a .pdf file (ETL path — Docling / Unstructured)
# ---------------------------------------------------------------------------
class TestPdfFileUpload:
"""Upload a PDF and verify it goes through the ETL extraction pipeline."""
"""Upload a PDF (ETL extraction path) via the HTTP API."""
async def test_pdf_processing_reaches_ready(
self,
@ -189,31 +94,6 @@ class TestPdfFileUpload:
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_pdf_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.pdf")
assert doc["document_metadata"]["ETL_SERVICE"] in {
"DOCLING",
"UNSTRUCTURED",
"LLAMACLOUD",
}
# ---------------------------------------------------------------------------
# Test D: Upload multiple files in a single request
@ -221,7 +101,7 @@ class TestPdfFileUpload:
class TestMultiFileUpload:
"""Upload several files at once and verify all are processed."""
"""Upload several files at once and verify the API response contract."""
async def test_multi_upload_returns_all_ids(
self,
@ -243,28 +123,6 @@ class TestMultiFileUpload:
assert len(body["document_ids"]) == 2
cleanup_doc_ids.extend(body["document_ids"])
async def test_multi_upload_all_reach_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_multiple_files(
client,
headers,
["sample.txt", "sample.md"],
search_space_id=search_space_id,
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
# ---------------------------------------------------------------------------
# Test E: Duplicate file upload (same file uploaded twice)
@ -284,7 +142,6 @@ class TestDuplicateFileUpload:
search_space_id: int,
cleanup_doc_ids: list[int],
):
# First upload
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
@ -296,7 +153,6 @@ class TestDuplicateFileUpload:
client, headers, first_ids, search_space_id=search_space_id
)
# Second upload of the same file
resp2 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
@ -327,7 +183,6 @@ class TestDuplicateContentDetection:
cleanup_doc_ids: list[int],
tmp_path: Path,
):
# First upload
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
@ -338,7 +193,6 @@ class TestDuplicateContentDetection:
client, headers, first_ids, search_space_id=search_space_id
)
# Copy fixture content to a differently named temp file
src = FIXTURES_DIR / "sample.txt"
dest = tmp_path / "renamed_sample.txt"
shutil.copy2(src, dest)
@ -445,71 +299,7 @@ class TestNoFilesUpload:
# ---------------------------------------------------------------------------
# Test J: Document deletion after successful upload
# ---------------------------------------------------------------------------
class TestDocumentDeletion:
"""Upload, wait for ready, delete, then verify it's gone."""
async def test_delete_processed_document(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
del_resp = await delete_document(client, headers, doc_ids[0])
assert del_resp.status_code == 200
get_resp = await client.get(
f"/api/v1/documents/{doc_ids[0]}",
headers=headers,
)
assert get_resp.status_code == 404
# ---------------------------------------------------------------------------
# Test K: Cannot delete a document while it is still processing
# ---------------------------------------------------------------------------
class TestDeleteWhileProcessing:
"""Attempting to delete a pending/processing document should be rejected."""
async def test_delete_pending_document_returns_409(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
# Immediately try to delete before processing finishes
del_resp = await delete_document(client, headers, doc_ids[0])
assert del_resp.status_code == 409
# Let it finish so cleanup can work
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
# ---------------------------------------------------------------------------
# Test L: Status polling returns correct structure
# Test K: Searchability after upload
# ---------------------------------------------------------------------------
@ -545,48 +335,3 @@ class TestDocumentSearchability:
assert doc_ids[0] in result_ids, (
f"Uploaded document {doc_ids[0]} not found in search results: {result_ids}"
)
class TestStatusPolling:
"""Verify the status endpoint returns well-formed responses."""
async def test_status_endpoint_returns_items(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
status_resp = await client.get(
"/api/v1/documents/status",
headers=headers,
params={
"search_space_id": search_space_id,
"document_ids": ",".join(str(d) for d in doc_ids),
},
)
assert status_resp.status_code == 200
body = status_resp.json()
assert "items" in body
assert len(body["items"]) == len(doc_ids)
for item in body["items"]:
assert "id" in item
assert "status" in item
assert "state" in item["status"]
assert item["status"]["state"] in {
"pending",
"processing",
"ready",
"failed",
}
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)

View file

@ -1,23 +1,20 @@
"""
End-to-end tests for page-limit enforcement during document upload.
Integration tests for page-limit enforcement during document upload.
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
columns directly in the database and then exercise the upload pipeline to
verify that:
columns directly in the database (setup only) and then exercise the upload
pipeline to verify that:
- Uploads are rejected *before* ETL when the limit is exhausted.
- ``pages_used`` increases after a successful upload.
- ``pages_used`` increases after a successful upload (verified via API).
- A ``page_limit_exceeded`` notification is created on rejection.
- ``pages_used`` is not modified when a document fails processing.
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
so no additional processing time is introduced.
Prerequisites (must be running):
- FastAPI backend
Prerequisites:
- PostgreSQL + pgvector
- Redis
- Celery worker
"""
from __future__ import annotations
@ -31,7 +28,21 @@ from tests.utils.helpers import (
upload_file,
)
pytestmark = pytest.mark.e2e
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helper: read pages_used through the public API
# ---------------------------------------------------------------------------
async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int:
"""Fetch the current user's pages_used via the /users/me API."""
resp = await client.get("/users/me", headers=headers)
assert resp.status_code == 200, (
f"GET /users/me failed ({resp.status_code}): {resp.text}"
)
return resp.json()["pages_used"]
# ---------------------------------------------------------------------------
@ -65,7 +76,7 @@ class TestPageUsageIncrementsOnSuccess:
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
used, _ = await page_limits.get()
used = await _get_pages_used(client, headers)
assert used > 0, "pages_used should have increased after successful processing"
@ -128,7 +139,7 @@ class TestUploadRejectedWhenLimitExhausted:
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
used, _ = await page_limits.get()
used = await _get_pages_used(client, headers)
assert used == 50, (
f"pages_used should remain 50 after rejected upload, got {used}"
)
@ -263,7 +274,7 @@ class TestPagesUnchangedOnProcessingFailure:
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
used, _ = await page_limits.get()
used = await _get_pages_used(client, headers)
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
@ -284,7 +295,6 @@ class TestSecondUploadExceedsLimit:
cleanup_doc_ids: list[int],
page_limits,
):
# Give just enough room for one ~1-page PDF
await page_limits.set(pages_used=0, pages_limit=1)
resp1 = await upload_file(
@ -300,7 +310,6 @@ class TestSecondUploadExceedsLimit:
for did in first_ids:
assert statuses1[did]["status"]["state"] == "ready"
# Second upload — should fail because quota is now consumed
resp2 = await upload_file(
client,
headers,

View file

@ -1,5 +1,5 @@
"""
End-to-end tests for backend file upload limit enforcement.
Integration tests for backend file upload limit enforcement.
These tests verify that the API rejects uploads that exceed:
- Max files per upload (10)
@ -9,8 +9,7 @@ These tests verify that the API rejects uploads that exceed:
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
enforced server-side to protect against direct API calls.
Prerequisites (must be running):
- FastAPI backend
Prerequisites:
- PostgreSQL + pgvector
"""
@ -21,7 +20,7 @@ import io
import httpx
import pytest
pytestmark = pytest.mark.e2e
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------

View file

@ -1,9 +1,12 @@
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
@ -144,7 +147,7 @@ async def test_embedding_written_to_db(
reloaded = result.scalars().first()
assert reloaded.embedding is not None
assert len(reloaded.embedding) == 1024
assert len(reloaded.embedding) == _EMBEDDING_DIM
@pytest.mark.usefixtures(

View file

@ -3,16 +3,14 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
import httpx
FIXTURES_DIR = Path(__file__).resolve().parent.parent / "fixtures"
BACKEND_URL = os.environ.get("TEST_BACKEND_URL", "http://localhost:8000")
TEST_EMAIL = os.environ.get("TEST_USER_EMAIL", "testuser@surfsense.com")
TEST_PASSWORD = os.environ.get("TEST_USER_PASSWORD", "testpassword123")
TEST_EMAIL = "testuser@surfsense.com"
TEST_PASSWORD = "testpassword123"
async def get_auth_token(client: httpx.AsyncClient) -> str:

View file

@ -2,10 +2,24 @@ import { DocsLayout } from "fumadocs-ui/layouts/docs";
import type { ReactNode } from "react";
import { baseOptions } from "@/app/layout.config";
import { source } from "@/lib/source";
import { SidebarSeparator } from "./sidebar-separator";
const gridTemplate = `"sidebar header toc"
"sidebar toc-popover toc"
"sidebar main toc" 1fr / var(--fd-sidebar-col) minmax(0, 1fr) min-content`;
export default function Layout({ children }: { children: ReactNode }) {
return (
<DocsLayout tree={source.pageTree} {...baseOptions}>
<DocsLayout
tree={source.pageTree}
{...baseOptions}
containerProps={{ style: { gridTemplate }, className: "bg-fd-card" }}
sidebar={{
components: {
Separator: SidebarSeparator,
},
}}
>
{children}
</DocsLayout>
);

View file

@ -0,0 +1,12 @@
"use client";
import type { Separator } from "fumadocs-core/page-tree";
export function SidebarSeparator({ item }: { item: Separator }) {
return (
<p className="inline-flex items-center gap-2 mb-1.5 px-2 mt-6 font-semibold first:mt-0 empty:mb-0">
{item.icon}
{item.name}
</p>
);
}

View file

@ -235,3 +235,4 @@ button {
@source '../node_modules/streamdown/dist/*.js';
@source '../node_modules/@streamdown/code/dist/*.js';
@source '../node_modules/@streamdown/math/dist/*.js';

View file

@ -1,7 +1,7 @@
import type { BaseLayoutProps } from "fumadocs-ui/layouts/shared";
export const baseOptions: BaseLayoutProps = {
nav: {
title: "SurfSense Documentation",
title: "SurfSense Docs",
},
githubUrl: "https://github.com/MODSetter/SurfSense",
};

View file

@ -1,5 +1,6 @@
{
"title": "Connectors",
"icon": "Cable",
"pages": [
"google-drive",
"gmail",

View file

@ -1,6 +1,7 @@
---
title: Docker Installation
description: Setting up SurfSense using Docker
icon: Container
---
This guide explains how to run SurfSense using Docker, with options ranging from a single-command install to a fully manual setup.

View file

@ -1,5 +1,6 @@
{
"title": "How to",
"pages": ["electric-sql", "realtime-collaboration", "migrate-from-allinone"],
"icon": "BookOpen",
"defaultOpen": false
}

View file

@ -1,6 +1,7 @@
---
title: Prerequisites
description: Required setup's before setting up SurfSense
icon: ClipboardCheck
---

View file

@ -1,6 +1,7 @@
---
title: Installation
description: Current ways to use SurfSense
icon: Download
---
# Installing SurfSense

View file

@ -1,6 +1,7 @@
---
title: Manual Installation
description: Setting up SurfSense manually for customized deployments (Preferred)
icon: Wrench
---
# Manual Installation (Preferred)

View file

@ -1,22 +1,18 @@
---
title: Testing
description: Running and writing end-to-end tests for SurfSense
description: Running and writing tests for SurfSense
icon: FlaskConical
---
SurfSense uses [pytest](https://docs.pytest.org/) for end-to-end testing. Tests are **self-bootstrapping** — they automatically register a test user and discover search spaces, so no manual database setup is required.
SurfSense uses [pytest](https://docs.pytest.org/) with two test layers: **unit** tests (no database) and **integration** tests (require PostgreSQL + pgvector). Tests are self-bootstrapping — they configure the test database, register a user, and clean up automatically.
## Prerequisites
Before running tests, make sure the full backend stack is running:
- **PostgreSQL + pgvector** running locally (database `surfsense_test` will be used)
- **`REGISTRATION_ENABLED=TRUE`** in your `.env` (this is the default)
- A working LLM model with a valid API key in `global_llm_config.yaml` (for integration tests)
- **FastAPI backend**
- **PostgreSQL + pgvector**
- **Redis**
- **Celery worker**
Your backend must have **`REGISTRATION_ENABLED=TRUE`** in its `.env` (this is the default). The tests register their own user on first run.
Your `global_llm_config.yaml` must have at least one working LLM model with a valid API key — document processing uses Auto mode, which routes through the global config.
No Redis or Celery is required — integration tests use an inline task dispatcher.
## Running Tests
@ -26,19 +22,19 @@ Your `global_llm_config.yaml` must have at least one working LLM model with a va
uv run pytest
```
**Run by marker** (e.g., only document tests):
**Run by marker:**
```bash
uv run pytest -m document
uv run pytest -m unit # fast, no DB needed
uv run pytest -m integration # requires PostgreSQL + pgvector
```
**Available markers:**
| Marker | Description |
|---|---|
| `document` | Document upload, processing, and deletion tests |
| `connector` | Connector indexing tests |
| `chat` | Chat and agent tests |
| `unit` | Pure logic tests, no DB or external services |
| `integration` | Tests that require a real PostgreSQL database |
**Useful flags:**
@ -51,11 +47,11 @@ uv run pytest -m document
## Configuration
Default pytest options are configured in `surfsense_backend/pyproject.toml`:
Default pytest options are in `surfsense_backend/pyproject.toml`:
```toml
[tool.pytest.ini_options]
addopts = "-v --tb=short -x --strict-markers -ra --durations=10"
addopts = "-v --tb=short -x --strict-markers -ra --durations=5"
```
- `-v` — verbose test names
@ -63,42 +59,47 @@ addopts = "-v --tb=short -x --strict-markers -ra --durations=10"
- `-x` — stop on first failure
- `--strict-markers` — reject unregistered markers
- `-ra` — show summary of all non-passing tests
- `--durations=10` — show the 10 slowest tests
- `--durations=5` — show the 5 slowest tests
## Environment Variables
All test configuration has sensible defaults. Override via environment variables if needed:
| Variable | Default | Description |
|---|---|---|
| `TEST_BACKEND_URL` | `http://localhost:8000` | Backend URL to test against |
| `TEST_DATABASE_URL` | Falls back to `DATABASE_URL` | Direct DB connection for test cleanup |
| `TEST_USER_EMAIL` | `testuser@surfsense.com` | Test user email |
| `TEST_USER_PASSWORD` | `testpassword123` | Test user password |
| `TEST_DATABASE_URL` | `postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test` | Database URL for tests |
These can be configured in `surfsense_backend/.env` (see the Testing section at the bottom of `.env.example`).
The test suite forces `DATABASE_URL` to point at the test database, so your production database is never touched.
### Unit Tests
Pure logic tests that run without a database. Cover model validation, chunking, hashing, and summarization.
### Integration Tests
Require PostgreSQL + pgvector. Split into two suites:
- **`document_upload/`** — Tests the HTTP API through public endpoints: upload, multi-file, duplicate detection, auth, error handling, page limits, and file size limits. Uses an in-process FastAPI client with `ASGITransport`.
- **`indexing_pipeline/`** — Tests pipeline internals directly: `prepare_for_indexing`, `index()`, and `index_uploaded_file()` covering chunking, embedding, summarization, fallbacks, and error handling.
External boundaries (LLM, embedding, chunking, Redis) are mocked in both suites.
## How It Works
Tests are fully self-bootstrapping:
1. **User creation** — on first run, tests try to log in. If the user doesn't exist, they register via `POST /auth/register`, then log in.
2. **Search space discovery** — after authentication, tests call `GET /api/v1/searchspaces` and use the first available search space (auto-created during registration).
3. **Session purge** — before any tests run, a session-scoped fixture deletes all documents in the test search space directly via the database. This handles stuck documents from previous crashed runs that the API refuses to delete (409 Conflict).
4. **Per-test cleanup** — every test that creates documents adds their IDs to a `cleanup_doc_ids` list. An autouse fixture deletes them after each test via the API, falling back to direct DB access for any stuck documents.
This means tests work on both fresh databases and existing ones without any manual setup.
1. **Database setup** — `TEST_DATABASE_URL` defaults to `surfsense_test`. Tables and extensions (`vector`, `pg_trgm`) are created once per session and dropped after.
2. **Transaction isolation** — Each test runs inside a savepoint that rolls back, so tests don't affect each other.
3. **User creation** — Integration tests register a test user via `POST /auth/register` on first run, then log in for subsequent requests.
4. **Search space discovery** — Tests call `GET /api/v1/searchspaces` and use the first available space.
5. **Cleanup** — A session fixture purges stale documents before tests run. Per-test cleanup deletes documents via API, falling back to direct DB access for stuck records.
## Writing New Tests
1. Create a test file in the appropriate directory (e.g., `tests/e2e/test_connectors.py`).
2. Add a module-level marker at the top:
1. Create a test file in the appropriate directory (`unit/` or `integration/`).
2. Add the marker at the top of the file:
```python
import pytest
pytestmark = pytest.mark.connector
pytestmark = pytest.mark.integration # or pytest.mark.unit
```
3. Use fixtures from `conftest.py` — `client`, `headers`, `search_space_id`, and `cleanup_doc_ids` are available to all tests.
3. Use fixtures from `conftest.py` — `client`, `headers`, `search_space_id`, and `cleanup_doc_ids` are available to integration tests. Unit tests get `make_connector_document` and sample ID fixtures.
4. Register any new markers in `pyproject.toml` under `markers`.

View file

@ -1,7 +1,13 @@
import { loader } from "fumadocs-core/source";
import { docs } from "@/.source/server";
import { icons } from "lucide-react";
import { createElement } from "react";
export const source = loader({
baseUrl: "/docs",
source: docs.toFumadocsSource(),
icon(icon) {
if (icon && icon in icons)
return createElement(icons[icon as keyof typeof icons]);
},
});