mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
refactor: update database connection handling in test configurations
This commit is contained in:
parent
87711ee381
commit
223c2de0d2
2 changed files with 13 additions and 18 deletions
|
|
@ -3,23 +3,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
|
os.environ.setdefault(
|
||||||
|
"DATABASE_URL",
|
||||||
|
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test",
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from app.db import DocumentType
|
from app.db import DocumentType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
|
|
||||||
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
|
|
||||||
|
|
||||||
# Shared DB URL referenced by both e2e and integration helper functions.
|
|
||||||
DATABASE_URL = os.environ.get(
|
|
||||||
"TEST_DATABASE_URL",
|
|
||||||
os.environ.get("DATABASE_URL", ""),
|
|
||||||
).replace("postgresql+asyncpg://", "postgresql://")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Unit test fixtures
|
# Unit test fixtures
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,9 @@ from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
from app.app import app
|
from app.app import app
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
from app.db import DATABASE_URL as APP_DB_URL, Base
|
from app.db import Base
|
||||||
from app.services.task_dispatcher import get_task_dispatcher
|
from app.services.task_dispatcher import get_task_dispatcher
|
||||||
from tests.conftest import DATABASE_URL
|
from tests.integration.conftest import TEST_DATABASE_URL
|
||||||
from tests.utils.helpers import (
|
from tests.utils.helpers import (
|
||||||
TEST_EMAIL,
|
TEST_EMAIL,
|
||||||
auth_headers,
|
auth_headers,
|
||||||
|
|
@ -36,6 +36,7 @@ from tests.utils.helpers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||||
|
_ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
@ -85,7 +86,7 @@ app.dependency_overrides[get_task_dispatcher] = lambda: InlineTaskDispatcher()
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
async def _ensure_tables():
|
async def _ensure_tables():
|
||||||
"""Create DB tables and extensions once per session."""
|
"""Create DB tables and extensions once per session."""
|
||||||
engine = create_async_engine(APP_DB_URL, poolclass=NullPool)
|
engine = create_async_engine(TEST_DATABASE_URL, poolclass=NullPool)
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
|
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
|
||||||
|
|
@ -144,7 +145,7 @@ def cleanup_doc_ids() -> list[int]:
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
async def _purge_test_search_space(search_space_id: int):
|
async def _purge_test_search_space(search_space_id: int):
|
||||||
"""Delete stale documents from previous runs before the session starts."""
|
"""Delete stale documents from previous runs before the session starts."""
|
||||||
conn = await asyncpg.connect(DATABASE_URL)
|
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||||
try:
|
try:
|
||||||
result = await conn.execute(
|
result = await conn.execute(
|
||||||
"DELETE FROM documents WHERE search_space_id = $1",
|
"DELETE FROM documents WHERE search_space_id = $1",
|
||||||
|
|
@ -180,7 +181,7 @@ async def _cleanup_documents(
|
||||||
remaining_ids.append(doc_id)
|
remaining_ids.append(doc_id)
|
||||||
|
|
||||||
if remaining_ids:
|
if remaining_ids:
|
||||||
conn = await asyncpg.connect(DATABASE_URL)
|
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||||
try:
|
try:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"DELETE FROM documents WHERE id = ANY($1::int[])",
|
"DELETE FROM documents WHERE id = ANY($1::int[])",
|
||||||
|
|
@ -196,7 +197,7 @@ async def _cleanup_documents(
|
||||||
|
|
||||||
|
|
||||||
async def _get_user_page_usage(email: str) -> tuple[int, int]:
|
async def _get_user_page_usage(email: str) -> tuple[int, int]:
|
||||||
conn = await asyncpg.connect(DATABASE_URL)
|
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||||
try:
|
try:
|
||||||
row = await conn.fetchrow(
|
row = await conn.fetchrow(
|
||||||
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
|
'SELECT pages_used, pages_limit FROM "user" WHERE email = $1',
|
||||||
|
|
@ -211,7 +212,7 @@ async def _get_user_page_usage(email: str) -> tuple[int, int]:
|
||||||
async def _set_user_page_limits(
|
async def _set_user_page_limits(
|
||||||
email: str, *, pages_used: int, pages_limit: int
|
email: str, *, pages_used: int, pages_limit: int
|
||||||
) -> None:
|
) -> None:
|
||||||
conn = await asyncpg.connect(DATABASE_URL)
|
conn = await asyncpg.connect(_ASYNCPG_URL)
|
||||||
try:
|
try:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
|
'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3',
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue