mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
feat: implement parallel file downloading and indexing in Google Drive indexer
- Added `_download_files_parallel` function to enable concurrent downloading of files from Google Drive, improving efficiency in document processing. - Introduced `_download_and_index` function to handle the parallel downloading and indexing phases, streamlining the overall workflow. - Updated `_index_full_scan` and `_index_with_delta_sync` methods to utilize the new parallel downloading functionality, enhancing performance. - Added unit tests to validate the new parallel downloading and indexing logic, ensuring robustness and error handling during document processing.
This commit is contained in:
parent
bd6e335cb3
commit
c016962064
4 changed files with 652 additions and 35 deletions
|
|
@ -5,6 +5,7 @@ checks and rename-only detection. download_and_extract_content()
|
|||
returns markdown which is fed into ConnectorDocument -> pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
|
@ -190,6 +191,68 @@ def _build_connector_doc(
|
|||
)
|
||||
|
||||
|
||||
async def _download_files_parallel(
|
||||
drive_client: GoogleDriveClient,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
max_concurrency: int = 5,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[list[ConnectorDocument], int]:
|
||||
"""Download and ETL files in parallel, returning ConnectorDocuments.
|
||||
|
||||
Returns (connector_docs, download_failed_count).
|
||||
"""
|
||||
results: list[ConnectorDocument] = []
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
last_heartbeat = time.time()
|
||||
completed_count = 0
|
||||
hb_lock = asyncio.Lock()
|
||||
|
||||
async def _download_one(file: dict) -> ConnectorDocument | None:
|
||||
nonlocal last_heartbeat, completed_count
|
||||
async with sem:
|
||||
markdown, drive_metadata, error = await download_and_extract_content(
|
||||
drive_client, file
|
||||
)
|
||||
if error or not markdown:
|
||||
return None
|
||||
doc = _build_connector_doc(
|
||||
file,
|
||||
markdown,
|
||||
drive_metadata,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
)
|
||||
async with hb_lock:
|
||||
completed_count += 1
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat(completed_count)
|
||||
last_heartbeat = now
|
||||
return doc
|
||||
|
||||
tasks = [_download_one(f) for f in files]
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed = 0
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Exception):
|
||||
failed += 1
|
||||
elif outcome is None:
|
||||
failed += 1
|
||||
else:
|
||||
results.append(outcome)
|
||||
|
||||
return results, failed
|
||||
|
||||
|
||||
async def _process_single_file(
|
||||
drive_client: GoogleDriveClient,
|
||||
session: AsyncSession,
|
||||
|
|
@ -283,6 +346,47 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
|
|||
logger.info(f"Removed deleted file document: {file_id}")
|
||||
|
||||
|
||||
async def _download_and_index(
|
||||
drive_client: GoogleDriveClient,
|
||||
session: AsyncSession,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Phase 2+3: parallel download then parallel indexing.
|
||||
|
||||
Returns (batch_indexed, total_failed).
|
||||
"""
|
||||
connector_docs, download_failed = await _download_files_parallel(
|
||||
drive_client,
|
||||
files,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
batch_indexed = 0
|
||||
batch_failed = 0
|
||||
if connector_docs:
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs, _get_llm, max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scan strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -310,11 +414,13 @@ async def _index_full_scan(
|
|||
{"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders},
|
||||
)
|
||||
|
||||
indexed = 0
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): collect files, run skip checks, track renames
|
||||
# ------------------------------------------------------------------
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
files_processed = 0
|
||||
last_heartbeat = time.time()
|
||||
files_to_download: list[dict] = []
|
||||
folders_to_process = [(folder_id, folder_name)]
|
||||
first_error: str | None = None
|
||||
|
||||
|
|
@ -346,22 +452,15 @@ async def _index_full_scan(
|
|||
|
||||
files_processed += 1
|
||||
|
||||
if on_heartbeat_callback:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(indexed)
|
||||
last_heartbeat = now
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
i, s, f = await _process_single_file(
|
||||
drive_client, session, file,
|
||||
connector_id, search_space_id, user_id, enable_summary,
|
||||
)
|
||||
indexed += i
|
||||
skipped += s
|
||||
failed += f
|
||||
|
||||
if indexed > 0 and indexed % 10 == 0:
|
||||
await session.commit()
|
||||
files_to_download.append(file)
|
||||
|
||||
page_token = next_token
|
||||
if not page_token:
|
||||
|
|
@ -375,6 +474,17 @@ async def _index_full_scan(
|
|||
)
|
||||
raise Exception(f"Failed to list Google Drive files: {first_error}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2+3 (parallel): download, ETL, index
|
||||
# ------------------------------------------------------------------
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
drive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
||||
return indexed, skipped
|
||||
|
||||
|
|
@ -416,11 +526,14 @@ async def _index_with_delta_sync(
|
|||
return 0, 0
|
||||
|
||||
logger.info(f"Processing {len(changes)} changes")
|
||||
indexed = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): handle removals, collect files for download
|
||||
# ------------------------------------------------------------------
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
last_heartbeat = time.time()
|
||||
|
||||
for change in changes:
|
||||
if files_processed >= max_files:
|
||||
|
|
@ -438,23 +551,27 @@ async def _index_with_delta_sync(
|
|||
if not file:
|
||||
continue
|
||||
|
||||
if on_heartbeat_callback:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat_callback(indexed)
|
||||
last_heartbeat = now
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
i, s, f = await _process_single_file(
|
||||
drive_client, session, file,
|
||||
connector_id, search_space_id, user_id, enable_summary,
|
||||
)
|
||||
indexed += i
|
||||
skipped += s
|
||||
failed += f
|
||||
files_to_download.append(file)
|
||||
|
||||
if indexed > 0 and indexed % 10 == 0:
|
||||
await session.commit()
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2+3 (parallel): download, ETL, index
|
||||
# ------------------------------------------------------------------
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
drive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
||||
return indexed, skipped
|
||||
|
||||
|
|
|
|||
34
surfsense_backend/tests/unit/connector_indexers/conftest.py
Normal file
34
surfsense_backend/tests/unit/connector_indexers/conftest.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Pre-register the connector_indexers package to bypass a circular import
|
||||
in its ``__init__.py`` (airtable_indexer -> routes -> connector_indexers).
|
||||
|
||||
This lets tests import individual indexer modules (e.g.
|
||||
``google_drive_indexer``) without triggering the full package init.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
_BACKEND = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _stub_package(dotted: str, fs_dir: Path) -> None:
|
||||
if dotted not in sys.modules:
|
||||
mod = types.ModuleType(dotted)
|
||||
mod.__path__ = [str(fs_dir)]
|
||||
mod.__package__ = dotted
|
||||
sys.modules[dotted] = mod
|
||||
|
||||
parts = dotted.split(".")
|
||||
if len(parts) > 1:
|
||||
parent_dotted = ".".join(parts[:-1])
|
||||
parent = sys.modules.get(parent_dotted)
|
||||
if parent is not None:
|
||||
setattr(parent, parts[-1], sys.modules[dotted])
|
||||
|
||||
|
||||
_stub_package("app.tasks", _BACKEND / "app" / "tasks")
|
||||
_stub_package(
|
||||
"app.tasks.connector_indexers",
|
||||
_BACKEND / "app" / "tasks" / "connector_indexers",
|
||||
)
|
||||
|
|
@ -0,0 +1,466 @@
|
|||
"""Tests for parallel download + indexing in the Google Drive indexer."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
_download_files_parallel,
|
||||
_index_full_scan,
|
||||
_index_with_delta_sync,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict:
|
||||
return {"id": file_id, "name": name, "mimeType": mime}
|
||||
|
||||
|
||||
def _mock_extract_ok(file_id: str, file_name: str):
|
||||
"""Return a successful (markdown, metadata, None) tuple."""
|
||||
return (
|
||||
f"# Content of {file_name}",
|
||||
{"google_drive_file_id": file_id, "google_drive_file_name": file_name},
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_drive_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract(monkeypatch):
|
||||
"""Provide a helper to set the download_and_extract_content mock."""
|
||||
def _patch(side_effect=None, return_value=None):
|
||||
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
return _patch
|
||||
|
||||
|
||||
async def test_single_file_returns_one_connector_document(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
|
||||
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
[_make_file_dict("f1", "test.txt")],
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 0
|
||||
assert docs[0].title == "test.txt"
|
||||
assert docs[0].unique_id == "f1"
|
||||
|
||||
|
||||
async def test_multiple_files_all_produce_documents(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""All files are downloaded and converted to ConnectorDocuments."""
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
|
||||
|
||||
|
||||
async def test_one_download_exception_does_not_block_others(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""A RuntimeError in one download still lets the other files succeed."""
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "file0.txt"),
|
||||
RuntimeError("network timeout"),
|
||||
_mock_extract_ok("f2", "file2.txt"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert failed == 1
|
||||
assert {d.unique_id for d in docs} == {"f0", "f2"}
|
||||
|
||||
|
||||
async def test_etl_error_counts_as_download_failure(
|
||||
mock_drive_client, patch_extract,
|
||||
):
|
||||
"""download_and_extract_content returning an error is counted as failed."""
|
||||
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
||||
patch_extract(
|
||||
side_effect=[
|
||||
_mock_extract_ok("f0", "good.txt"),
|
||||
(None, {}, "ETL failed"),
|
||||
]
|
||||
)
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert failed == 1
|
||||
|
||||
|
||||
async def test_concurrency_bounded_by_semaphore(
|
||||
mock_drive_client, monkeypatch,
|
||||
):
|
||||
"""Peak concurrent downloads never exceeds max_concurrency."""
|
||||
lock = asyncio.Lock()
|
||||
active = 0
|
||||
peak = 0
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
nonlocal active, peak
|
||||
async with lock:
|
||||
active += 1
|
||||
peak = max(peak, active)
|
||||
await asyncio.sleep(0.05)
|
||||
async with lock:
|
||||
active -= 1
|
||||
fid = file["id"]
|
||||
return _mock_extract_ok(fid, file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
max_concurrency=2,
|
||||
)
|
||||
|
||||
assert len(docs) == 6
|
||||
assert failed == 0
|
||||
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
|
||||
|
||||
|
||||
async def test_heartbeat_fires_during_parallel_downloads(
|
||||
mock_drive_client, monkeypatch,
|
||||
):
|
||||
"""on_heartbeat is called at least once when downloads take time."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
|
||||
|
||||
async def _slow_extract(client, file):
|
||||
await asyncio.sleep(0.05)
|
||||
return _mock_extract_ok(file["id"], file["name"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
|
||||
_slow_extract,
|
||||
)
|
||||
|
||||
heartbeat_calls: list[int] = []
|
||||
|
||||
async def _on_heartbeat(count: int):
|
||||
heartbeat_calls.append(count)
|
||||
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
|
||||
docs, failed = await _download_files_parallel(
|
||||
mock_drive_client,
|
||||
files,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
on_heartbeat=_on_heartbeat,
|
||||
)
|
||||
|
||||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _folder_dict(file_id: str, name: str) -> dict:
|
||||
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up all mocks needed to call _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_connector = MagicMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_log_entry = MagicMock()
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
return skip_results.get(file["id"], (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_mock = AsyncMock(return_value=([], 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 0, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
return {
|
||||
"drive_client": mock_drive_client,
|
||||
"session": mock_session,
|
||||
"connector": mock_connector,
|
||||
"task_logger": mock_task_logger,
|
||||
"log_entry": mock_log_entry,
|
||||
"skip_results": skip_results,
|
||||
"download_mock": download_mock,
|
||||
"batch_mock": batch_mock,
|
||||
"pipeline_mock": pipeline_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_full_scan(mocks, *, max_files=500, include_subfolders=False):
|
||||
return await _index_full_scan(
|
||||
mocks["drive_client"],
|
||||
mocks["session"],
|
||||
mocks["connector"],
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"My Folder",
|
||||
mocks["task_logger"],
|
||||
mocks["log_entry"],
|
||||
max_files,
|
||||
include_subfolders=include_subfolders,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
||||
"""Full scan collects files serially, downloads and indexes in parallel,
|
||||
and returns correct (indexed, skipped) with renames counted as indexed."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [
|
||||
_folder_dict("folder1", "SubFolder"),
|
||||
_make_file_dict("skip1", "unchanged.txt"),
|
||||
_make_file_dict("rename1", "renamed.txt"),
|
||||
_make_file_dict("new1", "new1.txt"),
|
||||
_make_file_dict("new2", "new2.txt"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
||||
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old' → 'renamed.txt'")
|
||||
|
||||
mock_docs = [MagicMock(), MagicMock()]
|
||||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
indexed, skipped = await _run_full_scan(full_scan_mocks)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 from batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
|
||||
full_scan_mocks["download_mock"].assert_called_once()
|
||||
call_files = full_scan_mocks["download_mock"].call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"new1", "new2"}
|
||||
|
||||
|
||||
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
||||
"""Only max_files non-folder files are processed; the rest are ignored."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
full_scan_mocks["download_mock"].return_value = ([], 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 0, 0)
|
||||
|
||||
await _run_full_scan(full_scan_mocks, max_files=3)
|
||||
|
||||
download_call_files = full_scan_mocks["download_mock"].call_args[0][1]
|
||||
assert len(download_call_files) == 3
|
||||
|
||||
|
||||
async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
||||
full_scan_mocks, monkeypatch,
|
||||
):
|
||||
"""index_batch_parallel is called with max_concurrency=3."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [_make_file_dict("f1", "file1.txt")]
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
mock_docs = [MagicMock()]
|
||||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 1, 0)
|
||||
|
||||
await _run_full_scan(full_scan_mocks)
|
||||
|
||||
call_kwargs = full_scan_mocks["batch_mock"].call_args
|
||||
assert call_kwargs[1].get("max_concurrency") == 3 or (
|
||||
len(call_kwargs[0]) > 2 and call_kwargs[0][2] == 3
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 7 -- _index_with_delta_sync three-phase pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||
"""Removed/trashed changes call _remove_document; the rest go through
|
||||
_download_files_parallel and index_batch_parallel."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
changes = [
|
||||
{"fileId": "del1", "removed": True},
|
||||
{"fileId": "del2", "file": {"id": "del2", "trashed": True}},
|
||||
{"fileId": "trash1", "file": {"id": "trash1", "trashed": True}},
|
||||
{"fileId": "mod1", "file": _make_file_dict("mod1", "modified1.txt")},
|
||||
{"fileId": "mod2", "file": _make_file_dict("mod2", "modified2.txt")},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "fetch_all_changes",
|
||||
AsyncMock(return_value=(changes, "new-token", None)),
|
||||
)
|
||||
|
||||
change_types = {
|
||||
"del1": "removed",
|
||||
"del2": "removed",
|
||||
"trash1": "trashed",
|
||||
"mod1": "modified",
|
||||
"mod2": "modified",
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
_mod, "categorize_change",
|
||||
lambda change: change_types[change["fileId"]],
|
||||
)
|
||||
|
||||
remove_calls: list[str] = []
|
||||
|
||||
async def _fake_remove(session, file_id, search_space_id):
|
||||
remove_calls.append(file_id)
|
||||
|
||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file",
|
||||
AsyncMock(return_value=(False, None)),
|
||||
)
|
||||
|
||||
mock_docs = [MagicMock(), MagicMock()]
|
||||
download_mock = AsyncMock(return_value=(mock_docs, 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 2, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped = await _index_with_delta_sync(
|
||||
MagicMock(),
|
||||
mock_session,
|
||||
MagicMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"start-token-abc",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert sorted(remove_calls) == ["del1", "del2", "trash1"]
|
||||
|
||||
download_mock.assert_called_once()
|
||||
downloaded_files = download_mock.call_args[0][1]
|
||||
assert len(downloaded_files) == 2
|
||||
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue