feat: implement task dispatcher for document processing

- Introduced a TaskDispatcher abstraction to decouple the upload endpoint from Celery, allowing for easier testing with synchronous implementations.
- Updated the create_documents_file_upload function to utilize the new dispatcher for task management.
- Removed direct Celery task imports from the upload function, enhancing modularity.
- Added integration tests for document upload, including page limit enforcement and file size restrictions.
This commit is contained in:
Anish Sarkar 2026-02-26 23:55:47 +05:30
parent 30617c6e54
commit 3393e435f9
9 changed files with 380 additions and 280 deletions

View file

@ -0,0 +1,282 @@
"""Integration conftest — runs the FastAPI app in-process via ASGITransport.
Prerequisites: PostgreSQL + pgvector only.
External system boundaries are mocked:
- LLM summarization, text embedding, text chunking (external APIs)
- Redis heartbeat (external infrastructure)
- Task dispatch is swapped via DI (InlineTaskDispatcher)
"""
from __future__ import annotations
import contextlib
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import asyncpg
import httpx
import pytest
from httpx import ASGITransport
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
from app.app import app
from app.config import config as app_config
from app.db import DATABASE_URL as APP_DB_URL, Base
from app.services.task_dispatcher import get_task_dispatcher
from tests.conftest import DATABASE_URL
from tests.utils.helpers import (
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Inline task dispatcher (replaces Celery via DI — not a mock)
# ---------------------------------------------------------------------------
class InlineTaskDispatcher:
"""Processes files synchronously in the calling coroutine.
Swapped in via FastAPI dependency_overrides so the upload endpoint
processes documents inline instead of dispatching to Celery.
Exceptions are caught to match Celery's fire-and-forget semantics —
the processing function already marks documents as failed internally.
"""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
_process_file_with_document,
)
with contextlib.suppress(Exception):
await _process_file_with_document(
document_id, temp_path, filename, search_space_id, user_id
)
app.dependency_overrides[get_task_dispatcher] = lambda: InlineTaskDispatcher()
# ---------------------------------------------------------------------------
# Database setup (ASGITransport skips the app lifespan)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def _ensure_tables():
"""Create DB tables and extensions once per session."""
engine = create_async_engine(APP_DB_URL, poolclass=NullPool)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
await conn.run_sync(Base.metadata.create_all)
await engine.dispose()
# ---------------------------------------------------------------------------
# Auth & search space (session-scoped, via the in-process app)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def auth_token(_ensure_tables) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_auth_token(c)
@pytest.fixture(scope="session")
async def search_space_id(auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_search_space_id(c, auth_token)
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
return auth_headers(auth_token)
# ---------------------------------------------------------------------------
# Per-test HTTP client & cleanup
# ---------------------------------------------------------------------------
@pytest.fixture
async def client() -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client using ASGITransport (no running server)."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=180.0
) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(search_space_id: int):
"""Delete stale documents from previous runs before the session starts."""
conn = await asyncpg.connect(DATABASE_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
deleted = int(result.split()[-1])
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) "
f"from search space {search_space_id}"
)
finally:
await conn.close()
yield
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
cleanup_doc_ids: list[int],
):
"""Delete test documents after every test (API first, DB fallback)."""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB for setup, API for verification)
# ---------------------------------------------------------------------------
async def _get_user_page_usage(email: str) -> tuple[int, int]:
conn = await asyncpg.connect(DATABASE_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""Manipulate the test user's page limits (direct DB for setup only).
Automatically restores original values after each test.
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)
# ---------------------------------------------------------------------------
# Mock external system boundaries
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _mock_external_apis(monkeypatch):
"""Mock LLM, embedding, and chunking — these are external API boundaries."""
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Mocked summary."),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
MagicMock(return_value=[0.1] * _EMBEDDING_DIM),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
MagicMock(return_value=["Test chunk content."]),
)
@pytest.fixture(autouse=True)
def _mock_redis_heartbeat(monkeypatch):
"""Mock Redis heartbeat — Redis is an external infrastructure boundary."""
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._start_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._stop_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
AsyncMock(),
)

View file

@ -0,0 +1,554 @@
"""
Integration tests for manual document upload.
These tests exercise the full pipeline via the HTTP API:
API upload inline task dispatch ETL extraction chunking embedding DB storage
External boundaries mocked: LLM summarization, text embedding, text chunking,
Redis heartbeat. Task dispatch is swapped via DI (InlineTaskDispatcher).
Prerequisites:
- PostgreSQL + pgvector
"""
from __future__ import annotations
import shutil
from pathlib import Path
import httpx
import pytest
from tests.utils.helpers import (
FIXTURES_DIR,
delete_document,
get_document,
poll_document_status,
upload_file,
upload_multiple_files,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helpers local to this module
# ---------------------------------------------------------------------------
def _assert_document_ready(doc: dict, *, expected_filename: str) -> None:
"""Common assertions for a successfully processed document."""
assert doc["title"] == expected_filename
assert doc["document_type"] == "FILE"
assert doc["content"], "Document content (summary) should not be empty"
assert doc["content_hash"], "content_hash should be set"
assert doc["document_metadata"].get("FILE_NAME") == expected_filename
# ---------------------------------------------------------------------------
# Test A: Upload a .txt file (direct read path)
# ---------------------------------------------------------------------------
class TestTxtFileUpload:
"""Upload a plain-text file and verify the full pipeline."""
async def test_upload_txt_returns_document_id(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
body = resp.json()
assert body["pending_files"] >= 1
assert len(body["document_ids"]) >= 1
cleanup_doc_ids.extend(body["document_ids"])
async def test_txt_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_txt_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.txt")
# ---------------------------------------------------------------------------
# Test B: Upload a .md file (markdown direct-read path)
# ---------------------------------------------------------------------------
class TestMarkdownFileUpload:
"""Upload a Markdown file and verify the full pipeline."""
async def test_md_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.md", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_md_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.md", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.md")
# ---------------------------------------------------------------------------
# Test C: Upload a .pdf file (ETL path)
# ---------------------------------------------------------------------------
class TestPdfFileUpload:
"""Upload a PDF and verify it goes through the ETL extraction pipeline."""
async def test_pdf_processing_reaches_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
async def test_pdf_document_fields_populated(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.pdf")
# ---------------------------------------------------------------------------
# Test D: Upload multiple files in a single request
# ---------------------------------------------------------------------------
class TestMultiFileUpload:
"""Upload several files at once and verify all are processed."""
async def test_multi_upload_returns_all_ids(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_multiple_files(
client,
headers,
["sample.txt", "sample.md"],
search_space_id=search_space_id,
)
assert resp.status_code == 200
body = resp.json()
assert body["pending_files"] == 2
assert len(body["document_ids"]) == 2
cleanup_doc_ids.extend(body["document_ids"])
async def test_multi_upload_all_reach_ready(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_multiple_files(
client,
headers,
["sample.txt", "sample.md"],
search_space_id=search_space_id,
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
# ---------------------------------------------------------------------------
# Test E: Duplicate file upload (same file uploaded twice)
# ---------------------------------------------------------------------------
class TestDuplicateFileUpload:
"""
Uploading the exact same file a second time should be detected as a
duplicate via ``unique_identifier_hash``.
"""
async def test_duplicate_file_is_skipped(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id
)
resp2 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp2.status_code == 200
body2 = resp2.json()
assert body2["skipped_duplicates"] >= 1
assert len(body2["duplicate_document_ids"]) >= 1
cleanup_doc_ids.extend(body2.get("document_ids", []))
# ---------------------------------------------------------------------------
# Test F: Duplicate content detection (different name, same content)
# ---------------------------------------------------------------------------
class TestDuplicateContentDetection:
"""
Uploading a file with a different name but identical content should be
detected as duplicate content via ``content_hash``.
"""
async def test_same_content_different_name_detected(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
tmp_path: Path,
):
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id
)
src = FIXTURES_DIR / "sample.txt"
dest = tmp_path / "renamed_sample.txt"
shutil.copy2(src, dest)
with open(dest, "rb") as f:
resp2 = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files={"files": ("renamed_sample.txt", f)},
data={"search_space_id": str(search_space_id)},
)
assert resp2.status_code == 200
second_ids = resp2.json()["document_ids"]
cleanup_doc_ids.extend(second_ids)
assert second_ids, (
"Expected at least one document id for renamed duplicate content upload"
)
statuses = await poll_document_status(
client, headers, second_ids, search_space_id=search_space_id
)
for did in second_ids:
assert statuses[did]["status"]["state"] == "failed"
assert "duplicate" in statuses[did]["status"].get("reason", "").lower()
# ---------------------------------------------------------------------------
# Test G: Empty / corrupt file handling
# ---------------------------------------------------------------------------
class TestEmptyFileUpload:
"""An empty file should be processed but ultimately fail gracefully."""
async def test_empty_pdf_fails(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "empty.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
assert doc_ids, "Expected at least one document id for empty PDF upload"
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
assert statuses[did]["status"].get("reason"), (
"Failed document should include a reason"
)
# ---------------------------------------------------------------------------
# Test H: Upload without authentication
# ---------------------------------------------------------------------------
class TestUnauthenticatedUpload:
"""Requests without a valid JWT should be rejected."""
async def test_upload_without_auth_returns_401(
self,
client: httpx.AsyncClient,
search_space_id: int,
):
file_path = FIXTURES_DIR / "sample.txt"
with open(file_path, "rb") as f:
resp = await client.post(
"/api/v1/documents/fileupload",
files={"files": ("sample.txt", f)},
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Test I: Upload with no files attached
# ---------------------------------------------------------------------------
class TestNoFilesUpload:
"""Submitting the form with zero files should return a validation error."""
async def test_no_files_returns_error(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code in {400, 422}
# ---------------------------------------------------------------------------
# Test J: Document deletion after successful upload
# ---------------------------------------------------------------------------
class TestDocumentDeletion:
"""Upload, wait for ready, delete, then verify it's gone."""
async def test_delete_processed_document(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
del_resp = await delete_document(client, headers, doc_ids[0])
assert del_resp.status_code == 200
get_resp = await client.get(
f"/api/v1/documents/{doc_ids[0]}",
headers=headers,
)
assert get_resp.status_code == 404
# ---------------------------------------------------------------------------
# Test K: Searchability after upload
# ---------------------------------------------------------------------------
class TestDocumentSearchability:
"""After upload reaches ready, the document must appear in the title search."""
async def test_uploaded_document_appears_in_search(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
search_resp = await client.get(
"/api/v1/documents/search",
headers=headers,
params={"title": "sample", "search_space_id": search_space_id},
)
assert search_resp.status_code == 200
result_ids = [d["id"] for d in search_resp.json()["items"]]
assert doc_ids[0] in result_ids, (
f"Uploaded document {doc_ids[0]} not found in search results: {result_ids}"
)
# ---------------------------------------------------------------------------
# Test L: Status polling returns correct structure
# ---------------------------------------------------------------------------
class TestStatusPolling:
"""Verify the status endpoint returns well-formed responses."""
async def test_status_endpoint_returns_items(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
status_resp = await client.get(
"/api/v1/documents/status",
headers=headers,
params={
"search_space_id": search_space_id,
"document_ids": ",".join(str(d) for d in doc_ids),
},
)
assert status_resp.status_code == 200
body = status_resp.json()
assert "items" in body
assert len(body["items"]) == len(doc_ids)
for item in body["items"]:
assert "id" in item
assert "status" in item
assert "state" in item["status"]
assert item["status"]["state"] in {
"pending",
"processing",
"ready",
"failed",
}
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)

View file

@ -0,0 +1,332 @@
"""
Integration tests for page-limit enforcement during document upload.
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
columns directly in the database (setup only) and then exercise the upload
pipeline to verify that:
- Uploads are rejected *before* ETL when the limit is exhausted.
- ``pages_used`` increases after a successful upload (verified via API).
- A ``page_limit_exceeded`` notification is created on rejection.
- ``pages_used`` is not modified when a document fails processing.
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
so no additional processing time is introduced.
Prerequisites:
- PostgreSQL + pgvector
"""
from __future__ import annotations
import httpx
import pytest
from tests.utils.helpers import (
get_notifications,
poll_document_status,
upload_file,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helper: read pages_used through the public API
# ---------------------------------------------------------------------------
async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int:
"""Fetch the current user's pages_used via the /users/me API."""
resp = await client.get("/users/me", headers=headers)
assert resp.status_code == 200, (
f"GET /users/me failed ({resp.status_code}): {resp.text}"
)
return resp.json()["pages_used"]
# ---------------------------------------------------------------------------
# Test A: Successful upload increments pages_used
# ---------------------------------------------------------------------------
class TestPageUsageIncrementsOnSuccess:
"""After a successful PDF upload the user's ``pages_used`` must grow."""
async def test_pages_used_increases_after_pdf_upload(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1000)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
used = await _get_pages_used(client, headers)
assert used > 0, "pages_used should have increased after successful processing"
# ---------------------------------------------------------------------------
# Test B: Upload rejected when page limit is fully exhausted
# ---------------------------------------------------------------------------
class TestUploadRejectedWhenLimitExhausted:
"""
When ``pages_used == pages_limit`` (zero remaining) the document
should reach ``failed`` status with a page-limit reason.
"""
async def test_pdf_fails_when_no_pages_remaining(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=100, pages_limit=100)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
reason = statuses[did]["status"].get("reason", "").lower()
assert "page limit" in reason, (
f"Expected 'page limit' in failure reason, got: {reason!r}"
)
async def test_pages_used_unchanged_after_limit_rejection(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=50, pages_limit=50)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
used = await _get_pages_used(client, headers)
assert used == 50, (
f"pages_used should remain 50 after rejected upload, got {used}"
)
# ---------------------------------------------------------------------------
# Test C: Page-limit notification is created on rejection
# ---------------------------------------------------------------------------
class TestPageLimitNotification:
"""A ``page_limit_exceeded`` notification must be created when upload
is rejected due to the limit."""
async def test_page_limit_exceeded_notification_created(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=100, pages_limit=100)
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
notifications = await get_notifications(
client,
headers,
type_filter="page_limit_exceeded",
search_space_id=search_space_id,
)
assert len(notifications) >= 1, (
"Expected at least one page_limit_exceeded notification"
)
latest = notifications[0]
assert (
"page limit" in latest["title"].lower()
or "page limit" in latest["message"].lower()
), (
f"Notification should mention page limit: title={latest['title']!r}, "
f"message={latest['message']!r}"
)
# ---------------------------------------------------------------------------
# Test D: Successful upload creates a completed document_processing notification
# ---------------------------------------------------------------------------
class TestDocumentProcessingNotification:
"""A ``document_processing`` notification with ``completed`` status must
exist after a successful upload."""
async def test_processing_completed_notification_exists(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1000)
resp = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id
)
notifications = await get_notifications(
client,
headers,
type_filter="document_processing",
search_space_id=search_space_id,
)
completed = [
n
for n in notifications
if n.get("metadata", {}).get("processing_stage") == "completed"
]
assert len(completed) >= 1, (
"Expected at least one document_processing notification with 'completed' stage"
)
# ---------------------------------------------------------------------------
# Test E: pages_used unchanged when a document fails for non-limit reasons
# ---------------------------------------------------------------------------
class TestPagesUnchangedOnProcessingFailure:
"""If a document fails during ETL (e.g. empty/corrupt file) rather than
a page-limit rejection, ``pages_used`` should remain unchanged."""
async def test_pages_used_stable_on_etl_failure(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=10, pages_limit=1000)
resp = await upload_file(
client, headers, "empty.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
if doc_ids:
statuses = await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=120.0
)
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
used = await _get_pages_used(client, headers)
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
# ---------------------------------------------------------------------------
# Test F: Second upload rejected after first consumes remaining quota
# ---------------------------------------------------------------------------
class TestSecondUploadExceedsLimit:
"""Upload one PDF successfully, consuming the quota, then verify a
second upload is rejected."""
async def test_second_upload_rejected_after_quota_consumed(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
page_limits,
):
await page_limits.set(pages_used=0, pages_limit=1)
resp1 = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp1.status_code == 200
first_ids = resp1.json()["document_ids"]
cleanup_doc_ids.extend(first_ids)
statuses1 = await poll_document_status(
client, headers, first_ids, search_space_id=search_space_id, timeout=300.0
)
for did in first_ids:
assert statuses1[did]["status"]["state"] == "ready"
resp2 = await upload_file(
client,
headers,
"sample.pdf",
search_space_id=search_space_id,
filename_override="sample_copy.pdf",
)
assert resp2.status_code == 200
second_ids = resp2.json()["document_ids"]
cleanup_doc_ids.extend(second_ids)
statuses2 = await poll_document_status(
client, headers, second_ids, search_space_id=search_space_id, timeout=300.0
)
for did in second_ids:
assert statuses2[did]["status"]["state"] == "failed"
reason = statuses2[did]["status"].get("reason", "").lower()
assert "page limit" in reason, (
f"Expected 'page limit' in failure reason, got: {reason!r}"
)

View file

@ -0,0 +1,145 @@
"""
Integration tests for backend file upload limit enforcement.
These tests verify that the API rejects uploads that exceed:
- Max files per upload (10)
- Max per-file size (50 MB)
- Max total upload size (200 MB)
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
enforced server-side to protect against direct API calls.
Prerequisites:
- PostgreSQL + pgvector
"""
from __future__ import annotations
import io
import httpx
import pytest
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Test A: File count limit
# ---------------------------------------------------------------------------
class TestFileCountLimit:
"""Uploading more than 10 files in a single request should be rejected."""
async def test_11_files_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
files = [
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
for i in range(11)
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "too many files" in resp.json()["detail"].lower()
async def test_10_files_accepted(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
files = [
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
for i in range(10)
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 200
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
# ---------------------------------------------------------------------------
# Test B: Per-file size limit
# ---------------------------------------------------------------------------
class TestPerFileSizeLimit:
"""A single file exceeding 50 MB should be rejected."""
async def test_oversized_file_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
oversized = io.BytesIO(b"\x00" * (50 * 1024 * 1024 + 1))
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=[("files", ("big.pdf", oversized, "application/pdf"))],
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "per-file limit" in resp.json()["detail"].lower()
async def test_file_at_limit_accepted(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
at_limit = io.BytesIO(b"\x00" * (50 * 1024 * 1024))
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=[("files", ("exact50mb.txt", at_limit, "text/plain"))],
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 200
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
# ---------------------------------------------------------------------------
# Test C: Total upload size limit
# ---------------------------------------------------------------------------
class TestTotalSizeLimit:
"""Multiple files whose combined size exceeds 200 MB should be rejected."""
async def test_total_size_over_200mb_returns_413(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
):
chunk_size = 45 * 1024 * 1024 # 45 MB each
files = [
(
"files",
(f"chunk_{i}.txt", io.BytesIO(b"\x00" * chunk_size), "text/plain"),
)
for i in range(5) # 5 x 45 MB = 225 MB > 200 MB
]
resp = await client.post(
"/api/v1/documents/fileupload",
headers=headers,
files=files,
data={"search_space_id": str(search_space_id)},
)
assert resp.status_code == 413
assert "total upload size" in resp.json()["detail"].lower()