mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-28 02:23:53 +02:00
feat: implement and test prepare_for_indexing
This commit is contained in:
parent
a0134a5830
commit
579a9e2cb5
4 changed files with 243 additions and 3 deletions
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,7 +9,8 @@ from sqlalchemy import text
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.db import Base
|
||||
from app.db import Base, SearchSpace
|
||||
from app.db import User
|
||||
|
||||
_EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+
|
||||
|
||||
|
|
@ -18,7 +20,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool, echo=False)
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
# Required for asyncpg + savepoints: disables prepared statement cache
|
||||
# to prevent "another operation is in progress" errors during savepoint rollbacks.
|
||||
connect_args={"prepared_statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
|
|
@ -26,8 +35,11 @@ async def async_engine():
|
|||
|
||||
yield engine
|
||||
|
||||
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
|
||||
# DROP SCHEMA CASCADE handles this without needing topological sort.
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.execute(text("DROP SCHEMA public CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
|
@ -50,6 +62,32 @@ async def db_session(async_engine) -> AsyncSession:
|
|||
await transaction.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Test Space",
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> AsyncMock:
|
||||
llm = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_new_document_is_persisted_with_pending_status(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert len(results) == 1
|
||||
document_id = results[0].id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded is not None
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_unchanged_document_is_skipped(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
await service.prepare_for_indexing([doc])
|
||||
results = await service.prepare_for_indexing([doc])
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
async def test_title_only_change_updates_title_in_db(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
document_id = first[0].id
|
||||
|
||||
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title")
|
||||
results = await service.prepare_for_indexing([renamed])
|
||||
|
||||
assert results == []
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
|
||||
|
||||
async def test_changed_content_is_returned_for_reprocessing(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.PENDING)
|
||||
|
||||
|
||||
async def test_all_documents_in_batch_are_persisted(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
docs = [
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"),
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"),
|
||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"),
|
||||
]
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing(docs)
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
async def test_duplicate_in_batch_is_persisted_once(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
results = await service.prepare_for_indexing([doc, doc])
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
async def test_title_and_content_change_updates_both_and_returns_document(
|
||||
db_session, db_search_space, make_connector_document
|
||||
):
|
||||
original = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Original Title",
|
||||
source_markdown="## v1",
|
||||
)
|
||||
service = IndexingPipelineService(session=db_session)
|
||||
|
||||
first = await service.prepare_for_indexing([original])
|
||||
original_id = first[0].id
|
||||
|
||||
updated = make_connector_document(
|
||||
search_space_id=db_search_space.id,
|
||||
title="Updated Title",
|
||||
source_markdown="## v2",
|
||||
)
|
||||
results = await service.prepare_for_indexing([updated])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == original_id
|
||||
|
||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
||||
reloaded = result.scalars().first()
|
||||
|
||||
assert reloaded.title == "Updated Title"
|
||||
assert reloaded.source_markdown == "## v2"
|
||||
Loading…
Add table
Add a link
Reference in a new issue