mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 15:52:40 +02:00
feat: implement and test index method
This commit is contained in:
parent
497ed681d5
commit
61e50834e6
8 changed files with 218 additions and 31 deletions
|
|
@ -4,6 +4,7 @@ from app.db import DocumentType
|
|||
|
||||
|
||||
class ConnectorDocument(BaseModel):
|
||||
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
||||
title: str
|
||||
source_markdown: str
|
||||
unique_id: str
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def chunk_text(text: str) -> list[str]:
|
||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||
return [c.text for c in config.chunker_instance.chunk(text)]
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from app.config import config
|
||||
|
||||
|
||||
def embed_text(text: str) -> list[float]:
|
||||
"""Embed a single text string using the configured embedding model."""
|
||||
return config.embedding_model_instance.embed(text)
|
||||
|
|
@ -4,10 +4,12 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
|
||||
|
||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||
"""Return a SHA-256 hash of the document's content scoped to its search space."""
|
||||
combined = f"{doc.search_space_id}:{doc.source_markdown}"
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import optimize_content_for_context_window
|
||||
|
||||
|
||||
async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str:
|
||||
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
||||
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
||||
optimized_content = optimize_content_for_context_window(
|
||||
source_markdown, metadata, model_name
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
content_with_metadata = (
|
||||
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
|
||||
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
||||
)
|
||||
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
||||
summary_content = summary_result.content
|
||||
|
||||
if metadata:
|
||||
metadata_parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in metadata.items():
|
||||
if value:
|
||||
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
|
||||
metadata_section = "\n".join(metadata_parts)
|
||||
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
|
||||
|
||||
return summary_content
|
||||
|
|
@ -4,14 +4,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_chunker import chunk_text
|
||||
from app.indexing_pipeline.document_embedder import embed_text
|
||||
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
|
||||
from app.utils.document_converters import create_document_chunks, generate_document_summary
|
||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||
|
||||
|
||||
def _safe_set_chunks(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
|
|
@ -22,12 +24,17 @@ def _safe_set_chunks(document: Document, chunks: list) -> None:
|
|||
|
||||
|
||||
class IndexingPipelineService:
|
||||
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def prepare_for_indexing(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Persist new documents and detect changes, returning only those that need indexing.
|
||||
"""
|
||||
documents = []
|
||||
|
||||
for connector_doc in connector_docs:
|
||||
|
|
@ -73,19 +80,26 @@ class IndexingPipelineService:
|
|||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument, llm
|
||||
) -> None:
|
||||
"""
|
||||
Run summarization, embedding, and chunking for a document and persist the results.
|
||||
"""
|
||||
try:
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
if connector_doc.should_summarize:
|
||||
content, embedding = await generate_document_summary(
|
||||
content = await summarize_document(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
embedding = config.embedding_model_instance.embed(content)
|
||||
|
||||
chunks = await create_document_chunks(connector_doc.source_markdown)
|
||||
embedding = embed_text(content)
|
||||
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=embed_text(text))
|
||||
for text in chunk_text(connector_doc.source_markdown)
|
||||
]
|
||||
|
||||
document.source_markdown = connector_doc.source_markdown
|
||||
document.content = content
|
||||
|
|
|
|||
|
|
@ -89,40 +89,42 @@ async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpac
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> AsyncMock:
|
||||
llm = AsyncMock()
|
||||
llm.ainvoke = AsyncMock(return_value=MagicMock(content="Mocked summary."))
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_generate_summary(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(return_value=("Mocked summary.", [0.1] * _EMBEDDING_DIM))
|
||||
def patched_summarize(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(return_value="Mocked summary.")
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.generate_document_summary",
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_create_chunks(monkeypatch) -> MagicMock:
|
||||
from app.db import Chunk
|
||||
|
||||
chunk = Chunk(content="Test chunk content.", embedding=[0.1] * _EMBEDDING_DIM)
|
||||
mock = AsyncMock(return_value=[chunk])
|
||||
def patched_summarize_raises(monkeypatch) -> AsyncMock:
|
||||
mock = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.create_document_chunks",
|
||||
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_embedding_model(monkeypatch) -> MagicMock:
|
||||
from app.config import config
|
||||
def patched_embed_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_chunk_text(monkeypatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=["Test chunk content."])
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
model = MagicMock()
|
||||
model.embed = MagicMock(return_value=[0.1] * _EMBEDDING_DIM)
|
||||
monkeypatch.setattr(config, "embedding_model_instance", model)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_sets_status_ready(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
mock_llm, patched_generate_summary, patched_create_chunks,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
|
@ -18,9 +18,137 @@ async def test_sets_status_ready(
|
|||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, mock_llm)
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_content_is_summary_when_should_summarize_true(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "Mocked summary."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_embed_text", "patched_chunk_text")
|
||||
async def test_content_is_source_markdown_when_should_summarize_false(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
connector_doc = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
should_summarize=False,
|
||||
source_markdown="## Raw content",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.content == "## Raw content"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_chunks_written_to_db(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Test chunk content."
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
||||
async def test_embedding_written_to_db(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is not None
|
||||
assert len(reloaded.embedding) == 1024
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_chunk_text")
|
||||
async def test_llm_error_sets_status_failed(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_chunk_text")
|
||||
async def test_llm_error_leaves_no_partial_data(
|
||||
db_session, db_search_space, make_connector_document,
|
||||
):
|
||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
prepared = await service.prepare_for_indexing([connector_doc])
|
||||
document = prepared[0]
|
||||
document_id = document.id
|
||||
|
||||
await service.index(document, connector_doc, llm=None)
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.embedding is None
|
||||
assert reloaded.content == "Pending..."
|
||||
|
||||
chunks_result = await db_session.execute(
|
||||
select(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
assert chunks_result.scalars().all() == []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue