diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 1a61e779e..bd6086892 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -1,6 +1,8 @@ import asyncio import contextlib +import logging import time +from collections.abc import Awaitable, Callable from datetime import UTC, datetime from sqlalchemy import delete, select @@ -327,3 +329,105 @@ class IndexingPipelineService: await self.session.refresh(document) return document + + async def index_batch_parallel( + self, + connector_docs: list[ConnectorDocument], + get_llm: Callable[[AsyncSession], Awaitable], + *, + max_concurrency: int = 4, + on_heartbeat: Callable[[int], Awaitable[None]] | None = None, + heartbeat_interval: float = 30.0, + ) -> tuple[list[Document], int, int]: + """Index documents in parallel with bounded concurrency. + + Phase 1 (serial): prepare_for_indexing using self.session. + Phase 2 (parallel): index each document in an isolated session, + bounded by a semaphore to avoid overwhelming APIs/DB. + """ + logger = logging.getLogger(__name__) + + doc_map = { + compute_unique_identifier_hash(cd): cd for cd in connector_docs + } + documents = await self.prepare_for_indexing(connector_docs) + + if not documents: + return [], 0, 0 + + from app.tasks.celery_tasks import get_celery_session_maker + + sem = asyncio.Semaphore(max_concurrency) + lock = asyncio.Lock() + indexed_count = 0 + failed_count = 0 + results: list[Document] = [] + last_heartbeat = time.time() + + async def _index_one(document: Document) -> Document | Exception: + nonlocal indexed_count, failed_count, last_heartbeat + + connector_doc = doc_map.get(document.unique_identifier_hash) + if connector_doc is None: + logger.warning( + "No matching ConnectorDocument for document %s, skipping", + document.id, + ) + async with lock: + failed_count += 1 + return document + + async with sem: + session_maker = get_celery_session_maker() + async with session_maker() as isolated_session: + try: + refetched = await isolated_session.get( + Document, document.id + ) + if refetched is None: + async with lock: + failed_count += 1 + return document + + llm = await get_llm(isolated_session) + iso_pipeline = IndexingPipelineService(isolated_session) + result = await iso_pipeline.index( + refetched, connector_doc, llm + ) + + async with lock: + if DocumentStatus.is_state( + result.status, DocumentStatus.READY + ): + indexed_count += 1 + else: + failed_count += 1 + + if on_heartbeat: + now = time.time() + if now - last_heartbeat >= heartbeat_interval: + await on_heartbeat(indexed_count) + last_heartbeat = now + + return result + except Exception as exc: + logger.error( + "Parallel index failed for doc %s: %s", + document.id, + exc, + exc_info=True, + ) + async with lock: + failed_count += 1 + return exc + + tasks = [_index_one(doc) for doc in documents] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + for outcome in outcomes: + if isinstance(outcome, Document): + results.append(outcome) + elif isinstance(outcome, Exception): + pass + + return results, indexed_count, failed_count diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index a69b33bdc..61b1ccb2b 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -5,7 +5,6 @@ Uses the shared IndexingPipelineService for document deduplication, summarization, chunking, and embedding. """ -import time from collections.abc import Awaitable, Callable from datetime import datetime, timedelta @@ -16,10 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.db import DocumentType, SearchSourceConnectorType from app.indexing_pipeline.connector_document import ConnectorDocument -from app.indexing_pipeline.document_hashing import ( - compute_content_hash, - compute_unique_identifier_hash, -) +from app.indexing_pipeline.document_hashing import compute_content_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService @@ -399,53 +395,21 @@ async def index_google_calendar_events( documents_skipped += 1 continue - # ── Pipeline: migrate legacy docs + prepare + index ─────────── + # ── Pipeline: migrate legacy docs + parallel index ───────────── pipeline = IndexingPipelineService(session) await pipeline.migrate_legacy_docs(connector_docs) - documents = await pipeline.prepare_for_indexing(connector_docs) + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) - doc_map = { - compute_unique_identifier_hash(cd): cd for cd in connector_docs - } - - documents_indexed = 0 - documents_failed = 0 - last_heartbeat_time = time.time() - - for document in documents: - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - connector_doc = doc_map.get(document.unique_identifier_hash) - if connector_doc is None: - logger.warning( - f"No matching ConnectorDocument for document {document.id}, skipping" - ) - documents_failed += 1 - continue - - try: - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - await pipeline.index(document, connector_doc, user_llm) - documents_indexed += 1 - - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Google Calendar events processed so far" - ) - await session.commit() - - except Exception as e: - logger.error(f"Error processing Calendar event: {e!s}", exc_info=True) - documents_failed += 1 - continue + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index 96cc1cbb4..0d77ad3cd 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -5,8 +5,6 @@ Uses the shared IndexingPipelineService for document deduplication, summarization, chunking, and embedding. """ -import logging -import time from collections.abc import Awaitable, Callable from datetime import datetime @@ -17,10 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.google_gmail_connector import GoogleGmailConnector from app.db import DocumentType, SearchSourceConnectorType from app.indexing_pipeline.connector_document import ConnectorDocument -from app.indexing_pipeline.document_hashing import ( - compute_content_hash, - compute_unique_identifier_hash, -) +from app.indexing_pipeline.document_hashing import compute_content_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService @@ -336,53 +331,21 @@ async def index_google_gmail_messages( documents_skipped += 1 continue - # ── Pipeline: migrate legacy docs + prepare + index ─────────── + # ── Pipeline: migrate legacy docs + parallel index ───────────── pipeline = IndexingPipelineService(session) await pipeline.migrate_legacy_docs(connector_docs) - documents = await pipeline.prepare_for_indexing(connector_docs) + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) - doc_map = { - compute_unique_identifier_hash(cd): cd for cd in connector_docs - } - - documents_indexed = 0 - documents_failed = 0 - last_heartbeat_time = time.time() - - for document in documents: - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - connector_doc = doc_map.get(document.unique_identifier_hash) - if connector_doc is None: - logger.warning( - f"No matching ConnectorDocument for document {document.id}, skipping" - ) - documents_failed += 1 - continue - - try: - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - await pipeline.index(document, connector_doc, user_llm) - documents_indexed += 1 - - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Gmail messages processed so far" - ) - await session.commit() - - except Exception as e: - logger.error(f"Error processing Gmail message: {e!s}", exc_info=True) - documents_failed += 1 - continue + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py index 7e23383ac..3148812f8 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py @@ -25,6 +25,15 @@ def pipeline(mock_session): return IndexingPipelineService(mock_session) +def _make_orm_doc(connector_doc, doc_id): + """Create a MagicMock Document bound to a ConnectorDocument's hash.""" + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc) + doc.status = DocumentStatus.pending() + return doc + + async def test_index_calls_embed_and_chunk_via_to_thread( pipeline, make_connector_document, monkeypatch ): @@ -68,3 +77,110 @@ async def test_index_calls_embed_and_chunk_via_to_thread( assert "chunk_text" in to_thread_calls assert "embed_texts" in to_thread_calls + + +def _mock_session_factory(orm_docs_by_id): + """Replace get_celery_session_maker with a two-level callable. + + get_celery_session_maker() -> session_maker + session_maker() -> async context manager yielding a mock session + """ + + def _get_maker(): + def _make_session(): + session = MagicMock() + session.get = AsyncMock( + side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id) + ) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=session) + ctx.__aexit__ = AsyncMock(return_value=False) + return ctx + + return _make_session + + return _get_maker + + +async def test_batch_parallel_indexes_all_documents( + pipeline, make_connector_document, monkeypatch +): + """index_batch_parallel indexes all documents and returns correct counts.""" + docs = [ + make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id=f"msg-{i}", + search_space_id=1, + ) + for i in range(3) + ] + + orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)] + pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs) + + orm_by_id = {d.id: d for d in orm_docs} + monkeypatch.setattr( + "app.tasks.celery_tasks.get_celery_session_maker", + _mock_session_factory(orm_by_id), + ) + + index_calls = [] + + async def fake_index(self, document, connector_doc, llm): + index_calls.append(document.id) + document.status = DocumentStatus.ready() + return document + + monkeypatch.setattr(IndexingPipelineService, "index", fake_index) + + async def mock_get_llm(session): + return MagicMock() + + _, indexed, failed = await pipeline.index_batch_parallel( + docs, mock_get_llm, max_concurrency=2 + ) + + assert indexed == 3 + assert failed == 0 + assert sorted(index_calls) == [1, 2, 3] + + +async def test_batch_parallel_one_failure_does_not_affect_others( + pipeline, make_connector_document, monkeypatch +): + """One document failure doesn't prevent other documents from being indexed.""" + docs = [ + make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id=f"msg-{i}", + search_space_id=1, + ) + for i in range(3) + ] + + orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)] + pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs) + + orm_by_id = {d.id: d for d in orm_docs} + monkeypatch.setattr( + "app.tasks.celery_tasks.get_celery_session_maker", + _mock_session_factory(orm_by_id), + ) + + async def failing_index(self, document, connector_doc, llm): + if document.id == 2: + raise RuntimeError("LLM exploded") + document.status = DocumentStatus.ready() + return document + + monkeypatch.setattr(IndexingPipelineService, "index", failing_index) + + async def mock_get_llm(session): + return MagicMock() + + _, indexed, failed = await pipeline.index_batch_parallel( + docs, mock_get_llm, max_concurrency=4 + ) + + assert indexed == 2 + assert failed == 1