feat: implement task dispatcher for document processing

- Introduced a TaskDispatcher abstraction to decouple the upload endpoint from Celery, allowing for easier testing with synchronous implementations.
- Updated the create_documents_file_upload function to utilize the new dispatcher for task management.
- Removed direct Celery task imports from the upload function, enhancing modularity.
- Added integration tests for document upload, including page limit enforcement and file size restrictions.
This commit is contained in:
Anish Sarkar 2026-02-26 23:55:47 +05:30
parent 30617c6e54
commit 3393e435f9
9 changed files with 380 additions and 280 deletions

View file

@ -1,198 +0,0 @@
"""E2e conftest — fixtures that require a running backend + database."""
from __future__ import annotations
from collections.abc import AsyncGenerator
import asyncpg
import httpx
import pytest
from tests.conftest import DATABASE_URL
from tests.utils.helpers import (
BACKEND_URL,
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
# ---------------------------------------------------------------------------
# Backend connectivity fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
def backend_url() -> str:
return BACKEND_URL
@pytest.fixture(scope="session")
async def auth_token(backend_url: str) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
return await get_auth_token(client)
@pytest.fixture(scope="session")
async def search_space_id(backend_url: str, auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client:
return await get_search_space_id(client, auth_token)
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(
search_space_id: int,
):
"""
Delete all documents in the test search space before the session starts.
Uses direct database access to bypass the API's 409 protection on
pending/processing documents. This ensures stuck documents from
previous crashed runs are always cleaned up.
"""
deleted = await _force_delete_documents_db(search_space_id)
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) from search space {search_space_id}"
)
yield
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
"""Authorization headers reused across all tests in the session."""
return auth_headers(auth_token)
@pytest.fixture
async def client(backend_url: str) -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client pointing at the running backend."""
async with httpx.AsyncClient(base_url=backend_url, timeout=180.0) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
"""
Runs after every test. Tries the API first for clean deletes, then
falls back to direct DB access for any stuck documents.
"""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB access)
# ---------------------------------------------------------------------------
async def _force_delete_documents_db(search_space_id: int) -> int:
"""
Bypass the API and delete documents directly from the database.
This handles stuck documents in pending/processing state that the API
refuses to delete (409 Conflict). Chunks are cascade-deleted by the
foreign key constraint.
Returns the number of deleted rows.
"""
conn = await asyncpg.connect(DATABASE_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
return int(result.split()[-1])
finally:
await conn.close()
async def _get_user_page_usage(email: str) -> tuple[int, int]:
"""Return ``(pages_used, pages_limit)`` for the given user."""
conn = await asyncpg.connect(DATABASE_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
"""Overwrite ``pages_used`` and ``pages_limit`` for the given user."""
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""
Fixture that exposes helpers for manipulating the test user's page limits.
Automatically restores the original values after each test.
Usage inside a test::
await page_limits.set(pages_used=0, pages_limit=100)
used, limit = await page_limits.get()
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
async def get(self) -> tuple[int, int]:
return await _get_user_page_usage(TEST_EMAIL)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)

View file

@ -0,0 +1,282 @@
"""Integration conftest — runs the FastAPI app in-process via ASGITransport.
Prerequisites: PostgreSQL + pgvector only.
External system boundaries are mocked:
- LLM summarization, text embedding, text chunking (external APIs)
- Redis heartbeat (external infrastructure)
- Task dispatch is swapped via DI (InlineTaskDispatcher)
"""
from __future__ import annotations
import contextlib
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import asyncpg
import httpx
import pytest
from httpx import ASGITransport
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
from app.app import app
from app.config import config as app_config
from app.db import DATABASE_URL as APP_DB_URL, Base
from app.services.task_dispatcher import get_task_dispatcher
from tests.conftest import DATABASE_URL
from tests.utils.helpers import (
TEST_EMAIL,
auth_headers,
delete_document,
get_auth_token,
get_search_space_id,
)
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Inline task dispatcher (replaces Celery via DI — not a mock)
# ---------------------------------------------------------------------------
class InlineTaskDispatcher:
"""Processes files synchronously in the calling coroutine.
Swapped in via FastAPI dependency_overrides so the upload endpoint
processes documents inline instead of dispatching to Celery.
Exceptions are caught to match Celery's fire-and-forget semantics —
the processing function already marks documents as failed internally.
"""
async def dispatch_file_processing(
self,
*,
document_id: int,
temp_path: str,
filename: str,
search_space_id: int,
user_id: str,
) -> None:
from app.tasks.celery_tasks.document_tasks import (
_process_file_with_document,
)
with contextlib.suppress(Exception):
await _process_file_with_document(
document_id, temp_path, filename, search_space_id, user_id
)
app.dependency_overrides[get_task_dispatcher] = lambda: InlineTaskDispatcher()
# ---------------------------------------------------------------------------
# Database setup (ASGITransport skips the app lifespan)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def _ensure_tables():
"""Create DB tables and extensions once per session."""
engine = create_async_engine(APP_DB_URL, poolclass=NullPool)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
await conn.run_sync(Base.metadata.create_all)
await engine.dispose()
# ---------------------------------------------------------------------------
# Auth & search space (session-scoped, via the in-process app)
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session")
async def auth_token(_ensure_tables) -> str:
"""Authenticate once per session, registering the user if needed."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_auth_token(c)
@pytest.fixture(scope="session")
async def search_space_id(auth_token: str) -> int:
"""Discover the first search space belonging to the test user."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=30.0
) as c:
return await get_search_space_id(c, auth_token)
@pytest.fixture(scope="session")
def headers(auth_token: str) -> dict[str, str]:
return auth_headers(auth_token)
# ---------------------------------------------------------------------------
# Per-test HTTP client & cleanup
# ---------------------------------------------------------------------------
@pytest.fixture
async def client() -> AsyncGenerator[httpx.AsyncClient]:
"""Per-test async HTTP client using ASGITransport (no running server)."""
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test", timeout=180.0
) as c:
yield c
@pytest.fixture
def cleanup_doc_ids() -> list[int]:
"""Accumulator for document IDs that should be deleted after the test."""
return []
@pytest.fixture(scope="session", autouse=True)
async def _purge_test_search_space(search_space_id: int):
"""Delete stale documents from previous runs before the session starts."""
conn = await asyncpg.connect(DATABASE_URL)
try:
result = await conn.execute(
"DELETE FROM documents WHERE search_space_id = $1",
search_space_id,
)
deleted = int(result.split()[-1])
if deleted:
print(
f"\n[purge] Deleted {deleted} stale document(s) "
f"from search space {search_space_id}"
)
finally:
await conn.close()
yield
@pytest.fixture(autouse=True)
async def _cleanup_documents(
client: httpx.AsyncClient,
headers: dict[str, str],
cleanup_doc_ids: list[int],
):
"""Delete test documents after every test (API first, DB fallback)."""
yield
remaining_ids: list[int] = []
for doc_id in cleanup_doc_ids:
try:
resp = await delete_document(client, headers, doc_id)
if resp.status_code == 409:
remaining_ids.append(doc_id)
except Exception:
remaining_ids.append(doc_id)
if remaining_ids:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
"DELETE FROM documents WHERE id = ANY($1::int[])",
remaining_ids,
)
finally:
await conn.close()
# ---------------------------------------------------------------------------
# Page-limit helpers (direct DB for setup, API for verification)
# ---------------------------------------------------------------------------
async def _get_user_page_usage(email: str) -> tuple[int, int]:
conn = await asyncpg.connect(DATABASE_URL)
try:
row = await conn.fetchrow(
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
email,
)
assert row is not None, f"User {email!r} not found in database"
return row["pages_used"], row["pages_limit"]
finally:
await conn.close()
async def _set_user_page_limits(
email: str, *, pages_used: int, pages_limit: int
) -> None:
conn = await asyncpg.connect(DATABASE_URL)
try:
await conn.execute(
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
pages_used,
pages_limit,
email,
)
finally:
await conn.close()
@pytest.fixture
async def page_limits():
"""Manipulate the test user's page limits (direct DB for setup only).
Automatically restores original values after each test.
"""
class _PageLimits:
async def set(self, *, pages_used: int, pages_limit: int) -> None:
await _set_user_page_limits(
TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit
)
original = await _get_user_page_usage(TEST_EMAIL)
yield _PageLimits()
await _set_user_page_limits(
TEST_EMAIL, pages_used=original[0], pages_limit=original[1]
)
# ---------------------------------------------------------------------------
# Mock external system boundaries
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _mock_external_apis(monkeypatch):
"""Mock LLM, embedding, and chunking — these are external API boundaries."""
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Mocked summary."),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
MagicMock(return_value=[0.1] * _EMBEDDING_DIM),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
MagicMock(return_value=["Test chunk content."]),
)
@pytest.fixture(autouse=True)
def _mock_redis_heartbeat(monkeypatch):
"""Mock Redis heartbeat — Redis is an external infrastructure boundary."""
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._start_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._stop_heartbeat",
lambda notification_id: None,
)
monkeypatch.setattr(
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
AsyncMock(),
)

View file

@ -1,14 +1,14 @@
"""
End-to-end tests for manual document upload.
Integration tests for manual document upload.
These tests exercise the full pipeline:
API upload Celery task ETL extraction chunking embedding DB storage
These tests exercise the full pipeline via the HTTP API:
API upload inline task dispatch ETL extraction chunking embedding DB storage
Prerequisites (must be running):
- FastAPI backend
External boundaries mocked: LLM summarization, text embedding, text chunking,
Redis heartbeat. Task dispatch is swapped via DI (InlineTaskDispatcher).
Prerequisites:
- PostgreSQL + pgvector
- Redis
- Celery worker
"""
from __future__ import annotations
@ -28,7 +28,7 @@ from tests.utils.helpers import (
upload_multiple_files,
)
pytestmark = pytest.mark.e2e
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helpers local to this module
@ -45,7 +45,7 @@ def _assert_document_ready(doc: dict, *, expected_filename: str) -> None:
# ---------------------------------------------------------------------------
# Test A: Upload a .txt file (direct read path — no ETL service needed)
# Test A: Upload a .txt file (direct read path)
# ---------------------------------------------------------------------------
@ -108,7 +108,6 @@ class TestTxtFileUpload:
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.txt")
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
# ---------------------------------------------------------------------------
@ -158,11 +157,10 @@ class TestMarkdownFileUpload:
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.md")
assert doc["document_metadata"]["ETL_SERVICE"] == "MARKDOWN"
# ---------------------------------------------------------------------------
# Test C: Upload a .pdf file (ETL path — Docling / Unstructured)
# Test C: Upload a .pdf file (ETL path)
# ---------------------------------------------------------------------------
@ -208,11 +206,6 @@ class TestPdfFileUpload:
doc = await get_document(client, headers, doc_ids[0])
_assert_document_ready(doc, expected_filename="sample.pdf")
assert doc["document_metadata"]["ETL_SERVICE"] in {
"DOCLING",
"UNSTRUCTURED",
"LLAMACLOUD",
}
# ---------------------------------------------------------------------------
@ -284,7 +277,6 @@ class TestDuplicateFileUpload:
search_space_id: int,
cleanup_doc_ids: list[int],
):
# First upload
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
@ -296,7 +288,6 @@ class TestDuplicateFileUpload:
client, headers, first_ids, search_space_id=search_space_id
)
# Second upload of the same file
resp2 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
@ -327,7 +318,6 @@ class TestDuplicateContentDetection:
cleanup_doc_ids: list[int],
tmp_path: Path,
):
# First upload
resp1 = await upload_file(
client, headers, "sample.txt", search_space_id=search_space_id
)
@ -338,7 +328,6 @@ class TestDuplicateContentDetection:
client, headers, first_ids, search_space_id=search_space_id
)
# Copy fixture content to a differently named temp file
src = FIXTURES_DIR / "sample.txt"
dest = tmp_path / "renamed_sample.txt"
shutil.copy2(src, dest)
@ -477,39 +466,7 @@ class TestDocumentDeletion:
# ---------------------------------------------------------------------------
# Test K: Cannot delete a document while it is still processing
# ---------------------------------------------------------------------------
class TestDeleteWhileProcessing:
"""Attempting to delete a pending/processing document should be rejected."""
async def test_delete_pending_document_returns_409(
self,
client: httpx.AsyncClient,
headers: dict[str, str],
search_space_id: int,
cleanup_doc_ids: list[int],
):
resp = await upload_file(
client, headers, "sample.pdf", search_space_id=search_space_id
)
assert resp.status_code == 200
doc_ids = resp.json()["document_ids"]
cleanup_doc_ids.extend(doc_ids)
# Immediately try to delete before processing finishes
del_resp = await delete_document(client, headers, doc_ids[0])
assert del_resp.status_code == 409
# Let it finish so cleanup can work
await poll_document_status(
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
# ---------------------------------------------------------------------------
# Test L: Status polling returns correct structure
# Test K: Searchability after upload
# ---------------------------------------------------------------------------
@ -547,6 +504,11 @@ class TestDocumentSearchability:
)
# ---------------------------------------------------------------------------
# Test L: Status polling returns correct structure
# ---------------------------------------------------------------------------
class TestStatusPolling:
"""Verify the status endpoint returns well-formed responses."""

View file

@ -1,23 +1,20 @@
"""
End-to-end tests for page-limit enforcement during document upload.
Integration tests for page-limit enforcement during document upload.
These tests manipulate the test user's ``pages_used`` / ``pages_limit``
columns directly in the database and then exercise the upload pipeline to
verify that:
columns directly in the database (setup only) and then exercise the upload
pipeline to verify that:
- Uploads are rejected *before* ETL when the limit is exhausted.
- ``pages_used`` increases after a successful upload.
- ``pages_used`` increases after a successful upload (verified via API).
- A ``page_limit_exceeded`` notification is created on rejection.
- ``pages_used`` is not modified when a document fails processing.
All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``)
so no additional processing time is introduced.
Prerequisites (must be running):
- FastAPI backend
Prerequisites:
- PostgreSQL + pgvector
- Redis
- Celery worker
"""
from __future__ import annotations
@ -31,7 +28,21 @@ from tests.utils.helpers import (
upload_file,
)
pytestmark = pytest.mark.e2e
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helper: read pages_used through the public API
# ---------------------------------------------------------------------------
async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int:
"""Fetch the current user's pages_used via the /users/me API."""
resp = await client.get("/users/me", headers=headers)
assert resp.status_code == 200, (
f"GET /users/me failed ({resp.status_code}): {resp.text}"
)
return resp.json()["pages_used"]
# ---------------------------------------------------------------------------
@ -65,7 +76,7 @@ class TestPageUsageIncrementsOnSuccess:
for did in doc_ids:
assert statuses[did]["status"]["state"] == "ready"
used, _ = await page_limits.get()
used = await _get_pages_used(client, headers)
assert used > 0, "pages_used should have increased after successful processing"
@ -128,7 +139,7 @@ class TestUploadRejectedWhenLimitExhausted:
client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0
)
used, _ = await page_limits.get()
used = await _get_pages_used(client, headers)
assert used == 50, (
f"pages_used should remain 50 after rejected upload, got {used}"
)
@ -263,7 +274,7 @@ class TestPagesUnchangedOnProcessingFailure:
for did in doc_ids:
assert statuses[did]["status"]["state"] == "failed"
used, _ = await page_limits.get()
used = await _get_pages_used(client, headers)
assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}"
@ -284,7 +295,6 @@ class TestSecondUploadExceedsLimit:
cleanup_doc_ids: list[int],
page_limits,
):
# Give just enough room for one ~1-page PDF
await page_limits.set(pages_used=0, pages_limit=1)
resp1 = await upload_file(
@ -300,7 +310,6 @@ class TestSecondUploadExceedsLimit:
for did in first_ids:
assert statuses1[did]["status"]["state"] == "ready"
# Second upload — should fail because quota is now consumed
resp2 = await upload_file(
client,
headers,

View file

@ -1,5 +1,5 @@
"""
End-to-end tests for backend file upload limit enforcement.
Integration tests for backend file upload limit enforcement.
These tests verify that the API rejects uploads that exceed:
- Max files per upload (10)
@ -9,8 +9,7 @@ These tests verify that the API rejects uploads that exceed:
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
enforced server-side to protect against direct API calls.
Prerequisites (must be running):
- FastAPI backend
Prerequisites:
- PostgreSQL + pgvector
"""
@ -21,7 +20,7 @@ import io
import httpx
import pytest
pytestmark = pytest.mark.e2e
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------

View file

@ -3,16 +3,14 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
import httpx
FIXTURES_DIR = Path(__file__).resolve().parent.parent / "fixtures"
BACKEND_URL = os.environ.get("TEST_BACKEND_URL", "http://localhost:8000")
TEST_EMAIL = os.environ.get("TEST_USER_EMAIL", "testuser@surfsense.com")
TEST_PASSWORD = os.environ.get("TEST_USER_PASSWORD", "testpassword123")
TEST_EMAIL = "testuser@surfsense.com"
TEST_PASSWORD = "testpassword123"
async def get_auth_token(client: httpx.AsyncClient) -> str: