mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: user defined custom tools as part of workflow execution (#94)
* feat: add custom tools functionality * Show tools in nodes * integrate tool calling with pipeline engine
This commit is contained in:
parent
cc2d3e70d2
commit
3e55af9256
65 changed files with 5483 additions and 6673 deletions
395
api/conftest.py
395
api/conftest.py
|
|
@ -1,143 +1,315 @@
|
|||
"""
|
||||
Shared pytest fixtures for the API tests.
|
||||
This file contains database setup, test client configuration, and utility fixtures
|
||||
that can be reused across all test files.
|
||||
Pytest configuration and fixtures for async database testing.
|
||||
|
||||
This module sets up the test infrastructure using:
|
||||
- A separate test database (appends _test to the database name)
|
||||
- Alembic migrations run once per test session
|
||||
- Transaction-based isolation for each test (savepoint pattern)
|
||||
|
||||
References:
|
||||
- https://www.core27.co/post/transactional-unit-tests-with-pytest-and-async-sqlalchemy
|
||||
- https://docs.sqlalchemy.org/en/20/orm/session_transaction.html
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
# Load environment variables before importing anything else
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from api.app import app
|
||||
from api.db import db_client
|
||||
# Load .env.test from api directory for test configuration
|
||||
env_path = Path(__file__).parent / ".env.test"
|
||||
load_dotenv(env_path)
|
||||
|
||||
# Test database setup globals
|
||||
TEST_DATABASE_NAME = None
|
||||
TEST_DATABASE_URL = None
|
||||
import pytest
|
||||
from sqlalchemy import event, text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import SessionTransaction
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_database():
|
||||
def get_test_database_url() -> str:
|
||||
"""
|
||||
Set up a temporary PostgreSQL database for testing.
|
||||
This fixture creates a unique test database, runs migrations, and cleans up afterward.
|
||||
Get the test database URL by appending _test to the database name.
|
||||
|
||||
Example:
|
||||
postgresql+asyncpg://user:pass@host/mydb
|
||||
-> postgresql+asyncpg://user:pass@host/mydb_test
|
||||
"""
|
||||
global TEST_DATABASE_NAME, TEST_DATABASE_URL
|
||||
original_url = os.environ.get("DATABASE_URL")
|
||||
if not original_url:
|
||||
raise ValueError("DATABASE_URL environment variable is not set")
|
||||
|
||||
# Generate a unique test database name
|
||||
TEST_DATABASE_NAME = f"test_dograh_{uuid.uuid4().hex[:8]}"
|
||||
parsed = urlparse(original_url)
|
||||
# Append _test to the database name (path without leading slash)
|
||||
original_db_name = parsed.path.lstrip("/")
|
||||
test_db_name = f"{original_db_name}_test"
|
||||
|
||||
# Get the base DATABASE_URL and parse it
|
||||
base_url = os.environ.get("DATABASE_URL")
|
||||
# Extract connection parts and replace database name
|
||||
url_parts = base_url.split("/")
|
||||
base_connection = "/".join(url_parts[:-1])
|
||||
TEST_DATABASE_URL = f"{base_connection}/{TEST_DATABASE_NAME}"
|
||||
# Reconstruct the URL with the new database name
|
||||
test_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
f"/{test_db_name}",
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
return test_url
|
||||
|
||||
# Create a connection to the default postgres database to create our test database
|
||||
default_engine = create_async_engine(base_url)
|
||||
|
||||
def get_base_database_url() -> str:
|
||||
"""
|
||||
Get base database URL (postgres) for creating/dropping test database.
|
||||
"""
|
||||
original_url = os.environ.get("DATABASE_URL")
|
||||
parsed = urlparse(original_url)
|
||||
# Connect to 'postgres' database for admin operations
|
||||
base_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
"/postgres",
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
return base_url
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
"""Extract the test database name."""
|
||||
original_url = os.environ.get("DATABASE_URL")
|
||||
parsed = urlparse(original_url)
|
||||
original_db_name = parsed.path.lstrip("/")
|
||||
return f"{original_db_name}_test"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_test_database():
|
||||
"""
|
||||
Session-scoped fixture that creates the test database and runs migrations.
|
||||
|
||||
This runs once at the start of the test session.
|
||||
"""
|
||||
test_db_name = get_test_db_name()
|
||||
base_url = get_base_database_url()
|
||||
test_url = get_test_database_url()
|
||||
|
||||
# Create engine to connect to postgres database (for admin operations)
|
||||
admin_engine = create_async_engine(
|
||||
base_url,
|
||||
poolclass=NullPool,
|
||||
isolation_level="AUTOCOMMIT", # Required for CREATE DATABASE
|
||||
)
|
||||
|
||||
# Create test database if it doesn't exist
|
||||
async with admin_engine.connect() as conn:
|
||||
# Check if database exists
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
|
||||
{"dbname": test_db_name},
|
||||
)
|
||||
exists = result.scalar() is not None
|
||||
|
||||
if not exists:
|
||||
print(f"\n Creating test database: {test_db_name}")
|
||||
# Use template0 to avoid collation version mismatch issues
|
||||
await conn.execute(
|
||||
text(f'CREATE DATABASE "{test_db_name}" TEMPLATE template0')
|
||||
)
|
||||
else:
|
||||
print(f"\n Using existing test database: {test_db_name}")
|
||||
|
||||
await admin_engine.dispose()
|
||||
|
||||
# Run alembic migrations on the test database
|
||||
print(f" Running migrations on {test_db_name}...")
|
||||
await run_migrations(test_url)
|
||||
print(" Migrations complete!")
|
||||
|
||||
yield test_url
|
||||
|
||||
# Cleanup: Optionally drop the test database after tests
|
||||
# Commented out to allow inspection of test data after failures
|
||||
# async with admin_engine.connect() as conn:
|
||||
# await conn.execute(text(f'DROP DATABASE IF EXISTS "{test_db_name}"'))
|
||||
|
||||
|
||||
async def run_migrations(database_url: str):
|
||||
"""
|
||||
Run alembic migrations programmatically on the given database.
|
||||
"""
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
# Get alembic.ini path
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
|
||||
# Create alembic config
|
||||
alembic_cfg = Config(str(alembic_ini_path))
|
||||
|
||||
# Override the database URL - need to patch both os.environ AND api.constants
|
||||
# because api.constants.DATABASE_URL is cached at import time
|
||||
original_env_url = os.environ.get("DATABASE_URL")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Also patch the cached value in api.constants
|
||||
import api.constants
|
||||
|
||||
original_constants_url = api.constants.DATABASE_URL
|
||||
|
||||
api.constants.DATABASE_URL = database_url
|
||||
|
||||
# Run migrations in a thread to avoid blocking the event loop
|
||||
import asyncio
|
||||
|
||||
def _run_upgrade():
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
try:
|
||||
# Create the test database
|
||||
async with default_engine.connect() as conn:
|
||||
# Use autocommit mode to create database
|
||||
await conn.execute(text("COMMIT"))
|
||||
await conn.execute(text(f"CREATE DATABASE {TEST_DATABASE_NAME}"))
|
||||
|
||||
await default_engine.dispose()
|
||||
|
||||
# Run migrations on the test database
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = TEST_DATABASE_URL
|
||||
# Add the parent directory to PYTHONPATH so alembic can find the api module
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
env["PYTHONPATH"] = parent_dir + ":" + env.get("PYTHONPATH", "")
|
||||
|
||||
# Run alembic upgrade to create all tables
|
||||
result = subprocess.run(
|
||||
[
|
||||
"conda",
|
||||
"run",
|
||||
"-n",
|
||||
"dograh",
|
||||
"python",
|
||||
"-m",
|
||||
"alembic",
|
||||
"-c",
|
||||
"alembic.ini",
|
||||
"upgrade",
|
||||
"head",
|
||||
],
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Alembic stderr: {result.stderr}")
|
||||
logger.error(f"Alembic stdout: {result.stdout}")
|
||||
raise RuntimeError(f"Alembic migration failed: {result.stderr}")
|
||||
|
||||
logger.info(f"Created test database: {TEST_DATABASE_NAME}")
|
||||
yield TEST_DATABASE_URL
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(None, _run_upgrade)
|
||||
finally:
|
||||
# Cleanup: Drop the test database
|
||||
cleanup_engine = create_async_engine(base_url)
|
||||
try:
|
||||
async with cleanup_engine.connect() as conn:
|
||||
# Terminate any connections to the test database
|
||||
await conn.execute(text("COMMIT"))
|
||||
await conn.execute(
|
||||
text(f"""
|
||||
SELECT pg_terminate_backend(pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE datname = '{TEST_DATABASE_NAME}' AND pid <> pg_backend_pid()
|
||||
""")
|
||||
)
|
||||
await conn.execute(
|
||||
text(f"DROP DATABASE IF EXISTS {TEST_DATABASE_NAME}")
|
||||
)
|
||||
logger.info(f"Cleaned up test database: {TEST_DATABASE_NAME}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Warning: Could not clean up test database {TEST_DATABASE_NAME}: {e}"
|
||||
)
|
||||
finally:
|
||||
await cleanup_engine.dispose()
|
||||
# Restore original DATABASE_URL
|
||||
if original_env_url:
|
||||
os.environ["DATABASE_URL"] = original_env_url
|
||||
api.constants.DATABASE_URL = original_constants_url
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(test_database):
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_engine(setup_test_database):
|
||||
"""
|
||||
Create a test database client that uses the temporary database.
|
||||
This fixture replaces the global db_client with a test version.
|
||||
Create a test database engine (session-scoped).
|
||||
|
||||
Uses NullPool to avoid connection issues with async tests.
|
||||
"""
|
||||
test_url = setup_test_database
|
||||
engine = create_async_engine(
|
||||
test_url,
|
||||
poolclass=NullPool,
|
||||
echo=False, # Set to True for SQL debugging
|
||||
)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def db_connection(test_engine) -> AsyncGenerator[AsyncConnection, None]:
|
||||
"""
|
||||
Create a database connection for each test.
|
||||
|
||||
This connection wraps all operations in a transaction that
|
||||
will be rolled back at the end of the test.
|
||||
"""
|
||||
async with test_engine.connect() as connection:
|
||||
yield connection
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def async_session(
|
||||
db_connection: AsyncConnection,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Create a database session with transaction isolation for each test.
|
||||
|
||||
This fixture:
|
||||
1. Begins a transaction on the connection
|
||||
2. Creates a savepoint (nested transaction)
|
||||
3. Yields the session for test use
|
||||
4. Rolls back all changes after the test
|
||||
|
||||
Tests can call session.commit() and it will only commit to the savepoint,
|
||||
not to the actual database. The outer transaction rollback ensures
|
||||
complete isolation between tests.
|
||||
"""
|
||||
# Begin outer transaction
|
||||
trans = await db_connection.begin()
|
||||
|
||||
# Create session bound to this connection
|
||||
async_session_maker = async_sessionmaker(
|
||||
bind=db_connection,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Begin a nested transaction (savepoint)
|
||||
nested = await session.begin_nested()
|
||||
|
||||
# Set up event listener to restart savepoint after commits
|
||||
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||
def reopen_nested_transaction(session_sync, transaction: SessionTransaction):
|
||||
nonlocal nested
|
||||
if not nested.is_active:
|
||||
nested = session.sync_session.begin_nested()
|
||||
|
||||
yield session
|
||||
|
||||
# Rollback everything
|
||||
await trans.rollback()
|
||||
|
||||
|
||||
class _TestSessionContext:
|
||||
"""
|
||||
Context manager wrapper for test session.
|
||||
|
||||
Mimics the behavior of async_sessionmaker() context manager
|
||||
but uses the existing test session without closing it.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
await self._session.flush()
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def db_session(async_session: AsyncSession):
|
||||
"""
|
||||
Create a DBClient instance that uses the test session.
|
||||
|
||||
This patches the DBClient's async_session to use our test session,
|
||||
ensuring all database operations go through the transactional test session.
|
||||
|
||||
Note: This fixture yields a DBClient (not a raw session) for backward
|
||||
compatibility with existing tests that call db_session.get_or_create_user_by_provider_id(), etc.
|
||||
"""
|
||||
from api.db import db_client
|
||||
|
||||
def test_session_maker():
|
||||
return _TestSessionContext(async_session)
|
||||
|
||||
# Store originals
|
||||
original_engine = db_client.engine
|
||||
original_session = db_client.async_session
|
||||
original_async_session = db_client.async_session
|
||||
|
||||
# Replace the database client's engine and session with test ones
|
||||
test_engine = create_async_engine(test_database)
|
||||
test_session_maker = async_sessionmaker(bind=test_engine)
|
||||
|
||||
db_client.engine = test_engine
|
||||
# Patch the db_client to use our test session
|
||||
db_client.async_session = test_session_maker
|
||||
|
||||
yield db_client
|
||||
|
||||
# Restore original database client
|
||||
await test_engine.dispose()
|
||||
# Restore originals
|
||||
db_client.engine = original_engine
|
||||
db_client.async_session = original_session
|
||||
db_client.async_session = original_async_session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def test_client_factory(db_session):
|
||||
"""
|
||||
Factory fixture that creates test clients for specific users.
|
||||
|
|
@ -155,6 +327,9 @@ async def test_client_factory(db_session):
|
|||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from api.app import app
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue