mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
feat: implement batch indexing for selected Google Drive files
- Introduced `index_google_drive_selected_files` function to enable indexing of multiple user-selected files in parallel, improving efficiency. - Refactored existing indexing logic to handle batch processing, including error handling for individual file failures. - Added unit tests for the new batch indexing functionality, ensuring robustness and proper error collection during the indexing process.
This commit is contained in:
parent
2f30e48e90
commit
7c7f8b216c
3 changed files with 276 additions and 12 deletions
|
|
@ -2329,7 +2329,7 @@ async def run_google_drive_indexing(
|
|||
try:
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
index_google_drive_single_file,
|
||||
index_google_drive_selected_files,
|
||||
)
|
||||
|
||||
# Parse the structured data
|
||||
|
|
@ -2402,25 +2402,23 @@ async def run_google_drive_indexing(
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
# Index each individual file
|
||||
for file in items.files:
|
||||
# Index all selected files together via the parallel pipeline
|
||||
if items.files:
|
||||
try:
|
||||
indexed_count, error_message = await index_google_drive_single_file(
|
||||
file_tuples = [(f.id, f.name) for f in items.files]
|
||||
indexed_count, _skipped, file_errors = await index_google_drive_selected_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
file_id=file.id,
|
||||
file_name=file.name,
|
||||
files=file_tuples,
|
||||
)
|
||||
if error_message:
|
||||
errors.append(f"File '{file.name}': {error_message}")
|
||||
else:
|
||||
total_indexed += indexed_count
|
||||
total_indexed += indexed_count
|
||||
errors.extend(file_errors)
|
||||
except Exception as e:
|
||||
errors.append(f"File '{file.name}': {e!s}")
|
||||
errors.append(f"File batch indexing: {e!s}")
|
||||
logger.error(
|
||||
f"Error indexing file {file.name} ({file.id}): {e}",
|
||||
f"Error batch indexing files: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -387,6 +387,56 @@ async def _download_and_index(
|
|||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
async def _index_selected_files(
|
||||
drive_client: GoogleDriveClient,
|
||||
session: AsyncSession,
|
||||
file_ids: list[tuple[str, str | None]],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline.
|
||||
|
||||
Phase 1 (serial): fetch metadata + skip checks.
|
||||
Phase 2+3 (parallel): download, ETL, index via _download_and_index.
|
||||
|
||||
Returns (indexed_count, skipped_count, errors).
|
||||
"""
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(drive_client, file_id)
|
||||
if error or not file:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': {error or 'File not found'}")
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
drive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scan strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -803,3 +853,97 @@ async def index_google_drive_single_file(
|
|||
await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__})
|
||||
logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Google Drive file: {e!s}"
|
||||
|
||||
|
||||
async def index_google_drive_selected_files(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
files: list[tuple[str, str | None]],
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index multiple user-selected Google Drive files in parallel.
|
||||
|
||||
Sets up the connector/credentials once, then delegates to
|
||||
_index_selected_files for the three-phase parallel pipeline.
|
||||
|
||||
Returns (indexed_count, skipped_count, errors).
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="google_drive_selected_files_indexing",
|
||||
source="connector_indexing_task",
|
||||
message=f"Starting Google Drive batch file indexing for {len(files)} files",
|
||||
metadata={"connector_id": connector_id, "user_id": str(user_id), "file_count": len(files)},
|
||||
)
|
||||
|
||||
try:
|
||||
connector = None
|
||||
for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES:
|
||||
connector = await get_connector_by_id(session, connector_id, ct)
|
||||
if connector:
|
||||
break
|
||||
if not connector:
|
||||
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
|
||||
return 0, 0, [error_msg]
|
||||
|
||||
pre_built_credentials = None
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||
if not connected_account_id:
|
||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
|
||||
return 0, 0, [error_msg]
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
error_msg = "SECRET_KEY not configured but credentials are marked as encrypted"
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, [error_msg]
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
|
||||
|
||||
indexed, skipped, errors = await _index_selected_files(
|
||||
drive_client, session, files,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=connector_enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if errors:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Batch file indexing completed with {len(errors)} error(s)",
|
||||
"; ".join(errors),
|
||||
{"indexed": indexed, "skipped": skipped, "error_count": len(errors)},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully indexed {indexed} files ({skipped} skipped)",
|
||||
{"indexed": indexed, "skipped": skipped},
|
||||
)
|
||||
|
||||
logger.info(f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors")
|
||||
return indexed, skipped, errors
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
error_msg = f"Database error: {db_error!s}"
|
||||
await task_logger.log_task_failure(log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"})
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return 0, 0, [error_msg]
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
error_msg = f"Failed to index Google Drive files: {e!s}"
|
||||
await task_logger.log_task_failure(log_entry, error_msg, str(e), {"error_type": type(e).__name__})
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return 0, 0, [error_msg]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
_download_files_parallel,
|
||||
_index_full_scan,
|
||||
_index_selected_files,
|
||||
_index_with_delta_sync,
|
||||
)
|
||||
|
||||
|
|
@ -464,3 +465,124 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _index_selected_files -- parallel indexing of user-selected files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def selected_files_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_id):
|
||||
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
return skip_results.get(file["id"], (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
return {
|
||||
"drive_client": mock_drive_client,
|
||||
"session": mock_session,
|
||||
"get_file_results": get_file_results,
|
||||
"skip_results": skip_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_selected(mocks, file_ids):
|
||||
return await _index_selected_files(
|
||||
mocks["drive_client"],
|
||||
mocks["session"],
|
||||
file_ids,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_selected_files_single_file_indexed(selected_files_mocks):
|
||||
"""Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline."""
|
||||
selected_files_mocks["get_file_results"]["f1"] = (
|
||||
_make_file_dict("f1", "report.pdf"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks, [("f1", "report.pdf")],
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert errors == []
|
||||
selected_files_mocks["download_and_index_mock"].assert_called_once()
|
||||
|
||||
|
||||
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||
"""get_file_by_id failing for one file collects an error; others still indexed."""
|
||||
selected_files_mocks["get_file_results"]["f1"] = (
|
||||
_make_file_dict("f1", "first.txt"), None,
|
||||
)
|
||||
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
|
||||
selected_files_mocks["get_file_results"]["f3"] = (
|
||||
_make_file_dict("f3", "third.txt"), None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert len(errors) == 1
|
||||
assert "mid.txt" in errors[0]
|
||||
assert "HTTP 404" in errors[0]
|
||||
|
||||
|
||||
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||
"""Unchanged files are skipped, renames counted as indexed,
|
||||
and only new files are sent to _download_and_index."""
|
||||
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"), ("n2", "new2.txt")]:
|
||||
selected_files_mocks["get_file_results"][fid] = (
|
||||
_make_file_dict(fid, fname), None,
|
||||
)
|
||||
|
||||
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
||||
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
|
||||
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"), ("n2", "new2.txt")],
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
assert errors == []
|
||||
|
||||
mock = selected_files_mocks["download_and_index_mock"]
|
||||
mock.assert_called_once()
|
||||
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue