Initial Commit 🚀 🚀

This commit is contained in:
Abhishek Kumar 2025-09-09 14:37:32 +05:30
commit 4f2a629340
444 changed files with 76863 additions and 0 deletions

3
api/db/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from api.db.db_client import DBClient
db_client = DBClient()

108
api/db/api_key_client.py Normal file
View file

@ -0,0 +1,108 @@
from typing import List, Optional
from sqlalchemy import and_
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import APIKeyModel
from api.utils.api_key import generate_api_key, hash_api_key
class APIKeyClient(BaseDBClient):
async def create_api_key(
self, organization_id: int, name: str, created_by: Optional[int] = None
) -> tuple[APIKeyModel, str]:
"""Create a new API key for an organization.
Returns:
Tuple of (APIKeyModel, raw_api_key)
"""
# Generate a secure random API key
raw_api_key, key_hash, key_prefix = generate_api_key()
async with self.async_session() as session:
api_key = APIKeyModel(
organization_id=organization_id,
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
created_by=created_by,
is_active=True,
)
session.add(api_key)
await session.commit()
await session.refresh(api_key)
return api_key, raw_api_key
async def get_api_keys_by_organization(
self, organization_id: int, include_archived: bool = False
) -> List[APIKeyModel]:
"""Get all API keys for an organization."""
async with self.async_session() as session:
query = select(APIKeyModel).where(
APIKeyModel.organization_id == organization_id
)
if not include_archived:
query = query.where(APIKeyModel.archived_at.is_(None))
result = await session.execute(query)
return result.scalars().all()
async def get_api_key_by_hash(self, key_hash: str) -> Optional[APIKeyModel]:
"""Get an API key by its hash."""
async with self.async_session() as session:
result = await session.execute(
select(APIKeyModel).where(
and_(
APIKeyModel.key_hash == key_hash,
APIKeyModel.is_active == True,
APIKeyModel.archived_at.is_(None),
)
)
)
return result.scalars().first()
async def validate_api_key(self, raw_api_key: str) -> Optional[APIKeyModel]:
"""Validate an API key and return the associated model if valid."""
key_hash = hash_api_key(raw_api_key)
api_key = await self.get_api_key_by_hash(key_hash)
if api_key:
# Update last_used_at
from datetime import datetime, timezone
async with self.async_session() as session:
await session.execute(
APIKeyModel.__table__.update()
.where(APIKeyModel.id == api_key.id)
.values(last_used_at=datetime.now(timezone.utc))
)
await session.commit()
return api_key
async def archive_api_key(self, api_key_id: int) -> bool:
"""Archive an API key (soft delete)."""
from datetime import datetime, timezone
async with self.async_session() as session:
result = await session.execute(
APIKeyModel.__table__.update()
.where(APIKeyModel.id == api_key_id)
.values(is_active=False, archived_at=datetime.now(timezone.utc))
)
await session.commit()
return result.rowcount > 0
async def reactivate_api_key(self, api_key_id: int) -> bool:
"""Reactivate an archived API key."""
async with self.async_session() as session:
result = await session.execute(
APIKeyModel.__table__.update()
.where(APIKeyModel.id == api_key_id)
.values(is_active=True, archived_at=None)
)
await session.commit()
return result.rowcount > 0

34
api/db/base_client.py Normal file
View file

@ -0,0 +1,34 @@
from typing import Any, Dict, List
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from api.constants import DATABASE_URL
class BaseDBClient:
def __init__(self):
self.engine = create_async_engine(DATABASE_URL)
self.async_session = async_sessionmaker(bind=self.engine)
async def execute_raw_query(
self, query: str, params: Dict[str, Any] = None
) -> List[Dict[str, Any]]:
"""
Execute a raw SQL query and return results as a list of dictionaries.
Args:
query: The SQL query to execute
params: Optional dictionary of query parameters
Returns:
List of dictionaries containing the query results
"""
async with self.async_session() as session:
result = await session.execute(text(query), params or {})
rows = result.fetchall()
if rows:
# Convert rows to dictionaries
columns = result.keys()
return [dict(zip(columns, row)) for row in rows]
return []

379
api/db/campaign_client.py Normal file
View file

@ -0,0 +1,379 @@
from datetime import UTC, datetime
from typing import Optional
from sqlalchemy import func
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
class CampaignClient(BaseDBClient):
async def create_campaign(
self,
name: str,
workflow_id: int,
source_type: str,
source_id: str,
user_id: int,
organization_id: int,
) -> CampaignModel:
"""Create a new campaign"""
async with self.async_session() as session:
campaign = CampaignModel(
name=name,
workflow_id=workflow_id,
source_type=source_type,
source_id=source_id,
created_by=user_id,
organization_id=organization_id,
)
session.add(campaign)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(campaign)
return campaign
async def get_campaigns(
self,
organization_id: int,
) -> list[CampaignModel]:
"""Get all campaigns for organization"""
async with self.async_session() as session:
query = (
select(CampaignModel)
.where(CampaignModel.organization_id == organization_id)
.order_by(CampaignModel.created_at.desc())
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_campaign(
self,
campaign_id: int,
organization_id: int,
) -> Optional[CampaignModel]:
"""Get single campaign by ID, ensuring organization access"""
async with self.async_session() as session:
query = select(CampaignModel).where(
CampaignModel.id == campaign_id,
CampaignModel.organization_id == organization_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def update_campaign_state(
self,
campaign_id: int,
state: str,
organization_id: int,
) -> CampaignModel:
"""Update campaign state (start/pause/resume)"""
async with self.async_session() as session:
query = select(CampaignModel).where(
CampaignModel.id == campaign_id,
CampaignModel.organization_id == organization_id,
)
result = await session.execute(query)
campaign = result.scalar_one_or_none()
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
campaign.state = state
if state == "running" and not campaign.started_at:
campaign.started_at = datetime.now(UTC)
elif state in ["completed", "failed"]:
campaign.completed_at = datetime.now(UTC)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(campaign)
return campaign
async def update_campaign_progress(
self,
campaign_id: int,
processed_rows: int,
failed_rows: int,
organization_id: int,
) -> None:
"""Update campaign progress counters"""
async with self.async_session() as session:
query = select(CampaignModel).where(
CampaignModel.id == campaign_id,
CampaignModel.organization_id == organization_id,
)
result = await session.execute(query)
campaign = result.scalar_one_or_none()
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
campaign.processed_rows = processed_rows
campaign.failed_rows = failed_rows
campaign.updated_at = datetime.now(UTC)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
async def get_campaign_runs(
self,
campaign_id: int,
organization_id: int,
) -> list[WorkflowRunModel]:
"""Get workflow runs for a campaign"""
async with self.async_session() as session:
# First verify campaign belongs to organization
campaign_query = select(CampaignModel).where(
CampaignModel.id == campaign_id,
CampaignModel.organization_id == organization_id,
)
campaign_result = await session.execute(campaign_query)
campaign = campaign_result.scalar_one_or_none()
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
query = (
select(WorkflowRunModel)
.where(WorkflowRunModel.campaign_id == campaign_id)
.order_by(WorkflowRunModel.created_at.desc())
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_campaign_by_id(self, campaign_id: int) -> Optional[CampaignModel]:
"""Get campaign by ID without organization check (for internal use)"""
async with self.async_session() as session:
query = select(CampaignModel).where(CampaignModel.id == campaign_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def update_campaign(self, campaign_id: int, **kwargs) -> CampaignModel:
"""Update campaign with arbitrary fields"""
async with self.async_session() as session:
query = select(CampaignModel).where(CampaignModel.id == campaign_id)
result = await session.execute(query)
campaign = result.scalar_one_or_none()
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
# Update fields
for key, value in kwargs.items():
if hasattr(campaign, key):
setattr(campaign, key, value)
campaign.updated_at = datetime.now(UTC)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(campaign)
return campaign
# QueuedRun methods
async def bulk_create_queued_runs(self, queued_runs_data: list[dict]) -> None:
"""Bulk create queued runs"""
async with self.async_session() as session:
queued_runs = [QueuedRunModel(**data) for data in queued_runs_data]
session.add_all(queued_runs)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
async def get_queued_runs(
self,
campaign_id: int,
state: str = "queued",
limit: int = 10,
scheduled_for: Optional[bool] = None,
) -> list[QueuedRunModel]:
"""Get queued runs for processing, optionally filtering by scheduled status"""
async with self.async_session() as session:
query = select(QueuedRunModel).where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == state,
)
# Filter by scheduled status if specified
if scheduled_for is True:
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
elif scheduled_for is False:
query = query.where(QueuedRunModel.scheduled_for.is_(None))
query = query.order_by(QueuedRunModel.created_at).limit(limit)
result = await session.execute(query)
return list(result.scalars().all())
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
"""Update queued run"""
async with self.async_session() as session:
query = select(QueuedRunModel).where(QueuedRunModel.id == queued_run_id)
result = await session.execute(query)
queued_run = result.scalar_one_or_none()
if not queued_run:
raise ValueError(f"QueuedRun {queued_run_id} not found")
# Update fields
for key, value in kwargs.items():
if hasattr(queued_run, key):
setattr(queued_run, key, value)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(queued_run)
return queued_run
async def count_queued_runs(
self, campaign_id: int, state: Optional[str] = None
) -> int:
"""Count queued runs, optionally filtered by state"""
async with self.async_session() as session:
query = select(func.count(QueuedRunModel.id)).where(
QueuedRunModel.campaign_id == campaign_id
)
if state:
query = query.where(QueuedRunModel.state == state)
result = await session.execute(query)
return result.scalar() or 0
async def get_workflow_runs_by_campaign(
self, campaign_id: int
) -> list[WorkflowRunModel]:
"""Get all workflow runs for a campaign (internal use)"""
async with self.async_session() as session:
query = (
select(WorkflowRunModel)
.where(WorkflowRunModel.campaign_id == campaign_id)
.order_by(WorkflowRunModel.created_at)
)
result = await session.execute(query)
return list(result.scalars().all())
# New methods for retry support
async def get_scheduled_queued_runs(
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
) -> list[QueuedRunModel]:
"""Get scheduled queued runs that are due for processing"""
async with self.async_session() as session:
query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.isnot(None),
QueuedRunModel.scheduled_for <= scheduled_before,
)
.order_by(QueuedRunModel.scheduled_for)
.limit(limit)
)
result = await session.execute(query)
return list(result.scalars().all())
async def create_queued_run(
self,
campaign_id: int,
source_uuid: str,
context_variables: dict,
state: str = "queued",
retry_count: int = 0,
parent_queued_run_id: Optional[int] = None,
scheduled_for: Optional[datetime] = None,
retry_reason: Optional[str] = None,
) -> QueuedRunModel:
"""Create a single queued run with retry support"""
async with self.async_session() as session:
queued_run = QueuedRunModel(
campaign_id=campaign_id,
source_uuid=source_uuid,
context_variables=context_variables,
state=state,
retry_count=retry_count,
parent_queued_run_id=parent_queued_run_id,
scheduled_for=scheduled_for,
retry_reason=retry_reason,
)
session.add(queued_run)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(queued_run)
return queued_run
async def get_queued_run_by_id(
self, queued_run_id: int
) -> Optional[QueuedRunModel]:
"""Get a queued run by ID"""
async with self.async_session() as session:
query = select(QueuedRunModel).where(QueuedRunModel.id == queued_run_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_campaigns_by_status(self, statuses: list[str]) -> list[CampaignModel]:
"""Get campaigns by status"""
async with self.async_session() as session:
query = (
select(CampaignModel)
.where(CampaignModel.state.in_(statuses))
.order_by(CampaignModel.created_at.desc())
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_queued_runs_count(self, campaign_id: int, states: list[str]) -> int:
"""Get count of queued runs for a campaign in specified states"""
async with self.async_session() as session:
query = select(func.count(QueuedRunModel.id)).where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state.in_(states),
)
result = await session.execute(query)
return result.scalar() or 0
async def get_scheduled_runs_count(
self,
campaign_id: int,
scheduled_before: Optional[datetime] = None,
scheduled_after: Optional[datetime] = None,
) -> int:
"""Get count of scheduled runs for a campaign"""
async with self.async_session() as session:
conditions = [
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.scheduled_for.isnot(None),
QueuedRunModel.state == "queued",
]
if scheduled_before:
conditions.append(QueuedRunModel.scheduled_for <= scheduled_before)
if scheduled_after:
conditions.append(QueuedRunModel.scheduled_for > scheduled_after)
query = select(func.count(QueuedRunModel.id)).where(*conditions)
result = await session.execute(query)
return result.scalar() or 0

9
api/db/database.py Normal file
View file

@ -0,0 +1,9 @@
import os
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
DATABASE_URL = os.environ["DATABASE_URL"]
engine = create_async_engine(DATABASE_URL, echo=True)
async_session = sessionmaker(engine)

47
api/db/db_client.py Normal file
View file

@ -0,0 +1,47 @@
from api.db.api_key_client import APIKeyClient
from api.db.campaign_client import CampaignClient
from api.db.integration_client import IntegrationClient
from api.db.looptalk_client import LoopTalkClient
from api.db.organization_client import OrganizationClient
from api.db.organization_configuration_client import OrganizationConfigurationClient
from api.db.organization_usage_client import OrganizationUsageClient
from api.db.reports_client import ReportsClient
from api.db.user_client import UserClient
from api.db.workflow_client import WorkflowClient
from api.db.workflow_run_client import WorkflowRunClient
from api.db.workflow_template_client import WorkflowTemplateClient
class DBClient(
WorkflowClient,
WorkflowRunClient,
UserClient,
OrganizationClient,
OrganizationConfigurationClient,
OrganizationUsageClient,
IntegrationClient,
WorkflowTemplateClient,
LoopTalkClient,
CampaignClient,
ReportsClient,
APIKeyClient,
):
"""
Unified database client that combines all specialized database operations.
This client inherits from:
- WorkflowClient: handles workflow and workflow definition operations
- WorkflowRunClient: handles workflow run operations
- UserClient: handles user and user configuration operations
- OrganizationClient: handles organization operations
- OrganizationConfigurationClient: handles organization configuration operations
- OrganizationUsageClient: handles organization usage and quota operations
- IntegrationClient: handles integration operations
- WorkflowTemplateClient: handles workflow template operations
- LoopTalkClient: handles LoopTalk testing operations
- CampaignClient: handles campaign operations
- ReportsClient: handles reports and analytics operations
- APIKeyClient: handles API key operations
"""
pass

198
api/db/filters.py Normal file
View file

@ -0,0 +1,198 @@
"""Common filter utilities for database queries."""
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy import Integer, and_, cast
from sqlalchemy.dialects.postgresql import JSONB
from api.db.models import WorkflowRunModel
# Mapping of attribute names to database fields
ATTRIBUTE_FIELD_MAPPING = {
"dateRange": "created_at",
"dispositionCode": "gathered_context.mapped_call_disposition",
"duration": "usage_info.call_duration_seconds",
"status": "is_completed",
"tokenUsage": "cost_info.total_cost_usd",
"runId": "id",
"workflowId": "workflow_id",
"callTags": "gathered_context.call_tags",
"phoneNumber": "initial_context.phone",
}
def apply_workflow_run_filters(
base_query,
filters: Optional[List[Dict[str, Any]]] = None,
):
"""
Apply filters to a workflow run query.
Supports filtering by:
- dateRange: Filter by created_at date range
- dispositionCode: Filter by gathered_context.mapped_call_disposition
- duration: Filter by usage_info.call_duration_seconds range
- status: Filter by is_completed status
- tokenUsage: Filter by cost_info.total_cost_usd range
- runId: Filter by workflow run ID (exact match)
- workflowId: Filter by workflow ID (exact match)
- callTags: Filter by gathered_context.call_tags (array of strings)
- phoneNumber: Filter by initial_context.phone (text search)
Args:
base_query: The base SQLAlchemy query to apply filters to
filters: List of filter dictionaries with structure:
{"attribute": "filterName", "type": "filterType", "value": {...}}
Where type is one of:
- "dateRange": Date range filter with {"from": ..., "to": ...}
- "multiSelect": Multi-select filter with {"codes": [...]}
- "numberRange": Number range filter with {"min": ..., "max": ...}
- "number": Exact number filter with {"value": number}
- "text": Text search filter with {"value": string}
- "radio": Radio/status filter with {"status": ...}
- "tags": Tags filter with {"codes": [...]}
Returns:
The query with filters applied
"""
if not filters:
return base_query
filter_conditions = []
for filter_item in filters:
attribute = filter_item.get("attribute")
filter_type = filter_item.get("type")
value = filter_item.get("value", {})
# Resolve field from attribute mapping
field = ATTRIBUTE_FIELD_MAPPING.get(attribute)
if not field:
# Skip unknown attributes
continue
# Apply the filter based on provided type
if field and filter_type:
if filter_type == "number" and field == "id":
# Filter by exact workflow run ID
if value.get("value") is not None:
filter_conditions.append(WorkflowRunModel.id == value["value"])
elif filter_type == "number" and field == "workflow_id":
# Filter by exact workflow ID
if value.get("value") is not None:
filter_conditions.append(
WorkflowRunModel.workflow_id == value["value"]
)
elif filter_type == "dateRange" and field == "created_at":
# Same as attribute-based dateRange
if value.get("from"):
filter_conditions.append(
WorkflowRunModel.created_at
>= datetime.fromisoformat(value["from"])
)
if value.get("to"):
filter_conditions.append(
WorkflowRunModel.created_at
<= datetime.fromisoformat(value["to"])
)
elif (
filter_type == "multiSelect"
and field == "gathered_context.mapped_call_disposition"
):
codes = value.get("codes", [])
if codes:
filter_conditions.append(
cast(WorkflowRunModel.gathered_context, JSONB)[
"mapped_call_disposition"
]
.as_string()
.in_(codes)
)
elif filter_type == "radio" and field == "is_completed":
status = value.get("status")
if status == "completed":
filter_conditions.append(WorkflowRunModel.is_completed == True)
elif status == "in_progress":
filter_conditions.append(WorkflowRunModel.is_completed == False)
elif (
filter_type in ("tags", "multiSelect")
and field == "gathered_context.call_tags"
):
tags = value.get("codes", [])
if tags:
filter_conditions.append(
cast(WorkflowRunModel.gathered_context, JSONB)[
"call_tags"
].contains(tags)
)
elif filter_type == "text" and field == "initial_context.phone":
# Filter by phone number (contains search)
phone = value.get("value", "").strip()
if phone:
filter_conditions.append(
cast(WorkflowRunModel.initial_context, JSONB)["phone"]
.as_string()
.contains(phone)
)
elif filter_type == "numberRange":
min_val = value.get("min")
max_val = value.get("max")
if field == "usage_info.call_duration_seconds":
if min_val is not None:
filter_conditions.append(
cast(
cast(WorkflowRunModel.usage_info, JSONB)[
"call_duration_seconds"
],
Integer,
)
>= min_val
)
if max_val is not None:
filter_conditions.append(
cast(
cast(WorkflowRunModel.usage_info, JSONB)[
"call_duration_seconds"
],
Integer,
)
<= max_val
)
elif field == "cost_info.total_cost_usd":
if min_val is not None:
filter_conditions.append(
cast(
cast(WorkflowRunModel.cost_info, JSONB)[
"total_cost_usd"
],
Integer,
)
>= min_val
)
if max_val is not None:
filter_conditions.append(
cast(
cast(WorkflowRunModel.cost_info, JSONB)[
"total_cost_usd"
],
Integer,
)
<= max_val
)
if filter_conditions:
base_query = base_query.where(and_(*filter_conditions))
return base_query

View file

@ -0,0 +1,103 @@
from typing import List
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import IntegrationModel
class IntegrationClient(BaseDBClient):
async def get_integrations_by_organization_id(
self, organization_id: int
) -> list[IntegrationModel]:
"""Get all integrations for a specific organization."""
async with self.async_session() as session:
result = await session.execute(
select(IntegrationModel).where(
IntegrationModel.organisation_id == organization_id
)
)
return result.scalars().all()
async def create_integration(
self,
integration_id: str,
provider: str,
organisation_id: int,
connection_details: dict,
created_by: int = None,
is_active: bool = True,
) -> IntegrationModel:
"""Create a new integration for an organization."""
async with self.async_session() as session:
new_integration = IntegrationModel(
integration_id=integration_id,
organisation_id=organisation_id,
created_by=created_by,
is_active=is_active,
provider=provider,
connection_details=connection_details,
)
session.add(new_integration)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(new_integration)
return new_integration
async def update_integration_status(
self, integration_id: int, is_active: bool
) -> IntegrationModel | None:
"""Update the active status of an integration."""
async with self.async_session() as session:
result = await session.execute(
select(IntegrationModel).where(IntegrationModel.id == integration_id)
)
integration = result.scalars().first()
if not integration:
return None
integration.is_active = is_active
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(integration)
return integration
async def update_integration_connection_details(
self, integration_id: int, connection_details: dict
) -> IntegrationModel | None:
"""Update the connection details of an integration."""
async with self.async_session() as session:
result = await session.execute(
select(IntegrationModel).where(IntegrationModel.id == integration_id)
)
integration = result.scalars().first()
if not integration:
return None
integration.connection_details = connection_details
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(integration)
return integration
async def get_active_integrations_by_organization(
self, organization_id: int
) -> List[IntegrationModel]:
"""Get all active integrations for a specific organization."""
async with self.async_session() as session:
result = await session.execute(
select(IntegrationModel).where(
IntegrationModel.organisation_id == organization_id,
IntegrationModel.is_active == True,
)
)
return result.scalars().all()

259
api/db/looptalk_client.py Normal file
View file

@ -0,0 +1,259 @@
from datetime import UTC, datetime
from typing import Any, Dict, List, Optional
from uuid import uuid4
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from api.db.base_client import BaseDBClient
from api.db.models import (
LoopTalkConversation,
LoopTalkTestSession,
)
class LoopTalkClient(BaseDBClient):
"""Database client for LoopTalk testing operations."""
async def create_test_session(
self,
organization_id: int,
name: str,
actor_workflow_id: int,
adversary_workflow_id: int,
config: Dict[str, Any],
load_test_group_id: Optional[str] = None,
test_index: Optional[int] = None,
) -> LoopTalkTestSession:
"""Create a new LoopTalk test session."""
async with self.async_session() as session:
test_session = LoopTalkTestSession(
organization_id=organization_id,
name=name,
actor_workflow_id=actor_workflow_id,
adversary_workflow_id=adversary_workflow_id,
config=config,
load_test_group_id=load_test_group_id,
test_index=test_index,
status="pending",
)
session.add(test_session)
await session.commit()
await session.refresh(test_session)
return test_session
async def get_test_session(
self, test_session_id: int, organization_id: int
) -> Optional[LoopTalkTestSession]:
"""Get a test session by ID."""
async with self.async_session() as session:
result = await session.execute(
select(LoopTalkTestSession)
.options(
selectinload(LoopTalkTestSession.actor_workflow),
selectinload(LoopTalkTestSession.adversary_workflow),
selectinload(LoopTalkTestSession.conversations),
)
.where(
LoopTalkTestSession.id == test_session_id,
LoopTalkTestSession.organization_id == organization_id,
)
)
return result.scalar_one_or_none()
async def list_test_sessions(
self,
organization_id: int,
status: Optional[str] = None,
load_test_group_id: Optional[str] = None,
limit: int = 20,
offset: int = 0,
) -> List[LoopTalkTestSession]:
"""List test sessions with optional filtering."""
async with self.async_session() as session:
query = select(LoopTalkTestSession).where(
LoopTalkTestSession.organization_id == organization_id
)
if status:
# "active" is a virtual status used by the UI to represent
# both "pending" and "running" sessions. Translate it into
# the real enum values stored in the database to avoid
# invalid enum casting errors (e.g. asyncpg InvalidTextRepresentationError).
if status == "active":
query = query.where(
LoopTalkTestSession.status.in_(["pending", "running"])
)
else:
query = query.where(LoopTalkTestSession.status == status)
if load_test_group_id:
query = query.where(
LoopTalkTestSession.load_test_group_id == load_test_group_id
)
query = (
query.order_by(LoopTalkTestSession.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await session.execute(query)
return result.scalars().all()
async def update_test_session_status(
self,
test_session_id: int,
status: str,
error: Optional[str] = None,
results: Optional[Dict[str, Any]] = None,
) -> LoopTalkTestSession:
"""Update test session status and related fields."""
async with self.async_session() as session:
result = await session.execute(
select(LoopTalkTestSession).where(
LoopTalkTestSession.id == test_session_id
)
)
test_session = result.scalar_one()
test_session.status = status
if status == "running":
test_session.started_at = datetime.now(UTC)
elif status in ["completed", "failed"]:
test_session.completed_at = datetime.now(UTC)
if error:
test_session.error = error
if results:
test_session.results = results
await session.commit()
await session.refresh(test_session)
return test_session
async def create_conversation(self, test_session_id: int) -> LoopTalkConversation:
"""Create a new conversation for a test session."""
async with self.async_session() as session:
conversation = LoopTalkConversation(test_session_id=test_session_id)
session.add(conversation)
await session.commit()
await session.refresh(conversation)
return conversation
async def update_conversation(
self,
conversation_id: int,
duration_seconds: Optional[int] = None,
actor_recording_url: Optional[str] = None,
adversary_recording_url: Optional[str] = None,
combined_recording_url: Optional[str] = None,
transcript: Optional[Dict[str, Any]] = None,
metrics: Optional[Dict[str, Any]] = None,
ended_at: Optional[datetime] = None,
) -> LoopTalkConversation:
"""Update conversation details."""
async with self.async_session() as session:
result = await session.execute(
select(LoopTalkConversation).where(
LoopTalkConversation.id == conversation_id
)
)
conversation = result.scalar_one()
if duration_seconds is not None:
conversation.duration_seconds = duration_seconds
if actor_recording_url:
conversation.actor_recording_url = actor_recording_url
if adversary_recording_url:
conversation.adversary_recording_url = adversary_recording_url
if combined_recording_url:
conversation.combined_recording_url = combined_recording_url
if transcript:
conversation.transcript = transcript
if metrics:
conversation.metrics = metrics
if ended_at:
conversation.ended_at = ended_at
await session.commit()
await session.refresh(conversation)
return conversation
# Note: Turn tracking is handled by Langfuse, not stored in our database
async def create_load_test_group(
self,
organization_id: int,
name_prefix: str,
actor_workflow_id: int,
adversary_workflow_id: int,
config: Dict[str, Any],
test_count: int,
) -> List[LoopTalkTestSession]:
"""Create multiple test sessions for load testing."""
load_test_group_id = str(uuid4())
test_sessions = []
async with self.async_session() as session:
for i in range(test_count):
test_session = LoopTalkTestSession(
organization_id=organization_id,
name=f"{name_prefix} - Test {i + 1}",
actor_workflow_id=actor_workflow_id,
adversary_workflow_id=adversary_workflow_id,
config=config,
load_test_group_id=load_test_group_id,
test_index=i,
status="pending",
)
session.add(test_session)
test_sessions.append(test_session)
await session.commit()
# Refresh all sessions
for test_session in test_sessions:
await session.refresh(test_session)
return test_sessions
async def get_load_test_group_stats(
self, load_test_group_id: str, organization_id: int
) -> Dict[str, Any]:
"""Get statistics for a load test group."""
async with self.async_session() as session:
# Get all sessions in the group
result = await session.execute(
select(LoopTalkTestSession).where(
LoopTalkTestSession.load_test_group_id == load_test_group_id,
LoopTalkTestSession.organization_id == organization_id,
)
)
sessions = result.scalars().all()
# Calculate stats
stats = {
"total": len(sessions),
"pending": sum(1 for s in sessions if s.status == "pending"),
"running": sum(1 for s in sessions if s.status == "running"),
"completed": sum(1 for s in sessions if s.status == "completed"),
"failed": sum(1 for s in sessions if s.status == "failed"),
"sessions": [
{
"id": s.id,
"name": s.name,
"status": s.status,
"test_index": s.test_index,
"created_at": s.created_at,
"started_at": s.started_at,
"completed_at": s.completed_at,
"error": s.error,
}
for s in sessions
],
}
return stats

608
api/db/models.py Normal file
View file

@ -0,0 +1,608 @@
from datetime import UTC, datetime
from loguru import logger
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
Enum,
Float,
ForeignKey,
Index,
Integer,
String,
Table,
UniqueConstraint,
and_,
text,
)
from sqlalchemy.orm import declarative_base, relationship
from ..enums import IntegrationAction, WorkflowRunMode, WorkflowStatus
Base = declarative_base()
# TODO: remove workflow_defintion after migration, remove nullable workflow_defintion_id from Workflow and Workflowrun
# Association table for many-to-many relationship between users and organizations
organization_users_association = Table(
"organization_users",
Base.metadata,
Column("user_id", Integer, ForeignKey("users.id"), primary_key=True),
Column(
"organization_id", Integer, ForeignKey("organizations.id"), primary_key=True
),
)
class UserModel(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
provider_id = Column(String, unique=True, index=True, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
workflows = relationship("WorkflowModel", back_populates="user")
selected_organization_id = Column(
Integer, ForeignKey("organizations.id"), nullable=True
)
selected_organization = relationship("OrganizationModel", back_populates="users")
organizations = relationship(
"OrganizationModel",
secondary=organization_users_association,
back_populates="users",
)
is_superuser = Column(Boolean, default=False)
class UserConfigurationModel(Base):
__tablename__ = "user_configurations"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
configuration = Column(JSON, nullable=False, default=dict)
last_validated_at = Column(DateTime(timezone=True), nullable=True)
# New Organization model
class OrganizationModel(Base):
__tablename__ = "organizations"
id = Column(Integer, primary_key=True, index=True)
provider_id = Column(String, unique=True, index=True, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
# Quota fields
quota_type = Column(
Enum("monthly", "annual", name="quota_type"),
nullable=False,
default="monthly",
server_default=text("'monthly'::quota_type"),
)
quota_dograh_tokens = Column(
Integer, nullable=False, default=0, server_default=text("0")
)
quota_reset_day = Column(
Integer, nullable=False, default=1, server_default=text("1")
) # 1-28, only for monthly
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
quota_enabled = Column(
Boolean, nullable=False, default=False, server_default=text("false")
)
price_per_second_usd = Column(Float, nullable=True)
# Relationships
users = relationship(
"UserModel",
secondary=organization_users_association,
back_populates="organizations",
)
integrations = relationship("IntegrationModel", back_populates="organization")
usage_cycles = relationship(
"OrganizationUsageCycleModel", back_populates="organization"
)
configurations = relationship(
"OrganizationConfigurationModel", back_populates="organization"
)
api_keys = relationship("APIKeyModel", back_populates="organization")
class APIKeyModel(Base):
__tablename__ = "api_keys"
id = Column(Integer, primary_key=True, index=True)
organization_id = Column(
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
)
name = Column(String, nullable=False)
key_hash = Column(String, nullable=False, unique=True, index=True)
key_prefix = Column(String, nullable=False) # Store first 8 chars for display
is_active = Column(Boolean, default=True, nullable=False)
created_by = Column(Integer, ForeignKey("users.id"), nullable=True)
last_used_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
archived_at = Column(DateTime(timezone=True), nullable=True)
# Relationships
organization = relationship("OrganizationModel", back_populates="api_keys")
created_by_user = relationship("UserModel")
# Indexes for performance
__table_args__ = (
Index("ix_api_keys_organization_id", "organization_id"),
Index("ix_api_keys_key_hash", "key_hash"),
Index("ix_api_keys_active", "is_active"),
)
class OrganizationConfigurationModel(Base):
__tablename__ = "organization_configurations"
id = Column(Integer, primary_key=True, index=True)
organization_id = Column(
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
)
key = Column(String, nullable=False)
value = Column(JSON, nullable=False, default=dict)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships
organization = relationship("OrganizationModel", back_populates="configurations")
# Constraints and indexes
__table_args__ = (
UniqueConstraint("organization_id", "key", name="_organization_key_uc"),
Index("ix_organization_configurations_organization_id", "organization_id"),
)
class IntegrationModel(Base):
__tablename__ = "integrations"
id = Column(Integer, primary_key=True, index=True)
integration_id = Column(String, nullable=False, index=True) # Nango Connection ID
organisation_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
provider = Column(String, nullable=False)
created_by = Column(Integer, ForeignKey("users.id"))
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
connection_details = Column(JSON, nullable=False, default=dict)
action = Column(String, nullable=False, default=IntegrationAction.ALL_CALLS.value)
# Relationships
organization = relationship("OrganizationModel", back_populates="integrations")
class WorkflowDefinitionModel(Base):
__tablename__ = "workflow_definitions"
id = Column(Integer, primary_key=True, index=True)
workflow_hash = Column(String, nullable=False)
workflow_json = Column(JSON, nullable=False, default=dict)
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=True)
is_current = Column(
Boolean, default=False, nullable=False, server_default=text("false")
)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
# Table constraints and indexes
__table_args__ = (
UniqueConstraint(
"workflow_hash", "workflow_id", name="uq_workflow_hash_workflow_id"
),
Index("ix_workflow_hash_workflow_id", "workflow_hash", "workflow_id"),
)
# Relationships
workflow = relationship(
"WorkflowModel",
back_populates="definitions",
foreign_keys=[workflow_id],
)
workflow_runs = relationship("WorkflowRunModel", back_populates="definition")
class WorkflowModel(Base):
__tablename__ = "workflows"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
user = relationship("UserModel", back_populates="workflows")
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=True)
organization = relationship("OrganizationModel")
name = Column(String, index=True, nullable=False)
status = Column(
Enum(*[status.value for status in WorkflowStatus], name="workflow_status"),
nullable=False,
default=WorkflowStatus.ACTIVE.value,
server_default=text("'active'::workflow_status"),
)
workflow_definition = Column(JSON, nullable=False, default=dict)
template_context_variables = Column(JSON, nullable=False, default=dict)
call_disposition_codes = Column(JSON, nullable=False, default=dict)
workflow_configurations = Column(
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
)
runs = relationship("WorkflowRunModel", back_populates="workflow")
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
# All versions / historical definitions of this workflow
definitions = relationship(
"WorkflowDefinitionModel",
back_populates="workflow",
foreign_keys="WorkflowDefinitionModel.workflow_id",
)
# Relationship to fetch the current (is_current=True) definition
current_definition = relationship(
"WorkflowDefinitionModel",
primaryjoin=lambda: and_(
WorkflowDefinitionModel.workflow_id == WorkflowModel.id,
WorkflowDefinitionModel.is_current.is_(True),
),
uselist=False,
viewonly=True,
)
@property
def current_definition_id(self):
"""Return ID of the current workflow definition (helper for backwards-compat)."""
current_def = self.__dict__.get("current_definition")
if current_def is not None:
return current_def.id
# If relationship is not loaded, we cannot safely access definitions without
# risking an implicit lazy load on a detached instance. Return ``None`` in
# that scenario so callers can handle the absence explicitly.
return None
@property
def workflow_definition_with_fallback(self):
"""
Get workflow definition with fallback to legacy workflow_definition field.
Returns:
dict: The workflow definition JSON
"""
# Access the relationship only if it has ALREADY been eagerly loaded on this
# instance to avoid triggering an implicit lazy load once the SQLAlchemy
# Session has been closed (which would raise a DetachedInstanceError).
# ``__dict__`` will contain "current_definition" **only** when the attribute
# has been populated (e.g. via `selectinload` or an explicit access while
# the session was still open). Using ``__dict__.get`` guarantees that we
# do not accidentally issue a lazy load query on a detached instance.
current_definition = self.__dict__.get("current_definition")
if current_definition is not None:
return current_definition.workflow_json
# Fallback for backwards-compatibility when the relationship is not (yet)
# loaded. In this case we fall back to the legacy ``workflow_definition``
# column that always contains the most recent definition JSON.
logger.warning(
f"Workflow {self.id} has no loaded current definition, using workflow_definition as fallback",
)
return self.workflow_definition
class WorkflowTemplates(Base):
__tablename__ = "workflow_templates"
id = Column(Integer, primary_key=True, index=True)
template_name = Column(String, nullable=False, index=True)
template_description = Column(String, nullable=False, index=True)
template_json = Column(JSON, nullable=False, default=dict)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
class WorkflowRunModel(Base):
__tablename__ = "workflow_runs"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
workflow = relationship("WorkflowModel", back_populates="runs")
definition_id = Column(
Integer, ForeignKey("workflow_definitions.id"), nullable=True
)
definition = relationship("WorkflowDefinitionModel", back_populates="workflow_runs")
mode = Column(
Enum(*[mode.value for mode in WorkflowRunMode], name="workflow_run_mode"),
nullable=False,
)
is_completed = Column(Boolean, default=False)
recording_url = Column(String, nullable=True)
transcript_url = Column(String, nullable=True)
# Store storage backend as string enum (s3, minio)
storage_backend = Column(
Enum("s3", "minio", name="storage_backend"),
nullable=False,
default="s3",
server_default=text("'s3'::storage_backend"),
)
usage_info = Column(JSON, nullable=False, default=dict)
cost_info = Column(JSON, nullable=False, default=dict)
initial_context = Column(JSON, nullable=False, default=dict)
gathered_context = Column(JSON, nullable=False, default=dict)
logs = Column(JSON, nullable=False, default=dict, server_default=text("'{}'::json"))
annotations = Column(JSON, nullable=False, default=dict)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
campaign_id = Column(Integer, ForeignKey("campaigns.id"), nullable=True)
campaign = relationship("CampaignModel")
queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
queued_run = relationship("QueuedRunModel", foreign_keys=[queued_run_id])
# LoopTalk Testing Models
class LoopTalkTestSession(Base):
__tablename__ = "looptalk_test_sessions"
id = Column(Integer, primary_key=True, index=True)
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
name = Column(String, nullable=False)
status = Column(
Enum("pending", "running", "completed", "failed", name="test_session_status"),
nullable=False,
default="pending",
)
# Workflow configuration
actor_workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
adversary_workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
# Load testing configuration
load_test_group_id = Column(String, nullable=True, index=True)
test_index = Column(Integer, nullable=True)
# Test metadata
config = Column(JSON, nullable=False, default=dict)
results = Column(JSON, nullable=False, default=dict)
error = Column(String, nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
# Relationships
organization = relationship("OrganizationModel")
actor_workflow = relationship("WorkflowModel", foreign_keys=[actor_workflow_id])
adversary_workflow = relationship(
"WorkflowModel", foreign_keys=[adversary_workflow_id]
)
conversations = relationship("LoopTalkConversation", back_populates="test_session")
# Indexes for performance
__table_args__ = (
Index("ix_looptalk_test_sessions_org_id", "organization_id"),
Index("ix_looptalk_test_sessions_group_id", "load_test_group_id"),
Index("ix_looptalk_test_sessions_status", "status"),
)
class LoopTalkConversation(Base):
__tablename__ = "looptalk_conversations"
id = Column(Integer, primary_key=True, index=True)
test_session_id = Column(
Integer, ForeignKey("looptalk_test_sessions.id"), nullable=False
)
# Conversation metadata
duration_seconds = Column(Integer, nullable=True)
# Note: Turn tracking is handled by Langfuse, not stored here
# Audio recording URLs
actor_recording_url = Column(String, nullable=True)
adversary_recording_url = Column(String, nullable=True)
combined_recording_url = Column(String, nullable=True)
# Transcripts (if needed for quick access)
transcript = Column(JSON, nullable=False, default=dict)
# Metrics
metrics = Column(JSON, nullable=False, default=dict)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
ended_at = Column(DateTime(timezone=True), nullable=True)
# Relationships
test_session = relationship("LoopTalkTestSession", back_populates="conversations")
# Indexes
__table_args__ = (Index("ix_looptalk_conversations_session_id", "test_session_id"),)
class OrganizationUsageCycleModel(Base):
"""
This model is used to track the usage of Dograh tokens for an organization for a given usage
cycle.
"""
__tablename__ = "organization_usage_cycles"
id = Column(Integer, primary_key=True, index=True)
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
period_start = Column(DateTime(timezone=True), nullable=False)
period_end = Column(DateTime(timezone=True), nullable=False)
quota_dograh_tokens = Column(Integer, nullable=False)
used_dograh_tokens = Column(Float, nullable=False, default=0)
total_duration_seconds = Column(
Integer, nullable=False, default=0, server_default=text("0")
)
# New USD tracking fields
used_amount_usd = Column(Float, nullable=True, default=0)
quota_amount_usd = Column(Float, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships
organization = relationship("OrganizationModel", back_populates="usage_cycles")
# Constraints and indexes
__table_args__ = (
UniqueConstraint(
"organization_id", "period_start", "period_end", name="unique_org_period"
),
Index("idx_usage_cycles_org_period", "organization_id", "period_end"),
)
class CampaignModel(Base):
__tablename__ = "campaigns"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False, index=True)
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
# Source configuration
source_type = Column(String, nullable=False, default="google-sheet")
source_id = Column(String, nullable=False) # Sheet URL
# State management
state = Column(
Enum(
"created",
"syncing",
"running",
"paused",
"completed",
"failed",
name="campaign_state",
),
nullable=False,
default="created",
)
# Progress tracking
total_rows = Column(Integer, nullable=True)
processed_rows = Column(Integer, nullable=False, default=0)
failed_rows = Column(Integer, nullable=False, default=0)
# Rate limiting and sync configuration
rate_limit_per_second = Column(Integer, nullable=False, default=1)
max_retries = Column(Integer, nullable=False, default=0)
source_sync_status = Column(String, nullable=False, default="pending")
source_last_synced_at = Column(DateTime(timezone=True), nullable=True)
source_sync_error = Column(String, nullable=True)
# Retry configuration for call failures
retry_config = Column(
JSON,
nullable=False,
default={
"enabled": True,
"max_retries": 2,
"retry_delay_seconds": 120,
"retry_on_busy": True,
"retry_on_no_answer": True,
"retry_on_voicemail": True,
},
server_default=text(
'\'{"enabled": true, "max_retries": 2, "retry_on_busy": true, "retry_on_no_answer": true, "retry_on_voicemail": true, "retry_delay_seconds": 120}\'::jsonb'
),
)
# Orchestrator tracking fields
last_batch_scheduled_at = Column(DateTime(timezone=True), nullable=True)
last_activity_at = Column(DateTime(timezone=True), nullable=True)
orchestrator_metadata = Column(
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships
organization = relationship("OrganizationModel")
workflow = relationship("WorkflowModel")
created_by_user = relationship("UserModel")
# Indexes
__table_args__ = (
Index("ix_campaigns_org_id", "organization_id"),
Index("ix_campaigns_state", "state"),
Index("ix_campaigns_workflow_id", "workflow_id"),
# Index for efficient querying of active campaigns
Index(
"idx_campaigns_active_status",
"state",
postgresql_where=text("state IN ('syncing', 'running', 'paused')"),
),
)
class QueuedRunModel(Base):
__tablename__ = "queued_runs"
id = Column(Integer, primary_key=True, index=True)
campaign_id = Column(
Integer, ForeignKey("campaigns.id", ondelete="CASCADE"), nullable=False
)
source_uuid = Column(String, nullable=False)
context_variables = Column(JSON, nullable=False, default=dict)
state = Column(
Enum("queued", "processed", "failed", name="queued_run_state"),
nullable=False,
default="queued",
)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
processed_at = Column(DateTime(timezone=True), nullable=True)
# New retry-related fields
retry_count = Column(Integer, default=0, nullable=False, server_default=text("0"))
parent_queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
scheduled_for = Column(DateTime(timezone=True), nullable=True)
retry_reason = Column(String, nullable=True) # 'busy', 'no_answer', 'voicemail'
# Relationships
campaign = relationship("CampaignModel")
parent_queued_run = relationship("QueuedRunModel", remote_side=[id])
# Indexes
__table_args__ = (
Index("idx_queued_runs_campaign_state", "campaign_id", "state"),
Index("idx_queued_runs_created", "created_at"),
Index("idx_queued_runs_source_uuid", "source_uuid"),
Index(
"idx_queued_runs_scheduled", "scheduled_for"
), # New index for scheduled retries
# Optimized index for checking queued runs efficiently
Index(
"idx_queued_runs_campaign_state_optimized",
"campaign_id",
"state",
postgresql_where=text("state = 'queued'"),
),
# Optimized index for scheduled retries
Index(
"idx_queued_runs_scheduled_optimized",
"campaign_id",
"scheduled_for",
postgresql_where=text("scheduled_for IS NOT NULL"),
),
UniqueConstraint(
"campaign_id",
"source_uuid",
"retry_count",
name="unique_campaign_source_retry",
),
)

View file

@ -0,0 +1,114 @@
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import (
APIKeyModel,
OrganizationModel,
organization_users_association,
)
from api.utils.api_key import generate_api_key
class OrganizationClient(BaseDBClient):
async def get_organization_by_id(
self, organization_id: int
) -> Optional[OrganizationModel]:
"""Get an organization by its ID."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
return result.scalars().first()
async def get_or_create_organization_by_provider_id(
self, org_provider_id: str, user_id: int
) -> tuple[OrganizationModel, bool]:
"""Get an existing organization by provider_id or create a new one.
Returns:
A tuple of (organization, was_created) where was_created is True if the organization
was created in this call, False if it already existed.
"""
async with self.async_session() as session:
# First try to get existing organization
result = await session.execute(
select(OrganizationModel).where(
OrganizationModel.provider_id == org_provider_id
)
)
organization = result.scalars().first()
if organization is None:
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
# This is atomic and handles race conditions at the database level
stmt = insert(OrganizationModel.__table__).values(
provider_id=org_provider_id, created_at=datetime.now(timezone.utc)
)
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
stmt = stmt.on_conflict_do_nothing(index_elements=["provider_id"])
result = await session.execute(stmt)
await session.commit()
# Check if we actually inserted (rowcount > 0) or if there was a conflict (rowcount == 0)
was_created = result.rowcount > 0
# Now fetch the organization (either the one we just created or the one that existed)
result = await session.execute(
select(OrganizationModel).where(
OrganizationModel.provider_id == org_provider_id
)
)
organization = result.scalars().first()
if organization is None:
# This should never happen, but handle it just in case
error_msg = f"Failed to create or fetch organization with provider_id {org_provider_id}"
raise ValueError(error_msg)
# Only create API key if we actually created the organization
if was_created:
# Create a default API key for the new organization
_, key_hash, key_prefix = generate_api_key()
api_key = APIKeyModel(
organization_id=organization.id,
name="Default API Key",
key_hash=key_hash,
key_prefix=key_prefix,
is_active=True,
created_by=user_id,
)
session.add(api_key)
await session.commit()
await session.refresh(organization)
return organization, was_created
return organization, False
async def add_user_to_organization(
self, user_id: int, organization_id: int
) -> None:
"""Ensure that a user is linked to an organization (many-to-many).
The association is created only if it does not already exist.
Uses INSERT ... ON CONFLICT DO NOTHING to handle race conditions.
"""
async with self.async_session() as session:
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
# This handles race conditions at the database level
stmt = insert(organization_users_association).values(
user_id=user_id, organization_id=organization_id
)
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
# The primary key constraint on (user_id, organization_id) will trigger the conflict
stmt = stmt.on_conflict_do_nothing()
await session.execute(stmt)
await session.commit()

View file

@ -0,0 +1,96 @@
from typing import Any, Optional
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import OrganizationConfigurationModel
class OrganizationConfigurationClient(BaseDBClient):
async def get_configuration(
self, organization_id: int, key: str
) -> Optional[OrganizationConfigurationModel]:
"""Get a specific configuration for an organization by key."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key == key,
)
)
return result.scalars().first()
async def get_all_configurations(
self, organization_id: int
) -> list[OrganizationConfigurationModel]:
"""Get all configurations for an organization."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id
)
)
return result.scalars().all()
async def upsert_configuration(
self, organization_id: int, key: str, value: Any
) -> OrganizationConfigurationModel:
"""Create or update a configuration for an organization."""
async with self.async_session() as session:
# First try to get existing configuration
result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key == key,
)
)
config = result.scalars().first()
if config:
# Update existing configuration
config.value = value
else:
# Create new configuration
config = OrganizationConfigurationModel(
organization_id=organization_id,
key=key,
value=value,
)
session.add(config)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(config)
return config
async def delete_configuration(self, organization_id: int, key: str) -> bool:
"""Delete a configuration for an organization."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key == key,
)
)
config = result.scalars().first()
if not config:
return False
await session.delete(config)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
return True
async def get_configuration_value(
self, organization_id: int, key: str, default: Any = None
) -> Any:
"""Get the value of a configuration, returning default if not found."""
config = await self.get_configuration(organization_id, key)
return config.value if config else default

View file

@ -0,0 +1,524 @@
from datetime import datetime, timezone
from typing import Optional
from zoneinfo import ZoneInfo
from dateutil.relativedelta import relativedelta
from sqlalchemy import Date, and_, cast, func, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import joinedload
from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters
from api.db.models import (
OrganizationModel,
OrganizationUsageCycleModel,
UserConfigurationModel,
UserModel,
WorkflowModel,
WorkflowRunModel,
)
from api.schemas.user_configuration import UserConfiguration
class OrganizationUsageClient(BaseDBClient):
"""Client for managing organization usage and quota operations."""
async def get_or_create_current_cycle(
self, organization_id: int, session=None
) -> OrganizationUsageCycleModel:
"""Get or create the current usage cycle for an organization.
Args:
organization_id: The organization ID
session: Optional session to use for the operation. If provided,
the caller is responsible for committing.
"""
if session is None:
async with self.async_session() as session:
return await self._get_or_create_current_cycle_impl(
organization_id, session, commit=True
)
else:
return await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
async def _get_or_create_current_cycle_impl(
self, organization_id: int, session, commit: bool
) -> OrganizationUsageCycleModel:
"""Internal implementation for get_or_create_current_cycle."""
# Get organization to determine quota type
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one()
# Calculate current period
period_start, period_end = self._calculate_current_period(org)
# Try to get existing cycle
cycle_result = await session.execute(
select(OrganizationUsageCycleModel).where(
and_(
OrganizationUsageCycleModel.organization_id == organization_id,
OrganizationUsageCycleModel.period_start == period_start,
OrganizationUsageCycleModel.period_end == period_end,
)
)
)
cycle = cycle_result.scalar_one_or_none()
if cycle:
return cycle
# Create new cycle if it doesn't exist
stmt = insert(OrganizationUsageCycleModel).values(
organization_id=organization_id,
period_start=period_start,
period_end=period_end,
quota_dograh_tokens=org.quota_dograh_tokens,
)
# Handle concurrent inserts gracefully
stmt = stmt.on_conflict_do_nothing(
index_elements=["organization_id", "period_start", "period_end"]
)
await session.execute(stmt)
if commit:
await session.commit()
# Fetch the created cycle
cycle_result = await session.execute(
select(OrganizationUsageCycleModel).where(
and_(
OrganizationUsageCycleModel.organization_id == organization_id,
OrganizationUsageCycleModel.period_start == period_start,
OrganizationUsageCycleModel.period_end == period_end,
)
)
)
return cycle_result.scalar_one()
async def check_and_reserve_quota(
self, organization_id: int, estimated_tokens: int = 0
) -> bool:
"""
Check if organization has sufficient quota and optionally reserve tokens.
Returns True if quota is available, False otherwise.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get organization
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one_or_none()
if not org or not org.quota_enabled:
# No quota enforcement if not enabled
return True
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Atomic check and update with row-level lock
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(
and_(
OrganizationUsageCycleModel.id == cycle.id,
OrganizationUsageCycleModel.used_dograh_tokens
+ estimated_tokens
<= OrganizationUsageCycleModel.quota_dograh_tokens,
)
)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one_or_none()
if cycle_locked:
# Update the usage atomically
cycle_locked.used_dograh_tokens += estimated_tokens
await session.commit()
return True
return False
async def update_usage_after_run(
self,
organization_id: int,
actual_tokens: int,
duration_seconds: int = 0,
charge_usd: float = None,
) -> None:
"""Update usage after a workflow run completes with actual token count and duration.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Acquire a row-level lock for atomic update
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(OrganizationUsageCycleModel.id == cycle.id)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one()
# Update usage atomically
cycle_locked.used_dograh_tokens += actual_tokens
cycle_locked.total_duration_seconds += int(round(duration_seconds))
# Update USD amount if provided
if charge_usd is not None:
if cycle_locked.used_amount_usd is None:
cycle_locked.used_amount_usd = 0
cycle_locked.used_amount_usd += charge_usd
await session.commit()
async def get_current_usage(self, organization_id: int) -> dict:
"""Get current period usage information."""
async with self.async_session() as session:
# Get organization
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one()
# Get or create current cycle within the same session
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Calculate next refresh date
if org.quota_type == "monthly":
next_refresh = cycle.period_end + relativedelta(days=1)
else: # annual
next_refresh = cycle.period_end + relativedelta(days=1)
result = {
"period_start": cycle.period_start.isoformat(),
"period_end": cycle.period_end.isoformat(),
"used_dograh_tokens": cycle.used_dograh_tokens,
"quota_dograh_tokens": cycle.quota_dograh_tokens,
"percentage_used": (
round(
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
)
if cycle.quota_dograh_tokens > 0
else 0
),
"next_refresh_date": next_refresh.date().isoformat(),
"quota_enabled": org.quota_enabled,
"total_duration_seconds": cycle.total_duration_seconds,
}
# Add USD fields if organization has pricing
if org.price_per_second_usd is not None:
result["used_amount_usd"] = cycle.used_amount_usd or 0
result["quota_amount_usd"] = cycle.quota_amount_usd
result["currency"] = "USD"
result["price_per_second_usd"] = org.price_per_second_usd
# Calculate percentage based on USD if available
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
result["percentage_used"] = round(
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
)
return result
async def get_usage_history(
self,
organization_id: int,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 50,
offset: int = 0,
filters: Optional[list[dict]] = None,
) -> tuple[list[dict], int]:
"""Get paginated workflow runs with usage for an organization."""
async with self.async_session() as session:
query = (
select(WorkflowRunModel)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
# Apply date filters if provided
if start_date:
query = query.where(WorkflowRunModel.created_at >= start_date)
if end_date:
query = query.where(WorkflowRunModel.created_at <= end_date)
# Only allow specific filters for usage history endpoint
# This ensures security and prevents unexpected filter attributes
allowed_filters = {"duration", "dispositionCode", "phoneNumber"}
sanitized_filters = []
if filters:
for filter_item in filters:
attribute = filter_item.get("attribute")
# Only process allowed filters
if attribute in allowed_filters:
sanitized_filters.append(filter_item)
# Apply filters using the common filter function
query = apply_workflow_run_filters(query, sanitized_filters)
# Get total count
count_result = await session.execute(
select(func.count()).select_from(query.subquery())
)
total_count = count_result.scalar()
results = await session.execute(
query.options(joinedload(WorkflowRunModel.workflow))
.limit(limit)
.offset(offset)
)
runs = results.scalars().all()
# Format runs
formatted_runs = []
total_tokens = 0
total_duration_seconds = 0
for run in runs:
if run.cost_info:
# Try to get dograh_token_usage first (new format)
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
# If not present, calculate from total_cost_usd (old format)
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
dograh_tokens = round(
float(run.cost_info["total_cost_usd"]) * 100, 2
)
# Get call duration
call_duration = run.cost_info.get("call_duration_seconds", 0)
else:
dograh_tokens = 0
call_duration = 0
total_tokens += dograh_tokens
total_duration_seconds += int(round(call_duration))
# Extract phone number from initial_context
phone_number = None
if run.initial_context:
phone_number = run.initial_context.get("phone_number")
# Extract disposition from gathered_context
disposition = None
if run.gathered_context:
disposition = run.gathered_context.get("mapped_call_disposition")
run_data = {
"id": run.id,
"workflow_id": run.workflow_id,
"workflow_name": run.workflow.name if run.workflow else None,
"name": run.name,
"created_at": run.created_at.isoformat(),
"dograh_token_usage": dograh_tokens,
"call_duration_seconds": int(round(call_duration)),
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"phone_number": phone_number,
"disposition": disposition,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
}
# Add USD cost if available in cost_info
if run.cost_info and "charge_usd" in run.cost_info:
run_data["charge_usd"] = run.cost_info["charge_usd"]
formatted_runs.append(run_data)
return formatted_runs, total_count, total_tokens, total_duration_seconds
async def get_daily_usage_breakdown(
self,
organization_id: int,
start_date: datetime,
end_date: datetime,
price_per_second_usd: float,
user_id: Optional[int] = None,
) -> dict:
"""Get daily usage breakdown for an organization with pricing."""
async with self.async_session() as session:
# Get user timezone if user_id is provided
user_timezone = "UTC" # Default timezone
if user_id:
config_result = await session.execute(
select(UserConfigurationModel).where(
UserConfigurationModel.user_id == user_id
)
)
config_obj = config_result.scalar_one_or_none()
if config_obj and config_obj.configuration:
user_config = UserConfiguration.model_validate(
config_obj.configuration
)
if user_config.timezone:
user_timezone = user_config.timezone
# Validate timezone string
try:
# Test if timezone is valid
ZoneInfo(user_timezone)
except Exception:
# Fallback to UTC if timezone is invalid
user_timezone = "UTC"
# Query to get daily aggregates
# Use AT TIME ZONE to convert to user's timezone before grouping by date
date_expr = cast(
func.timezone(user_timezone, WorkflowRunModel.created_at), Date
)
daily_usage = await session.execute(
select(
date_expr.label("date"),
func.sum(
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
).label("total_seconds"),
func.count(WorkflowRunModel.id).label("call_count"),
)
.join(WorkflowModel, WorkflowModel.id == WorkflowRunModel.workflow_id)
.join(UserModel, UserModel.id == WorkflowModel.user_id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.created_at >= start_date,
WorkflowRunModel.created_at <= end_date,
WorkflowRunModel.is_completed == True,
)
.group_by(date_expr)
.order_by(date_expr.desc())
)
breakdown = []
total_minutes = 0
total_cost_usd = 0
total_dograh_tokens = 0
for row in daily_usage:
seconds = row.total_seconds or 0
minutes = seconds / 60
cost_usd = seconds * price_per_second_usd
dograh_tokens = cost_usd * 100 # 1 cent = 1 token
total_minutes += minutes
total_cost_usd += cost_usd
total_dograh_tokens += dograh_tokens
breakdown.append(
{
"date": row.date.isoformat(),
"minutes": round(minutes, 1),
"cost_usd": round(cost_usd, 2),
"dograh_tokens": round(dograh_tokens, 0),
"call_count": row.call_count,
}
)
return {
"breakdown": breakdown,
"total_minutes": round(total_minutes, 1),
"total_cost_usd": round(total_cost_usd, 2),
"total_dograh_tokens": round(total_dograh_tokens, 0),
"currency": "USD",
}
async def update_organization_quota(
self,
organization_id: int,
quota_type: str,
quota_dograh_tokens: int,
quota_reset_day: Optional[int] = None,
quota_start_date: Optional[datetime] = None,
) -> OrganizationModel:
"""Update organization quota settings."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = result.scalar_one()
org.quota_type = quota_type
org.quota_dograh_tokens = quota_dograh_tokens
org.quota_enabled = True
if quota_type == "monthly" and quota_reset_day:
org.quota_reset_day = quota_reset_day
elif quota_type == "annual" and quota_start_date:
org.quota_start_date = quota_start_date
await session.commit()
await session.refresh(org)
return org
def _calculate_current_period(
self, org: OrganizationModel
) -> tuple[datetime, datetime]:
"""Calculate the current billing period based on organization settings."""
now = datetime.now(timezone.utc)
if org.quota_type == "monthly":
# Find the start of the current billing month
reset_day = org.quota_reset_day
# Handle month boundaries
if now.day >= reset_day:
period_start = now.replace(
day=reset_day, hour=0, minute=0, second=0, microsecond=0
)
else:
# Previous month
period_start = (now - relativedelta(months=1)).replace(
day=reset_day, hour=0, minute=0, second=0, microsecond=0
)
# End is one month later minus 1 second
period_end = (
period_start + relativedelta(months=1) - relativedelta(seconds=1)
)
else: # annual
if not org.quota_start_date:
# Default to calendar year
period_start = now.replace(
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
period_end = (
period_start + relativedelta(years=1) - relativedelta(seconds=1)
)
else:
# Find current annual period
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
years_diff = now.year - start_date.year
# Adjust for whether we've passed the anniversary
if now.month < start_date.month or (
now.month == start_date.month and now.day < start_date.day
):
years_diff -= 1
period_start = start_date + relativedelta(years=years_diff)
period_end = (
period_start + relativedelta(years=1) - relativedelta(seconds=1)
)
return period_start, period_end

156
api/db/reports_client.py Normal file
View file

@ -0,0 +1,156 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy import String, and_, func, select
from api.db.base_client import BaseDBClient
from api.db.models import WorkflowModel, WorkflowRunModel
class ReportsClient(BaseDBClient):
async def get_workflow_runs_for_daily_report(
self,
organization_id: int,
start_utc: datetime,
end_utc: datetime,
workflow_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Optimized method for daily reports - fetches only required JSON fields.
Uses PostgreSQL JSON operators to extract only needed fields from JSON columns.
Args:
organization_id: The organization ID to filter by
start_utc: Start datetime in UTC
end_utc: End datetime in UTC
workflow_id: Optional workflow ID to filter by
Returns:
List of dictionaries with report-specific fields
"""
async with self.async_session() as session:
# Select only the specific JSON fields needed for daily reports
# Using PostgreSQL's JSON operators to extract specific fields
query = (
select(
WorkflowRunModel.id,
WorkflowRunModel.workflow_id,
WorkflowRunModel.created_at,
# Extract only specific fields from JSON columns
# Use TRIM and REPLACE to remove any quotes from JSON values
func.coalesce(
func.replace(
func.replace(
func.cast(
WorkflowRunModel.gathered_context[
"mapped_call_disposition"
],
String,
),
'"',
"",
),
"'",
"",
),
"UNKNOWN",
).label("disposition"),
func.coalesce(
func.replace(
func.replace(
func.cast(
WorkflowRunModel.gathered_context[
"customer_phone_number"
],
String,
),
'"',
"",
),
"'",
"",
),
func.replace(
func.replace(
func.cast(
WorkflowRunModel.initial_context["phone_number"],
String,
),
'"',
"",
),
"'",
"",
),
"",
).label("phone_number"),
func.coalesce(
func.replace(
func.replace(
func.cast(
WorkflowRunModel.usage_info[
"call_duration_seconds"
],
String,
),
'"',
"",
),
"'",
"",
),
"0",
).label("call_duration_seconds"),
WorkflowModel.name.label("workflow_name"),
)
.select_from(WorkflowRunModel)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.where(
and_(
WorkflowModel.organization_id == organization_id,
WorkflowRunModel.created_at >= start_utc,
WorkflowRunModel.created_at <= end_utc,
)
)
)
if workflow_id is not None:
query = query.where(WorkflowRunModel.workflow_id == workflow_id)
result = await session.execute(query)
rows = result.all()
return [
{
"id": row.id,
"workflow_id": row.workflow_id,
"workflow_name": row.workflow_name,
"created_at": row.created_at,
"gathered_context": {
"mapped_call_disposition": row.disposition,
"customer_phone_number": row.phone_number, # Also provide it here for compatibility
},
"usage_info": {"call_duration_seconds": row.call_duration_seconds},
"initial_context": {"phone_number": row.phone_number},
}
for row in rows
]
async def get_workflows_for_organization(
self, organization_id: int
) -> List[WorkflowModel]:
"""
Get all workflows for an organization.
Args:
organization_id: The organization ID
"""
async with self.async_session() as session:
query = (
select(WorkflowModel)
.where(WorkflowModel.organization_id == organization_id)
.order_by(WorkflowModel.name)
)
result = await session.execute(query)
return result.scalars().all()

139
api/db/user_client.py Normal file
View file

@ -0,0 +1,139 @@
from datetime import datetime, timezone
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import UserConfigurationModel, UserModel
from api.schemas.user_configuration import UserConfiguration
class UserClient(BaseDBClient):
async def get_or_create_user_by_provider_id(self, provider_id: str) -> UserModel:
async with self.async_session() as session:
# First try to get existing user
result = await session.execute(
select(UserModel).where(UserModel.provider_id == provider_id)
)
user = result.scalars().first()
if user is None:
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
# This is atomic and handles race conditions at the database level
from sqlalchemy.dialects.postgresql import insert
stmt = insert(UserModel.__table__).values(
provider_id=provider_id,
created_at=datetime.now(timezone.utc),
selected_organization_id=None, # Will be set later
is_superuser=False, # Default value
)
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
stmt = stmt.on_conflict_do_nothing(index_elements=["provider_id"])
result = await session.execute(stmt)
await session.commit()
# Now fetch the user (either the one we just created or the one that existed)
result = await session.execute(
select(UserModel).where(UserModel.provider_id == provider_id)
)
user = result.scalars().first()
if user is None:
# This should never happen, but handle it just in case
error_msg = (
f"Failed to create or fetch user with provider_id {provider_id}"
)
raise ValueError(error_msg)
return user
async def get_user_by_id(self, user_id: int) -> UserModel | None:
"""Fetch a user by their internal ID."""
async with self.async_session() as session:
result = await session.execute(
select(UserModel).where(UserModel.id == user_id)
)
return result.scalars().first()
async def get_user_configurations(self, user_id: int) -> UserConfiguration:
async with self.async_session() as session:
result = await session.execute(
select(UserConfigurationModel).where(
UserConfigurationModel.user_id == user_id
)
)
configuration_obj = result.scalars().first()
if not configuration_obj:
return UserConfiguration()
return UserConfiguration.model_validate(
{
**configuration_obj.configuration,
"last_validated_at": configuration_obj.last_validated_at,
}
)
async def update_user_configuration(
self, user_id: int, configuration: UserConfiguration
) -> UserConfiguration:
async with self.async_session() as session:
result = await session.execute(
select(UserConfigurationModel).where(
UserConfigurationModel.user_id == user_id
)
)
configuration_obj = result.scalars().first()
if not configuration_obj:
configuration_obj = UserConfigurationModel(
user_id=user_id, configuration=configuration.model_dump()
)
session.add(configuration_obj)
else:
configuration_obj.configuration = configuration.model_dump()
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(configuration_obj)
return UserConfiguration.model_validate(configuration_obj.configuration)
async def update_user_configuration_last_validated_at(self, user_id: int) -> None:
async with self.async_session() as session:
result = await session.execute(
select(UserConfigurationModel).where(
UserConfigurationModel.user_id == user_id
)
)
configuration_obj = result.scalars().first()
if not configuration_obj:
raise ValueError(f"User configuration with ID {user_id} not found")
configuration_obj.last_validated_at = datetime.now()
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(configuration_obj)
async def update_user_selected_organization(
self, user_id: int, organization_id: int
) -> None:
"""Update the user's selected organization ID."""
async with self.async_session() as session:
from sqlalchemy import update
# Use a direct UPDATE statement to avoid race conditions
# This is atomic at the database level
stmt = (
update(UserModel)
.where(UserModel.id == user_id)
.values(selected_organization_id=organization_id)
)
result = await session.execute(stmt)
if result.rowcount == 0:
raise ValueError(f"User with ID {user_id} not found")
await session.commit()

312
api/db/workflow_client.py Normal file
View file

@ -0,0 +1,312 @@
import hashlib
import json
from typing import Optional
from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from api.db.base_client import BaseDBClient
from api.db.models import WorkflowDefinitionModel, WorkflowModel, WorkflowRunModel
class WorkflowClient(BaseDBClient):
def _generate_workflow_hash(self, workflow_definition: dict) -> str:
"""Generate a consistent hash for workflow definition."""
# Convert to JSON with sorted keys for consistent hashing
json_str = json.dumps(
workflow_definition, sort_keys=True, separators=(",", ":")
)
return hashlib.sha256(json_str.encode()).hexdigest()
async def _get_or_create_workflow_definition(
self, workflow_definition: dict, session, workflow_id: int = None
) -> WorkflowDefinitionModel:
"""Get existing workflow definition by hash or create a new one."""
workflow_hash = self._generate_workflow_hash(workflow_definition)
# Try to find existing definition
result = await session.execute(
select(WorkflowDefinitionModel).where(
WorkflowDefinitionModel.workflow_hash == workflow_hash,
WorkflowDefinitionModel.workflow_id == workflow_id,
)
)
existing_definition = result.scalars().first()
if existing_definition:
return existing_definition
# Create new definition if it doesn't exist
new_definition = WorkflowDefinitionModel(
workflow_hash=workflow_hash,
workflow_json=workflow_definition,
workflow_id=workflow_id,
)
session.add(new_definition)
await session.flush() # Flush to get the ID without committing
return new_definition
async def create_workflow(
self,
name: str,
workflow_definition: dict,
user_id: int,
organization_id: int = None,
) -> WorkflowModel:
async with self.async_session() as session:
try:
new_workflow = WorkflowModel(
name=name,
workflow_definition=workflow_definition, # Keep for backwards compatibility
user_id=user_id,
organization_id=organization_id,
)
session.add(new_workflow)
await session.flush() # Flush to get the workflow ID
# Now get or create workflow definition with the workflow_id
definition = await self._get_or_create_workflow_definition(
workflow_definition, session, new_workflow.id
)
# Mark this definition as the current one and unset others
definition.is_current = True
# Set any other definitions for this workflow to not current
other_defs_result = await session.execute(
select(WorkflowDefinitionModel).where(
WorkflowDefinitionModel.workflow_id == new_workflow.id,
WorkflowDefinitionModel.id != definition.id,
)
)
for other_def in other_defs_result.scalars().all():
other_def.is_current = False
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(new_workflow)
return new_workflow
async def get_all_workflows(
self, user_id: int = None, organization_id: int = None, status: str = None
) -> list[WorkflowModel]:
async with self.async_session() as session:
query = select(WorkflowModel).options(
selectinload(WorkflowModel.current_definition)
)
if organization_id:
# Filter by organization_id when provided
query = query.where(WorkflowModel.organization_id == organization_id)
elif user_id:
# Fallback to user_id for backwards compatibility
query = query.where(WorkflowModel.user_id == user_id)
# Filter by status if provided
if status:
query = query.where(WorkflowModel.status == status)
result = await session.execute(query)
return result.scalars().all()
async def get_workflow(
self, workflow_id: int, user_id: int = None, organization_id: int = None
) -> WorkflowModel | None:
async with self.async_session() as session:
query = (
select(WorkflowModel)
.options(selectinload(WorkflowModel.current_definition))
.where(WorkflowModel.id == workflow_id)
)
if organization_id:
# Filter by organization_id when provided
query = query.where(WorkflowModel.organization_id == organization_id)
elif user_id:
# Fallback to user_id for backwards compatibility
query = query.where(WorkflowModel.user_id == user_id)
result = await session.execute(query)
return result.scalars().first()
async def get_workflow_by_id(self, workflow_id: int) -> WorkflowModel | None:
async with self.async_session() as session:
result = await session.execute(
select(WorkflowModel)
.options(selectinload(WorkflowModel.current_definition))
.where(WorkflowModel.id == workflow_id)
)
return result.scalars().first()
async def update_workflow(
self,
workflow_id: int,
name: str,
workflow_definition: dict | None,
template_context_variables: dict | None,
workflow_configurations: dict | None,
user_id: int = None,
organization_id: int = None,
) -> WorkflowModel:
"""
Update an existing workflow in the database.
Args:
workflow_id: The ID of the workflow to update
name: The new name for the workflow
workflow_definition: The new workflow definition
template_context_variables: The template context variables
user_id: The user ID (for backwards compatibility)
organization_id: The organization ID
Returns:
The updated WorkflowModel
Raises:
ValueError: If the workflow with the given ID is not found
"""
async with self.async_session() as session:
query = (
select(WorkflowModel)
.options(selectinload(WorkflowModel.current_definition))
.where(WorkflowModel.id == workflow_id)
)
if organization_id:
# Filter by organization_id when provided
query = query.where(WorkflowModel.organization_id == organization_id)
elif user_id:
# Fallback to user_id for backwards compatibility
query = query.where(WorkflowModel.user_id == user_id)
result = await session.execute(query)
workflow = result.scalars().first()
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
workflow.name = name
if template_context_variables is not None:
workflow.template_context_variables = template_context_variables
if workflow_configurations is not None:
workflow.workflow_configurations = workflow_configurations
# In case of only name update, the workflow_definition can be None
if workflow_definition:
# Get or create new workflow definition
definition = await self._get_or_create_workflow_definition(
workflow_definition, session, workflow_id
)
# Update legacy field for backwards compatibility
workflow.workflow_definition = workflow_definition
# Mark new definition as current and reset others
definition.is_current = True
other_defs_result = await session.execute(
select(WorkflowDefinitionModel).where(
WorkflowDefinitionModel.workflow_id == workflow_id,
WorkflowDefinitionModel.id != definition.id,
)
)
for other_def in other_defs_result.scalars().all():
other_def.is_current = False
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(workflow)
return workflow
async def get_workflows_by_ids(
self, workflow_ids: list[int], organization_id: int
) -> list[WorkflowModel]:
"""Get workflows by IDs for a specific organization"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowModel)
.join(WorkflowModel.user)
.where(
WorkflowModel.id.in_(workflow_ids),
WorkflowModel.user.has(selected_organization_id=organization_id),
)
)
return result.scalars().all()
async def get_workflow_name(
self, workflow_id: int, user_id: int = None, organization_id: int = None
) -> Optional[str]:
"""Get just the workflow name by ID"""
async with self.async_session() as session:
query = select(WorkflowModel.name).where(WorkflowModel.id == workflow_id)
if organization_id:
# Filter by organization_id when provided
query = query.where(WorkflowModel.organization_id == organization_id)
elif user_id:
# Fallback to user_id for backwards compatibility
query = query.where(WorkflowModel.user_id == user_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def update_workflow_status(
self,
workflow_id: int,
status: str,
organization_id: int = None,
) -> WorkflowModel:
"""
Update the status of a workflow.
Args:
workflow_id: The ID of the workflow to update
status: The new status (active/archived)
organization_id: The organization ID
Returns:
The updated WorkflowModel
Raises:
ValueError: If the workflow is not found
"""
async with self.async_session() as session:
query = (
select(WorkflowModel)
.options(selectinload(WorkflowModel.current_definition))
.where(WorkflowModel.id == workflow_id)
)
if organization_id:
query = query.where(WorkflowModel.organization_id == organization_id)
result = await session.execute(query)
workflow = result.scalars().first()
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
workflow.status = status
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(workflow)
return workflow
async def get_workflow_run_count(self, workflow_id: int) -> int:
"""Get the count of runs for a workflow."""
async with self.async_session() as session:
result = await session.execute(
select(func.count(WorkflowRunModel.id)).where(
WorkflowRunModel.workflow_id == workflow_id
)
)
return result.scalar() or 0

View file

@ -0,0 +1,404 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import joinedload, selectinload
from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters
from api.db.models import (
OrganizationModel,
UserModel,
WorkflowDefinitionModel,
WorkflowModel,
WorkflowRunModel,
)
from api.schemas.workflow import WorkflowRunResponseSchema
class WorkflowRunClient(BaseDBClient):
async def create_workflow_run(
self,
name: str,
workflow_id: int,
mode: str,
user_id: int,
initial_context: dict = None,
campaign_id: int = None,
queued_run_id: int = None,
) -> WorkflowRunModel:
async with self.async_session() as session:
# Get workflow and user to check organization
workflow = await session.execute(
select(WorkflowModel)
.options(joinedload(WorkflowModel.user))
.where(
WorkflowModel.id == workflow_id, WorkflowModel.user_id == user_id
)
)
workflow = workflow.scalars().first()
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
# # Check quota if user has an organization
# if workflow.user and workflow.user.selected_organization_id:
# # Import here to avoid circular dependency
# from api.db.organization_usage_client import OrganizationUsageClient
# usage_client = OrganizationUsageClient()
# # Check quota (no reservation for now, actual cost will be added after completion)
# has_quota = await usage_client.check_and_reserve_quota(
# workflow.user.selected_organization_id, estimated_tokens=0
# )
# if not has_quota:
# raise ValueError(
# "Organization quota exceeded. Please contact your administrator."
# )
# Fetch the current definition for this workflow
current_def_result = await session.execute(
select(WorkflowDefinitionModel).where(
WorkflowDefinitionModel.workflow_id == workflow.id,
WorkflowDefinitionModel.is_current == True,
)
)
current_def = current_def_result.scalars().first()
new_run = WorkflowRunModel(
name=name,
workflow=workflow,
mode=mode,
definition_id=current_def.id if current_def else None,
initial_context=initial_context or workflow.template_context_variables,
campaign_id=campaign_id,
queued_run_id=queued_run_id,
)
session.add(new_run)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(new_run)
return new_run
async def get_all_workflow_runs(self) -> list[WorkflowRunModel]:
async with self.async_session() as session:
result = await session.execute(select(WorkflowRunModel))
return result.scalars().all()
async def get_workflow_runs_for_superadmin(
self,
limit: int = 50,
offset: int = 0,
filters: Optional[List[Dict[str, Any]]] = None,
) -> tuple[list[dict], int]:
"""
Get paginated workflow runs for superadmin with organization information.
Returns tuple of (workflow_runs, total_count).
"""
async with self.async_session() as session:
# Build base query with joins
base_query = (
select(WorkflowRunModel)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.outerjoin(
OrganizationModel,
UserModel.selected_organization_id == OrganizationModel.id,
)
)
# Apply filters
base_query = apply_workflow_run_filters(base_query, filters)
# Count total with filters
count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id))
count_result = await session.execute(count_query)
total_count = count_result.scalar()
# Get paginated results with filters
result = await session.execute(
base_query.options(
joinedload(WorkflowRunModel.workflow).joinedload(
WorkflowModel.user
),
joinedload(WorkflowRunModel.workflow)
.joinedload(WorkflowModel.user)
.joinedload(UserModel.selected_organization),
)
.order_by(WorkflowRunModel.created_at.desc())
.limit(limit)
.offset(offset)
)
workflow_runs = result.scalars().all()
# Format the response
formatted_runs = []
for run in workflow_runs:
organization = (
run.workflow.user.selected_organization
if run.workflow.user
else None
)
formatted_runs.append(
{
"id": run.id,
"name": run.name,
"workflow_id": run.workflow_id,
"workflow_name": run.workflow.name if run.workflow else None,
"user_id": run.workflow.user_id if run.workflow else None,
"organization_id": organization.id if organization else None,
"organization_name": organization.provider_id
if organization
else None,
"mode": run.mode,
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"usage_info": run.usage_info,
"cost_info": run.cost_info,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
"admin_comment": (run.annotations or {}).get("admin_comment"),
"admin_comment_ts": (run.annotations or {}).get(
"admin_comment_ts"
),
"created_at": run.created_at,
}
)
return formatted_runs, total_count
async def get_workflow_run(
self, run_id: int, user_id: int = None, organization_id: int = None
) -> WorkflowRunModel | None:
async with self.async_session() as session:
query = select(WorkflowRunModel).join(WorkflowRunModel.workflow)
if organization_id:
# Filter by organization_id when provided
query = query.where(
WorkflowRunModel.id == run_id,
WorkflowModel.organization_id == organization_id,
)
elif user_id:
# Fallback to user_id for backwards compatibility
query = query.where(
WorkflowRunModel.id == run_id,
WorkflowModel.user_id == user_id,
)
else:
query = query.where(WorkflowRunModel.id == run_id)
result = await session.execute(query)
return result.scalars().first()
async def get_workflow_run_by_id(self, run_id: int) -> WorkflowRunModel | None:
"""Get workflow run by ID without user filtering - for background tasks"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel)
.options(
joinedload(WorkflowRunModel.workflow).joinedload(WorkflowModel.user)
)
.where(WorkflowRunModel.id == run_id)
)
return result.scalars().first()
async def get_workflow_runs_by_workflow_id(
self,
workflow_id: int,
user_id: int = None,
organization_id: int = None,
limit: int = 50,
offset: int = 0,
filters: Optional[List[Dict[str, Any]]] = None,
) -> tuple[list[WorkflowRunResponseSchema], int]:
async with self.async_session() as session:
# Build base query
base_query = (
select(WorkflowRunModel)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.where(WorkflowRunModel.workflow_id == workflow_id)
)
if organization_id:
# Filter by organization_id when provided
base_query = base_query.where(
WorkflowModel.organization_id == organization_id
)
elif user_id:
# Fallback to user_id for backwards compatibility
base_query = base_query.where(WorkflowModel.user_id == user_id)
# Apply filters
base_query = apply_workflow_run_filters(base_query, filters)
# Count total with filters
count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id))
count_result = await session.execute(count_query)
total_count = count_result.scalar()
# Get paginated results with filters
result = await session.execute(
base_query.order_by(WorkflowRunModel.created_at.desc())
.limit(limit)
.offset(offset)
)
runs = [
WorkflowRunResponseSchema.model_validate(
{
"id": run.id,
"workflow_id": run.workflow_id,
"name": run.name,
"mode": run.mode,
"created_at": run.created_at,
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
}
)
for run in result.scalars().all()
]
return runs, total_count
async def update_workflow_run(
self,
run_id: int,
is_completed: bool = False,
recording_url: str | None = None,
transcript_url: str | None = None,
storage_backend: str | None = None,
usage_info: dict | None = None,
cost_info: dict | None = None,
initial_context: dict | None = None,
gathered_context: dict | None = None,
logs: dict | None = None,
) -> WorkflowRunModel:
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel).where(WorkflowRunModel.id == run_id)
)
run = result.scalars().first()
if not run:
raise ValueError(f"Workflow run with ID {run_id} not found")
if recording_url:
run.recording_url = recording_url
if transcript_url:
run.transcript_url = transcript_url
if storage_backend:
run.storage_backend = storage_backend
if usage_info:
run.usage_info = usage_info
if cost_info:
run.cost_info = cost_info
if initial_context:
run.initial_context = initial_context
if gathered_context:
# Lets merge the incoming gathered context keys with the existing ones
run.gathered_context = {
**run.gathered_context,
**gathered_context,
}
if logs:
# Lets merge the incoming logs key with existing ones
run.logs = {**run.logs, **logs}
if is_completed:
run.is_completed = is_completed
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(run)
return run
async def update_admin_comment(
self, run_id: int, admin_comment: str
) -> WorkflowRunModel:
"""Update (or create) the admin comment inside the ``annotations`` JSON column.
The comment is stored under the key ``admin_comment`` so we do not
overwrite any other existing annotations that may be present.
"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel).where(WorkflowRunModel.id == run_id)
)
run = result.scalars().first()
if run is None:
raise ValueError(f"Workflow run with ID {run_id} not found")
# Ensure we never mutate a shared dict between instances
current_annotations = dict(run.annotations or {})
current_annotations["admin_comment"] = admin_comment
current_annotations["admin_comment_ts"] = datetime.now(
timezone.utc
).isoformat()
run.annotations = current_annotations
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(run)
return run
async def get_workflow_run_with_context(
self, workflow_run_id: int
) -> Tuple[Optional[WorkflowRunModel], Optional[int]]:
"""
Get workflow run with all related data and return organization_id.
Returns:
Tuple of (workflow_run, organization_id) or (None, None) if not found
"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel)
.options(
selectinload(WorkflowRunModel.workflow).selectinload(
WorkflowModel.user
)
)
.where(WorkflowRunModel.id == workflow_run_id)
)
workflow_run = result.scalars().first()
if not workflow_run:
return None, None
if not workflow_run.workflow or not workflow_run.workflow.user:
return workflow_run, None
organization_id = workflow_run.workflow.user.selected_organization_id
return workflow_run, organization_id

