feat: implement and test prepare_for_indexing

This commit is contained in:
CREDO23 2026-02-25 00:06:34 +02:00
parent a0134a5830
commit 579a9e2cb5
4 changed files with 243 additions and 3 deletions

View file

@ -1,5 +1,6 @@
import os
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -8,7 +9,8 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool
from app.db import Base
from app.db import Base, SearchSpace
from app.db import User
_EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+
@ -18,7 +20,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
@pytest_asyncio.fixture(scope="session")
async def async_engine():
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool, echo=False)
engine = create_async_engine(
TEST_DATABASE_URL,
poolclass=NullPool,
echo=False,
# Required for asyncpg + savepoints: disables prepared statement cache
# to prevent "another operation is in progress" errors during savepoint rollbacks.
connect_args={"prepared_statement_cache_size": 0},
)
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
@ -26,8 +35,11 @@ async def async_engine():
yield engine
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
# DROP SCHEMA CASCADE handles this without needing topological sort.
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.execute(text("DROP SCHEMA public CASCADE"))
await conn.execute(text("CREATE SCHEMA public"))
await engine.dispose()
@ -50,6 +62,32 @@ async def db_session(async_engine) -> AsyncSession:
await transaction.rollback()
@pytest_asyncio.fixture
async def db_user(db_session: AsyncSession) -> User:
user = User(
id=uuid.uuid4(),
email="test@surfsense.net",
hashed_password="hashed",
is_active=True,
is_superuser=False,
is_verified=True,
)
db_session.add(user)
await db_session.flush()
return user
@pytest_asyncio.fixture
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
space = SearchSpace(
name="Test Space",
user_id=db_user.id,
)
db_session.add(space)
await db_session.flush()
return space
@pytest.fixture
def mock_llm() -> AsyncMock:
llm = AsyncMock()