diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py b/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py index 696def4b2..1180c25c9 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_index_document.py @@ -9,7 +9,7 @@ pytestmark = pytest.mark.integration @pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_sets_status_ready( - db_session, db_search_space, make_connector_document, + db_session, db_search_space, make_connector_document, mocker, ): connector_doc = make_connector_document(search_space_id=db_search_space.id) service = IndexingPipelineService(session=db_session) @@ -18,7 +18,7 @@ async def test_sets_status_ready( document = prepared[0] document_id = document.id - await service.index(document, connector_doc, llm=None) + await service.index(document, connector_doc, llm=mocker.Mock()) result = await db_session.execute(select(Document).filter(Document.id == document_id)) reloaded = result.scalars().first() @@ -45,7 +45,7 @@ async def test_content_is_summary_when_should_summarize_true( assert reloaded.content == "Mocked summary." -@pytest.mark.usefixtures("patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_content_is_source_markdown_when_should_summarize_false( db_session, db_search_space, make_connector_document, ): @@ -70,7 +70,7 @@ async def test_content_is_source_markdown_when_should_summarize_false( @pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_chunks_written_to_db( - db_session, db_search_space, make_connector_document, + db_session, db_search_space, make_connector_document, mocker, ): connector_doc = make_connector_document(search_space_id=db_search_space.id) service = IndexingPipelineService(session=db_session) @@ -79,7 +79,7 @@ async def test_chunks_written_to_db( document = prepared[0] document_id = document.id - await service.index(document, connector_doc, llm=None) + await service.index(document, connector_doc, llm=mocker.Mock()) result = await db_session.execute( select(Chunk).filter(Chunk.document_id == document_id) @@ -92,7 +92,7 @@ async def test_chunks_written_to_db( @pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_embedding_written_to_db( - db_session, db_search_space, make_connector_document, + db_session, db_search_space, make_connector_document, mocker, ): connector_doc = make_connector_document(search_space_id=db_search_space.id) service = IndexingPipelineService(session=db_session) @@ -101,7 +101,7 @@ async def test_embedding_written_to_db( document = prepared[0] document_id = document.id - await service.index(document, connector_doc, llm=None) + await service.index(document, connector_doc, llm=mocker.Mock()) result = await db_session.execute(select(Document).filter(Document.id == document_id)) reloaded = result.scalars().first() @@ -112,7 +112,7 @@ async def test_embedding_written_to_db( @pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_updated_at_advances_after_indexing( - db_session, db_search_space, make_connector_document, + db_session, db_search_space, make_connector_document, mocker, ): connector_doc = make_connector_document(search_space_id=db_search_space.id) service = IndexingPipelineService(session=db_session) @@ -124,7 +124,7 @@ async def test_updated_at_advances_after_indexing( result = await db_session.execute(select(Document).filter(Document.id == document_id)) updated_at_pending = result.scalars().first().updated_at - await service.index(document, connector_doc, llm=None) + await service.index(document, connector_doc, llm=mocker.Mock()) result = await db_session.execute(select(Document).filter(Document.id == document_id)) updated_at_ready = result.scalars().first().updated_at @@ -132,7 +132,7 @@ async def test_updated_at_advances_after_indexing( assert updated_at_ready > updated_at_pending -@pytest.mark.usefixtures("patched_embed_text", "patched_chunk_text") +@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_no_llm_falls_back_to_source_markdown( db_session, db_search_space, make_connector_document, ): @@ -158,7 +158,7 @@ async def test_no_llm_falls_back_to_source_markdown( @pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text") async def test_reindex_replaces_old_chunks( - db_session, db_search_space, make_connector_document, + db_session, db_search_space, make_connector_document, mocker, ): connector_doc = make_connector_document( search_space_id=db_search_space.id, @@ -170,14 +170,14 @@ async def test_reindex_replaces_old_chunks( document = prepared[0] document_id = document.id - await service.index(document, connector_doc, llm=None) + await service.index(document, connector_doc, llm=mocker.Mock()) updated_doc = make_connector_document( search_space_id=db_search_space.id, source_markdown="## v2", ) re_prepared = await service.prepare_for_indexing([updated_doc]) - await service.index(re_prepared[0], updated_doc, llm=None) + await service.index(re_prepared[0], updated_doc, llm=mocker.Mock()) result = await db_session.execute( select(Chunk).filter(Chunk.document_id == document_id) @@ -187,7 +187,7 @@ async def test_reindex_replaces_old_chunks( assert len(chunks) == 1 -@pytest.mark.usefixtures("patched_summarize_raises", "patched_chunk_text") +@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") async def test_llm_error_sets_status_failed( db_session, db_search_space, make_connector_document, mocker, ): @@ -206,7 +206,7 @@ async def test_llm_error_sets_status_failed( assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED) -@pytest.mark.usefixtures("patched_summarize_raises", "patched_chunk_text") +@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text") async def test_llm_error_leaves_no_partial_data( db_session, db_search_space, make_connector_document, mocker, ): diff --git a/surfsense_backend/tests/unit/indexing_pipeline/conftest.py b/surfsense_backend/tests/unit/indexing_pipeline/conftest.py index 886318bc9..2147cfa3f 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/conftest.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/conftest.py @@ -1,15 +1,33 @@ import pytest +from unittest.mock import AsyncMock, MagicMock @pytest.fixture -def patched_chunker_instance(mocker): - mock = mocker.patch("app.indexing_pipeline.document_chunker.config.chunker_instance") - mock.chunk.return_value = [mocker.Mock(text="prose chunk")] +def patched_summarizer_chain(monkeypatch): + chain = MagicMock() + chain.ainvoke = AsyncMock(return_value=MagicMock(content="The summary.")) + + template = MagicMock() + template.__or__ = MagicMock(return_value=chain) + + monkeypatch.setattr( + "app.indexing_pipeline.document_summarizer.SUMMARY_PROMPT_TEMPLATE", + template, + ) + return chain + + +@pytest.fixture +def patched_chunker_instance(monkeypatch): + mock = MagicMock() + mock.chunk.return_value = [MagicMock(text="prose chunk")] + monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.chunker_instance", mock) return mock @pytest.fixture -def patched_code_chunker_instance(mocker): - mock = mocker.patch("app.indexing_pipeline.document_chunker.config.code_chunker_instance") - mock.chunk.return_value = [mocker.Mock(text="code chunk")] +def patched_code_chunker_instance(monkeypatch): + mock = MagicMock() + mock.chunk.return_value = [MagicMock(text="code chunk")] + monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock) return mock diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_document_chunker.py b/surfsense_backend/tests/unit/indexing_pipeline/test_document_chunker.py index 258227cbe..78d0641c1 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_document_chunker.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_document_chunker.py @@ -5,15 +5,15 @@ from app.indexing_pipeline.document_chunker import chunk_text pytestmark = pytest.mark.unit -def test_uses_code_chunker_when_flag_is_true(patched_code_chunker_instance): +@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance") +def test_uses_code_chunker_when_flag_is_true(): result = chunk_text("def foo(): pass", use_code_chunker=True) - patched_code_chunker_instance.chunk.assert_called_once_with("def foo(): pass") assert result == ["code chunk"] -def test_uses_default_chunker_when_flag_is_false(patched_chunker_instance): +@pytest.mark.usefixtures("patched_chunker_instance", "patched_code_chunker_instance") +def test_uses_default_chunker_when_flag_is_false(): result = chunk_text("Some prose text.", use_code_chunker=False) - patched_chunker_instance.chunk.assert_called_once_with("Some prose text.") assert result == ["prose chunk"] diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py b/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py new file mode 100644 index 000000000..2f713d13d --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_document_summarizer.py @@ -0,0 +1,39 @@ +import pytest +from unittest.mock import MagicMock + +from app.indexing_pipeline.document_summarizer import summarize_document + +pytestmark = pytest.mark.unit + + +@pytest.mark.usefixtures("patched_summarizer_chain") +async def test_without_metadata_returns_raw_summary(): + result = await summarize_document("# Content", llm=MagicMock(model="gpt-4")) + + assert result == "The summary." + + +@pytest.mark.usefixtures("patched_summarizer_chain") +async def test_with_metadata_includes_metadata_values_in_output(): + result = await summarize_document( + "# Content", + llm=MagicMock(model="gpt-4"), + metadata={"author": "Alice", "source": "Notion"}, + ) + + assert "Alice" in result + assert "Notion" in result + + +@pytest.mark.usefixtures("patched_summarizer_chain") +async def test_with_metadata_omits_empty_fields_from_output(): + result = await summarize_document( + "# Content", + llm=MagicMock(model="gpt-4"), + metadata={"author": "Alice", "description": ""}, + ) + + assert "Alice" in result + assert "description" not in result.lower() + +