feat: implement and test prepare_for_indexing

This commit is contained in:
CREDO23 2026-02-25 00:06:34 +02:00
parent a0134a5830
commit 579a9e2cb5
4 changed files with 243 additions and 3 deletions

View file

@ -1,5 +1,6 @@
import os
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -8,7 +9,8 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool
from app.db import Base
from app.db import Base, SearchSpace
from app.db import User
_EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+
@ -18,7 +20,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
@pytest_asyncio.fixture(scope="session")
async def async_engine():
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool, echo=False)
engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=NullPool,
echo=False,
# Required for asyncpg + savepoints: disables prepared statement cache
# to prevent "another operation is in progress" errors during savepoint rollbacks.
connect_args={"prepared_statement_cache_size": 0},
)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
@ -26,8 +35,11 @@ async def async_engine():
yield engine
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
# DROP SCHEMA CASCADE handles this without needing topological sort.
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.execute(text("DROP SCHEMA public CASCADE"))
await conn.execute(text("CREATE SCHEMA public"))
await engine.dispose()
@ -50,6 +62,32 @@ async def db_session(async_engine) -> AsyncSession:
await transaction.rollback()
@pytest_asyncio.fixture
async def db_user(db_session: AsyncSession) -> User:
user = User(
id=uuid.uuid4(),
email="test@surfsense.net",
hashed_password="hashed",
is_active=True,
is_superuser=False,
is_verified=True,
)
db_session.add(user)
await db_session.flush()
return user
@pytest_asyncio.fixture
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
space = SearchSpace(
name="Test Space",
user_id=db_user.id,
)
db_session.add(space)
await db_session.flush()
return space
@pytest.fixture
def mock_llm() -> AsyncMock:
llm = AsyncMock()

View file

@ -0,0 +1,145 @@
import pytest
from sqlalchemy import select
from app.db import Document, DocumentStatus
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration
async def test_new_document_is_persisted_with_pending_status(
db_session, db_search_space, make_connector_document
):
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing([doc])
assert len(results) == 1
document_id = results[0].id
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded is not None
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
async def test_unchanged_document_is_skipped(
db_session, db_search_space, make_connector_document
):
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
await service.prepare_for_indexing([doc])
results = await service.prepare_for_indexing([doc])
assert results == []
async def test_title_only_change_updates_title_in_db(
db_session, db_search_space, make_connector_document
):
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title")
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
document_id = first[0].id
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title")
results = await service.prepare_for_indexing([renamed])
assert results == []
result = await db_session.execute(select(Document).filter(Document.id == document_id))
reloaded = result.scalars().first()
assert reloaded.title == "Updated Title"
async def test_changed_content_is_returned_for_reprocessing(
db_session, db_search_space, make_connector_document
):
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
original_id = first[0].id
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
results = await service.prepare_for_indexing([updated])
assert len(results) == 1
assert results[0].id == original_id
result = await db_session.execute(select(Document).filter(Document.id == original_id))
reloaded = result.scalars().first()
assert reloaded.source_markdown == "## v2"
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
async def test_all_documents_in_batch_are_persisted(
db_session, db_search_space, make_connector_document
):
docs = [
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"),
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"),
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"),
]
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing(docs)
assert len(results) == 3
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
rows = result.scalars().all()
assert len(rows) == 3
async def test_duplicate_in_batch_is_persisted_once(
db_session, db_search_space, make_connector_document
):
doc = make_connector_document(search_space_id=db_search_space.id)
service = IndexingPipelineService(session=db_session)
results = await service.prepare_for_indexing([doc, doc])
assert len(results) == 1
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
rows = result.scalars().all()
assert len(rows) == 1
async def test_title_and_content_change_updates_both_and_returns_document(
db_session, db_search_space, make_connector_document
):
original = make_connector_document(
search_space_id=db_search_space.id,
title="Original Title",
source_markdown="## v1",
)
service = IndexingPipelineService(session=db_session)
first = await service.prepare_for_indexing([original])
original_id = first[0].id
updated = make_connector_document(
search_space_id=db_search_space.id,
title="Updated Title",
source_markdown="## v2",
)
results = await service.prepare_for_indexing([updated])
assert len(results) == 1
assert results[0].id == original_id
result = await db_session.execute(select(Document).filter(Document.id == original_id))
reloaded = result.scalars().first()
assert reloaded.title == "Updated Title"
assert reloaded.source_markdown == "## v2"