mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 23:32:40 +02:00
feat: implement parallel file downloading and indexing in Google Drive indexer
- Added `_download_files_parallel` function to enable concurrent downloading of files from Google Drive, improving efficiency in document processing. - Introduced `_download_and_index` function to handle the parallel downloading and indexing phases, streamlining the overall workflow. - Updated `_index_full_scan` and `_index_with_delta_sync` methods to utilize the new parallel downloading functionality, enhancing performance. - Added unit tests to validate the new parallel downloading and indexing logic, ensuring robustness and error handling during document processing.
This commit is contained in:
parent
bd6e335cb3
commit
c016962064
4 changed files with 652 additions and 35 deletions
|
|
@ -5,6 +5,7 @@ checks and rename-only detection. download_and_extract_content()
|
|||
returns markdown which is fed into ConnectorDocument -> pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
|
@ -190,6 +191,68 @@ def _build_connector_doc(
|
|||
)
|
||||
|
||||
|
||||
async def _download_files_parallel(
|
||||
drive_client: GoogleDriveClient,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
max_concurrency: int = 5,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[list[ConnectorDocument], int]:
|
||||
"""Download and ETL files in parallel, returning ConnectorDocuments.
|
||||
|
||||
Returns (connector_docs, download_failed_count).
|
||||
"""
|
||||
results: list[ConnectorDocument] = []
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
last_heartbeat = time.time()
|
||||
completed_count = 0
|
||||
hb_lock = asyncio.Lock()
|
||||
|
||||
async def _download_one(file: dict) -> ConnectorDocument | None:
|
||||
nonlocal last_heartbeat, completed_count
|
||||
async with sem:
|
||||
markdown, drive_metadata, error = await download_and_extract_content(
|
||||
drive_client, file
|
||||
)
|
||||
if error or not markdown:
|
||||
return None
|
||||
doc = _build_connector_doc(
|
||||
file,
|
||||
markdown,
|
||||
drive_metadata,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
async with hb_lock:
|
||||
completed_count += 1
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat(completed_count)
|
||||
last_heartbeat = now
|
||||
return doc
|
||||
|
||||
tasks = [_download_one(f) for f in files]
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed = 0
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Exception):
|
||||
failed += 1
|
||||
elif outcome is None:
|
||||
failed += 1
|
||||
else:
|
||||
results.append(outcome)
|
||||
|
||||
return results, failed
|
||||
|
||||
|
||||
async def _process_single_file(
|
||||
drive_client: GoogleDriveClient,
|
||||
session: AsyncSession,
|
||||
|
|
@ -283,6 +346,47 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
|
|||
logger.info(f"Removed deleted file document: {file_id}")
|
||||
|
||||
|
||||
async def _download_and_index(
|
||||
drive_client: GoogleDriveClient,
|
||||
session: AsyncSession,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Phase 2+3: parallel download then parallel indexing.
|
||||
|
||||
Returns (batch_indexed, total_failed).
|
||||
"""
|
||||
connector_docs, download_failed = await _download_files_parallel(
|
||||
drive_client,
|
||||
files,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
batch_indexed = 0
|
||||
batch_failed = 0
|
||||
if connector_docs:
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs, _get_llm, max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scan strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -310,11 +414,13 @@ async def _index_full_scan(
|
|||
{"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders},
|
||||
)
|
||||
|
||||
indexed = 0
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): collect files, run skip checks, track renames
|
||||
# ------------------------------------------------------------------
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
files_processed = 0
|
||||
last_heartbeat = time.time()
|
||||
files_to_download: list[dict] = []
|
||||
folders_to_process = [(folder_id, folder_name)]
|
||||
first_error: str | None = None
|
||||
|
||||
|
|
@ -346,22 +452,15 @@ async def _index_full_scan(
|
|||
|
||||
files_processed += 1
|
||||
|
||||
if on_heartbeat_callback:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(indexed)
|
||||
last_heartbeat = now
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
i, s, f = await _process_single_file(
|
||||
drive_client, session, file,
|
||||
connector_id, search_space_id, user_id, enable_summary,
|
||||
)
|
||||
indexed += i
|
||||
skipped += s
|
||||
failed += f
|
||||
|
||||
if indexed > 0 and indexed % 10 == 0:
|
||||
await session.commit()
|
||||
files_to_download.append(file)
|
||||
|
||||
page_token = next_token
|
||||
if not page_token:
|
||||
|
|
@ -375,6 +474,17 @@ async def _index_full_scan(
|
|||
)
|
||||
raise Exception(f"Failed to list Google Drive files: {first_error}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2+3 (parallel): download, ETL, index
|
||||
# ------------------------------------------------------------------
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
drive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
||||
return indexed, skipped
|
||||
|
||||
|
|
@ -416,11 +526,14 @@ async def _index_with_delta_sync(
|
|||
return 0, 0
|
||||
|
||||
logger.info(f"Processing {len(changes)} changes")
|
||||
indexed = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): handle removals, collect files for download
|
||||
# ------------------------------------------------------------------
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
last_heartbeat = time.time()
|
||||
|
||||
for change in changes:
|
||||
if files_processed >= max_files:
|
||||
|
|
@ -438,23 +551,27 @@ async def _index_with_delta_sync(
|
|||
if not file:
|
||||
continue
|
||||
|
||||
if on_heartbeat_callback:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(indexed)
|
||||
last_heartbeat = now
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
i, s, f = await _process_single_file(
|
||||
drive_client, session, file,
|
||||
connector_id, search_space_id, user_id, enable_summary,
|
||||
)
|
||||
indexed += i
|
||||
skipped += s
|
||||
failed += f
|
||||
files_to_download.append(file)
|
||||
|
||||
if indexed > 0 and indexed % 10 == 0:
|
||||
await session.commit()
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2+3 (parallel): download, ETL, index
|
||||
# ------------------------------------------------------------------
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
drive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
||||
return indexed, skipped
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue