mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 01:06:23 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/onedrive-connector
This commit is contained in:
commit
5a3eece397
70 changed files with 8288 additions and 5698 deletions
|
|
@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
|||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
||||
def _cal_doc(
|
||||
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||
) -> ConnectorDocument:
|
||||
return ConnectorDocument(
|
||||
title=f"Event {unique_id}",
|
||||
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
|
||||
|
|
@ -34,7 +36,9 @@ def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_calendar_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
|
|
@ -63,7 +67,9 @@ async def test_calendar_pipeline_creates_ready_document(
|
|||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_calendar_legacy_doc_migrated(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
|
|
@ -101,7 +107,9 @@ async def test_calendar_legacy_doc_migrated(
|
|||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == original_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
|||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
||||
def _drive_doc(
|
||||
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||
) -> ConnectorDocument:
|
||||
return ConnectorDocument(
|
||||
title=f"File {unique_id}.pdf",
|
||||
source_markdown=f"## Document Content\n\nText from file {unique_id}",
|
||||
|
|
@ -33,7 +35,9 @@ def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_drive_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
|
|
@ -62,7 +66,9 @@ async def test_drive_pipeline_creates_ready_document(
|
|||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_drive_legacy_doc_migrated(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
|
|
@ -100,7 +106,9 @@ async def test_drive_legacy_doc_migrated(
|
|||
service = IndexingPipelineService(session=db_session)
|
||||
await service.migrate_legacy_docs([connector_doc])
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
result = await db_session.execute(
|
||||
select(Document).filter(Document.id == original_id)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
|
||||
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
|
@ -111,7 +119,9 @@ async def test_drive_legacy_doc_migrated(
|
|||
|
||||
|
||||
async def test_should_skip_file_skips_failed_document(
|
||||
db_session, db_search_space, db_user,
|
||||
db_session,
|
||||
db_search_space,
|
||||
db_user,
|
||||
):
|
||||
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
|
||||
import importlib
|
||||
|
|
@ -162,7 +172,12 @@ async def test_should_skip_file_skips_failed_document(
|
|||
db_session.add(failed_doc)
|
||||
await db_session.flush()
|
||||
|
||||
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
|
||||
incoming_file = {
|
||||
"id": file_id,
|
||||
"name": "Failed File.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"md5Checksum": md5,
|
||||
}
|
||||
|
||||
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from app.db import Document, DocumentStatus, DocumentType
|
|||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_identifier_hash,
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
|
|
@ -17,7 +16,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
|||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
||||
def _gmail_doc(
|
||||
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||
) -> ConnectorDocument:
|
||||
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
|
||||
return ConnectorDocument(
|
||||
title=f"Subject for {unique_id}",
|
||||
|
|
@ -37,7 +38,9 @@ def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_gmail_pipeline_creates_ready_document(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
|
|
@ -67,7 +70,9 @@ async def test_gmail_pipeline_creates_ready_document(
|
|||
assert row.source_markdown == doc.source_markdown
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_gmail_legacy_doc_migrated_then_reused(
|
||||
db_session, db_search_space, db_connector, db_user, mocker
|
||||
):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ
|
|||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_index_batch_creates_ready_documents(
|
||||
db_session, db_search_space, make_connector_document, mocker
|
||||
):
|
||||
|
|
@ -47,7 +49,9 @@ async def test_index_batch_creates_ready_documents(
|
|||
assert row.embedding is not None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
||||
@pytest.mark.usefixtures(
|
||||
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||
)
|
||||
async def test_index_batch_empty_returns_empty(db_session, mocker):
|
||||
"""index_batch with empty input returns an empty list."""
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
|
|
|||
106
surfsense_backend/tests/integration/retriever/conftest.py
Normal file
106
surfsense_backend/tests/integration/retriever/conftest.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Shared fixtures for retriever integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Chunk, Document, DocumentType, SearchSpace, User
|
||||
|
||||
EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||
DUMMY_EMBEDDING = [0.1] * EMBEDDING_DIM
|
||||
|
||||
|
||||
def _make_document(
|
||||
*,
|
||||
title: str,
|
||||
document_type: DocumentType,
|
||||
content: str,
|
||||
search_space_id: int,
|
||||
created_by_id: str,
|
||||
) -> Document:
|
||||
uid = uuid.uuid4().hex[:12]
|
||||
return Document(
|
||||
title=title,
|
||||
document_type=document_type,
|
||||
content=content,
|
||||
content_hash=f"content-{uid}",
|
||||
unique_identifier_hash=f"uid-{uid}",
|
||||
source_markdown=content,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=created_by_id,
|
||||
embedding=DUMMY_EMBEDDING,
|
||||
updated_at=datetime.now(UTC),
|
||||
status={"state": "ready"},
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk(*, content: str, document_id: int) -> Chunk:
|
||||
return Chunk(
|
||||
content=content,
|
||||
document_id=document_id,
|
||||
embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seed_large_doc(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
):
|
||||
"""Insert a document with 35 chunks (more than _MAX_FETCH_CHUNKS_PER_DOC=20).
|
||||
|
||||
Also inserts a small 3-chunk document for diversity testing.
|
||||
Returns a dict with ``large_doc``, ``small_doc``, ``search_space``, ``user``,
|
||||
and ``large_chunk_ids`` (all 35 chunk IDs).
|
||||
"""
|
||||
user_id = str(db_user.id)
|
||||
space_id = db_search_space.id
|
||||
|
||||
large_doc = _make_document(
|
||||
title="Large PDF Document",
|
||||
document_type=DocumentType.FILE,
|
||||
content="large document about quarterly performance reviews and budgets",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
small_doc = _make_document(
|
||||
title="Small Note",
|
||||
document_type=DocumentType.NOTE,
|
||||
content="quarterly performance review summary note",
|
||||
search_space_id=space_id,
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
db_session.add_all([large_doc, small_doc])
|
||||
await db_session.flush()
|
||||
|
||||
large_chunks = []
|
||||
for i in range(35):
|
||||
chunk = _make_chunk(
|
||||
content=f"chunk {i} about quarterly performance review section {i}",
|
||||
document_id=large_doc.id,
|
||||
)
|
||||
large_chunks.append(chunk)
|
||||
|
||||
small_chunks = [
|
||||
_make_chunk(
|
||||
content="quarterly performance review summary note content",
|
||||
document_id=small_doc.id,
|
||||
),
|
||||
]
|
||||
|
||||
db_session.add_all(large_chunks + small_chunks)
|
||||
await db_session.flush()
|
||||
|
||||
return {
|
||||
"large_doc": large_doc,
|
||||
"small_doc": small_doc,
|
||||
"large_chunk_ids": [c.id for c in large_chunks],
|
||||
"small_chunk_ids": [c.id for c in small_chunks],
|
||||
"search_space": db_search_space,
|
||||
"user": db_user,
|
||||
}
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""Integration tests for optimized ChucksHybridSearchRetriever.
|
||||
|
||||
Verifies the SQL ROW_NUMBER per-doc chunk limit, column pruning,
|
||||
and doc metadata caching from RRF results.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.retriever.chunks_hybrid_search import (
|
||||
_MAX_FETCH_CHUNKS_PER_DOC,
|
||||
ChucksHybridSearchRetriever,
|
||||
)
|
||||
|
||||
from .conftest import DUMMY_EMBEDDING
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
|
||||
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
large_doc_id = seed_large_doc["large_doc"].id
|
||||
for result in results:
|
||||
if result["document"].get("id") == large_doc_id:
|
||||
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
|
||||
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
|
||||
break
|
||||
else:
|
||||
pytest.fail("Large doc not found in search results")
|
||||
|
||||
|
||||
async def test_doc_metadata_populated_from_rrf(db_session, seed_large_doc):
|
||||
"""Document metadata (title, type, etc.) should be present even without joinedload."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
assert len(results) >= 1
|
||||
for result in results:
|
||||
doc = result["document"]
|
||||
assert "id" in doc
|
||||
assert "title" in doc
|
||||
assert doc["title"]
|
||||
assert "document_type" in doc
|
||||
assert doc["document_type"] is not None
|
||||
|
||||
|
||||
async def test_matched_chunk_ids_tracked(db_session, seed_large_doc):
|
||||
"""matched_chunk_ids should contain the chunk IDs that appeared in the RRF results."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
matched = result.get("matched_chunk_ids", [])
|
||||
chunk_ids_in_result = {c["chunk_id"] for c in result["chunks"]}
|
||||
for mid in matched:
|
||||
assert mid in chunk_ids_in_result, (
|
||||
f"matched_chunk_id {mid} not found in chunks"
|
||||
)
|
||||
|
||||
|
||||
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
|
||||
"""Chunks within each document should be ordered by chunk ID (original order)."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
|
||||
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
|
||||
|
||||
|
||||
async def test_score_is_positive_float(db_session, seed_large_doc):
|
||||
"""Each result should have a positive float score from RRF."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = ChucksHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
assert len(results) >= 1
|
||||
for result in results:
|
||||
assert isinstance(result["score"], float)
|
||||
assert result["score"] > 0
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
"""Integration tests for optimized DocumentHybridSearchRetriever.
|
||||
|
||||
Verifies the SQL ROW_NUMBER per-doc chunk limit and column pruning.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.retriever.documents_hybrid_search import (
|
||||
_MAX_FETCH_CHUNKS_PER_DOC,
|
||||
DocumentHybridSearchRetriever,
|
||||
)
|
||||
|
||||
from .conftest import DUMMY_EMBEDDING
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
|
||||
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = DocumentHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
large_doc_id = seed_large_doc["large_doc"].id
|
||||
for result in results:
|
||||
if result["document"].get("id") == large_doc_id:
|
||||
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
|
||||
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
|
||||
break
|
||||
else:
|
||||
pytest.fail("Large doc not found in search results")
|
||||
|
||||
|
||||
async def test_doc_metadata_populated(db_session, seed_large_doc):
|
||||
"""Document metadata should be present from the RRF results."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = DocumentHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
assert len(results) >= 1
|
||||
for result in results:
|
||||
doc = result["document"]
|
||||
assert "id" in doc
|
||||
assert "title" in doc
|
||||
assert doc["title"]
|
||||
assert "document_type" in doc
|
||||
assert doc["document_type"] is not None
|
||||
|
||||
|
||||
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
|
||||
"""Chunks within each document should be ordered by chunk ID."""
|
||||
space_id = seed_large_doc["search_space"].id
|
||||
|
||||
retriever = DocumentHybridSearchRetriever(db_session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="quarterly performance review",
|
||||
top_k=10,
|
||||
search_space_id=space_id,
|
||||
query_embedding=DUMMY_EMBEDDING,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
|
||||
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
|
||||
|
|
@ -42,14 +42,11 @@ def _to_markdown(page: dict) -> str:
|
|||
if comments:
|
||||
comments_content = "\n\n## Comments\n\n"
|
||||
for comment in comments:
|
||||
comment_body = (
|
||||
comment.get("body", {}).get("storage", {}).get("value", "")
|
||||
)
|
||||
comment_body = comment.get("body", {}).get("storage", {}).get("value", "")
|
||||
comment_author = comment.get("version", {}).get("authorId", "Unknown")
|
||||
comment_date = comment.get("version", {}).get("createdAt", "")
|
||||
comments_content += (
|
||||
f"**Comment by {comment_author}** ({comment_date}):\n"
|
||||
f"{comment_body}\n\n"
|
||||
f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
|
||||
)
|
||||
return f"# {page_title}\n\n{page_content}{comments_content}"
|
||||
|
||||
|
|
@ -138,22 +135,32 @@ def confluence_mocks(monkeypatch):
|
|||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
_mod,
|
||||
"get_connector_by_id",
|
||||
AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
confluence_client = _mock_confluence_client(pages=[_make_page()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
|
||||
_mod,
|
||||
"ConfluenceHistoryConnector",
|
||||
MagicMock(return_value=confluence_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
_mod,
|
||||
"check_duplicate_document_by_hash",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
_mod,
|
||||
"update_connector_last_indexed",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
_mod,
|
||||
"calculate_date_range",
|
||||
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
|
|
@ -162,15 +169,20 @@ def confluence_mocks(monkeypatch):
|
|||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
_mod,
|
||||
"TaskLoggingService",
|
||||
MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ def mock_drive_client():
|
|||
@pytest.fixture
|
||||
def patch_extract(monkeypatch):
|
||||
"""Provide a helper to set the download_and_extract_content mock."""
|
||||
|
||||
def _patch(side_effect=None, return_value=None):
|
||||
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -48,11 +49,13 @@ def patch_extract(monkeypatch):
|
|||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
async def test_single_file_returns_one_connector_document(
|
||||
mock_drive_client, patch_extract,
|
||||
mock_drive_client,
|
||||
patch_extract,
|
||||
):
|
||||
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
|
||||
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
||||
|
|
@ -73,7 +76,8 @@ async def test_single_file_returns_one_connector_document(
|
|||
|
||||
|
||||
async def test_multiple_files_all_produce_documents(
|
||||
mock_drive_client, patch_extract,
|
||||
mock_drive_client,
|
||||
patch_extract,
|
||||
):
|
||||
"""All files are downloaded and converted to ConnectorDocuments."""
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
|
|
@ -96,7 +100,8 @@ async def test_multiple_files_all_produce_documents(
|
|||
|
||||
|
||||
async def test_one_download_exception_does_not_block_others(
|
||||
mock_drive_client, patch_extract,
|
||||
mock_drive_client,
|
||||
patch_extract,
|
||||
):
|
||||
"""A RuntimeError in one download still lets the other files succeed."""
|
||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||
|
|
@ -123,7 +128,8 @@ async def test_one_download_exception_does_not_block_others(
|
|||
|
||||
|
||||
async def test_etl_error_counts_as_download_failure(
|
||||
mock_drive_client, patch_extract,
|
||||
mock_drive_client,
|
||||
patch_extract,
|
||||
):
|
||||
"""download_and_extract_content returning an error is counted as failed."""
|
||||
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
||||
|
|
@ -148,7 +154,8 @@ async def test_etl_error_counts_as_download_failure(
|
|||
|
||||
|
||||
async def test_concurrency_bounded_by_semaphore(
|
||||
mock_drive_client, monkeypatch,
|
||||
mock_drive_client,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Peak concurrent downloads never exceeds max_concurrency."""
|
||||
lock = asyncio.Lock()
|
||||
|
|
@ -189,7 +196,8 @@ async def test_concurrency_bounded_by_semaphore(
|
|||
|
||||
|
||||
async def test_heartbeat_fires_during_parallel_downloads(
|
||||
mock_drive_client, monkeypatch,
|
||||
mock_drive_client,
|
||||
monkeypatch,
|
||||
):
|
||||
"""on_heartbeat is called at least once when downloads take time."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
|
@ -231,8 +239,13 @@ async def test_heartbeat_fires_during_parallel_downloads(
|
|||
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _folder_dict(file_id: str, name: str) -> dict:
|
||||
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
"mimeType": "application/vnd.google-apps.folder",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -259,12 +272,17 @@ def full_scan_mocks(mock_drive_client, monkeypatch):
|
|||
batch_mock = AsyncMock(return_value=([], 0, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
||||
_mod,
|
||||
"get_user_long_context_llm",
|
||||
AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -312,12 +330,16 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
|||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
_mod,
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
||||
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old' → 'renamed.txt'")
|
||||
full_scan_mocks["skip_results"]["rename1"] = (
|
||||
True,
|
||||
"File renamed: 'old' → 'renamed.txt'",
|
||||
)
|
||||
|
||||
mock_docs = [MagicMock(), MagicMock()]
|
||||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
|
|
@ -341,7 +363,8 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
|||
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
_mod,
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
|
|
@ -355,14 +378,16 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
|||
|
||||
|
||||
async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
||||
full_scan_mocks, monkeypatch,
|
||||
full_scan_mocks,
|
||||
monkeypatch,
|
||||
):
|
||||
"""index_batch_parallel is called with max_concurrency=3."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
page_files = [_make_file_dict("f1", "file1.txt")]
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_files_in_folder",
|
||||
_mod,
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
|
||||
|
|
@ -382,6 +407,7 @@ async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
|||
# Slice 7 -- _index_with_delta_sync three-phase pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||
"""Removed/trashed changes call _remove_document; the rest go through
|
||||
_download_files_parallel and index_batch_parallel."""
|
||||
|
|
@ -396,7 +422,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "fetch_all_changes",
|
||||
_mod,
|
||||
"fetch_all_changes",
|
||||
AsyncMock(return_value=(changes, "new-token", None)),
|
||||
)
|
||||
|
||||
|
|
@ -408,7 +435,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
"mod2": "modified",
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
_mod, "categorize_change",
|
||||
_mod,
|
||||
"categorize_change",
|
||||
lambda change: change_types[change["fileId"]],
|
||||
)
|
||||
|
||||
|
|
@ -420,7 +448,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file",
|
||||
_mod,
|
||||
"_should_skip_file",
|
||||
AsyncMock(return_value=(False, None)),
|
||||
)
|
||||
|
||||
|
|
@ -431,11 +460,16 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
batch_mock = AsyncMock(return_value=([], 2, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
||||
_mod,
|
||||
"get_user_long_context_llm",
|
||||
AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
|
@ -472,6 +506,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
# _index_selected_files -- parallel indexing of user-selected files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selected_files_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up mocks for _index_selected_files tests."""
|
||||
|
|
@ -496,6 +531,14 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
|
|||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
"drive_client": mock_drive_client,
|
||||
"session": mock_session,
|
||||
|
|
@ -526,7 +569,8 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
|
|||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks, [("f1", "report.pdf")],
|
||||
selected_files_mocks,
|
||||
[("f1", "report.pdf")],
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
|
|
@ -538,11 +582,13 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
|
|||
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||
"""get_file_by_id failing for one file collects an error; others still indexed."""
|
||||
selected_files_mocks["get_file_results"]["f1"] = (
|
||||
_make_file_dict("f1", "first.txt"), None,
|
||||
_make_file_dict("f1", "first.txt"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
|
||||
selected_files_mocks["get_file_results"]["f3"] = (
|
||||
_make_file_dict("f3", "third.txt"), None,
|
||||
_make_file_dict("f3", "third.txt"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
|
|
@ -561,30 +607,46 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
|||
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||
"""Unchanged files are skipped, renames counted as indexed,
|
||||
and only new files are sent to _download_and_index."""
|
||||
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"), ("n2", "new2.txt")]:
|
||||
for fid, fname in [
|
||||
("s1", "unchanged.txt"),
|
||||
("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"),
|
||||
("n2", "new2.txt"),
|
||||
]:
|
||||
selected_files_mocks["get_file_results"][fid] = (
|
||||
_make_file_dict(fid, fname), None,
|
||||
_make_file_dict(fid, fname),
|
||||
None,
|
||||
)
|
||||
|
||||
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
||||
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
|
||||
selected_files_mocks["skip_results"]["r1"] = (
|
||||
True,
|
||||
"File renamed: 'old' \u2192 'renamed.txt'",
|
||||
)
|
||||
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"), ("n2", "new2.txt")],
|
||||
[
|
||||
("s1", "unchanged.txt"),
|
||||
("r1", "renamed.txt"),
|
||||
("n1", "new1.txt"),
|
||||
("n2", "new2.txt"),
|
||||
],
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
assert indexed == 3 # 1 renamed + 2 batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
assert errors == []
|
||||
|
||||
mock = selected_files_mocks["download_and_index_mock"]
|
||||
mock.assert_called_once()
|
||||
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
|
||||
call_files = (
|
||||
mock.call_args[1].get("files")
|
||||
if "files" in (mock.call_args[1] or {})
|
||||
else mock.call_args[0][2]
|
||||
)
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
||||
|
||||
|
|
@ -593,6 +655,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
|||
# asyncio.to_thread verification — prove blocking calls run in parallel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_client_download_file_runs_in_thread_parallel():
|
||||
"""Calling download_file concurrently via asyncio.gather should overlap
|
||||
blocking work on separate threads, proving to_thread is effective.
|
||||
|
|
@ -602,11 +665,11 @@ async def test_client_download_file_runs_in_thread_parallel():
|
|||
"""
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
|
||||
BLOCK_SECONDS = 0.2
|
||||
NUM_CALLS = 3
|
||||
block_seconds = 0.2
|
||||
num_calls = 3
|
||||
|
||||
def _blocking_download(service, file_id, credentials):
|
||||
time.sleep(BLOCK_SECONDS)
|
||||
time.sleep(block_seconds)
|
||||
return b"fake-content", None
|
||||
|
||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||
|
|
@ -615,11 +678,13 @@ async def test_client_download_file_runs_in_thread_parallel():
|
|||
client._service_lock = asyncio.Lock()
|
||||
|
||||
with patch.object(
|
||||
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
|
||||
GoogleDriveClient,
|
||||
"_sync_download_file",
|
||||
staticmethod(_blocking_download),
|
||||
):
|
||||
start = time.monotonic()
|
||||
results = await asyncio.gather(
|
||||
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
|
||||
*(client.download_file(f"file-{i}") for i in range(num_calls))
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
|
|
@ -627,7 +692,7 @@ async def test_client_download_file_runs_in_thread_parallel():
|
|||
assert content == b"fake-content"
|
||||
assert error is None
|
||||
|
||||
serial_minimum = BLOCK_SECONDS * NUM_CALLS
|
||||
serial_minimum = block_seconds * num_calls
|
||||
assert elapsed < serial_minimum, (
|
||||
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
||||
f"downloads are not running in parallel"
|
||||
|
|
@ -638,11 +703,11 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
|||
"""Same strategy for export_google_file — verify to_thread parallelism."""
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
|
||||
BLOCK_SECONDS = 0.2
|
||||
NUM_CALLS = 3
|
||||
block_seconds = 0.2
|
||||
num_calls = 3
|
||||
|
||||
def _blocking_export(service, file_id, mime_type, credentials):
|
||||
time.sleep(BLOCK_SECONDS)
|
||||
time.sleep(block_seconds)
|
||||
return b"exported", None
|
||||
|
||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||
|
|
@ -651,12 +716,16 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
|||
client._service_lock = asyncio.Lock()
|
||||
|
||||
with patch.object(
|
||||
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
|
||||
GoogleDriveClient,
|
||||
"_sync_export_google_file",
|
||||
staticmethod(_blocking_export),
|
||||
):
|
||||
start = time.monotonic()
|
||||
results = await asyncio.gather(
|
||||
*(client.export_google_file(f"file-{i}", "application/pdf")
|
||||
for i in range(NUM_CALLS))
|
||||
*(
|
||||
client.export_google_file(f"file-{i}", "application/pdf")
|
||||
for i in range(num_calls)
|
||||
)
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
|
|
@ -664,7 +733,7 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
|||
assert content == b"exported"
|
||||
assert error is None
|
||||
|
||||
serial_minimum = BLOCK_SECONDS * NUM_CALLS
|
||||
serial_minimum = block_seconds * num_calls
|
||||
assert elapsed < serial_minimum, (
|
||||
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
||||
f"exports are not running in parallel"
|
||||
|
|
|
|||
|
|
@ -145,22 +145,32 @@ def jira_mocks(monkeypatch):
|
|||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
_mod,
|
||||
"get_connector_by_id",
|
||||
AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
jira_client = _mock_jira_client(issues=[_make_issue()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
|
||||
_mod,
|
||||
"JiraHistoryConnector",
|
||||
MagicMock(return_value=jira_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
_mod,
|
||||
"check_duplicate_document_by_hash",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
_mod,
|
||||
"update_connector_last_indexed",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
_mod,
|
||||
"calculate_date_range",
|
||||
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
|
|
@ -169,15 +179,20 @@ def jira_mocks(monkeypatch):
|
|||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
_mod,
|
||||
"TaskLoggingService",
|
||||
MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -128,13 +128,17 @@ def _mock_linear_client(issues=None, error=None):
|
|||
client.get_issues_by_date_range = AsyncMock(
|
||||
return_value=(issues if issues is not None else [], error),
|
||||
)
|
||||
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
|
||||
issue_id=i.get("id", ""),
|
||||
identifier=i.get("identifier", ""),
|
||||
title=i.get("title", ""),
|
||||
))
|
||||
client.format_issue = MagicMock(
|
||||
side_effect=lambda i: _make_formatted_issue(
|
||||
issue_id=i.get("id", ""),
|
||||
identifier=i.get("identifier", ""),
|
||||
title=i.get("title", ""),
|
||||
)
|
||||
)
|
||||
client.format_issue_to_markdown = MagicMock(
|
||||
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
|
||||
side_effect=lambda fi: (
|
||||
f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
|
||||
)
|
||||
)
|
||||
return client
|
||||
|
||||
|
|
@ -147,24 +151,34 @@ def linear_mocks(monkeypatch):
|
|||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
_mod,
|
||||
"get_connector_by_id",
|
||||
AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
linear_client = _mock_linear_client(issues=[_make_issue()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "LinearConnector", MagicMock(return_value=linear_client),
|
||||
_mod,
|
||||
"LinearConnector",
|
||||
MagicMock(return_value=linear_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
_mod,
|
||||
"check_duplicate_document_by_hash",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
_mod,
|
||||
"update_connector_last_indexed",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
_mod,
|
||||
"calculate_date_range",
|
||||
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
|
|
@ -173,15 +187,20 @@ def linear_mocks(monkeypatch):
|
|||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
_mod,
|
||||
"TaskLoggingService",
|
||||
MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -255,7 +274,7 @@ async def test_issues_with_missing_id_are_skipped(linear_mocks):
|
|||
]
|
||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
_indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
|
||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
|
|
@ -271,7 +290,7 @@ async def test_issues_with_missing_title_are_skipped(linear_mocks):
|
|||
]
|
||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||
|
||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
_indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
|
||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
|
|
@ -305,7 +324,7 @@ async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
|
|||
|
||||
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
||||
|
||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
_indexed, skipped, _ = await _run_index(linear_mocks)
|
||||
|
||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||
assert len(connector_docs) == 1
|
||||
|
|
|
|||
|
|
@ -107,28 +107,40 @@ def notion_mocks(monkeypatch):
|
|||
|
||||
mock_connector = _mock_connector()
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
||||
_mod,
|
||||
"get_connector_by_id",
|
||||
AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
notion_client = _mock_notion_client(pages=[_make_page()])
|
||||
monkeypatch.setattr(
|
||||
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
|
||||
_mod,
|
||||
"NotionHistoryConnector",
|
||||
MagicMock(return_value=notion_client),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
||||
_mod,
|
||||
"check_duplicate_document_by_hash",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
||||
_mod,
|
||||
"update_connector_last_indexed",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
_mod,
|
||||
"calculate_date_range",
|
||||
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
|
||||
_mod,
|
||||
"process_blocks",
|
||||
MagicMock(return_value="Converted markdown content"),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
|
|
@ -137,15 +149,20 @@ def notion_mocks(monkeypatch):
|
|||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
||||
_mod,
|
||||
"TaskLoggingService",
|
||||
MagicMock(return_value=mock_task_logger),
|
||||
)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
||||
_mod,
|
||||
"IndexingPipelineService",
|
||||
MagicMock(return_value=pipeline_mock),
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -216,7 +233,10 @@ async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
|
|||
"""Pages without page_id are skipped and not passed to the pipeline."""
|
||||
pages = [
|
||||
_make_page(page_id="valid-1"),
|
||||
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
|
||||
{
|
||||
"title": "No ID page",
|
||||
"content": [{"type": "paragraph", "content": "text", "children": []}],
|
||||
},
|
||||
]
|
||||
notion_mocks["notion_client"].get_all_pages.return_value = pages
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,131 @@
|
|||
"""Unit tests for IndexingPipelineService.create_placeholder_documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.db import DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||
IndexingPipelineService,
|
||||
PlaceholderInfo,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_placeholder(**overrides) -> PlaceholderInfo:
|
||||
defaults = {
|
||||
"title": "Test Doc",
|
||||
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
|
||||
"unique_id": "file-001",
|
||||
"search_space_id": 1,
|
||||
"connector_id": 42,
|
||||
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return PlaceholderInfo(**defaults)
|
||||
|
||||
|
||||
def _uid_hash(p: PlaceholderInfo) -> str:
|
||||
return compute_identifier_hash(
|
||||
p.document_type.value, p.unique_id, p.search_space_id
|
||||
)
|
||||
|
||||
|
||||
def _session_with_existing_hashes(existing: set[str] | None = None):
|
||||
"""Build an AsyncMock session whose batch-query returns *existing* hashes."""
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalars.return_value.all.return_value = list(existing or [])
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
session.add = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_empty_input_returns_zero_without_db_calls():
|
||||
session = AsyncMock()
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
result = await pipeline.create_placeholder_documents([])
|
||||
|
||||
assert result == 0
|
||||
session.execute.assert_not_awaited()
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_creates_documents_with_pending_status_and_commits():
|
||||
session = _session_with_existing_hashes(set())
|
||||
pipeline = IndexingPipelineService(session)
|
||||
p = _make_placeholder(title="My File", unique_id="file-abc")
|
||||
|
||||
result = await pipeline.create_placeholder_documents([p])
|
||||
|
||||
assert result == 1
|
||||
session.add.assert_called_once()
|
||||
|
||||
doc = session.add.call_args[0][0]
|
||||
assert doc.title == "My File"
|
||||
assert doc.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
||||
assert doc.content == "Pending..."
|
||||
assert DocumentStatus.is_state(doc.status, DocumentStatus.PENDING)
|
||||
assert doc.search_space_id == 1
|
||||
assert doc.connector_id == 42
|
||||
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_existing_documents_are_skipped():
|
||||
"""Placeholders whose unique_identifier_hash already exists are not re-created."""
|
||||
existing_p = _make_placeholder(unique_id="already-there")
|
||||
new_p = _make_placeholder(unique_id="brand-new")
|
||||
|
||||
existing_hash = _uid_hash(existing_p)
|
||||
session = _session_with_existing_hashes({existing_hash})
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
result = await pipeline.create_placeholder_documents([existing_p, new_p])
|
||||
|
||||
assert result == 1
|
||||
doc = session.add.call_args[0][0]
|
||||
assert doc.unique_identifier_hash == _uid_hash(new_p)
|
||||
|
||||
|
||||
async def test_duplicate_unique_ids_within_input_are_deduped():
|
||||
"""Same unique_id passed twice only produces one placeholder."""
|
||||
p1 = _make_placeholder(unique_id="dup-id", title="First")
|
||||
p2 = _make_placeholder(unique_id="dup-id", title="Second")
|
||||
|
||||
session = _session_with_existing_hashes(set())
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
result = await pipeline.create_placeholder_documents([p1, p2])
|
||||
|
||||
assert result == 1
|
||||
session.add.assert_called_once()
|
||||
|
||||
|
||||
async def test_integrity_error_on_commit_returns_zero():
|
||||
"""IntegrityError during commit (race condition) is swallowed gracefully."""
|
||||
session = _session_with_existing_hashes(set())
|
||||
session.commit = AsyncMock(side_effect=IntegrityError("dup", {}, None))
|
||||
pipeline = IndexingPipelineService(session)
|
||||
p = _make_placeholder()
|
||||
|
||||
result = await pipeline.create_placeholder_documents([p])
|
||||
|
||||
assert result == 0
|
||||
session.rollback.assert_awaited_once()
|
||||
|
|
@ -19,9 +19,7 @@ def pipeline(mock_session):
|
|||
return IndexingPipelineService(mock_session)
|
||||
|
||||
|
||||
async def test_calls_prepare_then_index_per_document(
|
||||
pipeline, make_connector_document
|
||||
):
|
||||
async def test_calls_prepare_then_index_per_document(pipeline, make_connector_document):
|
||||
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
|
||||
doc1 = make_connector_document(
|
||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -57,7 +57,9 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock_chunk,
|
||||
)
|
||||
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
|
||||
mock_embed = MagicMock(
|
||||
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
|
||||
)
|
||||
mock_embed.__name__ = "embed_texts"
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for the duplicate-content safety logic in prepare_for_indexing.
|
||||
|
||||
Verifies that when an existing document's updated content matches another
|
||||
document's content_hash, the system marks it as failed (for placeholders)
|
||||
or leaves it untouched (for ready documents) — never deletes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import (
|
||||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_connector_doc(**overrides) -> ConnectorDocument:
|
||||
defaults = {
|
||||
"title": "Test Doc",
|
||||
"source_markdown": "## Some new content",
|
||||
"unique_id": "file-001",
|
||||
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
|
||||
"search_space_id": 1,
|
||||
"connector_id": 42,
|
||||
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ConnectorDocument(**defaults)
|
||||
|
||||
|
||||
def _make_existing_doc(connector_doc: ConnectorDocument, *, status: dict) -> MagicMock:
|
||||
"""Build a MagicMock that looks like an ORM Document with given status."""
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 999
|
||||
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||
doc.content_hash = "old-placeholder-content-hash"
|
||||
doc.title = connector_doc.title
|
||||
doc.status = status
|
||||
return doc
|
||||
|
||||
|
||||
def _mock_session_for_dedup(existing_doc, *, has_duplicate: bool):
|
||||
"""Build a session whose sequential execute() calls return:
|
||||
|
||||
1. The *existing_doc* for the unique_identifier_hash lookup.
|
||||
2. A row (or None) for the duplicate content_hash check.
|
||||
"""
|
||||
session = AsyncMock()
|
||||
|
||||
existing_result = MagicMock()
|
||||
existing_result.scalars.return_value.first.return_value = existing_doc
|
||||
|
||||
dup_result = MagicMock()
|
||||
dup_result.scalars.return_value.first.return_value = 42 if has_duplicate else None
|
||||
|
||||
session.execute = AsyncMock(side_effect=[existing_result, dup_result])
|
||||
session.add = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_pending_placeholder_with_duplicate_content_is_marked_failed():
|
||||
"""A placeholder (pending) whose updated content duplicates another doc
|
||||
must be marked as FAILED — never deleted."""
|
||||
cdoc = _make_connector_doc(source_markdown="## Shared content")
|
||||
existing = _make_existing_doc(cdoc, status=DocumentStatus.pending())
|
||||
|
||||
session = _mock_session_for_dedup(existing, has_duplicate=True)
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
results = await pipeline.prepare_for_indexing([cdoc])
|
||||
|
||||
assert results == [], "duplicate should not be returned for indexing"
|
||||
|
||||
assert DocumentStatus.is_state(existing.status, DocumentStatus.FAILED)
|
||||
assert "Duplicate content" in existing.status.get("reason", "")
|
||||
session.delete.assert_not_called()
|
||||
|
||||
|
||||
async def test_ready_document_with_duplicate_content_is_left_untouched():
|
||||
"""A READY document whose updated content duplicates another doc
|
||||
must be left completely untouched — not failed, not deleted."""
|
||||
cdoc = _make_connector_doc(source_markdown="## Shared content")
|
||||
existing = _make_existing_doc(cdoc, status=DocumentStatus.ready())
|
||||
|
||||
session = _mock_session_for_dedup(existing, has_duplicate=True)
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
results = await pipeline.prepare_for_indexing([cdoc])
|
||||
|
||||
assert results == [], "duplicate should not be returned for indexing"
|
||||
|
||||
assert DocumentStatus.is_state(existing.status, DocumentStatus.READY)
|
||||
session.delete.assert_not_called()
|
||||
133
surfsense_backend/tests/unit/middleware/test_knowledge_search.py
Normal file
133
surfsense_backend/tests/unit/middleware/test_knowledge_search.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""Unit tests for knowledge_search middleware helpers.
|
||||
|
||||
These test pure functions that don't require a database.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
_build_document_xml,
|
||||
_resolve_search_types,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── _resolve_search_types ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveSearchTypes:
|
||||
def test_returns_none_when_no_inputs(self):
|
||||
assert _resolve_search_types(None, None) is None
|
||||
|
||||
def test_returns_none_when_both_empty(self):
|
||||
assert _resolve_search_types([], []) is None
|
||||
|
||||
def test_includes_legacy_type_for_google_gmail(self):
|
||||
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
|
||||
assert "GOOGLE_GMAIL_CONNECTOR" in result
|
||||
assert "COMPOSIO_GMAIL_CONNECTOR" in result
|
||||
|
||||
def test_includes_legacy_type_for_google_drive(self):
|
||||
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
|
||||
assert "GOOGLE_DRIVE_FILE" in result
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
|
||||
|
||||
def test_includes_legacy_type_for_google_calendar(self):
|
||||
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
|
||||
assert "GOOGLE_CALENDAR_CONNECTOR" in result
|
||||
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
|
||||
|
||||
def test_no_legacy_expansion_for_unrelated_types(self):
|
||||
result = _resolve_search_types(["FILE", "NOTE"], None)
|
||||
assert set(result) == {"FILE", "NOTE"}
|
||||
|
||||
def test_combines_connectors_and_document_types(self):
|
||||
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
|
||||
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
|
||||
|
||||
def test_deduplicates(self):
|
||||
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
|
||||
assert result.count("FILE") == 1
|
||||
|
||||
|
||||
# ── _build_document_xml ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBuildDocumentXml:
|
||||
@pytest.fixture
|
||||
def sample_document(self):
|
||||
return {
|
||||
"document_id": 42,
|
||||
"document": {
|
||||
"id": 42,
|
||||
"document_type": "FILE",
|
||||
"title": "Test Doc",
|
||||
"metadata": {"url": "https://example.com"},
|
||||
},
|
||||
"chunks": [
|
||||
{"chunk_id": 101, "content": "First chunk content"},
|
||||
{"chunk_id": 102, "content": "Second chunk content"},
|
||||
{"chunk_id": 103, "content": "Third chunk content"},
|
||||
],
|
||||
}
|
||||
|
||||
def test_contains_document_metadata(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<document_id>42</document_id>" in xml
|
||||
assert "<document_type>FILE</document_type>" in xml
|
||||
assert "Test Doc" in xml
|
||||
|
||||
def test_contains_chunk_index(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<chunk_index>" in xml
|
||||
assert "</chunk_index>" in xml
|
||||
assert 'chunk_id="101"' in xml
|
||||
assert 'chunk_id="102"' in xml
|
||||
assert 'chunk_id="103"' in xml
|
||||
|
||||
def test_matched_chunks_flagged_in_index(self, sample_document):
|
||||
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
|
||||
lines = xml.split("\n")
|
||||
for line in lines:
|
||||
if 'chunk_id="101"' in line:
|
||||
assert 'matched="true"' in line
|
||||
if 'chunk_id="102"' in line:
|
||||
assert 'matched="true"' not in line
|
||||
if 'chunk_id="103"' in line:
|
||||
assert 'matched="true"' in line
|
||||
|
||||
def test_chunk_content_in_document_content_section(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<document_content>" in xml
|
||||
assert "First chunk content" in xml
|
||||
assert "Second chunk content" in xml
|
||||
assert "Third chunk content" in xml
|
||||
|
||||
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
|
||||
"""Verify that the line ranges in chunk_index actually point to the right content."""
|
||||
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
|
||||
xml_lines = xml.split("\n")
|
||||
|
||||
for line in xml_lines:
|
||||
if 'chunk_id="101"' in line and "lines=" in line:
|
||||
import re
|
||||
|
||||
m = re.search(r'lines="(\d+)-(\d+)"', line)
|
||||
assert m, f"No lines= attribute found in: {line}"
|
||||
start, _end = int(m.group(1)), int(m.group(2))
|
||||
target_line = xml_lines[start - 1]
|
||||
assert "101" in target_line
|
||||
assert "First chunk content" in target_line
|
||||
break
|
||||
else:
|
||||
pytest.fail("chunk_id=101 entry not found in chunk_index")
|
||||
|
||||
def test_splits_into_lines_correctly(self, sample_document):
|
||||
"""Each chunk occupies exactly one line (no embedded newlines)."""
|
||||
xml = _build_document_xml(sample_document)
|
||||
lines = xml.split("\n")
|
||||
chunk_lines = [
|
||||
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
||||
]
|
||||
assert len(chunk_lines) == 3
|
||||
Loading…
Add table
Add a link
Reference in a new issue