diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 1ffc6341f..bef2329d8 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -2329,7 +2329,7 @@ async def run_google_drive_indexing( try: from app.tasks.connector_indexers.google_drive_indexer import ( index_google_drive_files, - index_google_drive_single_file, + index_google_drive_selected_files, ) # Parse the structured data @@ -2402,25 +2402,23 @@ async def run_google_drive_indexing( exc_info=True, ) - # Index each individual file - for file in items.files: + # Index all selected files together via the parallel pipeline + if items.files: try: - indexed_count, error_message = await index_google_drive_single_file( + file_tuples = [(f.id, f.name) for f in items.files] + indexed_count, _skipped, file_errors = await index_google_drive_selected_files( session, connector_id, search_space_id, user_id, - file_id=file.id, - file_name=file.name, + files=file_tuples, ) - if error_message: - errors.append(f"File '{file.name}': {error_message}") - else: - total_indexed += indexed_count + total_indexed += indexed_count + errors.extend(file_errors) except Exception as e: - errors.append(f"File '{file.name}': {e!s}") + errors.append(f"File batch indexing: {e!s}") logger.error( - f"Error indexing file {file.name} ({file.id}): {e}", + f"Error batch indexing files: {e}", exc_info=True, ) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 8ba08533f..2d3139343 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -387,6 +387,56 @@ async def _download_and_index( return batch_indexed, download_failed + batch_failed +async def _index_selected_files( + drive_client: GoogleDriveClient, + session: AsyncSession, + file_ids: list[tuple[str, str | None]], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[int, int, list[str]]: + """Index user-selected files using the parallel pipeline. + + Phase 1 (serial): fetch metadata + skip checks. + Phase 2+3 (parallel): download, ETL, index via _download_and_index. + + Returns (indexed_count, skipped_count, errors). + """ + files_to_download: list[dict] = [] + errors: list[str] = [] + renamed_count = 0 + skipped = 0 + + for file_id, file_name in file_ids: + file, error = await get_file_by_id(drive_client, file_id) + if error or not file: + display = file_name or file_id + errors.append(f"File '{display}': {error or 'File not found'}") + continue + + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue + + files_to_download.append(file) + + batch_indexed, failed = await _download_and_index( + drive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat, + ) + + return renamed_count + batch_indexed, skipped, errors + + # --------------------------------------------------------------------------- # Scan strategies # --------------------------------------------------------------------------- @@ -803,3 +853,97 @@ async def index_google_drive_single_file( await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__}) logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True) return 0, f"Failed to index Google Drive file: {e!s}" + + +async def index_google_drive_selected_files( + session: AsyncSession, + connector_id: int, + search_space_id: int, + user_id: str, + files: list[tuple[str, str | None]], + on_heartbeat_callback: HeartbeatCallbackType | None = None, +) -> tuple[int, int, list[str]]: + """Index multiple user-selected Google Drive files in parallel. + + Sets up the connector/credentials once, then delegates to + _index_selected_files for the three-phase parallel pipeline. + + Returns (indexed_count, skipped_count, errors). + """ + task_logger = TaskLoggingService(session, search_space_id) + log_entry = await task_logger.log_task_start( + task_name="google_drive_selected_files_indexing", + source="connector_indexing_task", + message=f"Starting Google Drive batch file indexing for {len(files)} files", + metadata={"connector_id": connector_id, "user_id": str(user_id), "file_count": len(files)}, + ) + + try: + connector = None + for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES: + connector = await get_connector_by_id(session, connector_id, ct) + if connector: + break + if not connector: + error_msg = f"Google Drive connector with ID {connector_id} not found" + await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}) + return 0, 0, [error_msg] + + pre_built_credentials = None + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + connected_account_id = connector.config.get("composio_connected_account_id") + if not connected_account_id: + error_msg = f"Composio connected_account_id not found for connector {connector_id}" + await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"}) + return 0, 0, [error_msg] + pre_built_credentials = build_composio_credentials(connected_account_id) + else: + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and not config.SECRET_KEY: + error_msg = "SECRET_KEY not configured but credentials are marked as encrypted" + await task_logger.log_task_failure( + log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"}, + ) + return 0, 0, [error_msg] + + connector_enable_summary = getattr(connector, "enable_summary", True) + drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials) + + indexed, skipped, errors = await _index_selected_files( + drive_client, session, files, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=connector_enable_summary, + on_heartbeat=on_heartbeat_callback, + ) + + await session.commit() + + if errors: + await task_logger.log_task_failure( + log_entry, + f"Batch file indexing completed with {len(errors)} error(s)", + "; ".join(errors), + {"indexed": indexed, "skipped": skipped, "error_count": len(errors)}, + ) + else: + await task_logger.log_task_success( + log_entry, + f"Successfully indexed {indexed} files ({skipped} skipped)", + {"indexed": indexed, "skipped": skipped}, + ) + + logger.info(f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors") + return indexed, skipped, errors + + except SQLAlchemyError as db_error: + await session.rollback() + error_msg = f"Database error: {db_error!s}" + await task_logger.log_task_failure(log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"}) + logger.error(error_msg, exc_info=True) + return 0, 0, [error_msg] + except Exception as e: + await session.rollback() + error_msg = f"Failed to index Google Drive files: {e!s}" + await task_logger.log_task_failure(log_entry, error_msg, str(e), {"error_type": type(e).__name__}) + logger.error(error_msg, exc_info=True) + return 0, 0, [error_msg] diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 22e900406..1183efa9f 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -8,6 +8,7 @@ import pytest from app.tasks.connector_indexers.google_drive_indexer import ( _download_files_parallel, _index_full_scan, + _index_selected_files, _index_with_delta_sync, ) @@ -464,3 +465,124 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch): assert indexed == 2 assert skipped == 0 + + +# --------------------------------------------------------------------------- +# _index_selected_files -- parallel indexing of user-selected files +# --------------------------------------------------------------------------- + +@pytest.fixture +def selected_files_mocks(mock_drive_client, monkeypatch): + """Wire up mocks for _index_selected_files tests.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + mock_session = AsyncMock() + + get_file_results: dict[str, tuple[dict | None, str | None]] = {} + + async def _fake_get_file(client, file_id): + return get_file_results.get(file_id, (None, f"Not configured: {file_id}")) + + monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file) + + skip_results: dict[str, tuple[bool, str | None]] = {} + + async def _fake_skip(session, file, search_space_id): + return skip_results.get(file["id"], (False, None)) + + monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip) + + download_and_index_mock = AsyncMock(return_value=(0, 0)) + monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) + + return { + "drive_client": mock_drive_client, + "session": mock_session, + "get_file_results": get_file_results, + "skip_results": skip_results, + "download_and_index_mock": download_and_index_mock, + } + + +async def _run_selected(mocks, file_ids): + return await _index_selected_files( + mocks["drive_client"], + mocks["session"], + file_ids, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + +async def test_selected_files_single_file_indexed(selected_files_mocks): + """Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline.""" + selected_files_mocks["get_file_results"]["f1"] = ( + _make_file_dict("f1", "report.pdf"), + None, + ) + selected_files_mocks["download_and_index_mock"].return_value = (1, 0) + + indexed, skipped, errors = await _run_selected( + selected_files_mocks, [("f1", "report.pdf")], + ) + + assert indexed == 1 + assert skipped == 0 + assert errors == [] + selected_files_mocks["download_and_index_mock"].assert_called_once() + + +async def test_selected_files_fetch_failure_isolation(selected_files_mocks): + """get_file_by_id failing for one file collects an error; others still indexed.""" + selected_files_mocks["get_file_results"]["f1"] = ( + _make_file_dict("f1", "first.txt"), None, + ) + selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404") + selected_files_mocks["get_file_results"]["f3"] = ( + _make_file_dict("f3", "third.txt"), None, + ) + selected_files_mocks["download_and_index_mock"].return_value = (2, 0) + + indexed, skipped, errors = await _run_selected( + selected_files_mocks, + [("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")], + ) + + assert indexed == 2 + assert skipped == 0 + assert len(errors) == 1 + assert "mid.txt" in errors[0] + assert "HTTP 404" in errors[0] + + +async def test_selected_files_skip_rename_counting(selected_files_mocks): + """Unchanged files are skipped, renames counted as indexed, + and only new files are sent to _download_and_index.""" + for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"), + ("n1", "new1.txt"), ("n2", "new2.txt")]: + selected_files_mocks["get_file_results"][fid] = ( + _make_file_dict(fid, fname), None, + ) + + selected_files_mocks["skip_results"]["s1"] = (True, "unchanged") + selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'") + + selected_files_mocks["download_and_index_mock"].return_value = (2, 0) + + indexed, skipped, errors = await _run_selected( + selected_files_mocks, + [("s1", "unchanged.txt"), ("r1", "renamed.txt"), + ("n1", "new1.txt"), ("n2", "new2.txt")], + ) + + assert indexed == 3 # 1 renamed + 2 batch + assert skipped == 1 # 1 unchanged + assert errors == [] + + mock = selected_files_mocks["download_and_index_mock"] + mock.assert_called_once() + call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2] + assert len(call_files) == 2 + assert {f["id"] for f in call_files} == {"n1", "n2"}