mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 01:06:23 +02:00
feat: implement and test prepare_for_indexing
This commit is contained in:
parent
a0134a5830
commit
579a9e2cb5
4 changed files with 243 additions and 3 deletions
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,7 +9,8 @@ from sqlalchemy import text
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.db import Base
|
||||
from app.db import Base, SearchSpace
|
||||
from app.db import User
|
||||
|
||||
_EMBEDDING_DIM = 4 # keep vectors tiny; real model uses 768+
|
||||
|
||||
|
|
@ -18,7 +20,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool, echo=False)
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
# Required for asyncpg + savepoints: disables prepared statement cache
|
||||
# to prevent "another operation is in progress" errors during savepoint rollbacks.
|
||||
connect_args={"prepared_statement_cache_size": 0},
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
|
|
@ -26,8 +35,11 @@ async def async_engine():
|
|||
|
||||
yield engine
|
||||
|
||||
# drop_all fails on circular FKs (new_chat_threads ↔ public_chat_snapshots).
|
||||
# DROP SCHEMA CASCADE handles this without needing topological sort.
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.execute(text("DROP SCHEMA public CASCADE"))
|
||||
await conn.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
|
@ -50,6 +62,32 @@ async def db_session(async_engine) -> AsyncSession:
|
|||
await transaction.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpace:
|
||||
space = SearchSpace(
|
||||
name="Test Space",
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> AsyncMock:
|
||||
llm = AsyncMock()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue