feat: implement batch indexing for selected Google Drive files

- Introduced `index_google_drive_selected_files` function to enable indexing of multiple user-selected files in parallel, improving efficiency.
- Refactored existing indexing logic to handle batch processing, including error handling for individual file failures.
- Added unit tests for the new batch indexing functionality, ensuring robustness and proper error collection during the indexing process.
This commit is contained in:
Anish Sarkar 2026-03-27 00:17:07 +05:30
parent 2f30e48e90
commit 7c7f8b216c
3 changed files with 276 additions and 12 deletions

View file

@ -387,6 +387,56 @@ async def _download_and_index(
return batch_indexed, download_failed + batch_failed
async def _index_selected_files(
drive_client: GoogleDriveClient,
session: AsyncSession,
file_ids: list[tuple[str, str | None]],
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index user-selected files using the parallel pipeline.
Phase 1 (serial): fetch metadata + skip checks.
Phase 2+3 (parallel): download, ETL, index via _download_and_index.
Returns (indexed_count, skipped_count, errors).
"""
files_to_download: list[dict] = []
errors: list[str] = []
renamed_count = 0
skipped = 0
for file_id, file_name in file_ids:
file, error = await get_file_by_id(drive_client, file_id)
if error or not file:
display = file_name or file_id
errors.append(f"File '{display}': {error or 'File not found'}")
continue
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
continue
files_to_download.append(file)
batch_indexed, failed = await _download_and_index(
drive_client, session, files_to_download,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=enable_summary,
on_heartbeat=on_heartbeat,
)
return renamed_count + batch_indexed, skipped, errors
# ---------------------------------------------------------------------------
# Scan strategies
# ---------------------------------------------------------------------------
@ -803,3 +853,97 @@ async def index_google_drive_single_file(
await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__})
logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True)
return 0, f"Failed to index Google Drive file: {e!s}"
async def index_google_drive_selected_files(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
files: list[tuple[str, str | None]],
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index multiple user-selected Google Drive files in parallel.
Sets up the connector/credentials once, then delegates to
_index_selected_files for the three-phase parallel pipeline.
Returns (indexed_count, skipped_count, errors).
"""
task_logger = TaskLoggingService(session, search_space_id)
log_entry = await task_logger.log_task_start(
task_name="google_drive_selected_files_indexing",
source="connector_indexing_task",
message=f"Starting Google Drive batch file indexing for {len(files)} files",
metadata={"connector_id": connector_id, "user_id": str(user_id), "file_count": len(files)},
)
try:
connector = None
for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct)
if connector:
break
if not connector:
error_msg = f"Google Drive connector with ID {connector_id} not found"
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
return 0, 0, [error_msg]
pre_built_credentials = None
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
return 0, 0, [error_msg]
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
error_msg = "SECRET_KEY not configured but credentials are marked as encrypted"
await task_logger.log_task_failure(
log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
)
return 0, 0, [error_msg]
connector_enable_summary = getattr(connector, "enable_summary", True)
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
indexed, skipped, errors = await _index_selected_files(
drive_client, session, files,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=connector_enable_summary,
on_heartbeat=on_heartbeat_callback,
)
await session.commit()
if errors:
await task_logger.log_task_failure(
log_entry,
f"Batch file indexing completed with {len(errors)} error(s)",
"; ".join(errors),
{"indexed": indexed, "skipped": skipped, "error_count": len(errors)},
)
else:
await task_logger.log_task_success(
log_entry,
f"Successfully indexed {indexed} files ({skipped} skipped)",
{"indexed": indexed, "skipped": skipped},
)
logger.info(f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors")
return indexed, skipped, errors
except SQLAlchemyError as db_error:
await session.rollback()
error_msg = f"Database error: {db_error!s}"
await task_logger.log_task_failure(log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"})
logger.error(error_msg, exc_info=True)
return 0, 0, [error_msg]
except Exception as e:
await session.rollback()
error_msg = f"Failed to index Google Drive files: {e!s}"
await task_logger.log_task_failure(log_entry, error_msg, str(e), {"error_type": type(e).__name__})
logger.error(error_msg, exc_info=True)
return 0, 0, [error_msg]