View file

@ -0,0 +1,99 @@
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import WorkflowTemplates
class WorkflowTemplateClient(BaseDBClient):
async def get_workflow_template(self, template_id: int) -> WorkflowTemplates | None:
"""Get a workflow template by ID."""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowTemplates).where(WorkflowTemplates.id == template_id)
)
return result.scalars().first()
async def get_workflow_template_by_name(
self, template_name: str
) -> WorkflowTemplates | None:
"""Get a workflow template by name."""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowTemplates).where(
WorkflowTemplates.template_name == template_name
)
)
return result.scalars().first()
async def get_all_workflow_templates(self) -> list[WorkflowTemplates]:
"""Get all workflow templates."""
async with self.async_session() as session:
result = await session.execute(select(WorkflowTemplates))
return result.scalars().all()
async def create_workflow_template(
self, template_name: str, template_description: str, template_json: dict
) -> WorkflowTemplates:
"""Create a new workflow template."""
async with self.async_session() as session:
try:
new_template = WorkflowTemplates(
template_name=template_name,
template_description=template_description,
template_json=template_json,
)
session.add(new_template)
await session.commit()
await session.refresh(new_template)
return new_template
except Exception as e:
await session.rollback()
raise e
async def update_workflow_template(
self,
template_id: int,
template_name: str | None = None,
template_json: dict | None = None,
) -> WorkflowTemplates:
"""Update an existing workflow template."""
async with self.async_session() as session:
try:
result = await session.execute(
select(WorkflowTemplates).where(WorkflowTemplates.id == template_id)
)
template = result.scalars().first()
if not template:
raise ValueError(
f"Workflow template with ID {template_id} not found"
)
if template_name is not None:
template.template_name = template_name
if template_json is not None:
template.template_json = template_json
await session.commit()
await session.refresh(template)
return template
except Exception as e:
await session.rollback()
raise e
async def delete_workflow_template(self, template_id: int) -> bool:
"""Delete a workflow template by ID."""
async with self.async_session() as session:
try:
result = await session.execute(
select(WorkflowTemplates).where(WorkflowTemplates.id == template_id)
)
template = result.scalars().first()
if not template:
return False
await session.delete(template)
await session.commit()
return True
except Exception as e:
await session.rollback()
raise e