mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
feat: implement parallel indexing for Google Calendar and Gmail connectors
- Refactored Google Calendar and Gmail indexers to utilize the new `index_batch_parallel` method for concurrent document indexing, enhancing performance. - Updated the indexing logic to replace serial processing with parallel execution, allowing for improved efficiency in handling multiple documents. - Adjusted logging and error handling to accommodate the new parallel processing approach, ensuring robust operation during indexing. - Enhanced unit tests to validate the functionality of the parallel indexing method and its integration with existing workflows.
This commit is contained in:
parent
e5cb6bfacf
commit
4fd776e7ef
4 changed files with 242 additions and 95 deletions
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
|
@ -327,3 +329,105 @@ class IndexingPipelineService:
|
|||
await self.session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
async def index_batch_parallel(
|
||||
self,
|
||||
connector_docs: list[ConnectorDocument],
|
||||
get_llm: Callable[[AsyncSession], Awaitable],
|
||||
*,
|
||||
max_concurrency: int = 4,
|
||||
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
|
||||
heartbeat_interval: float = 30.0,
|
||||
) -> tuple[list[Document], int, int]:
|
||||
"""Index documents in parallel with bounded concurrency.
|
||||
|
||||
Phase 1 (serial): prepare_for_indexing using self.session.
|
||||
Phase 2 (parallel): index each document in an isolated session,
|
||||
bounded by a semaphore to avoid overwhelming APIs/DB.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
doc_map = {
|
||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
||||
}
|
||||
documents = await self.prepare_for_indexing(connector_docs)
|
||||
|
||||
if not documents:
|
||||
return [], 0, 0
|
||||
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
lock = asyncio.Lock()
|
||||
indexed_count = 0
|
||||
failed_count = 0
|
||||
results: list[Document] = []
|
||||
last_heartbeat = time.time()
|
||||
|
||||
async def _index_one(document: Document) -> Document | Exception:
|
||||
nonlocal indexed_count, failed_count, last_heartbeat
|
||||
|
||||
connector_doc = doc_map.get(document.unique_identifier_hash)
|
||||
if connector_doc is None:
|
||||
logger.warning(
|
||||
"No matching ConnectorDocument for document %s, skipping",
|
||||
document.id,
|
||||
)
|
||||
async with lock:
|
||||
failed_count += 1
|
||||
return document
|
||||
|
||||
async with sem:
|
||||
session_maker = get_celery_session_maker()
|
||||
async with session_maker() as isolated_session:
|
||||
try:
|
||||
refetched = await isolated_session.get(
|
||||
Document, document.id
|
||||
)
|
||||
if refetched is None:
|
||||
async with lock:
|
||||
failed_count += 1
|
||||
return document
|
||||
|
||||
llm = await get_llm(isolated_session)
|
||||
iso_pipeline = IndexingPipelineService(isolated_session)
|
||||
result = await iso_pipeline.index(
|
||||
refetched, connector_doc, llm
|
||||
)
|
||||
|
||||
async with lock:
|
||||
if DocumentStatus.is_state(
|
||||
result.status, DocumentStatus.READY
|
||||
):
|
||||
indexed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= heartbeat_interval:
|
||||
await on_heartbeat(indexed_count)
|
||||
last_heartbeat = now
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Parallel index failed for doc %s: %s",
|
||||
document.id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
async with lock:
|
||||
failed_count += 1
|
||||
return exc
|
||||
|
||||
tasks = [_index_one(doc) for doc in documents]
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Document):
|
||||
results.append(outcome)
|
||||
elif isinstance(outcome, Exception):
|
||||
pass
|
||||
|
||||
return results, indexed_count, failed_count
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ Uses the shared IndexingPipelineService for document deduplication,
|
|||
summarization, chunking, and embedding.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
|
@ -16,10 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
|
@ -399,53 +395,21 @@ async def index_google_calendar_events(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# ── Pipeline: migrate legacy docs + prepare + index ───────────
|
||||
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
documents = await pipeline.prepare_for_indexing(connector_docs)
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
doc_map = {
|
||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
||||
}
|
||||
|
||||
documents_indexed = 0
|
||||
documents_failed = 0
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
for document in documents:
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
connector_doc = doc_map.get(document.unique_identifier_hash)
|
||||
if connector_doc is None:
|
||||
logger.warning(
|
||||
f"No matching ConnectorDocument for document {document.id}, skipping"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
await pipeline.index(document, connector_doc, user_llm)
|
||||
documents_indexed += 1
|
||||
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
# ── Finalize ──────────────────────────────────────────────────
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ Uses the shared IndexingPipelineService for document deduplication,
|
|||
summarization, chunking, and embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -17,10 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
from app.db import DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_content_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
|
@ -336,53 +331,21 @@ async def index_google_gmail_messages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# ── Pipeline: migrate legacy docs + prepare + index ───────────
|
||||
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
await pipeline.migrate_legacy_docs(connector_docs)
|
||||
|
||||
documents = await pipeline.prepare_for_indexing(connector_docs)
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
doc_map = {
|
||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
||||
}
|
||||
|
||||
documents_indexed = 0
|
||||
documents_failed = 0
|
||||
last_heartbeat_time = time.time()
|
||||
|
||||
for document in documents:
|
||||
if on_heartbeat_callback:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(documents_indexed)
|
||||
last_heartbeat_time = current_time
|
||||
|
||||
connector_doc = doc_map.get(document.unique_identifier_hash)
|
||||
if connector_doc is None:
|
||||
logger.warning(
|
||||
f"No matching ConnectorDocument for document {document.id}, skipping"
|
||||
)
|
||||
documents_failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
await pipeline.index(document, connector_doc, user_llm)
|
||||
documents_indexed += 1
|
||||
|
||||
if documents_indexed % 10 == 0:
|
||||
logger.info(
|
||||
f"Committing batch: {documents_indexed} Gmail messages processed so far"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
|
||||
documents_failed += 1
|
||||
continue
|
||||
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs,
|
||||
_get_llm,
|
||||
max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
# ── Finalize ──────────────────────────────────────────────────
|
||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,15 @@ def pipeline(mock_session):
|
|||
return IndexingPipelineService(mock_session)
|
||||
|
||||
|
||||
def _make_orm_doc(connector_doc, doc_id):
|
||||
"""Create a MagicMock Document bound to a ConnectorDocument's hash."""
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = doc_id
|
||||
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
doc.status = DocumentStatus.pending()
|
||||
return doc
|
||||
|
||||
|
||||
async def test_index_calls_embed_and_chunk_via_to_thread(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
|
|
@ -68,3 +77,110 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
|
||||
assert "chunk_text" in to_thread_calls
|
||||
assert "embed_texts" in to_thread_calls
|
||||
|
||||
|
||||
def _mock_session_factory(orm_docs_by_id):
|
||||
"""Replace get_celery_session_maker with a two-level callable.
|
||||
|
||||
get_celery_session_maker() -> session_maker
|
||||
session_maker() -> async context manager yielding a mock session
|
||||
"""
|
||||
|
||||
def _get_maker():
|
||||
def _make_session():
|
||||
session = MagicMock()
|
||||
session.get = AsyncMock(
|
||||
side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id)
|
||||
)
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=session)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return ctx
|
||||
|
||||
return _make_session
|
||||
|
||||
return _get_maker
|
||||
|
||||
|
||||
async def test_batch_parallel_indexes_all_documents(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""index_batch_parallel indexes all documents and returns correct counts."""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id=f"msg-{i}",
|
||||
search_space_id=1,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
|
||||
|
||||
orm_by_id = {d.id: d for d in orm_docs}
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.get_celery_session_maker",
|
||||
_mock_session_factory(orm_by_id),
|
||||
)
|
||||
|
||||
index_calls = []
|
||||
|
||||
async def fake_index(self, document, connector_doc, llm):
|
||||
index_calls.append(document.id)
|
||||
document.status = DocumentStatus.ready()
|
||||
return document
|
||||
|
||||
monkeypatch.setattr(IndexingPipelineService, "index", fake_index)
|
||||
|
||||
async def mock_get_llm(session):
|
||||
return MagicMock()
|
||||
|
||||
_, indexed, failed = await pipeline.index_batch_parallel(
|
||||
docs, mock_get_llm, max_concurrency=2
|
||||
)
|
||||
|
||||
assert indexed == 3
|
||||
assert failed == 0
|
||||
assert sorted(index_calls) == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_batch_parallel_one_failure_does_not_affect_others(
|
||||
pipeline, make_connector_document, monkeypatch
|
||||
):
|
||||
"""One document failure doesn't prevent other documents from being indexed."""
|
||||
docs = [
|
||||
make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
unique_id=f"msg-{i}",
|
||||
search_space_id=1,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
|
||||
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
|
||||
|
||||
orm_by_id = {d.id: d for d in orm_docs}
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.celery_tasks.get_celery_session_maker",
|
||||
_mock_session_factory(orm_by_id),
|
||||
)
|
||||
|
||||
async def failing_index(self, document, connector_doc, llm):
|
||||
if document.id == 2:
|
||||
raise RuntimeError("LLM exploded")
|
||||
document.status = DocumentStatus.ready()
|
||||
return document
|
||||
|
||||
monkeypatch.setattr(IndexingPipelineService, "index", failing_index)
|
||||
|
||||
async def mock_get_llm(session):
|
||||
return MagicMock()
|
||||
|
||||
_, indexed, failed = await pipeline.index_batch_parallel(
|
||||
docs, mock_get_llm, max_concurrency=4
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert failed == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue