Merge remote-tracking branch 'upstream/dev' into feat/onedrive-connector

This commit is contained in:
Anish Sarkar 2026-03-29 11:55:06 +05:30
commit 5a3eece397
70 changed files with 8288 additions and 5698 deletions

View file

@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
def _cal_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
return ConnectorDocument(
title=f"Event {unique_id}",
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
@ -34,7 +36,9 @@ def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_calendar_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -63,7 +67,9 @@ async def test_calendar_pipeline_creates_ready_document(
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_calendar_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -101,7 +107,9 @@ async def test_calendar_legacy_doc_migrated(
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR

View file

@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
def _drive_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
return ConnectorDocument(
title=f"File {unique_id}.pdf",
source_markdown=f"## Document Content\n\nText from file {unique_id}",
@ -33,7 +35,9 @@ def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_drive_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -62,7 +66,9 @@ async def test_drive_pipeline_creates_ready_document(
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_drive_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -100,7 +106,9 @@ async def test_drive_legacy_doc_migrated(
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
@ -111,7 +119,9 @@ async def test_drive_legacy_doc_migrated(
async def test_should_skip_file_skips_failed_document(
db_session, db_search_space, db_user,
db_session,
db_search_space,
db_user,
):
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
import importlib
@ -162,7 +172,12 @@ async def test_should_skip_file_skips_failed_document(
db_session.add(failed_doc)
await db_session.flush()
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
incoming_file = {
"id": file_id,
"name": "Failed File.pdf",
"mimeType": "application/pdf",
"md5Checksum": md5,
}
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)

View file

@ -8,7 +8,6 @@ from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_identifier_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
@ -17,7 +16,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
def _gmail_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
return ConnectorDocument(
title=f"Subject for {unique_id}",
@ -37,7 +38,9 @@ def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_gmail_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -67,7 +70,9 @@ async def test_gmail_pipeline_creates_ready_document(
assert row.source_markdown == doc.source_markdown
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_gmail_legacy_doc_migrated_then_reused(
db_session, db_search_space, db_connector, db_user, mocker
):

View file

@ -9,7 +9,9 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_index_batch_creates_ready_documents(
db_session, db_search_space, make_connector_document, mocker
):
@ -47,7 +49,9 @@ async def test_index_batch_creates_ready_documents(
assert row.embedding is not None
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_index_batch_empty_returns_empty(db_session, mocker):
"""index_batch with empty input returns an empty list."""
service = IndexingPipelineService(session=db_session)

View file

@ -0,0 +1,106 @@
"""Shared fixtures for retriever integration tests."""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Chunk, Document, DocumentType, SearchSpace, User
EMBEDDING_DIM = app_config.embedding_model_instance.dimension
DUMMY_EMBEDDING = [0.1] * EMBEDDING_DIM
def _make_document(
*,
title: str,
document_type: DocumentType,
content: str,
search_space_id: int,
created_by_id: str,
) -> Document:
uid = uuid.uuid4().hex[:12]
return Document(
title=title,
document_type=document_type,
content=content,
content_hash=f"content-{uid}",
unique_identifier_hash=f"uid-{uid}",
source_markdown=content,
search_space_id=search_space_id,
created_by_id=created_by_id,
embedding=DUMMY_EMBEDDING,
updated_at=datetime.now(UTC),
status={"state": "ready"},
)
def _make_chunk(*, content: str, document_id: int) -> Chunk:
return Chunk(
content=content,
document_id=document_id,
embedding=DUMMY_EMBEDDING,
)
@pytest_asyncio.fixture
async def seed_large_doc(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
"""Insert a document with 35 chunks (more than _MAX_FETCH_CHUNKS_PER_DOC=20).
Also inserts a small 3-chunk document for diversity testing.
Returns a dict with ``large_doc``, ``small_doc``, ``search_space``, ``user``,
and ``large_chunk_ids`` (all 35 chunk IDs).
"""
user_id = str(db_user.id)
space_id = db_search_space.id
large_doc = _make_document(
title="Large PDF Document",
document_type=DocumentType.FILE,
content="large document about quarterly performance reviews and budgets",
search_space_id=space_id,
created_by_id=user_id,
)
small_doc = _make_document(
title="Small Note",
document_type=DocumentType.NOTE,
content="quarterly performance review summary note",
search_space_id=space_id,
created_by_id=user_id,
)
db_session.add_all([large_doc, small_doc])
await db_session.flush()
large_chunks = []
for i in range(35):
chunk = _make_chunk(
content=f"chunk {i} about quarterly performance review section {i}",
document_id=large_doc.id,
)
large_chunks.append(chunk)
small_chunks = [
_make_chunk(
content="quarterly performance review summary note content",
document_id=small_doc.id,
),
]
db_session.add_all(large_chunks + small_chunks)
await db_session.flush()
return {
"large_doc": large_doc,
"small_doc": small_doc,
"large_chunk_ids": [c.id for c in large_chunks],
"small_chunk_ids": [c.id for c in small_chunks],
"search_space": db_search_space,
"user": db_user,
}

View file

@ -0,0 +1,116 @@
"""Integration tests for optimized ChucksHybridSearchRetriever.
Verifies the SQL ROW_NUMBER per-doc chunk limit, column pruning,
and doc metadata caching from RRF results.
"""
import pytest
from app.retriever.chunks_hybrid_search import (
_MAX_FETCH_CHUNKS_PER_DOC,
ChucksHybridSearchRetriever,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
large_doc_id = seed_large_doc["large_doc"].id
for result in results:
if result["document"].get("id") == large_doc_id:
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
break
else:
pytest.fail("Large doc not found in search results")
async def test_doc_metadata_populated_from_rrf(db_session, seed_large_doc):
"""Document metadata (title, type, etc.) should be present even without joinedload."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
doc = result["document"]
assert "id" in doc
assert "title" in doc
assert doc["title"]
assert "document_type" in doc
assert doc["document_type"] is not None
async def test_matched_chunk_ids_tracked(db_session, seed_large_doc):
"""matched_chunk_ids should contain the chunk IDs that appeared in the RRF results."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
matched = result.get("matched_chunk_ids", [])
chunk_ids_in_result = {c["chunk_id"] for c in result["chunks"]}
for mid in matched:
assert mid in chunk_ids_in_result, (
f"matched_chunk_id {mid} not found in chunks"
)
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
"""Chunks within each document should be ordered by chunk ID (original order)."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
async def test_score_is_positive_float(db_session, seed_large_doc):
"""Each result should have a positive float score from RRF."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
assert isinstance(result["score"], float)
assert result["score"] > 0

View file

@ -0,0 +1,76 @@
"""Integration tests for optimized DocumentHybridSearchRetriever.
Verifies the SQL ROW_NUMBER per-doc chunk limit and column pruning.
"""
import pytest
from app.retriever.documents_hybrid_search import (
_MAX_FETCH_CHUNKS_PER_DOC,
DocumentHybridSearchRetriever,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
large_doc_id = seed_large_doc["large_doc"].id
for result in results:
if result["document"].get("id") == large_doc_id:
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
break
else:
pytest.fail("Large doc not found in search results")
async def test_doc_metadata_populated(db_session, seed_large_doc):
"""Document metadata should be present from the RRF results."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
doc = result["document"]
assert "id" in doc
assert "title" in doc
assert doc["title"]
assert "document_type" in doc
assert doc["document_type"] is not None
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
"""Chunks within each document should be ordered by chunk ID."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"

View file

@ -42,14 +42,11 @@ def _to_markdown(page: dict) -> str:
if comments:
comments_content = "\n\n## Comments\n\n"
for comment in comments:
comment_body = (
comment.get("body", {}).get("storage", {}).get("value", "")
)
comment_body = comment.get("body", {}).get("storage", {}).get("value", "")
comment_author = comment.get("version", {}).get("authorId", "Unknown")
comment_date = comment.get("version", {}).get("createdAt", "")
comments_content += (
f"**Comment by {comment_author}** ({comment_date}):\n"
f"{comment_body}\n\n"
f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
)
return f"# {page_title}\n\n{page_content}{comments_content}"
@ -138,22 +135,32 @@ def confluence_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
confluence_client = _mock_confluence_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
_mod,
"ConfluenceHistoryConnector",
MagicMock(return_value=confluence_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
@ -162,15 +169,20 @@ def confluence_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {

View file

@ -41,6 +41,7 @@ def mock_drive_client():
@pytest.fixture
def patch_extract(monkeypatch):
"""Provide a helper to set the download_and_extract_content mock."""
def _patch(side_effect=None, return_value=None):
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
monkeypatch.setattr(
@ -48,11 +49,13 @@ def patch_extract(monkeypatch):
mock,
)
return mock
return _patch
async def test_single_file_returns_one_connector_document(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
@ -73,7 +76,8 @@ async def test_single_file_returns_one_connector_document(
async def test_multiple_files_all_produce_documents(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""All files are downloaded and converted to ConnectorDocuments."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
@ -96,7 +100,8 @@ async def test_multiple_files_all_produce_documents(
async def test_one_download_exception_does_not_block_others(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""A RuntimeError in one download still lets the other files succeed."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
@ -123,7 +128,8 @@ async def test_one_download_exception_does_not_block_others(
async def test_etl_error_counts_as_download_failure(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""download_and_extract_content returning an error is counted as failed."""
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
@ -148,7 +154,8 @@ async def test_etl_error_counts_as_download_failure(
async def test_concurrency_bounded_by_semaphore(
mock_drive_client, monkeypatch,
mock_drive_client,
monkeypatch,
):
"""Peak concurrent downloads never exceeds max_concurrency."""
lock = asyncio.Lock()
@ -189,7 +196,8 @@ async def test_concurrency_bounded_by_semaphore(
async def test_heartbeat_fires_during_parallel_downloads(
mock_drive_client, monkeypatch,
mock_drive_client,
monkeypatch,
):
"""on_heartbeat is called at least once when downloads take time."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
@ -231,8 +239,13 @@ async def test_heartbeat_fires_during_parallel_downloads(
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
# ---------------------------------------------------------------------------
def _folder_dict(file_id: str, name: str) -> dict:
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
return {
"id": file_id,
"name": name,
"mimeType": "application/vnd.google-apps.folder",
}
@pytest.fixture
@ -259,12 +272,17 @@ def full_scan_mocks(mock_drive_client, monkeypatch):
batch_mock = AsyncMock(return_value=([], 0, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
_mod,
"get_user_long_context_llm",
AsyncMock(return_value=MagicMock()),
)
return {
@ -312,12 +330,16 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
]
monkeypatch.setattr(
_mod, "get_files_in_folder",
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old''renamed.txt'")
full_scan_mocks["skip_results"]["rename1"] = (
True,
"File renamed: 'old''renamed.txt'",
)
mock_docs = [MagicMock(), MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
@ -341,7 +363,8 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
monkeypatch.setattr(
_mod, "get_files_in_folder",
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
@ -355,14 +378,16 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
async def test_full_scan_uses_max_concurrency_3_for_indexing(
full_scan_mocks, monkeypatch,
full_scan_mocks,
monkeypatch,
):
"""index_batch_parallel is called with max_concurrency=3."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict("f1", "file1.txt")]
monkeypatch.setattr(
_mod, "get_files_in_folder",
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
@ -382,6 +407,7 @@ async def test_full_scan_uses_max_concurrency_3_for_indexing(
# Slice 7 -- _index_with_delta_sync three-phase pipeline
# ---------------------------------------------------------------------------
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"""Removed/trashed changes call _remove_document; the rest go through
_download_files_parallel and index_batch_parallel."""
@ -396,7 +422,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
]
monkeypatch.setattr(
_mod, "fetch_all_changes",
_mod,
"fetch_all_changes",
AsyncMock(return_value=(changes, "new-token", None)),
)
@ -408,7 +435,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"mod2": "modified",
}
monkeypatch.setattr(
_mod, "categorize_change",
_mod,
"categorize_change",
lambda change: change_types[change["fileId"]],
)
@ -420,7 +448,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(
_mod, "_should_skip_file",
_mod,
"_should_skip_file",
AsyncMock(return_value=(False, None)),
)
@ -431,11 +460,16 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
batch_mock = AsyncMock(return_value=([], 2, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
_mod,
"get_user_long_context_llm",
AsyncMock(return_value=MagicMock()),
)
mock_session = AsyncMock()
@ -472,6 +506,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
# _index_selected_files -- parallel indexing of user-selected files
# ---------------------------------------------------------------------------
@pytest.fixture
def selected_files_mocks(mock_drive_client, monkeypatch):
"""Wire up mocks for _index_selected_files tests."""
@ -496,6 +531,14 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
pipeline_mock = MagicMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"drive_client": mock_drive_client,
"session": mock_session,
@ -526,7 +569,8 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks, [("f1", "report.pdf")],
selected_files_mocks,
[("f1", "report.pdf")],
)
assert indexed == 1
@ -538,11 +582,13 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
"""get_file_by_id failing for one file collects an error; others still indexed."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "first.txt"), None,
_make_file_dict("f1", "first.txt"),
None,
)
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
selected_files_mocks["get_file_results"]["f3"] = (
_make_file_dict("f3", "third.txt"), None,
_make_file_dict("f3", "third.txt"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
@ -561,30 +607,46 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
async def test_selected_files_skip_rename_counting(selected_files_mocks):
"""Unchanged files are skipped, renames counted as indexed,
and only new files are sent to _download_and_index."""
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")]:
for fid, fname in [
("s1", "unchanged.txt"),
("r1", "renamed.txt"),
("n1", "new1.txt"),
("n2", "new2.txt"),
]:
selected_files_mocks["get_file_results"][fid] = (
_make_file_dict(fid, fname), None,
_make_file_dict(fid, fname),
None,
)
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
selected_files_mocks["skip_results"]["r1"] = (
True,
"File renamed: 'old' \u2192 'renamed.txt'",
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")],
[
("s1", "unchanged.txt"),
("r1", "renamed.txt"),
("n1", "new1.txt"),
("n2", "new2.txt"),
],
)
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert errors == []
mock = selected_files_mocks["download_and_index_mock"]
mock.assert_called_once()
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
call_files = (
mock.call_args[1].get("files")
if "files" in (mock.call_args[1] or {})
else mock.call_args[0][2]
)
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"n1", "n2"}
@ -593,6 +655,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
# asyncio.to_thread verification — prove blocking calls run in parallel
# ---------------------------------------------------------------------------
async def test_client_download_file_runs_in_thread_parallel():
"""Calling download_file concurrently via asyncio.gather should overlap
blocking work on separate threads, proving to_thread is effective.
@ -602,11 +665,11 @@ async def test_client_download_file_runs_in_thread_parallel():
"""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
block_seconds = 0.2
num_calls = 3
def _blocking_download(service, file_id, credentials):
time.sleep(BLOCK_SECONDS)
time.sleep(block_seconds)
return b"fake-content", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
@ -615,11 +678,13 @@ async def test_client_download_file_runs_in_thread_parallel():
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
GoogleDriveClient,
"_sync_download_file",
staticmethod(_blocking_download),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
*(client.download_file(f"file-{i}") for i in range(num_calls))
)
elapsed = time.monotonic() - start
@ -627,7 +692,7 @@ async def test_client_download_file_runs_in_thread_parallel():
assert content == b"fake-content"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
serial_minimum = block_seconds * num_calls
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"downloads are not running in parallel"
@ -638,11 +703,11 @@ async def test_client_export_google_file_runs_in_thread_parallel():
"""Same strategy for export_google_file — verify to_thread parallelism."""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
block_seconds = 0.2
num_calls = 3
def _blocking_export(service, file_id, mime_type, credentials):
time.sleep(BLOCK_SECONDS)
time.sleep(block_seconds)
return b"exported", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
@ -651,12 +716,16 @@ async def test_client_export_google_file_runs_in_thread_parallel():
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
GoogleDriveClient,
"_sync_export_google_file",
staticmethod(_blocking_export),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.export_google_file(f"file-{i}", "application/pdf")
for i in range(NUM_CALLS))
*(
client.export_google_file(f"file-{i}", "application/pdf")
for i in range(num_calls)
)
)
elapsed = time.monotonic() - start
@ -664,7 +733,7 @@ async def test_client_export_google_file_runs_in_thread_parallel():
assert content == b"exported"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
serial_minimum = block_seconds * num_calls
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"exports are not running in parallel"

View file

@ -145,22 +145,32 @@ def jira_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
jira_client = _mock_jira_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
_mod,
"JiraHistoryConnector",
MagicMock(return_value=jira_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
@ -169,15 +179,20 @@ def jira_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {

View file

@ -128,13 +128,17 @@ def _mock_linear_client(issues=None, error=None):
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
))
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
side_effect=lambda fi: (
f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
)
)
return client
@ -147,24 +151,34 @@ def linear_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
linear_client = _mock_linear_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "LinearConnector", MagicMock(return_value=linear_client),
_mod,
"LinearConnector",
MagicMock(return_value=linear_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
@ -173,15 +187,20 @@ def linear_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
@ -255,7 +274,7 @@ async def test_issues_with_missing_id_are_skipped(linear_mocks):
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
@ -271,7 +290,7 @@ async def test_issues_with_missing_title_are_skipped(linear_mocks):
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
@ -305,7 +324,7 @@ async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
indexed, skipped, _ = await _run_index(linear_mocks)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1

View file

@ -107,28 +107,40 @@ def notion_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
notion_client = _mock_notion_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
_mod,
"NotionHistoryConnector",
MagicMock(return_value=notion_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
monkeypatch.setattr(
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
_mod,
"process_blocks",
MagicMock(return_value="Converted markdown content"),
)
mock_task_logger = MagicMock()
@ -137,15 +149,20 @@ def notion_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
@ -216,7 +233,10 @@ async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
"""Pages without page_id are skipped and not passed to the pipeline."""
pages = [
_make_page(page_id="valid-1"),
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
{
"title": "No ID page",
"content": [{"type": "paragraph", "content": "text", "children": []}],
},
]
notion_mocks["notion_client"].get_all_pages.return_value = pages

View file

@ -0,0 +1,131 @@
"""Unit tests for IndexingPipelineService.create_placeholder_documents."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy.exc import IntegrityError
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_placeholder(**overrides) -> PlaceholderInfo:
defaults = {
"title": "Test Doc",
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
"unique_id": "file-001",
"search_space_id": 1,
"connector_id": 42,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return PlaceholderInfo(**defaults)
def _uid_hash(p: PlaceholderInfo) -> str:
return compute_identifier_hash(
p.document_type.value, p.unique_id, p.search_space_id
)
def _session_with_existing_hashes(existing: set[str] | None = None):
"""Build an AsyncMock session whose batch-query returns *existing* hashes."""
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = list(existing or [])
session.execute = AsyncMock(return_value=result)
session.add = MagicMock()
return session
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_empty_input_returns_zero_without_db_calls():
session = AsyncMock()
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([])
assert result == 0
session.execute.assert_not_awaited()
session.commit.assert_not_awaited()
async def test_creates_documents_with_pending_status_and_commits():
session = _session_with_existing_hashes(set())
pipeline = IndexingPipelineService(session)
p = _make_placeholder(title="My File", unique_id="file-abc")
result = await pipeline.create_placeholder_documents([p])
assert result == 1
session.add.assert_called_once()
doc = session.add.call_args[0][0]
assert doc.title == "My File"
assert doc.document_type == DocumentType.GOOGLE_DRIVE_FILE
assert doc.content == "Pending..."
assert DocumentStatus.is_state(doc.status, DocumentStatus.PENDING)
assert doc.search_space_id == 1
assert doc.connector_id == 42
session.commit.assert_awaited_once()
async def test_existing_documents_are_skipped():
"""Placeholders whose unique_identifier_hash already exists are not re-created."""
existing_p = _make_placeholder(unique_id="already-there")
new_p = _make_placeholder(unique_id="brand-new")
existing_hash = _uid_hash(existing_p)
session = _session_with_existing_hashes({existing_hash})
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([existing_p, new_p])
assert result == 1
doc = session.add.call_args[0][0]
assert doc.unique_identifier_hash == _uid_hash(new_p)
async def test_duplicate_unique_ids_within_input_are_deduped():
"""Same unique_id passed twice only produces one placeholder."""
p1 = _make_placeholder(unique_id="dup-id", title="First")
p2 = _make_placeholder(unique_id="dup-id", title="Second")
session = _session_with_existing_hashes(set())
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([p1, p2])
assert result == 1
session.add.assert_called_once()
async def test_integrity_error_on_commit_returns_zero():
"""IntegrityError during commit (race condition) is swallowed gracefully."""
session = _session_with_existing_hashes(set())
session.commit = AsyncMock(side_effect=IntegrityError("dup", {}, None))
pipeline = IndexingPipelineService(session)
p = _make_placeholder()
result = await pipeline.create_placeholder_documents([p])
assert result == 0
session.rollback.assert_awaited_once()

View file

@ -19,9 +19,7 @@ def pipeline(mock_session):
return IndexingPipelineService(mock_session)
async def test_calls_prepare_then_index_per_document(
pipeline, make_connector_document
):
async def test_calls_prepare_then_index_per_document(pipeline, make_connector_document):
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,

View file

@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -57,7 +57,9 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock_chunk,
)
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
mock_embed = MagicMock(
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
)
mock_embed.__name__ = "embed_texts"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",

View file

@ -0,0 +1,110 @@
"""Unit tests for the duplicate-content safety logic in prepare_for_indexing.
Verifies that when an existing document's updated content matches another
document's content_hash, the system marks it as failed (for placeholders)
or leaves it untouched (for ready documents) never deletes.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_connector_doc(**overrides) -> ConnectorDocument:
defaults = {
"title": "Test Doc",
"source_markdown": "## Some new content",
"unique_id": "file-001",
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
"search_space_id": 1,
"connector_id": 42,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
def _make_existing_doc(connector_doc: ConnectorDocument, *, status: dict) -> MagicMock:
"""Build a MagicMock that looks like an ORM Document with given status."""
doc = MagicMock(spec=Document)
doc.id = 999
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
doc.content_hash = "old-placeholder-content-hash"
doc.title = connector_doc.title
doc.status = status
return doc
def _mock_session_for_dedup(existing_doc, *, has_duplicate: bool):
"""Build a session whose sequential execute() calls return:
1. The *existing_doc* for the unique_identifier_hash lookup.
2. A row (or None) for the duplicate content_hash check.
"""
session = AsyncMock()
existing_result = MagicMock()
existing_result.scalars.return_value.first.return_value = existing_doc
dup_result = MagicMock()
dup_result.scalars.return_value.first.return_value = 42 if has_duplicate else None
session.execute = AsyncMock(side_effect=[existing_result, dup_result])
session.add = MagicMock()
return session
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_pending_placeholder_with_duplicate_content_is_marked_failed():
"""A placeholder (pending) whose updated content duplicates another doc
must be marked as FAILED never deleted."""
cdoc = _make_connector_doc(source_markdown="## Shared content")
existing = _make_existing_doc(cdoc, status=DocumentStatus.pending())
session = _mock_session_for_dedup(existing, has_duplicate=True)
pipeline = IndexingPipelineService(session)
results = await pipeline.prepare_for_indexing([cdoc])
assert results == [], "duplicate should not be returned for indexing"
assert DocumentStatus.is_state(existing.status, DocumentStatus.FAILED)
assert "Duplicate content" in existing.status.get("reason", "")
session.delete.assert_not_called()
async def test_ready_document_with_duplicate_content_is_left_untouched():
"""A READY document whose updated content duplicates another doc
must be left completely untouched not failed, not deleted."""
cdoc = _make_connector_doc(source_markdown="## Shared content")
existing = _make_existing_doc(cdoc, status=DocumentStatus.ready())
session = _mock_session_for_dedup(existing, has_duplicate=True)
pipeline = IndexingPipelineService(session)
results = await pipeline.prepare_for_indexing([cdoc])
assert results == [], "duplicate should not be returned for indexing"
assert DocumentStatus.is_state(existing.status, DocumentStatus.READY)
session.delete.assert_not_called()

View file

@ -0,0 +1,133 @@
"""Unit tests for knowledge_search middleware helpers.
These test pure functions that don't require a database.
"""
import pytest
from app.agents.new_chat.middleware.knowledge_search import (
_build_document_xml,
_resolve_search_types,
)
pytestmark = pytest.mark.unit
# ── _resolve_search_types ──────────────────────────────────────────────
class TestResolveSearchTypes:
def test_returns_none_when_no_inputs(self):
assert _resolve_search_types(None, None) is None
def test_returns_none_when_both_empty(self):
assert _resolve_search_types([], []) is None
def test_includes_legacy_type_for_google_gmail(self):
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
assert "GOOGLE_GMAIL_CONNECTOR" in result
assert "COMPOSIO_GMAIL_CONNECTOR" in result
def test_includes_legacy_type_for_google_drive(self):
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
assert "GOOGLE_DRIVE_FILE" in result
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
def test_includes_legacy_type_for_google_calendar(self):
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
assert "GOOGLE_CALENDAR_CONNECTOR" in result
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
def test_no_legacy_expansion_for_unrelated_types(self):
result = _resolve_search_types(["FILE", "NOTE"], None)
assert set(result) == {"FILE", "NOTE"}
def test_combines_connectors_and_document_types(self):
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
def test_deduplicates(self):
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
assert result.count("FILE") == 1
# ── _build_document_xml ────────────────────────────────────────────────
class TestBuildDocumentXml:
@pytest.fixture
def sample_document(self):
return {
"document_id": 42,
"document": {
"id": 42,
"document_type": "FILE",
"title": "Test Doc",
"metadata": {"url": "https://example.com"},
},
"chunks": [
{"chunk_id": 101, "content": "First chunk content"},
{"chunk_id": 102, "content": "Second chunk content"},
{"chunk_id": 103, "content": "Third chunk content"},
],
}
def test_contains_document_metadata(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_id>42</document_id>" in xml
assert "<document_type>FILE</document_type>" in xml
assert "Test Doc" in xml
def test_contains_chunk_index(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<chunk_index>" in xml
assert "</chunk_index>" in xml
assert 'chunk_id="101"' in xml
assert 'chunk_id="102"' in xml
assert 'chunk_id="103"' in xml
def test_matched_chunks_flagged_in_index(self, sample_document):
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
lines = xml.split("\n")
for line in lines:
if 'chunk_id="101"' in line:
assert 'matched="true"' in line
if 'chunk_id="102"' in line:
assert 'matched="true"' not in line
if 'chunk_id="103"' in line:
assert 'matched="true"' in line
def test_chunk_content_in_document_content_section(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_content>" in xml
assert "First chunk content" in xml
assert "Second chunk content" in xml
assert "Third chunk content" in xml
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
"""Verify that the line ranges in chunk_index actually point to the right content."""
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
xml_lines = xml.split("\n")
for line in xml_lines:
if 'chunk_id="101"' in line and "lines=" in line:
import re
m = re.search(r'lines="(\d+)-(\d+)"', line)
assert m, f"No lines= attribute found in: {line}"
start, _end = int(m.group(1)), int(m.group(2))
target_line = xml_lines[start - 1]
assert "101" in target_line
assert "First chunk content" in target_line
break
else:
pytest.fail("chunk_id=101 entry not found in chunk_index")
def test_splits_into_lines_correctly(self, sample_document):
"""Each chunk occupies exactly one line (no embedded newlines)."""
xml = _build_document_xml(sample_document)
lines = xml.split("\n")
chunk_lines = [
line for line in lines if "<![CDATA[" in line and "<chunk" in line
]
assert len(chunk_lines) == 3