feat: user defined custom tools as part of workflow execution (#94)

* feat: add custom tools functionality

* Show tools in nodes

* integrate tool calling with pipeline engine
This commit is contained in:
Abhishek 2026-01-02 13:11:02 +05:30 committed by GitHub
parent cc2d3e70d2
commit 3e55af9256
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 5483 additions and 6673 deletions

View file

@ -1,143 +1,315 @@
"""
Shared pytest fixtures for the API tests.
This file contains database setup, test client configuration, and utility fixtures
that can be reused across all test files.
Pytest configuration and fixtures for async database testing.
This module sets up the test infrastructure using:
- A separate test database (appends _test to the database name)
- Alembic migrations run once per test session
- Transaction-based isolation for each test (savepoint pattern)
References:
- https://www.core27.co/post/transactional-unit-tests-with-pytest-and-async-sqlalchemy
- https://docs.sqlalchemy.org/en/20/orm/session_transaction.html
"""
import os
import subprocess
import uuid
from pathlib import Path
from typing import AsyncGenerator
from urllib.parse import urlparse, urlunparse
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from loguru import logger
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
# Load environment variables before importing anything else
from dotenv import load_dotenv
from api.app import app
from api.db import db_client
# Load .env.test from api directory for test configuration
env_path = Path(__file__).parent / ".env.test"
load_dotenv(env_path)
# Test database setup globals
TEST_DATABASE_NAME = None
TEST_DATABASE_URL = None
import pytest
from sqlalchemy import event, text
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import SessionTransaction
from sqlalchemy.pool import NullPool
@pytest_asyncio.fixture
async def test_database():
def get_test_database_url() -> str:
"""
Set up a temporary PostgreSQL database for testing.
This fixture creates a unique test database, runs migrations, and cleans up afterward.
Get the test database URL by appending _test to the database name.
Example:
postgresql+asyncpg://user:pass@host/mydb
-> postgresql+asyncpg://user:pass@host/mydb_test
"""
global TEST_DATABASE_NAME, TEST_DATABASE_URL
original_url = os.environ.get("DATABASE_URL")
if not original_url:
raise ValueError("DATABASE_URL environment variable is not set")
# Generate a unique test database name
TEST_DATABASE_NAME = f"test_dograh_{uuid.uuid4().hex[:8]}"
parsed = urlparse(original_url)
# Append _test to the database name (path without leading slash)
original_db_name = parsed.path.lstrip("/")
test_db_name = f"{original_db_name}_test"
# Get the base DATABASE_URL and parse it
base_url = os.environ.get("DATABASE_URL")
# Extract connection parts and replace database name
url_parts = base_url.split("/")
base_connection = "/".join(url_parts[:-1])
TEST_DATABASE_URL = f"{base_connection}/{TEST_DATABASE_NAME}"
# Reconstruct the URL with the new database name
test_url = urlunparse(
(
parsed.scheme,
parsed.netloc,
f"/{test_db_name}",
parsed.params,
parsed.query,
parsed.fragment,
)
)
return test_url
# Create a connection to the default postgres database to create our test database
default_engine = create_async_engine(base_url)
def get_base_database_url() -> str:
"""
Get base database URL (postgres) for creating/dropping test database.
"""
original_url = os.environ.get("DATABASE_URL")
parsed = urlparse(original_url)
# Connect to 'postgres' database for admin operations
base_url = urlunparse(
(
parsed.scheme,
parsed.netloc,
"/postgres",
parsed.params,
parsed.query,
parsed.fragment,
)
)
return base_url
def get_test_db_name() -> str:
"""Extract the test database name."""
original_url = os.environ.get("DATABASE_URL")
parsed = urlparse(original_url)
original_db_name = parsed.path.lstrip("/")
return f"{original_db_name}_test"
@pytest.fixture(scope="session")
async def setup_test_database():
"""
Session-scoped fixture that creates the test database and runs migrations.
This runs once at the start of the test session.
"""
test_db_name = get_test_db_name()
base_url = get_base_database_url()
test_url = get_test_database_url()
# Create engine to connect to postgres database (for admin operations)
admin_engine = create_async_engine(
base_url,
poolclass=NullPool,
isolation_level="AUTOCOMMIT", # Required for CREATE DATABASE
)
# Create test database if it doesn't exist
async with admin_engine.connect() as conn:
# Check if database exists
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
{"dbname": test_db_name},
)
exists = result.scalar() is not None
if not exists:
print(f"\n Creating test database: {test_db_name}")
# Use template0 to avoid collation version mismatch issues
await conn.execute(
text(f'CREATE DATABASE "{test_db_name}" TEMPLATE template0')
)
else:
print(f"\n Using existing test database: {test_db_name}")
await admin_engine.dispose()
# Run alembic migrations on the test database
print(f" Running migrations on {test_db_name}...")
await run_migrations(test_url)
print(" Migrations complete!")
yield test_url
# Cleanup: Optionally drop the test database after tests
# Commented out to allow inspection of test data after failures
# async with admin_engine.connect() as conn:
# await conn.execute(text(f'DROP DATABASE IF EXISTS "{test_db_name}"'))
async def run_migrations(database_url: str):
"""
Run alembic migrations programmatically on the given database.
"""
from alembic import command
from alembic.config import Config
# Get alembic.ini path
alembic_ini_path = Path(__file__).parent / "alembic.ini"
# Create alembic config
alembic_cfg = Config(str(alembic_ini_path))
# Override the database URL - need to patch both os.environ AND api.constants
# because api.constants.DATABASE_URL is cached at import time
original_env_url = os.environ.get("DATABASE_URL")
os.environ["DATABASE_URL"] = database_url
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
# Also patch the cached value in api.constants
import api.constants
original_constants_url = api.constants.DATABASE_URL
api.constants.DATABASE_URL = database_url
# Run migrations in a thread to avoid blocking the event loop
import asyncio
def _run_upgrade():
command.upgrade(alembic_cfg, "head")
try:
# Create the test database
async with default_engine.connect() as conn:
# Use autocommit mode to create database
await conn.execute(text("COMMIT"))
await conn.execute(text(f"CREATE DATABASE {TEST_DATABASE_NAME}"))
await default_engine.dispose()
# Run migrations on the test database
env = os.environ.copy()
env["DATABASE_URL"] = TEST_DATABASE_URL
# Add the parent directory to PYTHONPATH so alembic can find the api module
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env["PYTHONPATH"] = parent_dir + ":" + env.get("PYTHONPATH", "")
# Run alembic upgrade to create all tables
result = subprocess.run(
[
"conda",
"run",
"-n",
"dograh",
"python",
"-m",
"alembic",
"-c",
"alembic.ini",
"upgrade",
"head",
],
env=env,
capture_output=True,
text=True,
)
if result.returncode != 0:
logger.error(f"Alembic stderr: {result.stderr}")
logger.error(f"Alembic stdout: {result.stdout}")
raise RuntimeError(f"Alembic migration failed: {result.stderr}")
logger.info(f"Created test database: {TEST_DATABASE_NAME}")
yield TEST_DATABASE_URL
await asyncio.get_event_loop().run_in_executor(None, _run_upgrade)
finally:
# Cleanup: Drop the test database
cleanup_engine = create_async_engine(base_url)
try:
async with cleanup_engine.connect() as conn:
# Terminate any connections to the test database
await conn.execute(text("COMMIT"))
await conn.execute(
text(f"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '{TEST_DATABASE_NAME}' AND pid <> pg_backend_pid()
""")
)
await conn.execute(
text(f"DROP DATABASE IF EXISTS {TEST_DATABASE_NAME}")
)
logger.info(f"Cleaned up test database: {TEST_DATABASE_NAME}")
except Exception as e:
logger.error(
f"Warning: Could not clean up test database {TEST_DATABASE_NAME}: {e}"
)
finally:
await cleanup_engine.dispose()
# Restore original DATABASE_URL
if original_env_url:
os.environ["DATABASE_URL"] = original_env_url
api.constants.DATABASE_URL = original_constants_url
@pytest_asyncio.fixture
async def db_session(test_database):
@pytest.fixture(scope="session")
async def test_engine(setup_test_database):
"""
Create a test database client that uses the temporary database.
This fixture replaces the global db_client with a test version.
Create a test database engine (session-scoped).
Uses NullPool to avoid connection issues with async tests.
"""
test_url = setup_test_database
engine = create_async_engine(
test_url,
poolclass=NullPool,
echo=False, # Set to True for SQL debugging
)
yield engine
await engine.dispose()
@pytest.fixture(scope="function")
async def db_connection(test_engine) -> AsyncGenerator[AsyncConnection, None]:
"""
Create a database connection for each test.
This connection wraps all operations in a transaction that
will be rolled back at the end of the test.
"""
async with test_engine.connect() as connection:
yield connection
@pytest.fixture(scope="function")
async def async_session(
db_connection: AsyncConnection,
) -> AsyncGenerator[AsyncSession, None]:
"""
Create a database session with transaction isolation for each test.
This fixture:
1. Begins a transaction on the connection
2. Creates a savepoint (nested transaction)
3. Yields the session for test use
4. Rolls back all changes after the test
Tests can call session.commit() and it will only commit to the savepoint,
not to the actual database. The outer transaction rollback ensures
complete isolation between tests.
"""
# Begin outer transaction
trans = await db_connection.begin()
# Create session bound to this connection
async_session_maker = async_sessionmaker(
bind=db_connection,
expire_on_commit=False,
autoflush=False,
)
async with async_session_maker() as session:
# Begin a nested transaction (savepoint)
nested = await session.begin_nested()
# Set up event listener to restart savepoint after commits
@event.listens_for(session.sync_session, "after_transaction_end")
def reopen_nested_transaction(session_sync, transaction: SessionTransaction):
nonlocal nested
if not nested.is_active:
nested = session.sync_session.begin_nested()
yield session
# Rollback everything
await trans.rollback()
class _TestSessionContext:
"""
Context manager wrapper for test session.
Mimics the behavior of async_sessionmaker() context manager
but uses the existing test session without closing it.
"""
def __init__(self, session: AsyncSession):
self._session = session
async def __aenter__(self) -> AsyncSession:
return self._session
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self._session.flush()
return False
@pytest.fixture(scope="function")
async def db_session(async_session: AsyncSession):
"""
Create a DBClient instance that uses the test session.
This patches the DBClient's async_session to use our test session,
ensuring all database operations go through the transactional test session.
Note: This fixture yields a DBClient (not a raw session) for backward
compatibility with existing tests that call db_session.get_or_create_user_by_provider_id(), etc.
"""
from api.db import db_client
def test_session_maker():
return _TestSessionContext(async_session)
# Store originals
original_engine = db_client.engine
original_session = db_client.async_session
original_async_session = db_client.async_session
# Replace the database client's engine and session with test ones
test_engine = create_async_engine(test_database)
test_session_maker = async_sessionmaker(bind=test_engine)
db_client.engine = test_engine
# Patch the db_client to use our test session
db_client.async_session = test_session_maker
yield db_client
# Restore original database client
await test_engine.dispose()
# Restore originals
db_client.engine = original_engine
db_client.async_session = original_session
db_client.async_session = original_async_session
@pytest_asyncio.fixture
@pytest.fixture
async def test_client_factory(db_session):
"""
Factory fixture that creates test clients for specific users.
@ -155,6 +327,9 @@ async def test_client_factory(db_session):
"""
from contextlib import asynccontextmanager
from httpx import ASGITransport, AsyncClient
from api.app import app
from api.services.auth.depends import get_user
@asynccontextmanager