diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py new file mode 100644 index 000000000..d388c96ed --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -0,0 +1,56 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import Document, DocumentStatus +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash + + +class IndexingPipelineService: + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def prepare_for_indexing( + self, connector_docs: list[ConnectorDocument] + ) -> list[Document]: + documents = [] + + for connector_doc in connector_docs: + unique_identifier_hash = compute_unique_identifier_hash(connector_doc) + content_hash = compute_content_hash(connector_doc) + + result = await self.session.execute( + select(Document).filter(Document.unique_identifier_hash == unique_identifier_hash) + ) + existing = result.scalars().first() + + if existing is not None: + if existing.content_hash == content_hash: + if existing.title != connector_doc.title: + existing.title = connector_doc.title + continue + + existing.title = connector_doc.title + existing.content_hash = content_hash + existing.source_markdown = connector_doc.source_markdown + existing.status = DocumentStatus.pending() + documents.append(existing) + continue + + document = Document( + title=connector_doc.title, + document_type=connector_doc.document_type, + content="Pending...", + content_hash=content_hash, + unique_identifier_hash=unique_identifier_hash, + source_markdown=connector_doc.source_markdown, + document_metadata=connector_doc.metadata, + search_space_id=connector_doc.search_space_id, + connector_id=connector_doc.connector_id, + status=DocumentStatus.pending(), + ) + self.session.add(document) + documents.append(document) + + await self.session.commit() + return documents diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index af3abc457..79277b36d 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -78,6 +78,7 @@ dev = [ [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" +asyncio_default_test_loop_scope = "session" testpaths = ["tests"] markers = [ "unit: pure logic tests, no DB or external services", diff --git a/surfsense_backend/tests/integration/conftest.py b/surfsense_backend/tests/integration/conftest.py index 3fd4e1a31..fb6fca773 100644 --- a/surfsense_backend/tests/integration/conftest.py +++ b/surfsense_backend/tests/integration/conftest.py @@ -1,5 +1,6 @@ import os +import uuid from unittest.mock import AsyncMock, MagicMock import pytest @@ -8,7 +9,8 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.pool import NullPool -from app.db import Base +from app.db import Base, SearchSpace +from app.db import User _EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+ @@ -18,7 +20,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB) @pytest_asyncio.fixture(scope="session") async def async_engine(): - engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool, echo=False) + engine = create_async_engine( + TEST_DATABASE_URL, + poolclass=NullPool, + echo=False, + # Required for asyncpg + savepoints: disables prepared statement cache + # to prevent "another operation is in progress" errors during savepoint rollbacks. + connect_args={"prepared_statement_cache_size": 0}, + ) async with engine.begin() as conn: await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) @@ -26,8 +35,11 @@ async def async_engine(): yield engine + # drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots). + # DROP SCHEMA CASCADE handles this without needing topological sort. async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) + await conn.execute(text("DROP SCHEMA public CASCADE")) + await conn.execute(text("CREATE SCHEMA public")) await engine.dispose() @@ -50,6 +62,32 @@ async def db_session(async_engine) -> AsyncSession: await transaction.rollback() +@pytest_asyncio.fixture +async def db_user(db_session: AsyncSession) -> User: + user = User( + id=uuid.uuid4(), + email="test@surfsense.net", + hashed_password="hashed", + is_active=True, + is_superuser=False, + is_verified=True, + ) + db_session.add(user) + await db_session.flush() + return user + + +@pytest_asyncio.fixture +async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace: + space = SearchSpace( + name="Test Space", + user_id=db_user.id, + ) + db_session.add(space) + await db_session.flush() + return space + + @pytest.fixture def mock_llm() -> AsyncMock: llm = AsyncMock() diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py b/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py new file mode 100644 index 000000000..9ab46943e --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_prepare_for_indexing.py @@ -0,0 +1,145 @@ +import pytest +from sqlalchemy import select + +from app.db import Document, DocumentStatus +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +pytestmark = pytest.mark.integration + + +async def test_new_document_is_persisted_with_pending_status( + db_session, db_search_space, make_connector_document +): + doc = make_connector_document(search_space_id=db_search_space.id) + service = IndexingPipelineService(session=db_session) + + results = await service.prepare_for_indexing([doc]) + + assert len(results) == 1 + document_id = results[0].id + + result = await db_session.execute(select(Document).filter(Document.id == document_id)) + reloaded = result.scalars().first() + + assert reloaded is not None + assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING) + + +async def test_unchanged_document_is_skipped( + db_session, db_search_space, make_connector_document +): + doc = make_connector_document(search_space_id=db_search_space.id) + service = IndexingPipelineService(session=db_session) + + await service.prepare_for_indexing([doc]) + results = await service.prepare_for_indexing([doc]) + + assert results == [] + + +async def test_title_only_change_updates_title_in_db( + db_session, db_search_space, make_connector_document +): + original = make_connector_document(search_space_id=db_search_space.id, title="Original Title") + service = IndexingPipelineService(session=db_session) + + first = await service.prepare_for_indexing([original]) + document_id = first[0].id + + renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title") + results = await service.prepare_for_indexing([renamed]) + + assert results == [] + + result = await db_session.execute(select(Document).filter(Document.id == document_id)) + reloaded = result.scalars().first() + + assert reloaded.title == "Updated Title" + + +async def test_changed_content_is_returned_for_reprocessing( + db_session, db_search_space, make_connector_document +): + original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1") + service = IndexingPipelineService(session=db_session) + + first = await service.prepare_for_indexing([original]) + original_id = first[0].id + + updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2") + results = await service.prepare_for_indexing([updated]) + + assert len(results) == 1 + assert results[0].id == original_id + + result = await db_session.execute(select(Document).filter(Document.id == original_id)) + reloaded = result.scalars().first() + + assert reloaded.source_markdown == "## v2" + assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING) + + +async def test_all_documents_in_batch_are_persisted( + db_session, db_search_space, make_connector_document +): + docs = [ + make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"), + make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"), + make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"), + ] + service = IndexingPipelineService(session=db_session) + + results = await service.prepare_for_indexing(docs) + + assert len(results) == 3 + + result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id)) + rows = result.scalars().all() + + assert len(rows) == 3 + + +async def test_duplicate_in_batch_is_persisted_once( + db_session, db_search_space, make_connector_document +): + doc = make_connector_document(search_space_id=db_search_space.id) + service = IndexingPipelineService(session=db_session) + + results = await service.prepare_for_indexing([doc, doc]) + + assert len(results) == 1 + + result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id)) + rows = result.scalars().all() + + assert len(rows) == 1 + + +async def test_title_and_content_change_updates_both_and_returns_document( + db_session, db_search_space, make_connector_document +): + original = make_connector_document( + search_space_id=db_search_space.id, + title="Original Title", + source_markdown="## v1", + ) + service = IndexingPipelineService(session=db_session) + + first = await service.prepare_for_indexing([original]) + original_id = first[0].id + + updated = make_connector_document( + search_space_id=db_search_space.id, + title="Updated Title", + source_markdown="## v2", + ) + results = await service.prepare_for_indexing([updated]) + + assert len(results) == 1 + assert results[0].id == original_id + + result = await db_session.execute(select(Document).filter(Document.id == original_id)) + reloaded = result.scalars().first() + + assert reloaded.title == "Updated Title" + assert reloaded.source_markdown == "## v2"