Merge pull request #836 from AnishSarkar22/feat/document-test

feat: add document upload E2E tests
This commit is contained in:
Rohan Verma 2026-02-25 14:22:07 -08:00 committed by GitHub
commit 30617c6e54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 5796 additions and 2874 deletions

View file

@ -175,4 +175,13 @@ DAYTONA_API_KEY=dtn_asdasfasfafas
DAYTONA_API_URL=https://app.daytona.io/api DAYTONA_API_URL=https://app.daytona.io/api
DAYTONA_TARGET=us DAYTONA_TARGET=us
# Directory for locally-persisted sandbox files (after sandbox deletion) # Directory for locally-persisted sandbox files (after sandbox deletion)
SANDBOX_FILES_DIR=sandbox_files SANDBOX_FILES_DIR=sandbox_files
# ============================================================
# Testing (optional — all have sensible defaults)
# ============================================================
# TEST_BACKEND_URL=http://localhost:8000
# TEST_DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
# TEST_USER_EMAIL=testuser@surfsense.com
# TEST_USER_PASSWORD=testpassword123

View file

@ -147,7 +147,9 @@ async def delete_sandbox(thread_id: int | str) -> None:
try: try:
sandbox = client.find_one(labels=labels) sandbox = client.find_one(labels=labels)
except DaytonaError: except DaytonaError:
logger.debug("No sandbox to delete for thread %s (already removed)", thread_id) logger.debug(
"No sandbox to delete for thread %s (already removed)", thread_id
)
return return
try: try:
client.delete(sandbox) client.delete(sandbox)
@ -166,6 +168,7 @@ async def delete_sandbox(thread_id: int | str) -> None:
# Local file persistence # Local file persistence
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _get_sandbox_files_dir() -> Path: def _get_sandbox_files_dir() -> Path:
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files")) return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))

View file

@ -5,6 +5,7 @@ from app.db import DocumentType
class ConnectorDocument(BaseModel): class ConnectorDocument(BaseModel):
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline.""" """Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
title: str title: str
source_markdown: str source_markdown: str
unique_id: str unique_id: str

View file

@ -3,5 +3,7 @@ from app.config import config
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]: def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
"""Chunk a text string using the configured chunker and return the chunk texts.""" """Chunk a text string using the configured chunker and return the chunk texts."""
chunker = config.code_chunker_instance if use_code_chunker else config.chunker_instance chunker = (
config.code_chunker_instance if use_code_chunker else config.chunker_instance
)
return [c.text for c in chunker.chunk(text)] return [c.text for c in chunker.chunk(text)]

View file

@ -2,7 +2,9 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import optimize_content_for_context_window from app.utils.document_converters import optimize_content_for_context_window
async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str: async def summarize_document(
source_markdown: str, llm, metadata: dict | None = None
) -> str:
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided.""" """Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
model_name = getattr(llm, "model", "gpt-3.5-turbo") model_name = getattr(llm, "model", "gpt-3.5-turbo")
optimized_content = optimize_content_for_context_window( optimized_content = optimize_content_for_context_window(

View file

@ -12,7 +12,6 @@ from litellm.exceptions import (
Timeout, Timeout,
UnprocessableEntityError, UnprocessableEntityError,
) )
from sqlalchemy.exc import IntegrityError
# Tuples for use directly in except clauses. # Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = ( RETRYABLE_LLM_ERRORS = (
@ -36,29 +35,33 @@ PERMANENT_LLM_ERRORS = (
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError). # (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
EMBEDDING_ERRORS = ( EMBEDDING_ERRORS = (
RuntimeError, # local device failure or API backend normalization RuntimeError, # local device failure or API backend normalization
OSError, # model files missing or corrupted (local backends) OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM MemoryError, # document too large for available RAM
) )
class PipelineMessages: class PipelineMessages:
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync." RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync." LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync." LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync." LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync." LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity." LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
LLM_AUTH = "LLM authentication failed. Check your API key." LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions." LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration." LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid." LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = "Document exceeds the LLM context window even after optimization." LLM_UNPROCESSABLE = (
LLM_RESPONSE = "LLM returned an invalid response." "Document exceeds the LLM context window even after optimization."
)
LLM_RESPONSE = "LLM returned an invalid response."
EMBEDDING_FAILED = "Embedding failed. Check your embedding model configuration or service." EMBEDDING_FAILED = (
EMBEDDING_MODEL = "Embedding model files are missing or corrupted." "Embedding failed. Check your embedding model configuration or service."
EMBEDDING_MEMORY = "Not enough memory to embed this document." )
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk." CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."

View file

@ -2,6 +2,7 @@ import contextlib
from datetime import UTC, datetime from datetime import UTC, datetime
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus from app.db import Chunk, Document, DocumentStatus
@ -21,7 +22,6 @@ from app.indexing_pipeline.exceptions import (
EMBEDDING_ERRORS, EMBEDDING_ERRORS,
PERMANENT_LLM_ERRORS, PERMANENT_LLM_ERRORS,
RETRYABLE_LLM_ERRORS, RETRYABLE_LLM_ERRORS,
IntegrityError,
PipelineMessages, PipelineMessages,
embedding_message, embedding_message,
llm_permanent_message, llm_permanent_message,

View file

@ -8,27 +8,29 @@ logger = logging.getLogger(__name__)
class PipelineLogContext: class PipelineLogContext:
connector_id: int | None connector_id: int | None
search_space_id: int search_space_id: int
unique_id: str # always available from ConnectorDocument unique_id: str # always available from ConnectorDocument
doc_id: int | None = None # set once the DB row exists (index phase only) doc_id: int | None = None # set once the DB row exists (index phase only)
class LogMessages: class LogMessages:
# prepare_for_indexing # prepare_for_indexing
DOCUMENT_QUEUED = "New document queued for indexing." DOCUMENT_QUEUED = "New document queued for indexing."
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing." DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing." DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped." DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
BATCH_ABORTED = "Fatal DB error — aborting prepare batch." BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch." RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
# index # index
INDEX_STARTED = "Document indexing started." INDEX_STARTED = "Document indexing started."
INDEX_SUCCESS = "Document indexed successfully." INDEX_SUCCESS = "Document indexed successfully."
LLM_RETRYABLE = "Retryable LLM error — document marked failed, will retry on next sync." LLM_RETRYABLE = (
LLM_PERMANENT = "Permanent LLM error — document marked failed." "Retryable LLM error — document marked failed, will retry on next sync."
EMBEDDING_FAILED = "Embedding error — document marked failed." )
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed." LLM_PERMANENT = "Permanent LLM error — document marked failed."
UNEXPECTED = "Unexpected error — document marked failed." EMBEDDING_FAILED = "Embedding error — document marked failed."
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
UNEXPECTED = "Unexpected error — document marked failed."
def _format_context(ctx: PipelineLogContext) -> str: def _format_context(ctx: PipelineLogContext) -> str:
@ -52,7 +54,9 @@ def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
return msg return msg
def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra) -> None: def _safe_log(
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
) -> None:
# Logging must never raise — a broken log call inside an except block would # Logging must never raise — a broken log call inside an except block would
# chain with the original exception and mask it entirely. # chain with the original exception and mask it entirely.
try: try:
@ -64,6 +68,7 @@ def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extr
# ── prepare_for_indexing ────────────────────────────────────────────────────── # ── prepare_for_indexing ──────────────────────────────────────────────────────
def log_document_queued(ctx: PipelineLogContext) -> None: def log_document_queued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx) _safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
@ -77,7 +82,9 @@ def log_document_requeued(ctx: PipelineLogContext) -> None:
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None: def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc) _safe_log(
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
)
def log_race_condition(ctx: PipelineLogContext) -> None: def log_race_condition(ctx: PipelineLogContext) -> None:
@ -90,6 +97,7 @@ def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
# ── index ───────────────────────────────────────────────────────────────────── # ── index ─────────────────────────────────────────────────────────────────────
def log_index_started(ctx: PipelineLogContext) -> None: def log_index_started(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx) _safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)

View file

@ -44,6 +44,10 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
router = APIRouter() router = APIRouter()
MAX_FILES_PER_UPLOAD = 10
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
@router.post("/documents") @router.post("/documents")
async def create_documents( async def create_documents(
@ -148,12 +152,37 @@ async def create_documents_file_upload(
if not files: if not files:
raise HTTPException(status_code=400, detail="No files provided") raise HTTPException(status_code=400, detail="No files provided")
if len(files) > MAX_FILES_PER_UPLOAD:
raise HTTPException(
status_code=413,
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
)
total_size = 0
for file in files:
file_size = file.size or 0
if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
total_size += file_size
if total_size > MAX_TOTAL_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
created_documents: list[Document] = [] created_documents: list[Document] = []
files_to_process: list[ files_to_process: list[
tuple[Document, str, str] tuple[Document, str, str]
] = [] # (document, temp_path, filename) ] = [] # (document, temp_path, filename)
skipped_duplicates = 0 skipped_duplicates = 0
duplicate_document_ids: list[int] = [] duplicate_document_ids: list[int] = []
actual_total_size = 0
# ===== PHASE 1: Create pending documents for all files ===== # ===== PHASE 1: Create pending documents for all files =====
# This makes ALL documents visible in the UI immediately with pending status # This makes ALL documents visible in the UI immediately with pending status
@ -169,11 +198,28 @@ async def create_documents_file_upload(
temp_path = temp_file.name temp_path = temp_file.name
content = await file.read() content = await file.read()
file_size = len(content)
if file_size > MAX_FILE_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
actual_total_size += file_size
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
with open(temp_path, "wb") as f: with open(temp_path, "wb") as f:
f.write(content) f.write(content)
file_size = len(content)
# Generate unique identifier for deduplication check # Generate unique identifier for deduplication check
unique_identifier_hash = generate_unique_identifier_hash( unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, file.filename or "unknown", search_space_id DocumentType.FILE, file.filename or "unknown", search_space_id

View file

@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- POST /threads/{thread_id}/messages - Append message - POST /threads/{thread_id}/messages - Append message
""" """
import asyncio
import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
@ -52,9 +54,6 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user from app.users import current_active_user
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
import asyncio
import logging
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -75,11 +74,19 @@ def _try_delete_sandbox(thread_id: int) -> None:
try: try:
await delete_sandbox(thread_id) await delete_sandbox(thread_id)
except Exception: except Exception:
_logger.warning("Background sandbox delete failed for thread %s", thread_id, exc_info=True) _logger.warning(
"Background sandbox delete failed for thread %s",
thread_id,
exc_info=True,
)
try: try:
delete_local_sandbox_files(thread_id) delete_local_sandbox_files(thread_id)
except Exception: except Exception:
_logger.warning("Local sandbox file cleanup failed for thread %s", thread_id, exc_info=True) _logger.warning(
"Local sandbox file cleanup failed for thread %s",
thread_id,
exc_info=True,
)
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()

View file

@ -87,7 +87,7 @@ async def download_sandbox_file(
# Fall back to live sandbox download # Fall back to live sandbox download
try: try:
sandbox = await get_or_create_sandbox(thread_id) sandbox = await get_or_create_sandbox(thread_id)
raw_sandbox = sandbox._sandbox # noqa: SLF001 raw_sandbox = sandbox._sandbox
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path) content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
except Exception as exc: except Exception as exc:
logger.warning("Sandbox file download failed for %s: %s", path, exc) logger.warning("Sandbox file download failed for %s: %s", path, exc)

View file

@ -1303,10 +1303,9 @@ class ConnectorService:
sources_list = self._build_chunk_sources_from_documents( sources_list = self._build_chunk_sources_from_documents(
github_docs, github_docs,
description_fn=lambda chunk, _doc_info, metadata: metadata.get( description_fn=lambda chunk, _doc_info, metadata: (
"description" metadata.get("description") or chunk.get("content", "")
) ),
or chunk.get("content", ""),
url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "", url_fn=lambda _doc_info, metadata: metadata.get("url", "") or "",
) )

View file

@ -877,7 +877,9 @@ async def _stream_agent_events(
output_text = om.group(1) if om else "" output_text = om.group(1) if om else ""
thread_id_str = config.get("configurable", {}).get("thread_id", "") thread_id_str = config.get("configurable", {}).get("thread_id", "")
for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE): for sf_match in re.finditer(
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
):
fpath = sf_match.group(1).strip() fpath = sf_match.group(1).strip()
if fpath and fpath not in result.sandbox_files: if fpath and fpath not in result.sandbox_files:
result.sandbox_files.append(fpath) result.sandbox_files.append(fpath)
@ -963,7 +965,10 @@ def _try_persist_and_delete_sandbox(
sandbox_files: list[str], sandbox_files: list[str],
) -> None: ) -> None:
"""Fire-and-forget: persist sandbox files locally then delete the sandbox.""" """Fire-and-forget: persist sandbox files locally then delete the sandbox."""
from app.agents.new_chat.sandbox import is_sandbox_enabled, persist_and_delete_sandbox from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if not is_sandbox_enabled(): if not is_sandbox_enabled():
return return

View file

@ -1632,6 +1632,8 @@ async def process_file_in_background_with_document(
from app.config import config as app_config from app.config import config as app_config
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
doc_id = document.id
try: try:
markdown_content = None markdown_content = None
etl_service = None etl_service = None
@ -1855,7 +1857,7 @@ async def process_file_in_background_with_document(
content_hash = generate_content_hash(markdown_content, search_space_id) content_hash = generate_content_hash(markdown_content, search_space_id)
existing_by_content = await check_duplicate_document(session, content_hash) existing_by_content = await check_duplicate_document(session, content_hash)
if existing_by_content and existing_by_content.id != document.id: if existing_by_content and existing_by_content.id != doc_id:
# Duplicate content found - mark this document as failed # Duplicate content found - mark this document as failed
logging.info( logging.info(
f"Duplicate content detected for {filename}, " f"Duplicate content detected for {filename}, "
@ -1885,7 +1887,8 @@ async def process_file_in_background_with_document(
log_entry, log_entry,
f"Successfully processed file: {filename}", f"Successfully processed file: {filename}",
{ {
"document_id": document.id, "document_id": doc_id,
"content_hash": content_hash,
"file_type": etl_service, "file_type": etl_service,
}, },
) )
@ -1911,7 +1914,7 @@ async def process_file_in_background_with_document(
{ {
"error_type": type(e).__name__, "error_type": type(e).__name__,
"filename": filename, "filename": filename,
"document_id": document.id, "document_id": doc_id,
}, },
) )
logging.error(f"Error processing file with document: {error_message}") logging.error(f"Error processing file with document: {error_message}")

View file

@ -72,22 +72,10 @@ dependencies = [
[dependency-groups] [dependency-groups]
dev = [ dev = [
"ruff>=0.12.5", "ruff>=0.12.5",
"pytest>=8.0", "pytest>=9.0.2",
"pytest-asyncio>=0.25", "pytest-asyncio>=1.3.0",
"pytest-mock>=3.14", "pytest-mock>=3.14",
] "httpx>=0.28.1",
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
testpaths = ["tests"]
markers = [
"unit: pure logic tests, no DB or external services",
"integration: tests that require a real PostgreSQL database",
]
filterwarnings = [
"ignore::UserWarning:chonkie",
] ]
[tool.ruff] [tool.ruff]
@ -175,10 +163,28 @@ line-ending = "auto"
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
# Group imports by type # Group imports by type
known-first-party = ["app"] known-first-party = ["app", "tests"]
force-single-line = false force-single-line = false
combine-as-imports = true combine-as-imports = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short -x --strict-markers -ra --durations=5"
markers = [
"unit: pure logic tests, no DB or external services",
"integration: tests that require a real PostgreSQL database",
"e2e: tests requiring a running backend and real HTTP calls"
]
filterwarnings = [
"ignore::UserWarning:chonkie",
]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] where = ["."]
include = ["app*", "alembic*"] include = ["app*", "alembic*"]

View file

@ -1,8 +1,29 @@
"""Root conftest — shared fixtures available to all test modules."""
from __future__ import annotations
import os
from pathlib import Path
import pytest import pytest
from dotenv import load_dotenv
from app.db import DocumentType from app.db import DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.connector_document import ConnectorDocument
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
# Shared DB URL referenced by both e2e and integration helper functions.
DATABASE_URL = os.environ.get(
"TEST_DATABASE_URL",
os.environ.get("DATABASE_URL", ""),
).replace("postgresql+asyncpg://", "postgresql://")
# ---------------------------------------------------------------------------
# Unit test fixtures
# ---------------------------------------------------------------------------
@pytest.fixture @pytest.fixture
def sample_user_id() -> str: def sample_user_id() -> str:
@ -21,6 +42,11 @@ def sample_connector_id() -> int:
@pytest.fixture @pytest.fixture
def make_connector_document(): def make_connector_document():
"""
Generic factory for unit tests. Overridden in tests/integration/conftest.py
with real DB-backed IDs for integration tests.
"""
def _make(**overrides): def _make(**overrides):
defaults = { defaults = {
"title": "Test Document", "title": "Test Document",
@ -33,4 +59,5 @@ def make_connector_document():
} }
defaults.update(overrides) defaults.update(overrides)
return ConnectorDocument(**defaults) return ConnectorDocument(**defaults)
return _make return _make

View file

View file

@ -0,0 +1,198 @@
"""E2e conftest — fixtures that require a running backend + database."""
from __future__ import annotations
from collections.abc import AsyncGenerator
import asyncpg
import httpx
import pytest
from tests.conftest import DATABASE_URL
from tests.utils.helpers import (
BACKEND_URL,
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
# ---------------------------------------------------------------------------
# Backend connectivity fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def backend_url() -> str:
return BACKEND_URL
@pytest.fixture(scope="session")
async def auth_token(backend_url: str) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
return await get_auth_token(client)
@pytest.fixture(scope="session")
async def search_space_id(backend_url: str, auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
return await get_search_space_id(client, auth_token)
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(
search_space_id: int,
):
"""
Delete all documents in the test search space before the session starts.
Uses direct database access to bypass the API's 409 protection on
pending/processing documents. This ensures stuck documents from
previous crashed runs are always cleaned up.
"""
deleted = await _force_delete_documents_db(search_space_id)
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) from search space {search_space_id}"
)
yield
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
"""Authorization headers reused across all tests in the session."""
return auth_headers(auth_token)
@pytest.fixture
async def client(backend_url: str) -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client pointing at the running backend."""
async with httpx.AsyncClient(base_url=backend_url, timeout=180.0) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
"""
Runs after every test. Tries the API first for clean deletes, then
falls back to direct DB access for any stuck documents.
"""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB access)
# ---------------------------------------------------------------------------
async def _force_delete_documents_db(search_space_id: int) -> int:
"""
Bypass the API and delete documents directly from the database.
This handles stuck documents in pending/processing state that the API
refuses to delete (409 Conflict). Chunks are cascade-deleted by the
foreign key constraint.
Returns the number of deleted rows.
"""
conn = await asyncpg.connect(DATABASE_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
return int(result.split()[-1])
finally:
await conn.close()
async def _get_user_page_usage(email: str) -> tuple[int, int]:
"""Return ``(pages_used, pages_limit)`` for the given user."""
conn = await asyncpg.connect(DATABASE_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
"""Overwrite ``pages_used`` and ``pages_limit`` for the given user."""
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""
Fixture that exposes helpers for manipulating the test user's page limits.
Automatically restores the original values after each test.
Usage inside a test::
await page_limits.set(pages_used=0, pages_limit=100)
used, limit = await page_limits.get()
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
async def get(self) -> tuple[int, int]:
return await _get_user_page_usage(TEST_EMAIL)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)

View file

@ -0,0 +1,592 @@
"""
End-to-end tests for manual document upload.
These tests exercise the full pipeline:
API upload Celery task ETL extraction chunking embedding DB storage
Prerequisites (must be running):
- FastAPI backend
- PostgreSQL + pgvector
- Redis
- Celery worker
"""
from __future__ import annotations
import shutil
from pathlib import Path
import httpx
import pytest
from tests.utils.helpers import (
FIXTURES_DIR,
delete_document,
get_document,
poll_document_status,
upload_file,
upload_multiple_files,
)
pytestmark = pytest.mark.e2e
# ---------------------------------------------------------------------------
# Helpers local to this module
# ---------------------------------------------------------------------------
def _assert_document_ready(doc: dict, *, expected_filename: str) -> None:
"""Common assertions for a successfully processed document."""
assert doc["title"] == expected_filename
assert doc["document_type"] == "FILE"
assert doc["content"], "Document content (summary) should not be empty"
assert doc["content_hash"], "content_hash should be set"
assert doc["document_metadata"].get("FILE_NAME") == expected_filename
# ---------------------------------------------------------------------------
# Test A: Upload a .txt file (direct read path — no ETL service needed)
# ---------------------------------------------------------------------------
class TestTxtFileUpload:
"""Upload a plain-text file and verify the full pipeline."""
async def test_upload_txt_returns_document_id(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
body = resp.json()
assert body["pending_files"] >= 1
assert len(body["document_ids"]) >= 1
cleanup_doc_ids.extend(body["document_ids"])
async def test_txt_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_txt_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.txt")
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
# ---------------------------------------------------------------------------
# Test B: Upload a .md file (markdown direct-read path)
# ---------------------------------------------------------------------------
class TestMarkdownFileUpload:
"""Upload a Markdown file and verify the full pipeline."""
async def test_md_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.md", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_md_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.md", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.md")
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
# ---------------------------------------------------------------------------
# Test C: Upload a .pdf file (ETL path — Docling / Unstructured)
# ---------------------------------------------------------------------------
class TestPdfFileUpload:
"""Upload a PDF and verify it goes through the ETL extraction pipeline."""
async def test_pdf_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_pdf_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.pdf")
assert doc["document_metadata"]["ETL_SERVICE"] in {
"DOCLING",
"UNSTRUCTURED",
"LLAMACLOUD",
}
# ---------------------------------------------------------------------------
# Test D: Upload multiple files in a single request
# ---------------------------------------------------------------------------
class TestMultiFileUpload:
"""Upload several files at once and verify all are processed."""
async def test_multi_upload_returns_all_ids(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_multiple_files(
client,
headers,
["sample.txt", "sample.md"],
search_space_id=search_space_id,
)
assert resp.status_code == 200
body = resp.json()
assert body["pending_files"] == 2
assert len(body["document_ids"]) == 2
cleanup_doc_ids.extend(body["document_ids"])
async def test_multi_upload_all_reach_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_multiple_files(
client,
headers,
["sample.txt", "sample.md"],
search_space_id=search_space_id,
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
# ---------------------------------------------------------------------------
# Test E: Duplicate file upload (same file uploaded twice)
# ---------------------------------------------------------------------------
class TestDuplicateFileUpload:
"""
Uploading the exact same file a second time should be detected as a
duplicate via ``unique_identifier_hash``.
"""
async def test_duplicate_file_is_skipped(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
# First upload
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id
)
# Second upload of the same file
resp2 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp2.status_code == 200
body2 = resp2.json()
assert body2["skipped_duplicates"] >= 1
assert len(body2["duplicate_document_ids"]) >= 1
cleanup_doc_ids.extend(body2.get("document_ids", []))
# ---------------------------------------------------------------------------
# Test F: Duplicate content detection (different name, same content)
# ---------------------------------------------------------------------------
class TestDuplicateContentDetection:
"""
Uploading a file with a different name but identical content should be
detected as duplicate content via ``content_hash``.
"""
async def test_same_content_different_name_detected(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
tmp_path: Path,
):
# First upload
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id
)
# Copy fixture content to a differently named temp file
src = FIXTURES_DIR / "sample.txt"
dest = tmp_path / "renamed_sample.txt"
shutil.copy2(src, dest)
with open(dest, "rb") as f:
resp2 = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files={"files": ("renamed_sample.txt", f)},
data={"search_space_id": str(search_space_id)},
)
assert resp2.status_code == 200
second_ids = resp2.json()["document_ids"]
cleanup_doc_ids.extend(second_ids)
assert second_ids, (
"Expected at least one document id for renamed duplicate content upload"
)
statuses = await poll_document_status(
client, headers, second_ids, search_space_id=search_space_id
)
for did in second_ids:
assert statuses[did]["status"]["state"] == "failed"
assert "duplicate" in statuses[did]["status"].get("reason", "").lower()
# ---------------------------------------------------------------------------
# Test G: Empty / corrupt file handling
# ---------------------------------------------------------------------------
class TestEmptyFileUpload:
"""An empty file should be processed but ultimately fail gracefully."""
async def test_empty_pdf_fails(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "empty.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
assert doc_ids, "Expected at least one document id for empty PDF upload"
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
assert statuses[did]["status"].get("reason"), (
"Failed document should include a reason"
)
# ---------------------------------------------------------------------------
# Test H: Upload without authentication
# ---------------------------------------------------------------------------
class TestUnauthenticatedUpload:
"""Requests without a valid JWT should be rejected."""
async def test_upload_without_auth_returns_401(
self,
client: httpx.AsyncClient,
search_space_id: int,
):
file_path = FIXTURES_DIR / "sample.txt"
with open(file_path, "rb") as f:
resp = await client.post(
"/api/v1/documents/fileupload",
files={"files": ("sample.txt", f)},
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Test I: Upload with no files attached
# ---------------------------------------------------------------------------
class TestNoFilesUpload:
"""Submitting the form with zero files should return a validation error."""
async def test_no_files_returns_error(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code in {400, 422}
# ---------------------------------------------------------------------------
# Test J: Document deletion after successful upload
# ---------------------------------------------------------------------------
class TestDocumentDeletion:
"""Upload, wait for ready, delete, then verify it's gone."""
async def test_delete_processed_document(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
del_resp = await delete_document(client, headers, doc_ids[0])
assert del_resp.status_code == 200
get_resp = await client.get(
f"/api/v1/documents/{doc_ids[0]}",
headers=headers,
)
assert get_resp.status_code == 404
# ---------------------------------------------------------------------------
# Test K: Cannot delete a document while it is still processing
# ---------------------------------------------------------------------------
class TestDeleteWhileProcessing:
"""Attempting to delete a pending/processing document should be rejected."""
async def test_delete_pending_document_returns_409(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
# Immediately try to delete before processing finishes
del_resp = await delete_document(client, headers, doc_ids[0])
assert del_resp.status_code == 409
# Let it finish so cleanup can work
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
# ---------------------------------------------------------------------------
# Test L: Status polling returns correct structure
# ---------------------------------------------------------------------------
class TestDocumentSearchability:
"""After upload reaches ready, the document must appear in the title search."""
async def test_uploaded_document_appears_in_search(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
search_resp = await client.get(
"/api/v1/documents/search",
headers=headers,
params={"title": "sample", "search_space_id": search_space_id},
)
assert search_resp.status_code == 200
result_ids = [d["id"] for d in search_resp.json()["items"]]
assert doc_ids[0] in result_ids, (
f"Uploaded document {doc_ids[0]} not found in search results: {result_ids}"
)
class TestStatusPolling:
"""Verify the status endpoint returns well-formed responses."""
async def test_status_endpoint_returns_items(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
status_resp = await client.get(
"/api/v1/documents/status",
headers=headers,
params={
"search_space_id": search_space_id,
"document_ids": ",".join(str(d) for d in doc_ids),
},
)
assert status_resp.status_code == 200
body = status_resp.json()
assert "items" in body
assert len(body["items"]) == len(doc_ids)
for item in body["items"]:
assert "id" in item
assert "status" in item
assert "state" in item["status"]
assert item["status"]["state"] in {
"pending",
"processing",
"ready",
"failed",
}
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)

View file

@ -0,0 +1,323 @@
"""
End-to-end tests for page-limit enforcement during document upload.
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
columns directly in the database and then exercise the upload pipeline to
verify that:
- Uploads are rejected *before* ETL when the limit is exhausted.
- ``pages_used`` increases after a successful upload.
- A ``page_limit_exceeded`` notification is created on rejection.
- ``pages_used`` is not modified when a document fails processing.
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
so no additional processing time is introduced.
Prerequisites (must be running):
- FastAPI backend
- PostgreSQL + pgvector
- Redis
- Celery worker
"""
from __future__ import annotations
import httpx
import pytest
from tests.utils.helpers import (
get_notifications,
poll_document_status,
upload_file,
)
pytestmark = pytest.mark.e2e
# ---------------------------------------------------------------------------
# Test A: Successful upload increments pages_used
# ---------------------------------------------------------------------------
class TestPageUsageIncrementsOnSuccess:
"""After a successful PDF upload the user's ``pages_used`` must grow."""
async def test_pages_used_increases_after_pdf_upload(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1000)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
used, _ = await page_limits.get()
assert used > 0, "pages_used should have increased after successful processing"
# ---------------------------------------------------------------------------
# Test B: Upload rejected when page limit is fully exhausted
# ---------------------------------------------------------------------------
class TestUploadRejectedWhenLimitExhausted:
"""
When ``pages_used == pages_limit`` (zero remaining) the document
should reach ``failed`` status with a page-limit reason.
"""
async def test_pdf_fails_when_no_pages_remaining(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=100, pages_limit=100)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
reason = statuses[did]["status"].get("reason", "").lower()
assert "page limit" in reason, (
f"Expected 'page limit' in failure reason, got: {reason!r}"
)
async def test_pages_used_unchanged_after_limit_rejection(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=50, pages_limit=50)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
used, _ = await page_limits.get()
assert used == 50, (
f"pages_used should remain 50 after rejected upload, got {used}"
)
# ---------------------------------------------------------------------------
# Test C: Page-limit notification is created on rejection
# ---------------------------------------------------------------------------
class TestPageLimitNotification:
"""A ``page_limit_exceeded`` notification must be created when upload
is rejected due to the limit."""
async def test_page_limit_exceeded_notification_created(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=100, pages_limit=100)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
notifications = await get_notifications(
client,
headers,
type_filter="page_limit_exceeded",
search_space_id=search_space_id,
)
assert len(notifications) >= 1, (
"Expected at least one page_limit_exceeded notification"
)
latest = notifications[0]
assert (
"page limit" in latest["title"].lower()
or "page limit" in latest["message"].lower()
), (
f"Notification should mention page limit: title={latest['title']!r}, "
f"message={latest['message']!r}"
)
# ---------------------------------------------------------------------------
# Test D: Successful upload creates a completed document_processing notification
# ---------------------------------------------------------------------------
class TestDocumentProcessingNotification:
"""A ``document_processing`` notification with ``completed`` status must
exist after a successful upload."""
async def test_processing_completed_notification_exists(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1000)
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
notifications = await get_notifications(
client,
headers,
type_filter="document_processing",
search_space_id=search_space_id,
)
completed = [
n
for n in notifications
if n.get("metadata", {}).get("processing_stage") == "completed"
]
assert len(completed) >= 1, (
"Expected at least one document_processing notification with 'completed' stage"
)
# ---------------------------------------------------------------------------
# Test E: pages_used unchanged when a document fails for non-limit reasons
# ---------------------------------------------------------------------------
class TestPagesUnchangedOnProcessingFailure:
"""If a document fails during ETL (e.g. empty/corrupt file) rather than
a page-limit rejection, ``pages_used`` should remain unchanged."""
async def test_pages_used_stable_on_etl_failure(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=10, pages_limit=1000)
resp = await upload_file(
client, headers, "empty.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
if doc_ids:
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
used, _ = await page_limits.get()
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
# ---------------------------------------------------------------------------
# Test F: Second upload rejected after first consumes remaining quota
# ---------------------------------------------------------------------------
class TestSecondUploadExceedsLimit:
"""Upload one PDF successfully, consuming the quota, then verify a
second upload is rejected."""
async def test_second_upload_rejected_after_quota_consumed(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
# Give just enough room for one ~1-page PDF
await page_limits.set(pages_used=0, pages_limit=1)
resp1 = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
statuses1 = await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id, timeout=300.0
)
for did in first_ids:
assert statuses1[did]["status"]["state"] == "ready"
# Second upload — should fail because quota is now consumed
resp2 = await upload_file(
client,
headers,
"sample.pdf",
search_space_id=search_space_id,
filename_override="sample_copy.pdf",
)
assert resp2.status_code == 200
second_ids = resp2.json()["document_ids"]
cleanup_doc_ids.extend(second_ids)
statuses2 = await poll_document_status(
client, headers, second_ids, search_space_id=search_space_id, timeout=300.0
)
for did in second_ids:
assert statuses2[did]["status"]["state"] == "failed"
reason = statuses2[did]["status"].get("reason", "").lower()
assert "page limit" in reason, (
f"Expected 'page limit' in failure reason, got: {reason!r}"
)

View file

@ -0,0 +1,146 @@
"""
End-to-end tests for backend file upload limit enforcement.
These tests verify that the API rejects uploads that exceed:
- Max files per upload (10)
- Max per-file size (50 MB)
- Max total upload size (200 MB)
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
enforced server-side to protect against direct API calls.
Prerequisites (must be running):
- FastAPI backend
- PostgreSQL + pgvector
"""
from __future__ import annotations
import io
import httpx
import pytest
pytestmark = pytest.mark.e2e
# ---------------------------------------------------------------------------
# Test A: File count limit
# ---------------------------------------------------------------------------
class TestFileCountLimit:
"""Uploading more than 10 files in a single request should be rejected."""
async def test_11_files_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
files = [
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
for i in range(11)
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "too many files" in resp.json()["detail"].lower()
async def test_10_files_accepted(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
files = [
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
for i in range(10)
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 200
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
# ---------------------------------------------------------------------------
# Test B: Per-file size limit
# ---------------------------------------------------------------------------
class TestPerFileSizeLimit:
"""A single file exceeding 50 MB should be rejected."""
async def test_oversized_file_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
oversized = io.BytesIO(b"\x00" * (50 * 1024 * 1024 + 1))
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=[("files", ("big.pdf", oversized, "application/pdf"))],
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "per-file limit" in resp.json()["detail"].lower()
async def test_file_at_limit_accepted(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
at_limit = io.BytesIO(b"\x00" * (50 * 1024 * 1024))
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=[("files", ("exact50mb.txt", at_limit, "text/plain"))],
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 200
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
# ---------------------------------------------------------------------------
# Test C: Total upload size limit
# ---------------------------------------------------------------------------
class TestTotalSizeLimit:
"""Multiple files whose combined size exceeds 200 MB should be rejected."""
async def test_total_size_over_200mb_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
chunk_size = 45 * 1024 * 1024 # 45 MB each
files = [
(
"files",
(f"chunk_{i}.txt", io.BytesIO(b"\x00" * chunk_size), "text/plain"),
)
for i in range(5) # 5 x 45 MB = 225 MB > 200 MB
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "total upload size" in resp.json()["detail"].lower()

View file

View file

@ -0,0 +1,51 @@
# SurfSense Test Document
## Overview
This is a **sample markdown document** used for end-to-end testing of the manual
document upload pipeline. It includes various markdown formatting elements.
## Key Features
- Document upload and processing
- Automatic chunking of content
- Embedding generation for semantic search
- Real-time status tracking via ElectricSQL
## Technical Architecture
### Backend Stack
The SurfSense backend is built with:
1. **FastAPI** for the REST API
2. **PostgreSQL** with pgvector for vector storage
3. **Celery** with Redis for background task processing
4. **Docling/Unstructured** for document parsing (ETL)
### Processing Pipeline
Documents go through a multi-stage pipeline:
| Stage | Description |
|-------|-------------|
| Upload | File received via API endpoint |
| Parsing | Content extracted using ETL service |
| Chunking | Text split into semantic chunks |
| Embedding | Vector representations generated |
| Storage | Chunks stored with embeddings in pgvector |
## Code Example
```python
async def process_document(file_path: str) -> Document:
content = extract_content(file_path)
chunks = create_chunks(content)
embeddings = generate_embeddings(chunks)
return store_document(chunks, embeddings)
```
## Conclusion
This document serves as a test fixture to validate the complete document processing
pipeline from upload through to chunk creation and embedding storage.

Binary file not shown.

View file

@ -0,0 +1,34 @@
SurfSense Document Upload Test
This is a sample text document used for end-to-end testing of the manual document
upload pipeline in SurfSense. The document contains multiple paragraphs to ensure
that the chunking system has enough content to work with.
Artificial Intelligence and Machine Learning
Artificial intelligence (AI) is a broad field of computer science concerned with
building smart machines capable of performing tasks that typically require human
intelligence. Machine learning is a subset of AI that enables systems to learn and
improve from experience without being explicitly programmed.
Natural Language Processing
Natural language processing (NLP) is a subfield of linguistics, computer science,
and artificial intelligence concerned with the interactions between computers and
human language. Key applications include machine translation, sentiment analysis,
text summarization, and question answering systems.
Vector Databases and Semantic Search
Vector databases store data as high-dimensional vectors, enabling efficient
similarity search operations. When combined with embedding models, they power
semantic search systems that understand the meaning behind queries rather than
relying on exact keyword matches. This technology is fundamental to modern
retrieval-augmented generation (RAG) systems.
Document Processing Pipelines
Modern document processing pipelines involve several stages: extraction, transformation,
chunking, embedding generation, and storage. Each stage plays a critical role in
converting raw documents into searchable, structured knowledge that can be retrieved
and used by AI systems for accurate information retrieval and generation.

View file

@ -1,4 +1,3 @@
import os import os
import uuid import uuid
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -9,14 +8,21 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from app.db import Base, SearchSpace, SearchSourceConnector, SearchSourceConnectorType from app.db import (
from app.db import User Base,
from app.db import DocumentType DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
)
from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.connector_document import ConnectorDocument
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation _EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
_DEFAULT_TEST_DB = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test" _DEFAULT_TEST_DB = (
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
)
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB) TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
@ -80,7 +86,9 @@ async def db_user(db_session: AsyncSession) -> User:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def db_connector(db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace") -> SearchSourceConnector: async def db_connector(
db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace"
) -> SearchSourceConnector:
connector = SearchSourceConnector( connector = SearchSourceConnector(
name="Test Connector", name="Test Connector",
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR, connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
@ -147,6 +155,7 @@ def patched_chunk_text(monkeypatch) -> MagicMock:
@pytest.fixture @pytest.fixture
def make_connector_document(db_connector, db_user): def make_connector_document(db_connector, db_user):
"""Integration-scoped override: uses real DB connector and user IDs.""" """Integration-scoped override: uses real DB connector and user IDs."""
def _make(**overrides): def _make(**overrides):
defaults = { defaults = {
"title": "Test Document", "title": "Test Document",
@ -159,6 +168,5 @@ def make_connector_document(db_connector, db_user):
} }
defaults.update(overrides) defaults.update(overrides)
return ConnectorDocument(**defaults) return ConnectorDocument(**defaults)
return _make return _make

View file

@ -7,7 +7,9 @@ from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_fi
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker): async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
"""Document status is READY after successful indexing.""" """Document status is READY after successful indexing."""
await index_uploaded_file( await index_uploaded_file(
@ -28,7 +30,9 @@ async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
assert DocumentStatus.is_state(document.status, DocumentStatus.READY) assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_content_is_summary(db_session, db_search_space, db_user, mocker): async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
"""Document content is set to the LLM-generated summary.""" """Document content is set to the LLM-generated summary."""
await index_uploaded_file( await index_uploaded_file(
@ -49,7 +53,9 @@ async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
assert document.content == "Mocked summary." assert document.content == "Mocked summary."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker): async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
"""Chunks derived from the source markdown are persisted in the DB.""" """Chunks derived from the source markdown are persisted in the DB."""
await index_uploaded_file( await index_uploaded_file(
@ -76,7 +82,9 @@ async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker
assert chunks[0].content == "Test chunk content." assert chunks[0].content == "Test chunk content."
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker): async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification.""" """RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):

View file

@ -7,9 +7,14 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_sets_status_ready( async def test_sets_status_ready(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""Document status is READY after successful indexing.""" """Document status is READY after successful indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -21,15 +26,22 @@ async def test_sets_status_ready(
await service.index(document, connector_doc, llm=mocker.Mock()) await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY) assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_content_is_summary_when_should_summarize_true( async def test_content_is_summary_when_should_summarize_true(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""Document content is set to the LLM-generated summary when should_summarize=True.""" """Document content is set to the LLM-generated summary when should_summarize=True."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -41,15 +53,21 @@ async def test_content_is_summary_when_should_summarize_true(
await service.index(document, connector_doc, llm=mocker.Mock()) await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.content == "Mocked summary." assert reloaded.content == "Mocked summary."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_content_is_source_markdown_when_should_summarize_false( async def test_content_is_source_markdown_when_should_summarize_false(
db_session, db_search_space, make_connector_document, db_session,
db_search_space,
make_connector_document,
): ):
"""Document content is set to source_markdown verbatim when should_summarize=False.""" """Document content is set to source_markdown verbatim when should_summarize=False."""
connector_doc = make_connector_document( connector_doc = make_connector_document(
@ -65,15 +83,22 @@ async def test_content_is_source_markdown_when_should_summarize_false(
await service.index(document, connector_doc, llm=None) await service.index(document, connector_doc, llm=None)
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.content == "## Raw content" assert reloaded.content == "## Raw content"
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_chunks_written_to_db( async def test_chunks_written_to_db(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""Chunks derived from source_markdown are persisted in the DB.""" """Chunks derived from source_markdown are persisted in the DB."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -94,9 +119,14 @@ async def test_chunks_written_to_db(
assert chunks[0].content == "Test chunk content." assert chunks[0].content == "Test chunk content."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_embedding_written_to_db( async def test_embedding_written_to_db(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""Document embedding vector is persisted in the DB after indexing.""" """Document embedding vector is persisted in the DB after indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -108,16 +138,23 @@ async def test_embedding_written_to_db(
await service.index(document, connector_doc, llm=mocker.Mock()) await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.embedding is not None assert reloaded.embedding is not None
assert len(reloaded.embedding) == 1024 assert len(reloaded.embedding) == 1024
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_updated_at_advances_after_indexing( async def test_updated_at_advances_after_indexing(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""updated_at timestamp is later after indexing than it was at prepare time.""" """updated_at timestamp is later after indexing than it was at prepare time."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -127,20 +164,28 @@ async def test_updated_at_advances_after_indexing(
document = prepared[0] document = prepared[0]
document_id = document.id document_id = document.id
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_pending = result.scalars().first().updated_at updated_at_pending = result.scalars().first().updated_at
await service.index(document, connector_doc, llm=mocker.Mock()) await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_ready = result.scalars().first().updated_at updated_at_ready = result.scalars().first().updated_at
assert updated_at_ready > updated_at_pending assert updated_at_ready > updated_at_pending
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_no_llm_falls_back_to_source_markdown( async def test_no_llm_falls_back_to_source_markdown(
db_session, db_search_space, make_connector_document, db_session,
db_search_space,
make_connector_document,
): ):
"""When llm=None and no fallback_summary, content falls back to source_markdown.""" """When llm=None and no fallback_summary, content falls back to source_markdown."""
connector_doc = make_connector_document( connector_doc = make_connector_document(
@ -156,16 +201,22 @@ async def test_no_llm_falls_back_to_source_markdown(
await service.index(document, connector_doc, llm=None) await service.index(document, connector_doc, llm=None)
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY) assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
assert reloaded.content == "## Fallback content" assert reloaded.content == "## Fallback content"
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_fallback_summary_used_when_llm_unavailable( async def test_fallback_summary_used_when_llm_unavailable(
db_session, db_search_space, make_connector_document, db_session,
db_search_space,
make_connector_document,
): ):
"""fallback_summary is used as content when llm=None and should_summarize=True.""" """fallback_summary is used as content when llm=None and should_summarize=True."""
connector_doc = make_connector_document( connector_doc = make_connector_document(
@ -181,16 +232,23 @@ async def test_fallback_summary_used_when_llm_unavailable(
await service.index(prepared[0], connector_doc, llm=None) await service.index(prepared[0], connector_doc, llm=None)
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY) assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
assert reloaded.content == "Short pre-built summary." assert reloaded.content == "Short pre-built summary."
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_reindex_replaces_old_chunks( async def test_reindex_replaces_old_chunks(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""Re-indexing a document replaces its old chunks rather than appending.""" """Re-indexing a document replaces its old chunks rather than appending."""
connector_doc = make_connector_document( connector_doc = make_connector_document(
@ -220,9 +278,14 @@ async def test_reindex_replaces_old_chunks(
assert len(chunks) == 1 assert len(chunks) == 1
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_llm_error_sets_status_failed( async def test_llm_error_sets_status_failed(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""Document status is FAILED when the LLM raises during indexing.""" """Document status is FAILED when the LLM raises during indexing."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -234,15 +297,22 @@ async def test_llm_error_sets_status_failed(
await service.index(document, connector_doc, llm=mocker.Mock()) await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED) assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_llm_error_leaves_no_partial_data( async def test_llm_error_leaves_no_partial_data(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""A failed indexing attempt leaves no partial embedding or chunks in the DB.""" """A failed indexing attempt leaves no partial embedding or chunks in the DB."""
connector_doc = make_connector_document(search_space_id=db_search_space.id) connector_doc = make_connector_document(search_space_id=db_search_space.id)
@ -254,7 +324,9 @@ async def test_llm_error_leaves_no_partial_data(
await service.index(document, connector_doc, llm=mocker.Mock()) await service.index(document, connector_doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.embedding is None assert reloaded.embedding is None

View file

@ -2,7 +2,9 @@ import pytest
from sqlalchemy import select from sqlalchemy import select
from app.db import Document, DocumentStatus from app.db import Document, DocumentStatus
from app.indexing_pipeline.document_hashing import compute_content_hash as real_compute_content_hash from app.indexing_pipeline.document_hashing import (
compute_content_hash as real_compute_content_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration
@ -20,7 +22,9 @@ async def test_new_document_is_persisted_with_pending_status(
assert len(results) == 1 assert len(results) == 1
document_id = results[0].id document_id = results[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded is not None assert reloaded is not None
@ -28,9 +32,14 @@ async def test_new_document_is_persisted_with_pending_status(
assert reloaded.source_markdown == doc.source_markdown assert reloaded.source_markdown == doc.source_markdown
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_unchanged_ready_document_is_skipped( async def test_unchanged_ready_document_is_skipped(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""A READY document with unchanged content is not returned for re-indexing.""" """A READY document with unchanged content is not returned for re-indexing."""
doc = make_connector_document(search_space_id=db_search_space.id) doc = make_connector_document(search_space_id=db_search_space.id)
@ -46,24 +55,35 @@ async def test_unchanged_ready_document_is_skipped(
assert results == [] assert results == []
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize", "patched_embed_text", "patched_chunk_text"
)
async def test_title_only_change_updates_title_in_db( async def test_title_only_change_updates_title_in_db(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""A title-only change updates the DB title without re-queuing the document.""" """A title-only change updates the DB title without re-queuing the document."""
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title") original = make_connector_document(
search_space_id=db_search_space.id, title="Original Title"
)
service = IndexingPipelineService(session=db_session) service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([original]) prepared = await service.prepare_for_indexing([original])
document_id = prepared[0].id document_id = prepared[0].id
await service.index(prepared[0], original, llm=mocker.Mock()) await service.index(prepared[0], original, llm=mocker.Mock())
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title") renamed = make_connector_document(
search_space_id=db_search_space.id, title="Updated Title"
)
results = await service.prepare_for_indexing([renamed]) results = await service.prepare_for_indexing([renamed])
assert results == [] assert results == []
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.title == "Updated Title" assert reloaded.title == "Updated Title"
@ -73,19 +93,25 @@ async def test_changed_content_is_returned_for_reprocessing(
db_session, db_search_space, make_connector_document db_session, db_search_space, make_connector_document
): ):
"""A document with changed content is returned for re-indexing with updated markdown.""" """A document with changed content is returned for re-indexing with updated markdown."""
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1") original = make_connector_document(
search_space_id=db_search_space.id, source_markdown="## v1"
)
service = IndexingPipelineService(session=db_session) service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original]) first = await service.prepare_for_indexing([original])
original_id = first[0].id original_id = first[0].id
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2") updated = make_connector_document(
search_space_id=db_search_space.id, source_markdown="## v2"
)
results = await service.prepare_for_indexing([updated]) results = await service.prepare_for_indexing([updated])
assert len(results) == 1 assert len(results) == 1
assert results[0].id == original_id assert results[0].id == original_id
result = await db_session.execute(select(Document).filter(Document.id == original_id)) result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.source_markdown == "## v2" assert reloaded.source_markdown == "## v2"
@ -97,9 +123,24 @@ async def test_all_documents_in_batch_are_persisted(
): ):
"""All documents in a batch are persisted and returned.""" """All documents in a batch are persisted and returned."""
docs = [ docs = [
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"), make_connector_document(
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"), search_space_id=db_search_space.id,
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"), unique_id="id-1",
title="Doc 1",
source_markdown="## Content 1",
),
make_connector_document(
search_space_id=db_search_space.id,
unique_id="id-2",
title="Doc 2",
source_markdown="## Content 2",
),
make_connector_document(
search_space_id=db_search_space.id,
unique_id="id-3",
title="Doc 3",
source_markdown="## Content 3",
),
] ]
service = IndexingPipelineService(session=db_session) service = IndexingPipelineService(session=db_session)
@ -107,7 +148,9 @@ async def test_all_documents_in_batch_are_persisted(
assert len(results) == 3 assert len(results) == 3
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id)) result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
rows = result.scalars().all() rows = result.scalars().all()
assert len(rows) == 3 assert len(rows) == 3
@ -124,7 +167,9 @@ async def test_duplicate_in_batch_is_persisted_once(
assert len(results) == 1 assert len(results) == 1
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id)) result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
rows = result.scalars().all() rows = result.scalars().all()
assert len(rows) == 1 assert len(rows) == 1
@ -143,7 +188,9 @@ async def test_created_by_id_is_persisted(
results = await service.prepare_for_indexing([doc]) results = await service.prepare_for_indexing([doc])
document_id = results[0].id document_id = results[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert str(reloaded.created_by_id) == str(db_user.id) assert str(reloaded.created_by_id) == str(db_user.id)
@ -170,7 +217,9 @@ async def test_metadata_is_updated_when_content_changes(
) )
await service.prepare_for_indexing([updated]) await service.prepare_for_indexing([updated])
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.document_metadata == {"status": "done"} assert reloaded.document_metadata == {"status": "done"}
@ -180,19 +229,27 @@ async def test_updated_at_advances_when_title_only_changes(
db_session, db_search_space, make_connector_document db_session, db_search_space, make_connector_document
): ):
"""updated_at advances even when only the title changes.""" """updated_at advances even when only the title changes."""
original = make_connector_document(search_space_id=db_search_space.id, title="Old Title") original = make_connector_document(
search_space_id=db_search_space.id, title="Old Title"
)
service = IndexingPipelineService(session=db_session) service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original]) first = await service.prepare_for_indexing([original])
document_id = first[0].id document_id = first[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_v1 = result.scalars().first().updated_at updated_at_v1 = result.scalars().first().updated_at
renamed = make_connector_document(search_space_id=db_search_space.id, title="New Title") renamed = make_connector_document(
search_space_id=db_search_space.id, title="New Title"
)
await service.prepare_for_indexing([renamed]) await service.prepare_for_indexing([renamed])
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_v2 = result.scalars().first().updated_at updated_at_v2 = result.scalars().first().updated_at
assert updated_at_v2 > updated_at_v1 assert updated_at_v2 > updated_at_v1
@ -202,19 +259,27 @@ async def test_updated_at_advances_when_content_changes(
db_session, db_search_space, make_connector_document db_session, db_search_space, make_connector_document
): ):
"""updated_at advances when document content changes.""" """updated_at advances when document content changes."""
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1") original = make_connector_document(
search_space_id=db_search_space.id, source_markdown="## v1"
)
service = IndexingPipelineService(session=db_session) service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original]) first = await service.prepare_for_indexing([original])
document_id = first[0].id document_id = first[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_v1 = result.scalars().first().updated_at updated_at_v1 = result.scalars().first().updated_at
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2") updated = make_connector_document(
search_space_id=db_search_space.id, source_markdown="## v2"
)
await service.prepare_for_indexing([updated]) await service.prepare_for_indexing([updated])
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
updated_at_v2 = result.scalars().first().updated_at updated_at_v2 = result.scalars().first().updated_at
assert updated_at_v2 > updated_at_v1 assert updated_at_v2 > updated_at_v1
@ -273,9 +338,14 @@ async def test_same_content_from_different_source_is_skipped(
assert len(result.scalars().all()) == 1 assert len(result.scalars().all()) == 1
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") @pytest.mark.usefixtures(
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
)
async def test_failed_document_with_unchanged_content_is_requeued( async def test_failed_document_with_unchanged_content_is_requeued(
db_session, db_search_space, make_connector_document, mocker, db_session,
db_search_space,
make_connector_document,
mocker,
): ):
"""A FAILED document with unchanged content is re-queued as PENDING on the next run.""" """A FAILED document with unchanged content is re-queued as PENDING on the next run."""
doc = make_connector_document(search_space_id=db_search_space.id) doc = make_connector_document(search_space_id=db_search_space.id)
@ -286,8 +356,12 @@ async def test_failed_document_with_unchanged_content_is_requeued(
document_id = prepared[0].id document_id = prepared[0].id
await service.index(prepared[0], doc, llm=mocker.Mock()) await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.FAILED) select(Document).filter(Document.id == document_id)
)
assert DocumentStatus.is_state(
result.scalars().first().status, DocumentStatus.FAILED
)
# Next run: same content, pipeline must re-queue the failed document # Next run: same content, pipeline must re-queue the failed document
results = await service.prepare_for_indexing([doc]) results = await service.prepare_for_indexing([doc])
@ -295,8 +369,12 @@ async def test_failed_document_with_unchanged_content_is_requeued(
assert len(results) == 1 assert len(results) == 1
assert results[0].id == document_id assert results[0].id == document_id
result = await db_session.execute(select(Document).filter(Document.id == document_id)) result = await db_session.execute(
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.PENDING) select(Document).filter(Document.id == document_id)
)
assert DocumentStatus.is_state(
result.scalars().first().status, DocumentStatus.PENDING
)
async def test_title_and_content_change_updates_both_and_returns_document( async def test_title_and_content_change_updates_both_and_returns_document(
@ -323,16 +401,20 @@ async def test_title_and_content_change_updates_both_and_returns_document(
assert len(results) == 1 assert len(results) == 1
assert results[0].id == original_id assert results[0].id == original_id
result = await db_session.execute(select(Document).filter(Document.id == original_id)) result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
reloaded = result.scalars().first() reloaded = result.scalars().first()
assert reloaded.title == "Updated Title" assert reloaded.title == "Updated Title"
assert reloaded.source_markdown == "## v2" assert reloaded.source_markdown == "## v2"
async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted( async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted(
db_session, db_search_space, make_connector_document, monkeypatch, db_session,
db_search_space,
make_connector_document,
monkeypatch,
): ):
""" """
A per-document error during prepare_for_indexing must be isolated. A per-document error during prepare_for_indexing must be isolated.
@ -374,4 +456,4 @@ async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_pers
result = await db_session.execute( result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id) select(Document).filter(Document.search_space_id == db_search_space.id)
) )
assert len(result.scalars().all()) == 2 assert len(result.scalars().all()) == 2

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest
@pytest.fixture @pytest.fixture
def patched_summarizer_chain(monkeypatch): def patched_summarizer_chain(monkeypatch):
@ -21,7 +22,9 @@ def patched_summarizer_chain(monkeypatch):
def patched_chunker_instance(monkeypatch): def patched_chunker_instance(monkeypatch):
mock = MagicMock() mock = MagicMock()
mock.chunk.return_value = [MagicMock(text="prose chunk")] mock.chunk.return_value = [MagicMock(text="prose chunk")]
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.chunker_instance", mock) monkeypatch.setattr(
"app.indexing_pipeline.document_chunker.config.chunker_instance", mock
)
return mock return mock
@ -29,5 +32,7 @@ def patched_chunker_instance(monkeypatch):
def patched_code_chunker_instance(monkeypatch): def patched_code_chunker_instance(monkeypatch):
mock = MagicMock() mock = MagicMock()
mock.chunk.return_value = [MagicMock(text="code chunk")] mock.chunk.return_value = [MagicMock(text="code chunk")]
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock) monkeypatch.setattr(
"app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock
)
return mock return mock

View file

@ -1,7 +1,10 @@
import pytest import pytest
from app.db import DocumentType from app.db import DocumentType
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_unique_identifier_hash,
)
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@ -10,21 +13,31 @@ def test_different_unique_id_produces_different_hash(make_connector_document):
"""Two documents with different unique_ids produce different identifier hashes.""" """Two documents with different unique_ids produce different identifier hashes."""
doc_a = make_connector_document(unique_id="id-001") doc_a = make_connector_document(unique_id="id-001")
doc_b = make_connector_document(unique_id="id-002") doc_b = make_connector_document(unique_id="id-002")
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b) assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
doc_b
)
def test_different_search_space_produces_different_identifier_hash(make_connector_document): def test_different_search_space_produces_different_identifier_hash(
make_connector_document,
):
"""Same document in different search spaces produces different identifier hashes.""" """Same document in different search spaces produces different identifier hashes."""
doc_a = make_connector_document(search_space_id=1) doc_a = make_connector_document(search_space_id=1)
doc_b = make_connector_document(search_space_id=2) doc_b = make_connector_document(search_space_id=2)
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b) assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
doc_b
)
def test_different_document_type_produces_different_identifier_hash(make_connector_document): def test_different_document_type_produces_different_identifier_hash(
make_connector_document,
):
"""Same unique_id with different document types produces different identifier hashes.""" """Same unique_id with different document types produces different identifier hashes."""
doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR) doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR)
doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR) doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR)
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b) assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
doc_b
)
def test_same_content_same_space_produces_same_content_hash(make_connector_document): def test_same_content_same_space_produces_same_content_hash(make_connector_document):
@ -34,7 +47,9 @@ def test_same_content_same_space_produces_same_content_hash(make_connector_docum
assert compute_content_hash(doc_a) == compute_content_hash(doc_b) assert compute_content_hash(doc_a) == compute_content_hash(doc_b)
def test_same_content_different_space_produces_different_content_hash(make_connector_document): def test_same_content_different_space_produces_different_content_hash(
make_connector_document,
):
"""Identical content in different search spaces produces different content hashes.""" """Identical content in different search spaces produces different content hashes."""
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1) doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2) doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2)

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from app.indexing_pipeline.document_summarizer import summarize_document from app.indexing_pipeline.document_summarizer import summarize_document
pytestmark = pytest.mark.unit pytestmark = pytest.mark.unit
@ -38,5 +39,3 @@ async def test_with_metadata_omits_empty_fields_from_output():
assert "Alice" in result assert "Alice" in result
assert "description" not in result.lower() assert "description" not in result.lower()

View file

@ -0,0 +1,224 @@
"""Shared test helpers for authentication, polling, and cleanup."""
from __future__ import annotations
import asyncio
import os
from pathlib import Path
import httpx
FIXTURES_DIR = Path(__file__).resolve().parent.parent / "fixtures"
BACKEND_URL = os.environ.get("TEST_BACKEND_URL", "http://localhost:8000")
TEST_EMAIL = os.environ.get("TEST_USER_EMAIL", "testuser@surfsense.com")
TEST_PASSWORD = os.environ.get("TEST_USER_PASSWORD", "testpassword123")
async def get_auth_token(client: httpx.AsyncClient) -> str:
"""Log in and return a Bearer JWT token, registering the user first if needed."""
response = await client.post(
"/auth/jwt/login",
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code == 200:
return response.json()["access_token"]
reg_response = await client.post(
"/auth/register",
json={"email": TEST_EMAIL, "password": TEST_PASSWORD},
)
assert reg_response.status_code == 201, (
f"Registration failed ({reg_response.status_code}): {reg_response.text}"
)
response = await client.post(
"/auth/jwt/login",
data={"username": TEST_EMAIL, "password": TEST_PASSWORD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
assert response.status_code == 200, (
f"Login after registration failed ({response.status_code}): {response.text}"
)
return response.json()["access_token"]
async def get_search_space_id(client: httpx.AsyncClient, token: str) -> int:
"""Fetch the first search space owned by the test user."""
resp = await client.get(
"/api/v1/searchspaces",
headers=auth_headers(token),
)
assert resp.status_code == 200, (
f"Failed to list search spaces ({resp.status_code}): {resp.text}"
)
spaces = resp.json()
assert len(spaces) > 0, "No search spaces found for test user"
return spaces[0]["id"]
def auth_headers(token: str) -> dict[str, str]:
"""Return Authorization header dict for a Bearer token."""
return {"Authorization": f"Bearer {token}"}
async def upload_file(
client: httpx.AsyncClient,
headers: dict[str, str],
fixture_name: str,
*,
search_space_id: int,
filename_override: str | None = None,
) -> httpx.Response:
"""Upload a single fixture file and return the raw response."""
file_path = FIXTURES_DIR / fixture_name
upload_name = filename_override or fixture_name
with open(file_path, "rb") as f:
return await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files={"files": (upload_name, f)},
data={"search_space_id": str(search_space_id)},
)
async def upload_multiple_files(
client: httpx.AsyncClient,
headers: dict[str, str],
fixture_names: list[str],
*,
search_space_id: int,
) -> httpx.Response:
"""Upload multiple fixture files in a single request."""
files = []
open_handles = []
try:
for name in fixture_names:
fh = open(FIXTURES_DIR / name, "rb") # noqa: SIM115
open_handles.append(fh)
files.append(("files", (name, fh)))
return await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
finally:
for fh in open_handles:
fh.close()
async def poll_document_status(
client: httpx.AsyncClient,
headers: dict[str, str],
document_ids: list[int],
*,
search_space_id: int,
timeout: float = 180.0,
interval: float = 3.0,
) -> dict[int, dict]:
"""
Poll ``GET /api/v1/documents/status`` until every document reaches a
terminal state (``ready`` or ``failed``) or *timeout* seconds elapse.
Returns a mapping of ``{document_id: status_item_dict}``.
Retries on transient transport errors until timeout.
"""
ids_param = ",".join(str(d) for d in document_ids)
terminal_states = {"ready", "failed"}
elapsed = 0.0
items: dict[int, dict] = {}
last_transport_error: Exception | None = None
while elapsed < timeout:
try:
resp = await client.get(
"/api/v1/documents/status",
headers=headers,
params={
"search_space_id": search_space_id,
"document_ids": ids_param,
},
)
except (httpx.ReadError, httpx.ConnectError, httpx.TimeoutException) as exc:
last_transport_error = exc
await asyncio.sleep(interval)
elapsed += interval
continue
assert resp.status_code == 200, (
f"Status poll failed ({resp.status_code}): {resp.text}"
)
items = {item["id"]: item for item in resp.json()["items"]}
if all(
items.get(did, {}).get("status", {}).get("state") in terminal_states
for did in document_ids
):
return items
await asyncio.sleep(interval)
elapsed += interval
raise TimeoutError(
f"Documents {document_ids} did not reach terminal state within {timeout}s. "
f"Last status: {items}. "
f"Last transport error: {last_transport_error!r}"
)
async def get_document(
client: httpx.AsyncClient,
headers: dict[str, str],
document_id: int,
) -> dict:
"""Fetch a single document by ID."""
resp = await client.get(
f"/api/v1/documents/{document_id}",
headers=headers,
)
assert resp.status_code == 200, (
f"GET document {document_id} failed ({resp.status_code}): {resp.text}"
)
return resp.json()
async def delete_document(
client: httpx.AsyncClient,
headers: dict[str, str],
document_id: int,
) -> httpx.Response:
"""Delete a document by ID, returning the raw response."""
return await client.delete(
f"/api/v1/documents/{document_id}",
headers=headers,
)
async def get_notifications(
client: httpx.AsyncClient,
headers: dict[str, str],
*,
type_filter: str | None = None,
search_space_id: int | None = None,
limit: int = 50,
) -> list[dict]:
"""Fetch notifications for the authenticated user, optionally filtered by type."""
params: dict[str, str | int] = {"limit": limit}
if type_filter:
params["type"] = type_filter
if search_space_id is not None:
params["search_space_id"] = search_space_id
resp = await client.get(
"/api/v1/notifications",
headers=headers,
params=params,
)
assert resp.status_code == 200, (
f"GET notifications failed ({resp.status_code}): {resp.text}"
)
return resp.json()["items"]

6326
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -82,6 +82,10 @@ const CYCLING_PLACEHOLDERS = [
const CHAT_UPLOAD_ACCEPT = const CHAT_UPLOAD_ACCEPT =
".pdf,.doc,.docx,.txt,.md,.markdown,.ppt,.pptx,.xls,.xlsx,.xlsm,.xlsb,.csv,.html,.htm,.xml,.rtf,.epub,.jpg,.jpeg,.png,.bmp,.webp,.tiff,.tif,.mp3,.mp4,.mpeg,.mpga,.m4a,.wav,.webm"; ".pdf,.doc,.docx,.txt,.md,.markdown,.ppt,.pptx,.xls,.xlsx,.xlsm,.xlsb,.csv,.html,.htm,.xml,.rtf,.epub,.jpg,.jpeg,.png,.bmp,.webp,.tiff,.tif,.mp3,.mp4,.mpeg,.mpga,.m4a,.wav,.webm";
const CHAT_MAX_FILES = 10;
const CHAT_MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024; // 50 MB per file
const CHAT_MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024; // 200 MB total
type UploadState = "pending" | "processing" | "ready" | "failed"; type UploadState = "pending" | "processing" | "ready" | "failed";
interface UploadedMentionDoc { interface UploadedMentionDoc {
@ -534,6 +538,28 @@ const Composer: FC = () => {
event.target.value = ""; event.target.value = "";
if (files.length === 0 || !search_space_id) return; if (files.length === 0 || !search_space_id) return;
if (files.length > CHAT_MAX_FILES) {
toast.error(`Too many files. Maximum ${CHAT_MAX_FILES} files per upload.`);
return;
}
let totalSize = 0;
for (const file of files) {
if (file.size > CHAT_MAX_FILE_SIZE_BYTES) {
toast.error(
`File "${file.name}" (${(file.size / (1024 * 1024)).toFixed(1)} MB) exceeds the ${CHAT_MAX_FILE_SIZE_BYTES / (1024 * 1024)} MB per-file limit.`
);
return;
}
totalSize += file.size;
}
if (totalSize > CHAT_MAX_TOTAL_SIZE_BYTES) {
toast.error(
`Total upload size (${(totalSize / (1024 * 1024)).toFixed(1)} MB) exceeds the ${CHAT_MAX_TOTAL_SIZE_BYTES / (1024 * 1024)} MB limit.`
);
return;
}
setIsUploadingDocs(true); setIsUploadingDocs(true);
try { try {
const uploadResponse = await documentsApiService.uploadDocument({ const uploadResponse = await documentsApiService.uploadDocument({

View file

@ -9,6 +9,8 @@
"docker-installation", "docker-installation",
"manual-installation", "manual-installation",
"connectors", "connectors",
"how-to" "how-to",
"---Development---",
"testing"
] ]
} }

View file

@ -0,0 +1,104 @@
---
title: Testing
description: Running and writing end-to-end tests for SurfSense
---
SurfSense uses [pytest](https://docs.pytest.org/) for end-to-end testing. Tests are **self-bootstrapping** — they automatically register a test user and discover search spaces, so no manual database setup is required.
## Prerequisites
Before running tests, make sure the full backend stack is running:
- **FastAPI backend**
- **PostgreSQL + pgvector**
- **Redis**
- **Celery worker**
Your backend must have **`REGISTRATION_ENABLED=TRUE`** in its `.env` (this is the default). The tests register their own user on first run.
Your `global_llm_config.yaml` must have at least one working LLM model with a valid API key — document processing uses Auto mode, which routes through the global config.
## Running Tests
**Run all tests:**
```bash
uv run pytest
```
**Run by marker** (e.g., only document tests):
```bash
uv run pytest -m document
```
**Available markers:**
| Marker | Description |
|---|---|
| `document` | Document upload, processing, and deletion tests |
| `connector` | Connector indexing tests |
| `chat` | Chat and agent tests |
**Useful flags:**
| Flag | Description |
|---|---|
| `-s` | Show live output (useful for debugging polling loops) |
| `--tb=long` | Full tracebacks instead of short summaries |
| `-k "test_name"` | Run a single test by name |
| `-o addopts=""` | Override default flags from `pyproject.toml` |
## Configuration
Default pytest options are configured in `surfsense_backend/pyproject.toml`:
```toml
[tool.pytest.ini_options]
addopts = "-v --tb=short -x --strict-markers -ra --durations=10"
```
- `-v` — verbose test names
- `--tb=short` — concise tracebacks on failure
- `-x` — stop on first failure
- `--strict-markers` — reject unregistered markers
- `-ra` — show summary of all non-passing tests
- `--durations=10` — show the 10 slowest tests
## Environment Variables
All test configuration has sensible defaults. Override via environment variables if needed:
| Variable | Default | Description |
|---|---|---|
| `TEST_BACKEND_URL` | `http://localhost:8000` | Backend URL to test against |
| `TEST_DATABASE_URL` | Falls back to `DATABASE_URL` | Direct DB connection for test cleanup |
| `TEST_USER_EMAIL` | `testuser@surfsense.com` | Test user email |
| `TEST_USER_PASSWORD` | `testpassword123` | Test user password |
These can be configured in `surfsense_backend/.env` (see the Testing section at the bottom of `.env.example`).
## How It Works
Tests are fully self-bootstrapping:
1. **User creation** — on first run, tests try to log in. If the user doesn't exist, they register via `POST /auth/register`, then log in.
2. **Search space discovery** — after authentication, tests call `GET /api/v1/searchspaces` and use the first available search space (auto-created during registration).
3. **Session purge** — before any tests run, a session-scoped fixture deletes all documents in the test search space directly via the database. This handles stuck documents from previous crashed runs that the API refuses to delete (409 Conflict).
4. **Per-test cleanup** — every test that creates documents adds their IDs to a `cleanup_doc_ids` list. An autouse fixture deletes them after each test via the API, falling back to direct DB access for any stuck documents.
This means tests work on both fresh databases and existing ones without any manual setup.
## Writing New Tests
1. Create a test file in the appropriate directory (e.g., `tests/e2e/test_connectors.py`).
2. Add a module-level marker at the top:
```python
import pytest
pytestmark = pytest.mark.connector
```
3. Use fixtures from `conftest.py` — `client`, `headers`, `search_space_id`, and `cleanup_doc_ids` are available to all tests.
4. Register any new markers in `pyproject.toml` under `markers`.