2025-09-09 14:37:32 +05:30
|
|
|
"""
|
2026-01-02 13:11:02 +05:30
|
|
|
Pytest configuration and fixtures for async database testing.
|
|
|
|
|
|
|
|
|
|
This module sets up the test infrastructure using:
|
|
|
|
|
- A separate test database (appends _test to the database name)
|
|
|
|
|
- Alembic migrations run once per test session
|
|
|
|
|
- Transaction-based isolation for each test (savepoint pattern)
|
|
|
|
|
|
|
|
|
|
References:
|
|
|
|
|
- https://www.core27.co/post/transactional-unit-tests-with-pytest-and-async-sqlalchemy
|
|
|
|
|
- https://docs.sqlalchemy.org/en/20/orm/session_transaction.html
|
2025-09-09 14:37:32 +05:30
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
2026-02-25 13:53:30 +05:30
|
|
|
|
|
|
|
|
# Load environment variables before importing anything else
|
2026-01-02 13:11:02 +05:30
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
|
from urllib.parse import urlparse, urlunparse
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
from dotenv import load_dotenv
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-02-25 13:53:30 +05:30
|
|
|
# Load .env.test before importing api.constants (which reads DATABASE_URL at import time)
|
|
|
|
|
env_path = Path(__file__).resolve().parent / ".env.test"
|
2026-01-02 13:11:02 +05:30
|
|
|
load_dotenv(env_path)
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-03 16:22:38 +05:30
|
|
|
import logging
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
import loguru
|
2026-01-02 13:11:02 +05:30
|
|
|
import pytest
|
2026-01-03 16:22:38 +05:30
|
|
|
|
2026-02-25 13:53:30 +05:30
|
|
|
from api.constants import APP_ROOT_DIR # noqa: E402
|
|
|
|
|
|
2026-01-03 16:22:38 +05:30
|
|
|
|
|
|
|
|
def setup_test_logging():
|
|
|
|
|
"""Configure logging for tests using LOG_LEVEL from .env.test"""
|
|
|
|
|
log_level = os.getenv("LOG_LEVEL", "DEBUG").upper()
|
|
|
|
|
|
|
|
|
|
# Remove default loguru handler
|
|
|
|
|
try:
|
|
|
|
|
loguru.logger.remove(0)
|
|
|
|
|
except ValueError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Add console handler with the configured log level
|
|
|
|
|
loguru.logger.add(
|
|
|
|
|
sys.stdout,
|
|
|
|
|
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | <level>{level}</level> | {file.name}:{line} | {message}",
|
|
|
|
|
level=log_level,
|
|
|
|
|
colorize=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Intercept standard library logging and redirect to loguru
|
|
|
|
|
class InterceptHandler(logging.Handler):
|
|
|
|
|
def emit(self, record):
|
|
|
|
|
try:
|
|
|
|
|
level = loguru.logger.level(record.levelname).name
|
|
|
|
|
except ValueError:
|
|
|
|
|
level = record.levelno
|
|
|
|
|
loguru.logger.opt(exception=record.exc_info).log(level, record.getMessage())
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(handlers=[InterceptHandler()], level=logging.DEBUG, force=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize test logging
|
|
|
|
|
setup_test_logging()
|
2026-01-02 13:11:02 +05:30
|
|
|
from sqlalchemy import event, text
|
|
|
|
|
from sqlalchemy.ext.asyncio import (
|
|
|
|
|
AsyncConnection,
|
|
|
|
|
AsyncSession,
|
|
|
|
|
async_sessionmaker,
|
|
|
|
|
create_async_engine,
|
|
|
|
|
)
|
|
|
|
|
from sqlalchemy.orm import SessionTransaction
|
|
|
|
|
from sqlalchemy.pool import NullPool
|
2025-09-09 14:37:32 +05:30
|
|
|
|
|
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
def get_test_database_url() -> str:
|
2025-09-09 14:37:32 +05:30
|
|
|
"""
|
2026-01-02 13:11:02 +05:30
|
|
|
Get the test database URL by appending _test to the database name.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
postgresql+asyncpg://user:pass@host/mydb
|
|
|
|
|
-> postgresql+asyncpg://user:pass@host/mydb_test
|
2025-09-09 14:37:32 +05:30
|
|
|
"""
|
2026-01-02 13:11:02 +05:30
|
|
|
original_url = os.environ.get("DATABASE_URL")
|
|
|
|
|
if not original_url:
|
|
|
|
|
raise ValueError("DATABASE_URL environment variable is not set")
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
parsed = urlparse(original_url)
|
|
|
|
|
# Append _test to the database name (path without leading slash)
|
|
|
|
|
original_db_name = parsed.path.lstrip("/")
|
|
|
|
|
test_db_name = f"{original_db_name}_test"
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
# Reconstruct the URL with the new database name
|
|
|
|
|
test_url = urlunparse(
|
|
|
|
|
(
|
|
|
|
|
parsed.scheme,
|
|
|
|
|
parsed.netloc,
|
|
|
|
|
f"/{test_db_name}",
|
|
|
|
|
parsed.params,
|
|
|
|
|
parsed.query,
|
|
|
|
|
parsed.fragment,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return test_url
|
2025-09-09 14:37:32 +05:30
|
|
|
|
|
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
def get_base_database_url() -> str:
|
|
|
|
|
"""
|
|
|
|
|
Get base database URL (postgres) for creating/dropping test database.
|
|
|
|
|
"""
|
|
|
|
|
original_url = os.environ.get("DATABASE_URL")
|
|
|
|
|
parsed = urlparse(original_url)
|
|
|
|
|
# Connect to 'postgres' database for admin operations
|
|
|
|
|
base_url = urlunparse(
|
|
|
|
|
(
|
|
|
|
|
parsed.scheme,
|
|
|
|
|
parsed.netloc,
|
|
|
|
|
"/postgres",
|
|
|
|
|
parsed.params,
|
|
|
|
|
parsed.query,
|
|
|
|
|
parsed.fragment,
|
2025-09-09 14:37:32 +05:30
|
|
|
)
|
2026-01-02 13:11:02 +05:30
|
|
|
)
|
|
|
|
|
return base_url
|
2025-09-09 14:37:32 +05:30
|
|
|
|
|
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
def get_test_db_name() -> str:
|
|
|
|
|
"""Extract the test database name."""
|
|
|
|
|
original_url = os.environ.get("DATABASE_URL")
|
|
|
|
|
parsed = urlparse(original_url)
|
|
|
|
|
original_db_name = parsed.path.lstrip("/")
|
|
|
|
|
return f"{original_db_name}_test"
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
|
async def setup_test_database():
|
|
|
|
|
"""
|
|
|
|
|
Session-scoped fixture that creates the test database and runs migrations.
|
|
|
|
|
|
|
|
|
|
This runs once at the start of the test session.
|
|
|
|
|
"""
|
|
|
|
|
test_db_name = get_test_db_name()
|
|
|
|
|
base_url = get_base_database_url()
|
|
|
|
|
test_url = get_test_database_url()
|
|
|
|
|
|
|
|
|
|
# Create engine to connect to postgres database (for admin operations)
|
|
|
|
|
admin_engine = create_async_engine(
|
|
|
|
|
base_url,
|
|
|
|
|
poolclass=NullPool,
|
|
|
|
|
isolation_level="AUTOCOMMIT", # Required for CREATE DATABASE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create test database if it doesn't exist
|
|
|
|
|
async with admin_engine.connect() as conn:
|
|
|
|
|
# Check if database exists
|
|
|
|
|
result = await conn.execute(
|
|
|
|
|
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
|
|
|
|
|
{"dbname": test_db_name},
|
|
|
|
|
)
|
|
|
|
|
exists = result.scalar() is not None
|
|
|
|
|
|
|
|
|
|
if not exists:
|
|
|
|
|
print(f"\n Creating test database: {test_db_name}")
|
|
|
|
|
# Use template0 to avoid collation version mismatch issues
|
|
|
|
|
await conn.execute(
|
|
|
|
|
text(f'CREATE DATABASE "{test_db_name}" TEMPLATE template0')
|
2025-09-09 14:37:32 +05:30
|
|
|
)
|
2026-01-02 13:11:02 +05:30
|
|
|
else:
|
|
|
|
|
print(f"\n Using existing test database: {test_db_name}")
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
await admin_engine.dispose()
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
# Run alembic migrations on the test database
|
|
|
|
|
print(f" Running migrations on {test_db_name}...")
|
|
|
|
|
await run_migrations(test_url)
|
|
|
|
|
print(" Migrations complete!")
|
|
|
|
|
|
|
|
|
|
yield test_url
|
|
|
|
|
|
|
|
|
|
# Cleanup: Optionally drop the test database after tests
|
|
|
|
|
# Commented out to allow inspection of test data after failures
|
|
|
|
|
# async with admin_engine.connect() as conn:
|
|
|
|
|
# await conn.execute(text(f'DROP DATABASE IF EXISTS "{test_db_name}"'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_migrations(database_url: str):
|
2025-09-09 14:37:32 +05:30
|
|
|
"""
|
2026-01-02 13:11:02 +05:30
|
|
|
Run alembic migrations programmatically on the given database.
|
2025-09-09 14:37:32 +05:30
|
|
|
"""
|
2026-01-02 13:11:02 +05:30
|
|
|
from alembic import command
|
|
|
|
|
from alembic.config import Config
|
|
|
|
|
|
|
|
|
|
# Get alembic.ini path
|
2026-02-25 13:53:30 +05:30
|
|
|
alembic_ini_path = APP_ROOT_DIR / "alembic.ini"
|
2026-01-02 13:11:02 +05:30
|
|
|
|
|
|
|
|
# Create alembic config
|
|
|
|
|
alembic_cfg = Config(str(alembic_ini_path))
|
|
|
|
|
|
|
|
|
|
# Override the database URL - need to patch both os.environ AND api.constants
|
|
|
|
|
# because api.constants.DATABASE_URL is cached at import time
|
|
|
|
|
original_env_url = os.environ.get("DATABASE_URL")
|
|
|
|
|
os.environ["DATABASE_URL"] = database_url
|
|
|
|
|
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
|
|
|
|
|
|
|
|
|
# Also patch the cached value in api.constants
|
|
|
|
|
import api.constants
|
|
|
|
|
|
|
|
|
|
original_constants_url = api.constants.DATABASE_URL
|
|
|
|
|
|
|
|
|
|
api.constants.DATABASE_URL = database_url
|
|
|
|
|
|
|
|
|
|
# Run migrations in a thread to avoid blocking the event loop
|
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
|
|
def _run_upgrade():
|
|
|
|
|
command.upgrade(alembic_cfg, "head")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.get_event_loop().run_in_executor(None, _run_upgrade)
|
|
|
|
|
finally:
|
|
|
|
|
# Restore original DATABASE_URL
|
|
|
|
|
if original_env_url:
|
|
|
|
|
os.environ["DATABASE_URL"] = original_env_url
|
|
|
|
|
api.constants.DATABASE_URL = original_constants_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
|
async def test_engine(setup_test_database):
|
|
|
|
|
"""
|
|
|
|
|
Create a test database engine (session-scoped).
|
|
|
|
|
|
|
|
|
|
Uses NullPool to avoid connection issues with async tests.
|
|
|
|
|
"""
|
|
|
|
|
test_url = setup_test_database
|
|
|
|
|
engine = create_async_engine(
|
|
|
|
|
test_url,
|
|
|
|
|
poolclass=NullPool,
|
|
|
|
|
echo=False, # Set to True for SQL debugging
|
|
|
|
|
)
|
|
|
|
|
yield engine
|
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
|
|
|
async def db_connection(test_engine) -> AsyncGenerator[AsyncConnection, None]:
|
|
|
|
|
"""
|
|
|
|
|
Create a database connection for each test.
|
|
|
|
|
|
|
|
|
|
This connection wraps all operations in a transaction that
|
|
|
|
|
will be rolled back at the end of the test.
|
|
|
|
|
"""
|
|
|
|
|
async with test_engine.connect() as connection:
|
|
|
|
|
yield connection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
|
|
|
async def async_session(
|
|
|
|
|
db_connection: AsyncConnection,
|
|
|
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
|
|
|
"""
|
|
|
|
|
Create a database session with transaction isolation for each test.
|
|
|
|
|
|
|
|
|
|
This fixture:
|
|
|
|
|
1. Begins a transaction on the connection
|
|
|
|
|
2. Creates a savepoint (nested transaction)
|
|
|
|
|
3. Yields the session for test use
|
|
|
|
|
4. Rolls back all changes after the test
|
|
|
|
|
|
|
|
|
|
Tests can call session.commit() and it will only commit to the savepoint,
|
|
|
|
|
not to the actual database. The outer transaction rollback ensures
|
|
|
|
|
complete isolation between tests.
|
|
|
|
|
"""
|
|
|
|
|
# Begin outer transaction
|
|
|
|
|
trans = await db_connection.begin()
|
|
|
|
|
|
|
|
|
|
# Create session bound to this connection
|
|
|
|
|
async_session_maker = async_sessionmaker(
|
|
|
|
|
bind=db_connection,
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
autoflush=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async with async_session_maker() as session:
|
|
|
|
|
# Begin a nested transaction (savepoint)
|
|
|
|
|
nested = await session.begin_nested()
|
|
|
|
|
|
|
|
|
|
# Set up event listener to restart savepoint after commits
|
|
|
|
|
@event.listens_for(session.sync_session, "after_transaction_end")
|
|
|
|
|
def reopen_nested_transaction(session_sync, transaction: SessionTransaction):
|
|
|
|
|
nonlocal nested
|
|
|
|
|
if not nested.is_active:
|
|
|
|
|
nested = session.sync_session.begin_nested()
|
|
|
|
|
|
|
|
|
|
yield session
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
# Rollback everything
|
|
|
|
|
await trans.rollback()
|
2025-09-09 14:37:32 +05:30
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
|
|
|
|
|
class _TestSessionContext:
|
|
|
|
|
"""
|
|
|
|
|
Context manager wrapper for test session.
|
|
|
|
|
|
|
|
|
|
Mimics the behavior of async_sessionmaker() context manager
|
|
|
|
|
but uses the existing test session without closing it.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
|
|
|
self._session = session
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self) -> AsyncSession:
|
|
|
|
|
return self._session
|
|
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
if exc_type is None:
|
|
|
|
|
await self._session.flush()
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
|
|
|
async def db_session(async_session: AsyncSession):
|
|
|
|
|
"""
|
|
|
|
|
Create a DBClient instance that uses the test session.
|
|
|
|
|
|
|
|
|
|
This patches the DBClient's async_session to use our test session,
|
|
|
|
|
ensuring all database operations go through the transactional test session.
|
|
|
|
|
|
|
|
|
|
Note: This fixture yields a DBClient (not a raw session) for backward
|
|
|
|
|
compatibility with existing tests that call db_session.get_or_create_user_by_provider_id(), etc.
|
|
|
|
|
"""
|
|
|
|
|
from api.db import db_client
|
|
|
|
|
|
|
|
|
|
def test_session_maker():
|
|
|
|
|
return _TestSessionContext(async_session)
|
|
|
|
|
|
|
|
|
|
# Store originals
|
|
|
|
|
original_engine = db_client.engine
|
|
|
|
|
original_async_session = db_client.async_session
|
|
|
|
|
|
|
|
|
|
# Patch the db_client to use our test session
|
2025-09-09 14:37:32 +05:30
|
|
|
db_client.async_session = test_session_maker
|
|
|
|
|
|
|
|
|
|
yield db_client
|
|
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
# Restore originals
|
2025-09-09 14:37:32 +05:30
|
|
|
db_client.engine = original_engine
|
2026-01-02 13:11:02 +05:30
|
|
|
db_client.async_session = original_async_session
|
2025-09-09 14:37:32 +05:30
|
|
|
|
|
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
@pytest.fixture
|
2025-09-09 14:37:32 +05:30
|
|
|
async def test_client_factory(db_session):
|
|
|
|
|
"""
|
|
|
|
|
Factory fixture that creates test clients for specific users.
|
|
|
|
|
This allows tests to create custom users and test clients on demand.
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
async def test_something(test_client_factory, db_session):
|
|
|
|
|
# Create a custom user
|
2026-04-24 12:02:52 +05:30
|
|
|
user, _ = await db_session.get_or_create_user_by_provider_id("custom_user_123")
|
2025-09-09 14:37:32 +05:30
|
|
|
|
|
|
|
|
# Create a test client for this user
|
|
|
|
|
async with test_client_factory(user) as client:
|
|
|
|
|
# Use the client in your test
|
|
|
|
|
response = await client.get("/some/endpoint")
|
|
|
|
|
"""
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
2026-01-02 13:11:02 +05:30
|
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
|
|
|
|
|
|
from api.app import app
|
2025-09-09 14:37:32 +05:30
|
|
|
from api.services.auth.depends import get_user
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def _create_client_for_user(user):
|
|
|
|
|
# Create mock auth dependency for this user
|
|
|
|
|
async def mock_get_user():
|
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
# Override the dependency
|
|
|
|
|
original_override = app.dependency_overrides.get(get_user)
|
|
|
|
|
app.dependency_overrides[get_user] = mock_get_user
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
async with AsyncClient(
|
|
|
|
|
transport=ASGITransport(app=app), base_url="http://test"
|
|
|
|
|
) as client:
|
|
|
|
|
yield client
|
|
|
|
|
finally:
|
|
|
|
|
# Clean up the override
|
|
|
|
|
if original_override:
|
|
|
|
|
app.dependency_overrides[get_user] = original_override
|
|
|
|
|
else:
|
|
|
|
|
app.dependency_overrides.pop(get_user, None)
|
|
|
|
|
|
|
|
|
|
return _create_client_for_user
|