diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py index 6cc5655c4..67254ec93 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py @@ -1,6 +1,7 @@ -"""Integration tests for local folder indexer — Tier 3 (I1-I5), Tier 4 (F1-F7), Tier 5 (P1).""" +"""Integration tests for local folder indexer — Tier 3 (I1-I5), Tier 4 (F1-F7), Tier 5 (P1), Tier 6 (B1-B2).""" import os +from contextlib import asynccontextmanager from pathlib import Path import pytest @@ -24,6 +25,34 @@ UNIFIED_FIXTURES = ( ) +class _FakeSessionMaker: + """Wraps an existing AsyncSession so ``async with factory()`` yields it + without closing it. Used to route batch-mode DB operations through the + test's savepoint-wrapped session.""" + + def __init__(self, session: AsyncSession): + self._session = session + + def __call__(self): + @asynccontextmanager + async def _ctx(): + yield self._session + return _ctx() + + +@pytest.fixture +def patched_batch_sessions(monkeypatch, db_session): + """Make ``_index_batch_files`` use the test session and run sequentially.""" + monkeypatch.setattr( + "app.tasks.connector_indexers.local_folder_indexer.get_celery_session_maker", + lambda: _FakeSessionMaker(db_session), + ) + monkeypatch.setattr( + "app.tasks.connector_indexers.local_folder_indexer.BATCH_CONCURRENCY", + 1, + ) + + # ==================================================================== # Tier 3: Full Indexer Integration (I1-I5) # ==================================================================== @@ -597,6 +626,105 @@ class TestFolderMirroring: assert notes_after is None +# ==================================================================== +# Tier 6: Batch Mode (B1-B2) +# ==================================================================== + + +class TestBatchMode: + + @pytest.mark.usefixtures(*UNIFIED_FIXTURES) + async def test_b1_batch_indexes_multiple_files( + self, + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, + tmp_path: Path, + patched_batch_sessions, + ): + """B1: Batch with 3 files indexes all of them.""" + from app.tasks.connector_indexers.local_folder_indexer import index_local_folder + + (tmp_path / "a.md").write_text("File A content") + (tmp_path / "b.md").write_text("File B content") + (tmp_path / "c.md").write_text("File C content") + + count, failed, root_folder_id, err = await index_local_folder( + session=db_session, + search_space_id=db_search_space.id, + user_id=str(db_user.id), + folder_path=str(tmp_path), + folder_name="test-folder", + target_file_paths=[ + str(tmp_path / "a.md"), + str(tmp_path / "b.md"), + str(tmp_path / "c.md"), + ], + ) + + assert count == 3 + assert failed == 0 + assert err is None + + docs = ( + await db_session.execute( + select(Document).where( + Document.document_type == DocumentType.LOCAL_FOLDER_FILE, + Document.search_space_id == db_search_space.id, + ) + ) + ).scalars().all() + assert len(docs) == 3 + assert {d.title for d in docs} == {"a.md", "b.md", "c.md"} + assert all( + DocumentStatus.is_state(d.status, DocumentStatus.READY) for d in docs + ) + + @pytest.mark.usefixtures(*UNIFIED_FIXTURES) + async def test_b2_partial_failure( + self, + db_session: AsyncSession, + db_user: User, + db_search_space: SearchSpace, + tmp_path: Path, + patched_batch_sessions, + ): + """B2: One unreadable file fails gracefully; the other two still get indexed.""" + from app.tasks.connector_indexers.local_folder_indexer import index_local_folder + + (tmp_path / "good1.md").write_text("Good file one") + (tmp_path / "good2.md").write_text("Good file two") + (tmp_path / "bad.md").write_bytes(b"\x00binary garbage") + + count, failed, _, err = await index_local_folder( + session=db_session, + search_space_id=db_search_space.id, + user_id=str(db_user.id), + folder_path=str(tmp_path), + folder_name="test-folder", + target_file_paths=[ + str(tmp_path / "good1.md"), + str(tmp_path / "bad.md"), + str(tmp_path / "good2.md"), + ], + ) + + assert count == 2 + assert failed == 1 + assert err is not None + + docs = ( + await db_session.execute( + select(Document).where( + Document.document_type == DocumentType.LOCAL_FOLDER_FILE, + Document.search_space_id == db_search_space.id, + ) + ) + ).scalars().all() + assert len(docs) == 2 + assert {d.title for d in docs} == {"good1.md", "good2.md"} + + # ==================================================================== # Tier 5: Pipeline Integration (P1) # ====================================================================