feat: implement parallel file downloading and indexing in Google Drive indexer

- Added `_download_files_parallel` function to enable concurrent downloading of files from Google Drive, improving efficiency in document processing.
- Introduced `_download_and_index` function to handle the parallel downloading and indexing phases, streamlining the overall workflow.
- Updated `_index_full_scan` and `_index_with_delta_sync` methods to utilize the new parallel downloading functionality, enhancing performance.
- Added unit tests to validate the new parallel downloading and indexing logic, ensuring robustness and error handling during document processing.
This commit is contained in:
Anish Sarkar 2026-03-26 23:53:26 +05:30
parent bd6e335cb3
commit c016962064
4 changed files with 652 additions and 35 deletions

View file

@ -5,6 +5,7 @@ checks and rename-only detection. download_and_extract_content()
returns markdown which is fed into ConnectorDocument -> pipeline.
"""
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
@ -190,6 +191,68 @@ def _build_connector_doc(
)
async def _download_files_parallel(
drive_client: GoogleDriveClient,
files: list[dict],
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
max_concurrency: int = 5,
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[list[ConnectorDocument], int]:
"""Download and ETL files in parallel, returning ConnectorDocuments.
Returns (connector_docs, download_failed_count).
"""
results: list[ConnectorDocument] = []
sem = asyncio.Semaphore(max_concurrency)
last_heartbeat = time.time()
completed_count = 0
hb_lock = asyncio.Lock()
async def _download_one(file: dict) -> ConnectorDocument | None:
nonlocal last_heartbeat, completed_count
async with sem:
markdown, drive_metadata, error = await download_and_extract_content(
drive_client, file
)
if error or not markdown:
return None
doc = _build_connector_doc(
file,
markdown,
drive_metadata,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=enable_summary,
)
async with hb_lock:
completed_count += 1
if on_heartbeat:
now = time.time()
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat(completed_count)
last_heartbeat = now
return doc
tasks = [_download_one(f) for f in files]
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
failed = 0
for outcome in outcomes:
if isinstance(outcome, Exception):
failed += 1
elif outcome is None:
failed += 1
else:
results.append(outcome)
return results, failed
async def _process_single_file(
drive_client: GoogleDriveClient,
session: AsyncSession,
@ -283,6 +346,47 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
logger.info(f"Removed deleted file document: {file_id}")
async def _download_and_index(
drive_client: GoogleDriveClient,
session: AsyncSession,
files: list[dict],
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int]:
"""Phase 2+3: parallel download then parallel indexing.
Returns (batch_indexed, total_failed).
"""
connector_docs, download_failed = await _download_files_parallel(
drive_client,
files,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=enable_summary,
on_heartbeat=on_heartbeat,
)
batch_indexed = 0
batch_failed = 0
if connector_docs:
pipeline = IndexingPipelineService(session)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
connector_docs, _get_llm, max_concurrency=3,
on_heartbeat=on_heartbeat,
)
return batch_indexed, download_failed + batch_failed
# ---------------------------------------------------------------------------
# Scan strategies
# ---------------------------------------------------------------------------
@ -310,11 +414,13 @@ async def _index_full_scan(
{"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders},
)
indexed = 0
# ------------------------------------------------------------------
# Phase 1 (serial): collect files, run skip checks, track renames
# ------------------------------------------------------------------
renamed_count = 0
skipped = 0
failed = 0
files_processed = 0
last_heartbeat = time.time()
files_to_download: list[dict] = []
folders_to_process = [(folder_id, folder_name)]
first_error: str | None = None
@ -346,22 +452,15 @@ async def _index_full_scan(
files_processed += 1
if on_heartbeat_callback:
now = time.time()
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(indexed)
last_heartbeat = now
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
continue
i, s, f = await _process_single_file(
drive_client, session, file,
connector_id, search_space_id, user_id, enable_summary,
)
indexed += i
skipped += s
failed += f
if indexed > 0 and indexed % 10 == 0:
await session.commit()
files_to_download.append(file)
page_token = next_token
if not page_token:
@ -375,6 +474,17 @@ async def _index_full_scan(
)
raise Exception(f"Failed to list Google Drive files: {first_error}")
# ------------------------------------------------------------------
# Phase 2+3 (parallel): download, ETL, index
# ------------------------------------------------------------------
batch_indexed, failed = await _download_and_index(
drive_client, session, files_to_download,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=enable_summary,
on_heartbeat=on_heartbeat_callback,
)
indexed = renamed_count + batch_indexed
logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed")
return indexed, skipped
@ -416,11 +526,14 @@ async def _index_with_delta_sync(
return 0, 0
logger.info(f"Processing {len(changes)} changes")
indexed = 0
# ------------------------------------------------------------------
# Phase 1 (serial): handle removals, collect files for download
# ------------------------------------------------------------------
renamed_count = 0
skipped = 0
failed = 0
files_to_download: list[dict] = []
files_processed = 0
last_heartbeat = time.time()
for change in changes:
if files_processed >= max_files:
@ -438,23 +551,27 @@ async def _index_with_delta_sync(
if not file:
continue
if on_heartbeat_callback:
now = time.time()
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(indexed)
last_heartbeat = now
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
continue
i, s, f = await _process_single_file(
drive_client, session, file,
connector_id, search_space_id, user_id, enable_summary,
)
indexed += i
skipped += s
failed += f
files_to_download.append(file)
if indexed > 0 and indexed % 10 == 0:
await session.commit()
# ------------------------------------------------------------------
# Phase 2+3 (parallel): download, ETL, index
# ------------------------------------------------------------------
batch_indexed, failed = await _download_and_index(
drive_client, session, files_to_download,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=enable_summary,
on_heartbeat=on_heartbeat_callback,
)
indexed = renamed_count + batch_indexed
logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed")
return indexed, skipped