mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
3
api/db/__init__.py
Normal file
3
api/db/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from api.db.db_client import DBClient
|
||||
|
||||
db_client = DBClient()
|
||||
108
api/db/api_key_client.py
Normal file
108
api/db/api_key_client.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import APIKeyModel
|
||||
from api.utils.api_key import generate_api_key, hash_api_key
|
||||
|
||||
|
||||
class APIKeyClient(BaseDBClient):
|
||||
async def create_api_key(
|
||||
self, organization_id: int, name: str, created_by: Optional[int] = None
|
||||
) -> tuple[APIKeyModel, str]:
|
||||
"""Create a new API key for an organization.
|
||||
|
||||
Returns:
|
||||
Tuple of (APIKeyModel, raw_api_key)
|
||||
"""
|
||||
# Generate a secure random API key
|
||||
raw_api_key, key_hash, key_prefix = generate_api_key()
|
||||
|
||||
async with self.async_session() as session:
|
||||
api_key = APIKeyModel(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
created_by=created_by,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
|
||||
return api_key, raw_api_key
|
||||
|
||||
async def get_api_keys_by_organization(
|
||||
self, organization_id: int, include_archived: bool = False
|
||||
) -> List[APIKeyModel]:
|
||||
"""Get all API keys for an organization."""
|
||||
async with self.async_session() as session:
|
||||
query = select(APIKeyModel).where(
|
||||
APIKeyModel.organization_id == organization_id
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(APIKeyModel.archived_at.is_(None))
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_api_key_by_hash(self, key_hash: str) -> Optional[APIKeyModel]:
|
||||
"""Get an API key by its hash."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(APIKeyModel).where(
|
||||
and_(
|
||||
APIKeyModel.key_hash == key_hash,
|
||||
APIKeyModel.is_active == True,
|
||||
APIKeyModel.archived_at.is_(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def validate_api_key(self, raw_api_key: str) -> Optional[APIKeyModel]:
|
||||
"""Validate an API key and return the associated model if valid."""
|
||||
key_hash = hash_api_key(raw_api_key)
|
||||
api_key = await self.get_api_key_by_hash(key_hash)
|
||||
|
||||
if api_key:
|
||||
# Update last_used_at
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(
|
||||
APIKeyModel.__table__.update()
|
||||
.where(APIKeyModel.id == api_key.id)
|
||||
.values(last_used_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return api_key
|
||||
|
||||
async def archive_api_key(self, api_key_id: int) -> bool:
|
||||
"""Archive an API key (soft delete)."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
APIKeyModel.__table__.update()
|
||||
.where(APIKeyModel.id == api_key_id)
|
||||
.values(is_active=False, archived_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def reactivate_api_key(self, api_key_id: int) -> bool:
|
||||
"""Reactivate an archived API key."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
APIKeyModel.__table__.update()
|
||||
.where(APIKeyModel.id == api_key_id)
|
||||
.values(is_active=True, archived_at=None)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
34
api/db/base_client.py
Normal file
34
api/db/base_client.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Any, Dict, List
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from api.constants import DATABASE_URL
|
||||
|
||||
|
||||
class BaseDBClient:
|
||||
def __init__(self):
|
||||
self.engine = create_async_engine(DATABASE_URL)
|
||||
self.async_session = async_sessionmaker(bind=self.engine)
|
||||
|
||||
async def execute_raw_query(
|
||||
self, query: str, params: Dict[str, Any] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a raw SQL query and return results as a list of dictionaries.
|
||||
|
||||
Args:
|
||||
query: The SQL query to execute
|
||||
params: Optional dictionary of query parameters
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing the query results
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(text(query), params or {})
|
||||
rows = result.fetchall()
|
||||
if rows:
|
||||
# Convert rows to dictionaries
|
||||
columns = result.keys()
|
||||
return [dict(zip(columns, row)) for row in rows]
|
||||
return []
|
||||
379
api/db/campaign_client.py
Normal file
379
api/db/campaign_client.py
Normal file
|
|
@ -0,0 +1,379 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
|
||||
|
||||
|
||||
class CampaignClient(BaseDBClient):
|
||||
async def create_campaign(
|
||||
self,
|
||||
name: str,
|
||||
workflow_id: int,
|
||||
source_type: str,
|
||||
source_id: str,
|
||||
user_id: int,
|
||||
organization_id: int,
|
||||
) -> CampaignModel:
|
||||
"""Create a new campaign"""
|
||||
async with self.async_session() as session:
|
||||
campaign = CampaignModel(
|
||||
name=name,
|
||||
workflow_id=workflow_id,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
created_by=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(campaign)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(campaign)
|
||||
return campaign
|
||||
|
||||
async def get_campaigns(
|
||||
self,
|
||||
organization_id: int,
|
||||
) -> list[CampaignModel]:
|
||||
"""Get all campaigns for organization"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(CampaignModel)
|
||||
.where(CampaignModel.organization_id == organization_id)
|
||||
.order_by(CampaignModel.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_campaign(
|
||||
self,
|
||||
campaign_id: int,
|
||||
organization_id: int,
|
||||
) -> Optional[CampaignModel]:
|
||||
"""Get single campaign by ID, ensuring organization access"""
|
||||
async with self.async_session() as session:
|
||||
query = select(CampaignModel).where(
|
||||
CampaignModel.id == campaign_id,
|
||||
CampaignModel.organization_id == organization_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_campaign_state(
|
||||
self,
|
||||
campaign_id: int,
|
||||
state: str,
|
||||
organization_id: int,
|
||||
) -> CampaignModel:
|
||||
"""Update campaign state (start/pause/resume)"""
|
||||
async with self.async_session() as session:
|
||||
query = select(CampaignModel).where(
|
||||
CampaignModel.id == campaign_id,
|
||||
CampaignModel.organization_id == organization_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
campaign = result.scalar_one_or_none()
|
||||
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
campaign.state = state
|
||||
if state == "running" and not campaign.started_at:
|
||||
campaign.started_at = datetime.now(UTC)
|
||||
elif state in ["completed", "failed"]:
|
||||
campaign.completed_at = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(campaign)
|
||||
return campaign
|
||||
|
||||
async def update_campaign_progress(
|
||||
self,
|
||||
campaign_id: int,
|
||||
processed_rows: int,
|
||||
failed_rows: int,
|
||||
organization_id: int,
|
||||
) -> None:
|
||||
"""Update campaign progress counters"""
|
||||
async with self.async_session() as session:
|
||||
query = select(CampaignModel).where(
|
||||
CampaignModel.id == campaign_id,
|
||||
CampaignModel.organization_id == organization_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
campaign = result.scalar_one_or_none()
|
||||
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
campaign.processed_rows = processed_rows
|
||||
campaign.failed_rows = failed_rows
|
||||
campaign.updated_at = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def get_campaign_runs(
|
||||
self,
|
||||
campaign_id: int,
|
||||
organization_id: int,
|
||||
) -> list[WorkflowRunModel]:
|
||||
"""Get workflow runs for a campaign"""
|
||||
async with self.async_session() as session:
|
||||
# First verify campaign belongs to organization
|
||||
campaign_query = select(CampaignModel).where(
|
||||
CampaignModel.id == campaign_id,
|
||||
CampaignModel.organization_id == organization_id,
|
||||
)
|
||||
campaign_result = await session.execute(campaign_query)
|
||||
campaign = campaign_result.scalar_one_or_none()
|
||||
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
query = (
|
||||
select(WorkflowRunModel)
|
||||
.where(WorkflowRunModel.campaign_id == campaign_id)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_campaign_by_id(self, campaign_id: int) -> Optional[CampaignModel]:
|
||||
"""Get campaign by ID without organization check (for internal use)"""
|
||||
async with self.async_session() as session:
|
||||
query = select(CampaignModel).where(CampaignModel.id == campaign_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_campaign(self, campaign_id: int, **kwargs) -> CampaignModel:
|
||||
"""Update campaign with arbitrary fields"""
|
||||
async with self.async_session() as session:
|
||||
query = select(CampaignModel).where(CampaignModel.id == campaign_id)
|
||||
result = await session.execute(query)
|
||||
campaign = result.scalar_one_or_none()
|
||||
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# Update fields
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(campaign, key):
|
||||
setattr(campaign, key, value)
|
||||
|
||||
campaign.updated_at = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(campaign)
|
||||
return campaign
|
||||
|
||||
# QueuedRun methods
|
||||
async def bulk_create_queued_runs(self, queued_runs_data: list[dict]) -> None:
|
||||
"""Bulk create queued runs"""
|
||||
async with self.async_session() as session:
|
||||
queued_runs = [QueuedRunModel(**data) for data in queued_runs_data]
|
||||
session.add_all(queued_runs)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def get_queued_runs(
|
||||
self,
|
||||
campaign_id: int,
|
||||
state: str = "queued",
|
||||
limit: int = 10,
|
||||
scheduled_for: Optional[bool] = None,
|
||||
) -> list[QueuedRunModel]:
|
||||
"""Get queued runs for processing, optionally filtering by scheduled status"""
|
||||
async with self.async_session() as session:
|
||||
query = select(QueuedRunModel).where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == state,
|
||||
)
|
||||
|
||||
# Filter by scheduled status if specified
|
||||
if scheduled_for is True:
|
||||
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
|
||||
elif scheduled_for is False:
|
||||
query = query.where(QueuedRunModel.scheduled_for.is_(None))
|
||||
|
||||
query = query.order_by(QueuedRunModel.created_at).limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
|
||||
"""Update queued run"""
|
||||
async with self.async_session() as session:
|
||||
query = select(QueuedRunModel).where(QueuedRunModel.id == queued_run_id)
|
||||
result = await session.execute(query)
|
||||
queued_run = result.scalar_one_or_none()
|
||||
|
||||
if not queued_run:
|
||||
raise ValueError(f"QueuedRun {queued_run_id} not found")
|
||||
|
||||
# Update fields
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(queued_run, key):
|
||||
setattr(queued_run, key, value)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(queued_run)
|
||||
return queued_run
|
||||
|
||||
async def count_queued_runs(
|
||||
self, campaign_id: int, state: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count queued runs, optionally filtered by state"""
|
||||
async with self.async_session() as session:
|
||||
query = select(func.count(QueuedRunModel.id)).where(
|
||||
QueuedRunModel.campaign_id == campaign_id
|
||||
)
|
||||
if state:
|
||||
query = query.where(QueuedRunModel.state == state)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_workflow_runs_by_campaign(
|
||||
self, campaign_id: int
|
||||
) -> list[WorkflowRunModel]:
|
||||
"""Get all workflow runs for a campaign (internal use)"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowRunModel)
|
||||
.where(WorkflowRunModel.campaign_id == campaign_id)
|
||||
.order_by(WorkflowRunModel.created_at)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# New methods for retry support
|
||||
async def get_scheduled_queued_runs(
|
||||
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
|
||||
) -> list[QueuedRunModel]:
|
||||
"""Get scheduled queued runs that are due for processing"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.isnot(None),
|
||||
QueuedRunModel.scheduled_for <= scheduled_before,
|
||||
)
|
||||
.order_by(QueuedRunModel.scheduled_for)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create_queued_run(
|
||||
self,
|
||||
campaign_id: int,
|
||||
source_uuid: str,
|
||||
context_variables: dict,
|
||||
state: str = "queued",
|
||||
retry_count: int = 0,
|
||||
parent_queued_run_id: Optional[int] = None,
|
||||
scheduled_for: Optional[datetime] = None,
|
||||
retry_reason: Optional[str] = None,
|
||||
) -> QueuedRunModel:
|
||||
"""Create a single queued run with retry support"""
|
||||
async with self.async_session() as session:
|
||||
queued_run = QueuedRunModel(
|
||||
campaign_id=campaign_id,
|
||||
source_uuid=source_uuid,
|
||||
context_variables=context_variables,
|
||||
state=state,
|
||||
retry_count=retry_count,
|
||||
parent_queued_run_id=parent_queued_run_id,
|
||||
scheduled_for=scheduled_for,
|
||||
retry_reason=retry_reason,
|
||||
)
|
||||
session.add(queued_run)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(queued_run)
|
||||
return queued_run
|
||||
|
||||
async def get_queued_run_by_id(
|
||||
self, queued_run_id: int
|
||||
) -> Optional[QueuedRunModel]:
|
||||
"""Get a queued run by ID"""
|
||||
async with self.async_session() as session:
|
||||
query = select(QueuedRunModel).where(QueuedRunModel.id == queued_run_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_campaigns_by_status(self, statuses: list[str]) -> list[CampaignModel]:
|
||||
"""Get campaigns by status"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(CampaignModel)
|
||||
.where(CampaignModel.state.in_(statuses))
|
||||
.order_by(CampaignModel.created_at.desc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_queued_runs_count(self, campaign_id: int, states: list[str]) -> int:
|
||||
"""Get count of queued runs for a campaign in specified states"""
|
||||
async with self.async_session() as session:
|
||||
query = select(func.count(QueuedRunModel.id)).where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state.in_(states),
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_scheduled_runs_count(
|
||||
self,
|
||||
campaign_id: int,
|
||||
scheduled_before: Optional[datetime] = None,
|
||||
scheduled_after: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""Get count of scheduled runs for a campaign"""
|
||||
async with self.async_session() as session:
|
||||
conditions = [
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.scheduled_for.isnot(None),
|
||||
QueuedRunModel.state == "queued",
|
||||
]
|
||||
|
||||
if scheduled_before:
|
||||
conditions.append(QueuedRunModel.scheduled_for <= scheduled_before)
|
||||
if scheduled_after:
|
||||
conditions.append(QueuedRunModel.scheduled_for > scheduled_after)
|
||||
|
||||
query = select(func.count(QueuedRunModel.id)).where(*conditions)
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
9
api/db/database.py
Normal file
9
api/db/database.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import os
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
DATABASE_URL = os.environ["DATABASE_URL"]
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, echo=True)
|
||||
async_session = sessionmaker(engine)
|
||||
47
api/db/db_client.py
Normal file
47
api/db/db_client.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from api.db.api_key_client import APIKeyClient
|
||||
from api.db.campaign_client import CampaignClient
|
||||
from api.db.integration_client import IntegrationClient
|
||||
from api.db.looptalk_client import LoopTalkClient
|
||||
from api.db.organization_client import OrganizationClient
|
||||
from api.db.organization_configuration_client import OrganizationConfigurationClient
|
||||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
from api.db.reports_client import ReportsClient
|
||||
from api.db.user_client import UserClient
|
||||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
||||
|
||||
class DBClient(
|
||||
WorkflowClient,
|
||||
WorkflowRunClient,
|
||||
UserClient,
|
||||
OrganizationClient,
|
||||
OrganizationConfigurationClient,
|
||||
OrganizationUsageClient,
|
||||
IntegrationClient,
|
||||
WorkflowTemplateClient,
|
||||
LoopTalkClient,
|
||||
CampaignClient,
|
||||
ReportsClient,
|
||||
APIKeyClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
||||
This client inherits from:
|
||||
- WorkflowClient: handles workflow and workflow definition operations
|
||||
- WorkflowRunClient: handles workflow run operations
|
||||
- UserClient: handles user and user configuration operations
|
||||
- OrganizationClient: handles organization operations
|
||||
- OrganizationConfigurationClient: handles organization configuration operations
|
||||
- OrganizationUsageClient: handles organization usage and quota operations
|
||||
- IntegrationClient: handles integration operations
|
||||
- WorkflowTemplateClient: handles workflow template operations
|
||||
- LoopTalkClient: handles LoopTalk testing operations
|
||||
- CampaignClient: handles campaign operations
|
||||
- ReportsClient: handles reports and analytics operations
|
||||
- APIKeyClient: handles API key operations
|
||||
"""
|
||||
|
||||
pass
|
||||
198
api/db/filters.py
Normal file
198
api/db/filters.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Common filter utilities for database queries."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import Integer, and_, cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from api.db.models import WorkflowRunModel
|
||||
|
||||
# Mapping of attribute names to database fields
|
||||
ATTRIBUTE_FIELD_MAPPING = {
|
||||
"dateRange": "created_at",
|
||||
"dispositionCode": "gathered_context.mapped_call_disposition",
|
||||
"duration": "usage_info.call_duration_seconds",
|
||||
"status": "is_completed",
|
||||
"tokenUsage": "cost_info.total_cost_usd",
|
||||
"runId": "id",
|
||||
"workflowId": "workflow_id",
|
||||
"callTags": "gathered_context.call_tags",
|
||||
"phoneNumber": "initial_context.phone",
|
||||
}
|
||||
|
||||
|
||||
def apply_workflow_run_filters(
|
||||
base_query,
|
||||
filters: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
Apply filters to a workflow run query.
|
||||
|
||||
Supports filtering by:
|
||||
- dateRange: Filter by created_at date range
|
||||
- dispositionCode: Filter by gathered_context.mapped_call_disposition
|
||||
- duration: Filter by usage_info.call_duration_seconds range
|
||||
- status: Filter by is_completed status
|
||||
- tokenUsage: Filter by cost_info.total_cost_usd range
|
||||
- runId: Filter by workflow run ID (exact match)
|
||||
- workflowId: Filter by workflow ID (exact match)
|
||||
- callTags: Filter by gathered_context.call_tags (array of strings)
|
||||
- phoneNumber: Filter by initial_context.phone (text search)
|
||||
|
||||
Args:
|
||||
base_query: The base SQLAlchemy query to apply filters to
|
||||
filters: List of filter dictionaries with structure:
|
||||
{"attribute": "filterName", "type": "filterType", "value": {...}}
|
||||
|
||||
Where type is one of:
|
||||
- "dateRange": Date range filter with {"from": ..., "to": ...}
|
||||
- "multiSelect": Multi-select filter with {"codes": [...]}
|
||||
- "numberRange": Number range filter with {"min": ..., "max": ...}
|
||||
- "number": Exact number filter with {"value": number}
|
||||
- "text": Text search filter with {"value": string}
|
||||
- "radio": Radio/status filter with {"status": ...}
|
||||
- "tags": Tags filter with {"codes": [...]}
|
||||
|
||||
Returns:
|
||||
The query with filters applied
|
||||
"""
|
||||
|
||||
if not filters:
|
||||
return base_query
|
||||
|
||||
filter_conditions = []
|
||||
|
||||
for filter_item in filters:
|
||||
attribute = filter_item.get("attribute")
|
||||
filter_type = filter_item.get("type")
|
||||
value = filter_item.get("value", {})
|
||||
|
||||
# Resolve field from attribute mapping
|
||||
field = ATTRIBUTE_FIELD_MAPPING.get(attribute)
|
||||
if not field:
|
||||
# Skip unknown attributes
|
||||
continue
|
||||
|
||||
# Apply the filter based on provided type
|
||||
if field and filter_type:
|
||||
if filter_type == "number" and field == "id":
|
||||
# Filter by exact workflow run ID
|
||||
if value.get("value") is not None:
|
||||
filter_conditions.append(WorkflowRunModel.id == value["value"])
|
||||
|
||||
elif filter_type == "number" and field == "workflow_id":
|
||||
# Filter by exact workflow ID
|
||||
if value.get("value") is not None:
|
||||
filter_conditions.append(
|
||||
WorkflowRunModel.workflow_id == value["value"]
|
||||
)
|
||||
|
||||
elif filter_type == "dateRange" and field == "created_at":
|
||||
# Same as attribute-based dateRange
|
||||
if value.get("from"):
|
||||
filter_conditions.append(
|
||||
WorkflowRunModel.created_at
|
||||
>= datetime.fromisoformat(value["from"])
|
||||
)
|
||||
if value.get("to"):
|
||||
filter_conditions.append(
|
||||
WorkflowRunModel.created_at
|
||||
<= datetime.fromisoformat(value["to"])
|
||||
)
|
||||
|
||||
elif (
|
||||
filter_type == "multiSelect"
|
||||
and field == "gathered_context.mapped_call_disposition"
|
||||
):
|
||||
codes = value.get("codes", [])
|
||||
if codes:
|
||||
filter_conditions.append(
|
||||
cast(WorkflowRunModel.gathered_context, JSONB)[
|
||||
"mapped_call_disposition"
|
||||
]
|
||||
.as_string()
|
||||
.in_(codes)
|
||||
)
|
||||
|
||||
elif filter_type == "radio" and field == "is_completed":
|
||||
status = value.get("status")
|
||||
if status == "completed":
|
||||
filter_conditions.append(WorkflowRunModel.is_completed == True)
|
||||
elif status == "in_progress":
|
||||
filter_conditions.append(WorkflowRunModel.is_completed == False)
|
||||
|
||||
elif (
|
||||
filter_type in ("tags", "multiSelect")
|
||||
and field == "gathered_context.call_tags"
|
||||
):
|
||||
tags = value.get("codes", [])
|
||||
if tags:
|
||||
filter_conditions.append(
|
||||
cast(WorkflowRunModel.gathered_context, JSONB)[
|
||||
"call_tags"
|
||||
].contains(tags)
|
||||
)
|
||||
|
||||
elif filter_type == "text" and field == "initial_context.phone":
|
||||
# Filter by phone number (contains search)
|
||||
phone = value.get("value", "").strip()
|
||||
if phone:
|
||||
filter_conditions.append(
|
||||
cast(WorkflowRunModel.initial_context, JSONB)["phone"]
|
||||
.as_string()
|
||||
.contains(phone)
|
||||
)
|
||||
|
||||
elif filter_type == "numberRange":
|
||||
min_val = value.get("min")
|
||||
max_val = value.get("max")
|
||||
|
||||
if field == "usage_info.call_duration_seconds":
|
||||
if min_val is not None:
|
||||
filter_conditions.append(
|
||||
cast(
|
||||
cast(WorkflowRunModel.usage_info, JSONB)[
|
||||
"call_duration_seconds"
|
||||
],
|
||||
Integer,
|
||||
)
|
||||
>= min_val
|
||||
)
|
||||
if max_val is not None:
|
||||
filter_conditions.append(
|
||||
cast(
|
||||
cast(WorkflowRunModel.usage_info, JSONB)[
|
||||
"call_duration_seconds"
|
||||
],
|
||||
Integer,
|
||||
)
|
||||
<= max_val
|
||||
)
|
||||
|
||||
elif field == "cost_info.total_cost_usd":
|
||||
if min_val is not None:
|
||||
filter_conditions.append(
|
||||
cast(
|
||||
cast(WorkflowRunModel.cost_info, JSONB)[
|
||||
"total_cost_usd"
|
||||
],
|
||||
Integer,
|
||||
)
|
||||
>= min_val
|
||||
)
|
||||
if max_val is not None:
|
||||
filter_conditions.append(
|
||||
cast(
|
||||
cast(WorkflowRunModel.cost_info, JSONB)[
|
||||
"total_cost_usd"
|
||||
],
|
||||
Integer,
|
||||
)
|
||||
<= max_val
|
||||
)
|
||||
|
||||
if filter_conditions:
|
||||
base_query = base_query.where(and_(*filter_conditions))
|
||||
|
||||
return base_query
|
||||
103
api/db/integration_client.py
Normal file
103
api/db/integration_client.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from typing import List
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import IntegrationModel
|
||||
|
||||
|
||||
class IntegrationClient(BaseDBClient):
|
||||
async def get_integrations_by_organization_id(
|
||||
self, organization_id: int
|
||||
) -> list[IntegrationModel]:
|
||||
"""Get all integrations for a specific organization."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntegrationModel).where(
|
||||
IntegrationModel.organisation_id == organization_id
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_integration(
|
||||
self,
|
||||
integration_id: str,
|
||||
provider: str,
|
||||
organisation_id: int,
|
||||
connection_details: dict,
|
||||
created_by: int = None,
|
||||
is_active: bool = True,
|
||||
) -> IntegrationModel:
|
||||
"""Create a new integration for an organization."""
|
||||
async with self.async_session() as session:
|
||||
new_integration = IntegrationModel(
|
||||
integration_id=integration_id,
|
||||
organisation_id=organisation_id,
|
||||
created_by=created_by,
|
||||
is_active=is_active,
|
||||
provider=provider,
|
||||
connection_details=connection_details,
|
||||
)
|
||||
session.add(new_integration)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(new_integration)
|
||||
return new_integration
|
||||
|
||||
async def update_integration_status(
|
||||
self, integration_id: int, is_active: bool
|
||||
) -> IntegrationModel | None:
|
||||
"""Update the active status of an integration."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntegrationModel).where(IntegrationModel.id == integration_id)
|
||||
)
|
||||
integration = result.scalars().first()
|
||||
if not integration:
|
||||
return None
|
||||
|
||||
integration.is_active = is_active
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(integration)
|
||||
return integration
|
||||
|
||||
async def update_integration_connection_details(
|
||||
self, integration_id: int, connection_details: dict
|
||||
) -> IntegrationModel | None:
|
||||
"""Update the connection details of an integration."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntegrationModel).where(IntegrationModel.id == integration_id)
|
||||
)
|
||||
integration = result.scalars().first()
|
||||
if not integration:
|
||||
return None
|
||||
|
||||
integration.connection_details = connection_details
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(integration)
|
||||
return integration
|
||||
|
||||
async def get_active_integrations_by_organization(
|
||||
self, organization_id: int
|
||||
) -> List[IntegrationModel]:
|
||||
"""Get all active integrations for a specific organization."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntegrationModel).where(
|
||||
IntegrationModel.organisation_id == organization_id,
|
||||
IntegrationModel.is_active == True,
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
259
api/db/looptalk_client.py
Normal file
259
api/db/looptalk_client.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import (
|
||||
LoopTalkConversation,
|
||||
LoopTalkTestSession,
|
||||
)
|
||||
|
||||
|
||||
class LoopTalkClient(BaseDBClient):
|
||||
"""Database client for LoopTalk testing operations."""
|
||||
|
||||
async def create_test_session(
|
||||
self,
|
||||
organization_id: int,
|
||||
name: str,
|
||||
actor_workflow_id: int,
|
||||
adversary_workflow_id: int,
|
||||
config: Dict[str, Any],
|
||||
load_test_group_id: Optional[str] = None,
|
||||
test_index: Optional[int] = None,
|
||||
) -> LoopTalkTestSession:
|
||||
"""Create a new LoopTalk test session."""
|
||||
async with self.async_session() as session:
|
||||
test_session = LoopTalkTestSession(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
actor_workflow_id=actor_workflow_id,
|
||||
adversary_workflow_id=adversary_workflow_id,
|
||||
config=config,
|
||||
load_test_group_id=load_test_group_id,
|
||||
test_index=test_index,
|
||||
status="pending",
|
||||
)
|
||||
session.add(test_session)
|
||||
await session.commit()
|
||||
await session.refresh(test_session)
|
||||
return test_session
|
||||
|
||||
async def get_test_session(
|
||||
self, test_session_id: int, organization_id: int
|
||||
) -> Optional[LoopTalkTestSession]:
|
||||
"""Get a test session by ID."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(LoopTalkTestSession)
|
||||
.options(
|
||||
selectinload(LoopTalkTestSession.actor_workflow),
|
||||
selectinload(LoopTalkTestSession.adversary_workflow),
|
||||
selectinload(LoopTalkTestSession.conversations),
|
||||
)
|
||||
.where(
|
||||
LoopTalkTestSession.id == test_session_id,
|
||||
LoopTalkTestSession.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_test_sessions(
|
||||
self,
|
||||
organization_id: int,
|
||||
status: Optional[str] = None,
|
||||
load_test_group_id: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[LoopTalkTestSession]:
|
||||
"""List test sessions with optional filtering."""
|
||||
async with self.async_session() as session:
|
||||
query = select(LoopTalkTestSession).where(
|
||||
LoopTalkTestSession.organization_id == organization_id
|
||||
)
|
||||
|
||||
if status:
|
||||
# "active" is a virtual status used by the UI to represent
|
||||
# both "pending" and "running" sessions. Translate it into
|
||||
# the real enum values stored in the database to avoid
|
||||
# invalid enum casting errors (e.g. asyncpg InvalidTextRepresentationError).
|
||||
if status == "active":
|
||||
query = query.where(
|
||||
LoopTalkTestSession.status.in_(["pending", "running"])
|
||||
)
|
||||
else:
|
||||
query = query.where(LoopTalkTestSession.status == status)
|
||||
|
||||
if load_test_group_id:
|
||||
query = query.where(
|
||||
LoopTalkTestSession.load_test_group_id == load_test_group_id
|
||||
)
|
||||
|
||||
query = (
|
||||
query.order_by(LoopTalkTestSession.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_test_session_status(
|
||||
self,
|
||||
test_session_id: int,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
results: Optional[Dict[str, Any]] = None,
|
||||
) -> LoopTalkTestSession:
|
||||
"""Update test session status and related fields."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(LoopTalkTestSession).where(
|
||||
LoopTalkTestSession.id == test_session_id
|
||||
)
|
||||
)
|
||||
test_session = result.scalar_one()
|
||||
|
||||
test_session.status = status
|
||||
|
||||
if status == "running":
|
||||
test_session.started_at = datetime.now(UTC)
|
||||
elif status in ["completed", "failed"]:
|
||||
test_session.completed_at = datetime.now(UTC)
|
||||
|
||||
if error:
|
||||
test_session.error = error
|
||||
|
||||
if results:
|
||||
test_session.results = results
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(test_session)
|
||||
return test_session
|
||||
|
||||
async def create_conversation(self, test_session_id: int) -> LoopTalkConversation:
|
||||
"""Create a new conversation for a test session."""
|
||||
async with self.async_session() as session:
|
||||
conversation = LoopTalkConversation(test_session_id=test_session_id)
|
||||
session.add(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: int,
|
||||
duration_seconds: Optional[int] = None,
|
||||
actor_recording_url: Optional[str] = None,
|
||||
adversary_recording_url: Optional[str] = None,
|
||||
combined_recording_url: Optional[str] = None,
|
||||
transcript: Optional[Dict[str, Any]] = None,
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
ended_at: Optional[datetime] = None,
|
||||
) -> LoopTalkConversation:
|
||||
"""Update conversation details."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(LoopTalkConversation).where(
|
||||
LoopTalkConversation.id == conversation_id
|
||||
)
|
||||
)
|
||||
conversation = result.scalar_one()
|
||||
|
||||
if duration_seconds is not None:
|
||||
conversation.duration_seconds = duration_seconds
|
||||
if actor_recording_url:
|
||||
conversation.actor_recording_url = actor_recording_url
|
||||
if adversary_recording_url:
|
||||
conversation.adversary_recording_url = adversary_recording_url
|
||||
if combined_recording_url:
|
||||
conversation.combined_recording_url = combined_recording_url
|
||||
if transcript:
|
||||
conversation.transcript = transcript
|
||||
if metrics:
|
||||
conversation.metrics = metrics
|
||||
if ended_at:
|
||||
conversation.ended_at = ended_at
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
# Note: Turn tracking is handled by Langfuse, not stored in our database
|
||||
|
||||
async def create_load_test_group(
|
||||
self,
|
||||
organization_id: int,
|
||||
name_prefix: str,
|
||||
actor_workflow_id: int,
|
||||
adversary_workflow_id: int,
|
||||
config: Dict[str, Any],
|
||||
test_count: int,
|
||||
) -> List[LoopTalkTestSession]:
|
||||
"""Create multiple test sessions for load testing."""
|
||||
load_test_group_id = str(uuid4())
|
||||
test_sessions = []
|
||||
|
||||
async with self.async_session() as session:
|
||||
for i in range(test_count):
|
||||
test_session = LoopTalkTestSession(
|
||||
organization_id=organization_id,
|
||||
name=f"{name_prefix} - Test {i + 1}",
|
||||
actor_workflow_id=actor_workflow_id,
|
||||
adversary_workflow_id=adversary_workflow_id,
|
||||
config=config,
|
||||
load_test_group_id=load_test_group_id,
|
||||
test_index=i,
|
||||
status="pending",
|
||||
)
|
||||
session.add(test_session)
|
||||
test_sessions.append(test_session)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Refresh all sessions
|
||||
for test_session in test_sessions:
|
||||
await session.refresh(test_session)
|
||||
|
||||
return test_sessions
|
||||
|
||||
async def get_load_test_group_stats(
|
||||
self, load_test_group_id: str, organization_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Get statistics for a load test group."""
|
||||
async with self.async_session() as session:
|
||||
# Get all sessions in the group
|
||||
result = await session.execute(
|
||||
select(LoopTalkTestSession).where(
|
||||
LoopTalkTestSession.load_test_group_id == load_test_group_id,
|
||||
LoopTalkTestSession.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
sessions = result.scalars().all()
|
||||
|
||||
# Calculate stats
|
||||
stats = {
|
||||
"total": len(sessions),
|
||||
"pending": sum(1 for s in sessions if s.status == "pending"),
|
||||
"running": sum(1 for s in sessions if s.status == "running"),
|
||||
"completed": sum(1 for s in sessions if s.status == "completed"),
|
||||
"failed": sum(1 for s in sessions if s.status == "failed"),
|
||||
"sessions": [
|
||||
{
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"status": s.status,
|
||||
"test_index": s.test_index,
|
||||
"created_at": s.created_at,
|
||||
"started_at": s.started_at,
|
||||
"completed_at": s.completed_at,
|
||||
"error": s.error,
|
||||
}
|
||||
for s in sessions
|
||||
],
|
||||
}
|
||||
|
||||
return stats
|
||||
608
api/db/models.py
Normal file
608
api/db/models.py
Normal file
|
|
@ -0,0 +1,608 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Table,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..enums import IntegrationAction, WorkflowRunMode, WorkflowStatus
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# TODO: remove workflow_defintion after migration, remove nullable workflow_defintion_id from Workflow and Workflowrun
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between users and organizations
|
||||
organization_users_association = Table(
|
||||
"organization_users",
|
||||
Base.metadata,
|
||||
Column("user_id", Integer, ForeignKey("users.id"), primary_key=True),
|
||||
Column(
|
||||
"organization_id", Integer, ForeignKey("organizations.id"), primary_key=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UserModel(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
provider_id = Column(String, unique=True, index=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
workflows = relationship("WorkflowModel", back_populates="user")
|
||||
selected_organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id"), nullable=True
|
||||
)
|
||||
selected_organization = relationship("OrganizationModel", back_populates="users")
|
||||
organizations = relationship(
|
||||
"OrganizationModel",
|
||||
secondary=organization_users_association,
|
||||
back_populates="users",
|
||||
)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
|
||||
|
||||
class UserConfigurationModel(Base):
|
||||
__tablename__ = "user_configurations"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
configuration = Column(JSON, nullable=False, default=dict)
|
||||
last_validated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
|
||||
# New Organization model
|
||||
class OrganizationModel(Base):
|
||||
__tablename__ = "organizations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
provider_id = Column(String, unique=True, index=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Quota fields
|
||||
quota_type = Column(
|
||||
Enum("monthly", "annual", name="quota_type"),
|
||||
nullable=False,
|
||||
default="monthly",
|
||||
server_default=text("'monthly'::quota_type"),
|
||||
)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
)
|
||||
quota_reset_day = Column(
|
||||
Integer, nullable=False, default=1, server_default=text("1")
|
||||
) # 1-28, only for monthly
|
||||
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
|
||||
quota_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
price_per_second_usd = Column(Float, nullable=True)
|
||||
|
||||
# Relationships
|
||||
users = relationship(
|
||||
"UserModel",
|
||||
secondary=organization_users_association,
|
||||
back_populates="organizations",
|
||||
)
|
||||
integrations = relationship("IntegrationModel", back_populates="organization")
|
||||
usage_cycles = relationship(
|
||||
"OrganizationUsageCycleModel", back_populates="organization"
|
||||
)
|
||||
configurations = relationship(
|
||||
"OrganizationConfigurationModel", back_populates="organization"
|
||||
)
|
||||
api_keys = relationship("APIKeyModel", back_populates="organization")
|
||||
|
||||
|
||||
class APIKeyModel(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
name = Column(String, nullable=False)
|
||||
key_hash = Column(String, nullable=False, unique=True, index=True)
|
||||
key_prefix = Column(String, nullable=False) # Store first 8 chars for display
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
archived_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel", back_populates="api_keys")
|
||||
created_by_user = relationship("UserModel")
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index("ix_api_keys_organization_id", "organization_id"),
|
||||
Index("ix_api_keys_key_hash", "key_hash"),
|
||||
Index("ix_api_keys_active", "is_active"),
|
||||
)
|
||||
|
||||
|
||||
class OrganizationConfigurationModel(Base):
|
||||
__tablename__ = "organization_configurations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
key = Column(String, nullable=False)
|
||||
value = Column(JSON, nullable=False, default=dict)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel", back_populates="configurations")
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
UniqueConstraint("organization_id", "key", name="_organization_key_uc"),
|
||||
Index("ix_organization_configurations_organization_id", "organization_id"),
|
||||
)
|
||||
|
||||
|
||||
class IntegrationModel(Base):
|
||||
__tablename__ = "integrations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
integration_id = Column(String, nullable=False, index=True) # Nango Connection ID
|
||||
organisation_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
provider = Column(String, nullable=False)
|
||||
created_by = Column(Integer, ForeignKey("users.id"))
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
connection_details = Column(JSON, nullable=False, default=dict)
|
||||
action = Column(String, nullable=False, default=IntegrationAction.ALL_CALLS.value)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel", back_populates="integrations")
|
||||
|
||||
|
||||
class WorkflowDefinitionModel(Base):
|
||||
__tablename__ = "workflow_definitions"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
workflow_hash = Column(String, nullable=False)
|
||||
workflow_json = Column(JSON, nullable=False, default=dict)
|
||||
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=True)
|
||||
is_current = Column(
|
||||
Boolean, default=False, nullable=False, server_default=text("false")
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Table constraints and indexes
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"workflow_hash", "workflow_id", name="uq_workflow_hash_workflow_id"
|
||||
),
|
||||
Index("ix_workflow_hash_workflow_id", "workflow_hash", "workflow_id"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
workflow = relationship(
|
||||
"WorkflowModel",
|
||||
back_populates="definitions",
|
||||
foreign_keys=[workflow_id],
|
||||
)
|
||||
workflow_runs = relationship("WorkflowRunModel", back_populates="definition")
|
||||
|
||||
|
||||
class WorkflowModel(Base):
|
||||
__tablename__ = "workflows"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
user = relationship("UserModel", back_populates="workflows")
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=True)
|
||||
organization = relationship("OrganizationModel")
|
||||
name = Column(String, index=True, nullable=False)
|
||||
status = Column(
|
||||
Enum(*[status.value for status in WorkflowStatus], name="workflow_status"),
|
||||
nullable=False,
|
||||
default=WorkflowStatus.ACTIVE.value,
|
||||
server_default=text("'active'::workflow_status"),
|
||||
)
|
||||
workflow_definition = Column(JSON, nullable=False, default=dict)
|
||||
template_context_variables = Column(JSON, nullable=False, default=dict)
|
||||
call_disposition_codes = Column(JSON, nullable=False, default=dict)
|
||||
workflow_configurations = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
runs = relationship("WorkflowRunModel", back_populates="workflow")
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# All versions / historical definitions of this workflow
|
||||
definitions = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
back_populates="workflow",
|
||||
foreign_keys="WorkflowDefinitionModel.workflow_id",
|
||||
)
|
||||
|
||||
# Relationship to fetch the current (is_current=True) definition
|
||||
current_definition = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
primaryjoin=lambda: and_(
|
||||
WorkflowDefinitionModel.workflow_id == WorkflowModel.id,
|
||||
WorkflowDefinitionModel.is_current.is_(True),
|
||||
),
|
||||
uselist=False,
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def current_definition_id(self):
|
||||
"""Return ID of the current workflow definition (helper for backwards-compat)."""
|
||||
current_def = self.__dict__.get("current_definition")
|
||||
if current_def is not None:
|
||||
return current_def.id
|
||||
|
||||
# If relationship is not loaded, we cannot safely access definitions without
|
||||
# risking an implicit lazy load on a detached instance. Return ``None`` in
|
||||
# that scenario so callers can handle the absence explicitly.
|
||||
return None
|
||||
|
||||
@property
|
||||
def workflow_definition_with_fallback(self):
|
||||
"""
|
||||
Get workflow definition with fallback to legacy workflow_definition field.
|
||||
|
||||
Returns:
|
||||
dict: The workflow definition JSON
|
||||
"""
|
||||
# Access the relationship only if it has ALREADY been eagerly loaded on this
|
||||
# instance to avoid triggering an implicit lazy load once the SQLAlchemy
|
||||
# Session has been closed (which would raise a DetachedInstanceError).
|
||||
|
||||
# ``__dict__`` will contain "current_definition" **only** when the attribute
|
||||
# has been populated (e.g. via `selectinload` or an explicit access while
|
||||
# the session was still open). Using ``__dict__.get`` guarantees that we
|
||||
# do not accidentally issue a lazy load query on a detached instance.
|
||||
|
||||
current_definition = self.__dict__.get("current_definition")
|
||||
|
||||
if current_definition is not None:
|
||||
return current_definition.workflow_json
|
||||
|
||||
# Fallback for backwards-compatibility when the relationship is not (yet)
|
||||
# loaded. In this case we fall back to the legacy ``workflow_definition``
|
||||
# column that always contains the most recent definition JSON.
|
||||
logger.warning(
|
||||
f"Workflow {self.id} has no loaded current definition, using workflow_definition as fallback",
|
||||
)
|
||||
return self.workflow_definition
|
||||
|
||||
|
||||
class WorkflowTemplates(Base):
|
||||
__tablename__ = "workflow_templates"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
template_name = Column(String, nullable=False, index=True)
|
||||
template_description = Column(String, nullable=False, index=True)
|
||||
template_json = Column(JSON, nullable=False, default=dict)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
class WorkflowRunModel(Base):
|
||||
__tablename__ = "workflow_runs"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
|
||||
workflow = relationship("WorkflowModel", back_populates="runs")
|
||||
definition_id = Column(
|
||||
Integer, ForeignKey("workflow_definitions.id"), nullable=True
|
||||
)
|
||||
definition = relationship("WorkflowDefinitionModel", back_populates="workflow_runs")
|
||||
mode = Column(
|
||||
Enum(*[mode.value for mode in WorkflowRunMode], name="workflow_run_mode"),
|
||||
nullable=False,
|
||||
)
|
||||
is_completed = Column(Boolean, default=False)
|
||||
recording_url = Column(String, nullable=True)
|
||||
transcript_url = Column(String, nullable=True)
|
||||
# Store storage backend as string enum (s3, minio)
|
||||
storage_backend = Column(
|
||||
Enum("s3", "minio", name="storage_backend"),
|
||||
nullable=False,
|
||||
default="s3",
|
||||
server_default=text("'s3'::storage_backend"),
|
||||
)
|
||||
usage_info = Column(JSON, nullable=False, default=dict)
|
||||
cost_info = Column(JSON, nullable=False, default=dict)
|
||||
initial_context = Column(JSON, nullable=False, default=dict)
|
||||
gathered_context = Column(JSON, nullable=False, default=dict)
|
||||
logs = Column(JSON, nullable=False, default=dict, server_default=text("'{}'::json"))
|
||||
annotations = Column(JSON, nullable=False, default=dict)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
campaign_id = Column(Integer, ForeignKey("campaigns.id"), nullable=True)
|
||||
campaign = relationship("CampaignModel")
|
||||
queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
|
||||
queued_run = relationship("QueuedRunModel", foreign_keys=[queued_run_id])
|
||||
|
||||
|
||||
# LoopTalk Testing Models
|
||||
class LoopTalkTestSession(Base):
|
||||
__tablename__ = "looptalk_test_sessions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
status = Column(
|
||||
Enum("pending", "running", "completed", "failed", name="test_session_status"),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
)
|
||||
|
||||
# Workflow configuration
|
||||
actor_workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
|
||||
adversary_workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
|
||||
|
||||
# Load testing configuration
|
||||
load_test_group_id = Column(String, nullable=True, index=True)
|
||||
test_index = Column(Integer, nullable=True)
|
||||
|
||||
# Test metadata
|
||||
config = Column(JSON, nullable=False, default=dict)
|
||||
results = Column(JSON, nullable=False, default=dict)
|
||||
error = Column(String, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
actor_workflow = relationship("WorkflowModel", foreign_keys=[actor_workflow_id])
|
||||
adversary_workflow = relationship(
|
||||
"WorkflowModel", foreign_keys=[adversary_workflow_id]
|
||||
)
|
||||
conversations = relationship("LoopTalkConversation", back_populates="test_session")
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index("ix_looptalk_test_sessions_org_id", "organization_id"),
|
||||
Index("ix_looptalk_test_sessions_group_id", "load_test_group_id"),
|
||||
Index("ix_looptalk_test_sessions_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class LoopTalkConversation(Base):
|
||||
__tablename__ = "looptalk_conversations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
test_session_id = Column(
|
||||
Integer, ForeignKey("looptalk_test_sessions.id"), nullable=False
|
||||
)
|
||||
|
||||
# Conversation metadata
|
||||
duration_seconds = Column(Integer, nullable=True)
|
||||
# Note: Turn tracking is handled by Langfuse, not stored here
|
||||
|
||||
# Audio recording URLs
|
||||
actor_recording_url = Column(String, nullable=True)
|
||||
adversary_recording_url = Column(String, nullable=True)
|
||||
combined_recording_url = Column(String, nullable=True)
|
||||
|
||||
# Transcripts (if needed for quick access)
|
||||
transcript = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Metrics
|
||||
metrics = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
ended_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
test_session = relationship("LoopTalkTestSession", back_populates="conversations")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (Index("ix_looptalk_conversations_session_id", "test_session_id"),)
|
||||
|
||||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
cycle.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_usage_cycles"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
period_start = Column(DateTime(timezone=True), nullable=False)
|
||||
period_end = Column(DateTime(timezone=True), nullable=False)
|
||||
quota_dograh_tokens = Column(Integer, nullable=False)
|
||||
used_dograh_tokens = Column(Float, nullable=False, default=0)
|
||||
total_duration_seconds = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
)
|
||||
# New USD tracking fields
|
||||
used_amount_usd = Column(Float, nullable=True, default=0)
|
||||
quota_amount_usd = Column(Float, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel", back_populates="usage_cycles")
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"organization_id", "period_start", "period_end", name="unique_org_period"
|
||||
),
|
||||
Index("idx_usage_cycles_org_period", "organization_id", "period_end"),
|
||||
)
|
||||
|
||||
|
||||
class CampaignModel(Base):
|
||||
__tablename__ = "campaigns"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, index=True)
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Source configuration
|
||||
source_type = Column(String, nullable=False, default="google-sheet")
|
||||
source_id = Column(String, nullable=False) # Sheet URL
|
||||
|
||||
# State management
|
||||
state = Column(
|
||||
Enum(
|
||||
"created",
|
||||
"syncing",
|
||||
"running",
|
||||
"paused",
|
||||
"completed",
|
||||
"failed",
|
||||
name="campaign_state",
|
||||
),
|
||||
nullable=False,
|
||||
default="created",
|
||||
)
|
||||
|
||||
# Progress tracking
|
||||
total_rows = Column(Integer, nullable=True)
|
||||
processed_rows = Column(Integer, nullable=False, default=0)
|
||||
failed_rows = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Rate limiting and sync configuration
|
||||
rate_limit_per_second = Column(Integer, nullable=False, default=1)
|
||||
max_retries = Column(Integer, nullable=False, default=0)
|
||||
source_sync_status = Column(String, nullable=False, default="pending")
|
||||
source_last_synced_at = Column(DateTime(timezone=True), nullable=True)
|
||||
source_sync_error = Column(String, nullable=True)
|
||||
|
||||
# Retry configuration for call failures
|
||||
retry_config = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default={
|
||||
"enabled": True,
|
||||
"max_retries": 2,
|
||||
"retry_delay_seconds": 120,
|
||||
"retry_on_busy": True,
|
||||
"retry_on_no_answer": True,
|
||||
"retry_on_voicemail": True,
|
||||
},
|
||||
server_default=text(
|
||||
'\'{"enabled": true, "max_retries": 2, "retry_on_busy": true, "retry_on_no_answer": true, "retry_on_voicemail": true, "retry_delay_seconds": 120}\'::jsonb'
|
||||
),
|
||||
)
|
||||
|
||||
# Orchestrator tracking fields
|
||||
last_batch_scheduled_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_activity_at = Column(DateTime(timezone=True), nullable=True)
|
||||
orchestrator_metadata = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
workflow = relationship("WorkflowModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_campaigns_org_id", "organization_id"),
|
||||
Index("ix_campaigns_state", "state"),
|
||||
Index("ix_campaigns_workflow_id", "workflow_id"),
|
||||
# Index for efficient querying of active campaigns
|
||||
Index(
|
||||
"idx_campaigns_active_status",
|
||||
"state",
|
||||
postgresql_where=text("state IN ('syncing', 'running', 'paused')"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class QueuedRunModel(Base):
|
||||
__tablename__ = "queued_runs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
campaign_id = Column(
|
||||
Integer, ForeignKey("campaigns.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
source_uuid = Column(String, nullable=False)
|
||||
context_variables = Column(JSON, nullable=False, default=dict)
|
||||
state = Column(
|
||||
Enum("queued", "processed", "failed", name="queued_run_state"),
|
||||
nullable=False,
|
||||
default="queued",
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# New retry-related fields
|
||||
retry_count = Column(Integer, default=0, nullable=False, server_default=text("0"))
|
||||
parent_queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
|
||||
scheduled_for = Column(DateTime(timezone=True), nullable=True)
|
||||
retry_reason = Column(String, nullable=True) # 'busy', 'no_answer', 'voicemail'
|
||||
|
||||
# Relationships
|
||||
campaign = relationship("CampaignModel")
|
||||
parent_queued_run = relationship("QueuedRunModel", remote_side=[id])
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("idx_queued_runs_campaign_state", "campaign_id", "state"),
|
||||
Index("idx_queued_runs_created", "created_at"),
|
||||
Index("idx_queued_runs_source_uuid", "source_uuid"),
|
||||
Index(
|
||||
"idx_queued_runs_scheduled", "scheduled_for"
|
||||
), # New index for scheduled retries
|
||||
# Optimized index for checking queued runs efficiently
|
||||
Index(
|
||||
"idx_queued_runs_campaign_state_optimized",
|
||||
"campaign_id",
|
||||
"state",
|
||||
postgresql_where=text("state = 'queued'"),
|
||||
),
|
||||
# Optimized index for scheduled retries
|
||||
Index(
|
||||
"idx_queued_runs_scheduled_optimized",
|
||||
"campaign_id",
|
||||
"scheduled_for",
|
||||
postgresql_where=text("scheduled_for IS NOT NULL"),
|
||||
),
|
||||
UniqueConstraint(
|
||||
"campaign_id",
|
||||
"source_uuid",
|
||||
"retry_count",
|
||||
name="unique_campaign_source_retry",
|
||||
),
|
||||
)
|
||||
114
api/db/organization_client.py
Normal file
114
api/db/organization_client.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import (
|
||||
APIKeyModel,
|
||||
OrganizationModel,
|
||||
organization_users_association,
|
||||
)
|
||||
from api.utils.api_key import generate_api_key
|
||||
|
||||
|
||||
class OrganizationClient(BaseDBClient):
|
||||
async def get_organization_by_id(
|
||||
self, organization_id: int
|
||||
) -> Optional[OrganizationModel]:
|
||||
"""Get an organization by its ID."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_or_create_organization_by_provider_id(
|
||||
self, org_provider_id: str, user_id: int
|
||||
) -> tuple[OrganizationModel, bool]:
|
||||
"""Get an existing organization by provider_id or create a new one.
|
||||
|
||||
Returns:
|
||||
A tuple of (organization, was_created) where was_created is True if the organization
|
||||
was created in this call, False if it already existed.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# First try to get existing organization
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(
|
||||
OrganizationModel.provider_id == org_provider_id
|
||||
)
|
||||
)
|
||||
organization = result.scalars().first()
|
||||
|
||||
if organization is None:
|
||||
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
|
||||
# This is atomic and handles race conditions at the database level
|
||||
|
||||
stmt = insert(OrganizationModel.__table__).values(
|
||||
provider_id=org_provider_id, created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=["provider_id"])
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# Check if we actually inserted (rowcount > 0) or if there was a conflict (rowcount == 0)
|
||||
was_created = result.rowcount > 0
|
||||
|
||||
# Now fetch the organization (either the one we just created or the one that existed)
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(
|
||||
OrganizationModel.provider_id == org_provider_id
|
||||
)
|
||||
)
|
||||
organization = result.scalars().first()
|
||||
|
||||
if organization is None:
|
||||
# This should never happen, but handle it just in case
|
||||
error_msg = f"Failed to create or fetch organization with provider_id {org_provider_id}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Only create API key if we actually created the organization
|
||||
if was_created:
|
||||
# Create a default API key for the new organization
|
||||
_, key_hash, key_prefix = generate_api_key()
|
||||
|
||||
api_key = APIKeyModel(
|
||||
organization_id=organization.id,
|
||||
name="Default API Key",
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
is_active=True,
|
||||
created_by=user_id,
|
||||
)
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(organization)
|
||||
return organization, was_created
|
||||
return organization, False
|
||||
|
||||
async def add_user_to_organization(
|
||||
self, user_id: int, organization_id: int
|
||||
) -> None:
|
||||
"""Ensure that a user is linked to an organization (many-to-many).
|
||||
|
||||
The association is created only if it does not already exist.
|
||||
Uses INSERT ... ON CONFLICT DO NOTHING to handle race conditions.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
|
||||
# This handles race conditions at the database level
|
||||
|
||||
stmt = insert(organization_users_association).values(
|
||||
user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
|
||||
# The primary key constraint on (user_id, organization_id) will trigger the conflict
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
96
api/db/organization_configuration_client.py
Normal file
96
api/db/organization_configuration_client.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import OrganizationConfigurationModel
|
||||
|
||||
|
||||
class OrganizationConfigurationClient(BaseDBClient):
|
||||
async def get_configuration(
|
||||
self, organization_id: int, key: str
|
||||
) -> Optional[OrganizationConfigurationModel]:
|
||||
"""Get a specific configuration for an organization by key."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id,
|
||||
OrganizationConfigurationModel.key == key,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_all_configurations(
|
||||
self, organization_id: int
|
||||
) -> list[OrganizationConfigurationModel]:
|
||||
"""Get all configurations for an organization."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def upsert_configuration(
|
||||
self, organization_id: int, key: str, value: Any
|
||||
) -> OrganizationConfigurationModel:
|
||||
"""Create or update a configuration for an organization."""
|
||||
async with self.async_session() as session:
|
||||
# First try to get existing configuration
|
||||
result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id,
|
||||
OrganizationConfigurationModel.key == key,
|
||||
)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if config:
|
||||
# Update existing configuration
|
||||
config.value = value
|
||||
else:
|
||||
# Create new configuration
|
||||
config = OrganizationConfigurationModel(
|
||||
organization_id=organization_id,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_configuration(self, organization_id: int, key: str) -> bool:
|
||||
"""Delete a configuration for an organization."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id,
|
||||
OrganizationConfigurationModel.key == key,
|
||||
)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
return True
|
||||
|
||||
async def get_configuration_value(
|
||||
self, organization_id: int, key: str, default: Any = None
|
||||
) -> Any:
|
||||
"""Get the value of a configuration, returning default if not found."""
|
||||
config = await self.get_configuration(organization_id, key)
|
||||
return config.value if config else default
|
||||
524
api/db/organization_usage_client.py
Normal file
524
api/db/organization_usage_client.py
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy import Date, and_, cast, func, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.filters import apply_workflow_run_filters
|
||||
from api.db.models import (
|
||||
OrganizationModel,
|
||||
OrganizationUsageCycleModel,
|
||||
UserConfigurationModel,
|
||||
UserModel,
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
|
||||
|
||||
class OrganizationUsageClient(BaseDBClient):
|
||||
"""Client for managing organization usage and quota operations."""
|
||||
|
||||
async def get_or_create_current_cycle(
|
||||
self, organization_id: int, session=None
|
||||
) -> OrganizationUsageCycleModel:
|
||||
"""Get or create the current usage cycle for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID
|
||||
session: Optional session to use for the operation. If provided,
|
||||
the caller is responsible for committing.
|
||||
"""
|
||||
if session is None:
|
||||
async with self.async_session() as session:
|
||||
return await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=True
|
||||
)
|
||||
else:
|
||||
return await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
async def _get_or_create_current_cycle_impl(
|
||||
self, organization_id: int, session, commit: bool
|
||||
) -> OrganizationUsageCycleModel:
|
||||
"""Internal implementation for get_or_create_current_cycle."""
|
||||
# Get organization to determine quota type
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# Calculate current period
|
||||
period_start, period_end = self._calculate_current_period(org)
|
||||
|
||||
# Try to get existing cycle
|
||||
cycle_result = await session.execute(
|
||||
select(OrganizationUsageCycleModel).where(
|
||||
and_(
|
||||
OrganizationUsageCycleModel.organization_id == organization_id,
|
||||
OrganizationUsageCycleModel.period_start == period_start,
|
||||
OrganizationUsageCycleModel.period_end == period_end,
|
||||
)
|
||||
)
|
||||
)
|
||||
cycle = cycle_result.scalar_one_or_none()
|
||||
|
||||
if cycle:
|
||||
return cycle
|
||||
|
||||
# Create new cycle if it doesn't exist
|
||||
stmt = insert(OrganizationUsageCycleModel).values(
|
||||
organization_id=organization_id,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
quota_dograh_tokens=org.quota_dograh_tokens,
|
||||
)
|
||||
# Handle concurrent inserts gracefully
|
||||
stmt = stmt.on_conflict_do_nothing(
|
||||
index_elements=["organization_id", "period_start", "period_end"]
|
||||
)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
# Fetch the created cycle
|
||||
cycle_result = await session.execute(
|
||||
select(OrganizationUsageCycleModel).where(
|
||||
and_(
|
||||
OrganizationUsageCycleModel.organization_id == organization_id,
|
||||
OrganizationUsageCycleModel.period_start == period_start,
|
||||
OrganizationUsageCycleModel.period_end == period_end,
|
||||
)
|
||||
)
|
||||
)
|
||||
return cycle_result.scalar_one()
|
||||
|
||||
async def check_and_reserve_quota(
|
||||
self, organization_id: int, estimated_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if organization has sufficient quota and optionally reserve tokens.
|
||||
Returns True if quota is available, False otherwise.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if not org or not org.quota_enabled:
|
||||
# No quota enforcement if not enabled
|
||||
return True
|
||||
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Atomic check and update with row-level lock
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(
|
||||
and_(
|
||||
OrganizationUsageCycleModel.id == cycle.id,
|
||||
OrganizationUsageCycleModel.used_dograh_tokens
|
||||
+ estimated_tokens
|
||||
<= OrganizationUsageCycleModel.quota_dograh_tokens,
|
||||
)
|
||||
)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
|
||||
cycle_locked = result.scalar_one_or_none()
|
||||
if cycle_locked:
|
||||
# Update the usage atomically
|
||||
cycle_locked.used_dograh_tokens += estimated_tokens
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def update_usage_after_run(
|
||||
self,
|
||||
organization_id: int,
|
||||
actual_tokens: int,
|
||||
duration_seconds: int = 0,
|
||||
charge_usd: float = None,
|
||||
) -> None:
|
||||
"""Update usage after a workflow run completes with actual token count and duration.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Acquire a row-level lock for atomic update
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(OrganizationUsageCycleModel.id == cycle.id)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
cycle_locked = result.scalar_one()
|
||||
|
||||
# Update usage atomically
|
||||
cycle_locked.used_dograh_tokens += actual_tokens
|
||||
cycle_locked.total_duration_seconds += int(round(duration_seconds))
|
||||
|
||||
# Update USD amount if provided
|
||||
if charge_usd is not None:
|
||||
if cycle_locked.used_amount_usd is None:
|
||||
cycle_locked.used_amount_usd = 0
|
||||
cycle_locked.used_amount_usd += charge_usd
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def get_current_usage(self, organization_id: int) -> dict:
|
||||
"""Get current period usage information."""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# Get or create current cycle within the same session
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Calculate next refresh date
|
||||
if org.quota_type == "monthly":
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
else: # annual
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
|
||||
result = {
|
||||
"period_start": cycle.period_start.isoformat(),
|
||||
"period_end": cycle.period_end.isoformat(),
|
||||
"used_dograh_tokens": cycle.used_dograh_tokens,
|
||||
"quota_dograh_tokens": cycle.quota_dograh_tokens,
|
||||
"percentage_used": (
|
||||
round(
|
||||
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
|
||||
)
|
||||
if cycle.quota_dograh_tokens > 0
|
||||
else 0
|
||||
),
|
||||
"next_refresh_date": next_refresh.date().isoformat(),
|
||||
"quota_enabled": org.quota_enabled,
|
||||
"total_duration_seconds": cycle.total_duration_seconds,
|
||||
}
|
||||
|
||||
# Add USD fields if organization has pricing
|
||||
if org.price_per_second_usd is not None:
|
||||
result["used_amount_usd"] = cycle.used_amount_usd or 0
|
||||
result["quota_amount_usd"] = cycle.quota_amount_usd
|
||||
result["currency"] = "USD"
|
||||
result["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Calculate percentage based on USD if available
|
||||
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
|
||||
result["percentage_used"] = round(
|
||||
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_usage_history(
|
||||
self,
|
||||
organization_id: int,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
filters: Optional[list[dict]] = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated workflow runs with usage for an organization."""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowRunModel)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
# Apply date filters if provided
|
||||
if start_date:
|
||||
query = query.where(WorkflowRunModel.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.where(WorkflowRunModel.created_at <= end_date)
|
||||
|
||||
# Only allow specific filters for usage history endpoint
|
||||
# This ensures security and prevents unexpected filter attributes
|
||||
allowed_filters = {"duration", "dispositionCode", "phoneNumber"}
|
||||
sanitized_filters = []
|
||||
|
||||
if filters:
|
||||
for filter_item in filters:
|
||||
attribute = filter_item.get("attribute")
|
||||
|
||||
# Only process allowed filters
|
||||
if attribute in allowed_filters:
|
||||
sanitized_filters.append(filter_item)
|
||||
|
||||
# Apply filters using the common filter function
|
||||
query = apply_workflow_run_filters(query, sanitized_filters)
|
||||
|
||||
# Get total count
|
||||
count_result = await session.execute(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
)
|
||||
total_count = count_result.scalar()
|
||||
|
||||
results = await session.execute(
|
||||
query.options(joinedload(WorkflowRunModel.workflow))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
runs = results.scalars().all()
|
||||
|
||||
# Format runs
|
||||
formatted_runs = []
|
||||
total_tokens = 0
|
||||
total_duration_seconds = 0
|
||||
for run in runs:
|
||||
if run.cost_info:
|
||||
# Try to get dograh_token_usage first (new format)
|
||||
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
|
||||
# If not present, calculate from total_cost_usd (old format)
|
||||
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
|
||||
dograh_tokens = round(
|
||||
float(run.cost_info["total_cost_usd"]) * 100, 2
|
||||
)
|
||||
# Get call duration
|
||||
call_duration = run.cost_info.get("call_duration_seconds", 0)
|
||||
else:
|
||||
dograh_tokens = 0
|
||||
call_duration = 0
|
||||
total_tokens += dograh_tokens
|
||||
total_duration_seconds += int(round(call_duration))
|
||||
|
||||
# Extract phone number from initial_context
|
||||
phone_number = None
|
||||
if run.initial_context:
|
||||
phone_number = run.initial_context.get("phone_number")
|
||||
|
||||
# Extract disposition from gathered_context
|
||||
disposition = None
|
||||
if run.gathered_context:
|
||||
disposition = run.gathered_context.get("mapped_call_disposition")
|
||||
|
||||
run_data = {
|
||||
"id": run.id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"workflow_name": run.workflow.name if run.workflow else None,
|
||||
"name": run.name,
|
||||
"created_at": run.created_at.isoformat(),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"call_duration_seconds": int(round(call_duration)),
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"phone_number": phone_number,
|
||||
"disposition": disposition,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
}
|
||||
|
||||
# Add USD cost if available in cost_info
|
||||
if run.cost_info and "charge_usd" in run.cost_info:
|
||||
run_data["charge_usd"] = run.cost_info["charge_usd"]
|
||||
|
||||
formatted_runs.append(run_data)
|
||||
|
||||
return formatted_runs, total_count, total_tokens, total_duration_seconds
|
||||
|
||||
async def get_daily_usage_breakdown(
|
||||
self,
|
||||
organization_id: int,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
price_per_second_usd: float,
|
||||
user_id: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Get daily usage breakdown for an organization with pricing."""
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Get user timezone if user_id is provided
|
||||
user_timezone = "UTC" # Default timezone
|
||||
if user_id:
|
||||
config_result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
if config_obj and config_obj.configuration:
|
||||
user_config = UserConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone:
|
||||
user_timezone = user_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
try:
|
||||
# Test if timezone is valid
|
||||
ZoneInfo(user_timezone)
|
||||
except Exception:
|
||||
# Fallback to UTC if timezone is invalid
|
||||
user_timezone = "UTC"
|
||||
# Query to get daily aggregates
|
||||
# Use AT TIME ZONE to convert to user's timezone before grouping by date
|
||||
date_expr = cast(
|
||||
func.timezone(user_timezone, WorkflowRunModel.created_at), Date
|
||||
)
|
||||
|
||||
daily_usage = await session.execute(
|
||||
select(
|
||||
date_expr.label("date"),
|
||||
func.sum(
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
|
||||
).label("total_seconds"),
|
||||
func.count(WorkflowRunModel.id).label("call_count"),
|
||||
)
|
||||
.join(WorkflowModel, WorkflowModel.id == WorkflowRunModel.workflow_id)
|
||||
.join(UserModel, UserModel.id == WorkflowModel.user_id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.created_at >= start_date,
|
||||
WorkflowRunModel.created_at <= end_date,
|
||||
WorkflowRunModel.is_completed == True,
|
||||
)
|
||||
.group_by(date_expr)
|
||||
.order_by(date_expr.desc())
|
||||
)
|
||||
|
||||
breakdown = []
|
||||
total_minutes = 0
|
||||
total_cost_usd = 0
|
||||
total_dograh_tokens = 0
|
||||
|
||||
for row in daily_usage:
|
||||
seconds = row.total_seconds or 0
|
||||
minutes = seconds / 60
|
||||
cost_usd = seconds * price_per_second_usd
|
||||
dograh_tokens = cost_usd * 100 # 1 cent = 1 token
|
||||
|
||||
total_minutes += minutes
|
||||
total_cost_usd += cost_usd
|
||||
total_dograh_tokens += dograh_tokens
|
||||
|
||||
breakdown.append(
|
||||
{
|
||||
"date": row.date.isoformat(),
|
||||
"minutes": round(minutes, 1),
|
||||
"cost_usd": round(cost_usd, 2),
|
||||
"dograh_tokens": round(dograh_tokens, 0),
|
||||
"call_count": row.call_count,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"breakdown": breakdown,
|
||||
"total_minutes": round(total_minutes, 1),
|
||||
"total_cost_usd": round(total_cost_usd, 2),
|
||||
"total_dograh_tokens": round(total_dograh_tokens, 0),
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
async def update_organization_quota(
|
||||
self,
|
||||
organization_id: int,
|
||||
quota_type: str,
|
||||
quota_dograh_tokens: int,
|
||||
quota_reset_day: Optional[int] = None,
|
||||
quota_start_date: Optional[datetime] = None,
|
||||
) -> OrganizationModel:
|
||||
"""Update organization quota settings."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = result.scalar_one()
|
||||
|
||||
org.quota_type = quota_type
|
||||
org.quota_dograh_tokens = quota_dograh_tokens
|
||||
org.quota_enabled = True
|
||||
|
||||
if quota_type == "monthly" and quota_reset_day:
|
||||
org.quota_reset_day = quota_reset_day
|
||||
elif quota_type == "annual" and quota_start_date:
|
||||
org.quota_start_date = quota_start_date
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
def _calculate_current_period(
|
||||
self, org: OrganizationModel
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current billing period based on organization settings."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if org.quota_type == "monthly":
|
||||
# Find the start of the current billing month
|
||||
reset_day = org.quota_reset_day
|
||||
|
||||
# Handle month boundaries
|
||||
if now.day >= reset_day:
|
||||
period_start = now.replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
else:
|
||||
# Previous month
|
||||
period_start = (now - relativedelta(months=1)).replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
# End is one month later minus 1 second
|
||||
period_end = (
|
||||
period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
)
|
||||
|
||||
else: # annual
|
||||
if not org.quota_start_date:
|
||||
# Default to calendar year
|
||||
period_start = now.replace(
|
||||
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
else:
|
||||
# Find current annual period
|
||||
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
|
||||
years_diff = now.year - start_date.year
|
||||
|
||||
# Adjust for whether we've passed the anniversary
|
||||
if now.month < start_date.month or (
|
||||
now.month == start_date.month and now.day < start_date.day
|
||||
):
|
||||
years_diff -= 1
|
||||
|
||||
period_start = start_date + relativedelta(years=years_diff)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
|
||||
return period_start, period_end
|
||||
156
api/db/reports_client.py
Normal file
156
api/db/reports_client.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import String, and_, func, select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import WorkflowModel, WorkflowRunModel
|
||||
|
||||
|
||||
class ReportsClient(BaseDBClient):
|
||||
async def get_workflow_runs_for_daily_report(
|
||||
self,
|
||||
organization_id: int,
|
||||
start_utc: datetime,
|
||||
end_utc: datetime,
|
||||
workflow_id: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Optimized method for daily reports - fetches only required JSON fields.
|
||||
Uses PostgreSQL JSON operators to extract only needed fields from JSON columns.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID to filter by
|
||||
start_utc: Start datetime in UTC
|
||||
end_utc: End datetime in UTC
|
||||
workflow_id: Optional workflow ID to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with report-specific fields
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Select only the specific JSON fields needed for daily reports
|
||||
# Using PostgreSQL's JSON operators to extract specific fields
|
||||
query = (
|
||||
select(
|
||||
WorkflowRunModel.id,
|
||||
WorkflowRunModel.workflow_id,
|
||||
WorkflowRunModel.created_at,
|
||||
# Extract only specific fields from JSON columns
|
||||
# Use TRIM and REPLACE to remove any quotes from JSON values
|
||||
func.coalesce(
|
||||
func.replace(
|
||||
func.replace(
|
||||
func.cast(
|
||||
WorkflowRunModel.gathered_context[
|
||||
"mapped_call_disposition"
|
||||
],
|
||||
String,
|
||||
),
|
||||
'"',
|
||||
"",
|
||||
),
|
||||
"'",
|
||||
"",
|
||||
),
|
||||
"UNKNOWN",
|
||||
).label("disposition"),
|
||||
func.coalesce(
|
||||
func.replace(
|
||||
func.replace(
|
||||
func.cast(
|
||||
WorkflowRunModel.gathered_context[
|
||||
"customer_phone_number"
|
||||
],
|
||||
String,
|
||||
),
|
||||
'"',
|
||||
"",
|
||||
),
|
||||
"'",
|
||||
"",
|
||||
),
|
||||
func.replace(
|
||||
func.replace(
|
||||
func.cast(
|
||||
WorkflowRunModel.initial_context["phone_number"],
|
||||
String,
|
||||
),
|
||||
'"',
|
||||
"",
|
||||
),
|
||||
"'",
|
||||
"",
|
||||
),
|
||||
"",
|
||||
).label("phone_number"),
|
||||
func.coalesce(
|
||||
func.replace(
|
||||
func.replace(
|
||||
func.cast(
|
||||
WorkflowRunModel.usage_info[
|
||||
"call_duration_seconds"
|
||||
],
|
||||
String,
|
||||
),
|
||||
'"',
|
||||
"",
|
||||
),
|
||||
"'",
|
||||
"",
|
||||
),
|
||||
"0",
|
||||
).label("call_duration_seconds"),
|
||||
WorkflowModel.name.label("workflow_name"),
|
||||
)
|
||||
.select_from(WorkflowRunModel)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.where(
|
||||
and_(
|
||||
WorkflowModel.organization_id == organization_id,
|
||||
WorkflowRunModel.created_at >= start_utc,
|
||||
WorkflowRunModel.created_at <= end_utc,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if workflow_id is not None:
|
||||
query = query.where(WorkflowRunModel.workflow_id == workflow_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"workflow_id": row.workflow_id,
|
||||
"workflow_name": row.workflow_name,
|
||||
"created_at": row.created_at,
|
||||
"gathered_context": {
|
||||
"mapped_call_disposition": row.disposition,
|
||||
"customer_phone_number": row.phone_number, # Also provide it here for compatibility
|
||||
},
|
||||
"usage_info": {"call_duration_seconds": row.call_duration_seconds},
|
||||
"initial_context": {"phone_number": row.phone_number},
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_workflows_for_organization(
|
||||
self, organization_id: int
|
||||
) -> List[WorkflowModel]:
|
||||
"""
|
||||
Get all workflows for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.order_by(WorkflowModel.name)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
139
api/db/user_client.py
Normal file
139
api/db/user_client.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
|
||||
|
||||
class UserClient(BaseDBClient):
|
||||
async def get_or_create_user_by_provider_id(self, provider_id: str) -> UserModel:
|
||||
async with self.async_session() as session:
|
||||
# First try to get existing user
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.provider_id == provider_id)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if user is None:
|
||||
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
|
||||
# This is atomic and handles race conditions at the database level
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
stmt = insert(UserModel.__table__).values(
|
||||
provider_id=provider_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
selected_organization_id=None, # Will be set later
|
||||
is_superuser=False, # Default value
|
||||
)
|
||||
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=["provider_id"])
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# Now fetch the user (either the one we just created or the one that existed)
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.provider_id == provider_id)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if user is None:
|
||||
# This should never happen, but handle it just in case
|
||||
error_msg = (
|
||||
f"Failed to create or fetch user with provider_id {provider_id}"
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> UserModel | None:
|
||||
"""Fetch a user by their internal ID."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.id == user_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_configurations(self, user_id: int) -> UserConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
return UserConfiguration()
|
||||
|
||||
return UserConfiguration.model_validate(
|
||||
{
|
||||
**configuration_obj.configuration,
|
||||
"last_validated_at": configuration_obj.last_validated_at,
|
||||
}
|
||||
)
|
||||
|
||||
async def update_user_configuration(
|
||||
self, user_id: int, configuration: UserConfiguration
|
||||
) -> UserConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
configuration_obj = UserConfigurationModel(
|
||||
user_id=user_id, configuration=configuration.model_dump()
|
||||
)
|
||||
session.add(configuration_obj)
|
||||
else:
|
||||
configuration_obj.configuration = configuration.model_dump()
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(configuration_obj)
|
||||
return UserConfiguration.model_validate(configuration_obj.configuration)
|
||||
|
||||
async def update_user_configuration_last_validated_at(self, user_id: int) -> None:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
raise ValueError(f"User configuration with ID {user_id} not found")
|
||||
configuration_obj.last_validated_at = datetime.now()
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(configuration_obj)
|
||||
|
||||
async def update_user_selected_organization(
|
||||
self, user_id: int, organization_id: int
|
||||
) -> None:
|
||||
"""Update the user's selected organization ID."""
|
||||
async with self.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
# Use a direct UPDATE statement to avoid race conditions
|
||||
# This is atomic at the database level
|
||||
stmt = (
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=organization_id)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
|
||||
if result.rowcount == 0:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
||||
await session.commit()
|
||||
312
api/db/workflow_client.py
Normal file
312
api/db/workflow_client.py
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
import hashlib
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel, WorkflowRunModel
|
||||
|
||||
|
||||
class WorkflowClient(BaseDBClient):
|
||||
def _generate_workflow_hash(self, workflow_definition: dict) -> str:
|
||||
"""Generate a consistent hash for workflow definition."""
|
||||
# Convert to JSON with sorted keys for consistent hashing
|
||||
json_str = json.dumps(
|
||||
workflow_definition, sort_keys=True, separators=(",", ":")
|
||||
)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
async def _get_or_create_workflow_definition(
|
||||
self, workflow_definition: dict, session, workflow_id: int = None
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Get existing workflow definition by hash or create a new one."""
|
||||
workflow_hash = self._generate_workflow_hash(workflow_definition)
|
||||
|
||||
# Try to find existing definition
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_hash == workflow_hash,
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
)
|
||||
)
|
||||
existing_definition = result.scalars().first()
|
||||
|
||||
if existing_definition:
|
||||
return existing_definition
|
||||
|
||||
# Create new definition if it doesn't exist
|
||||
new_definition = WorkflowDefinitionModel(
|
||||
workflow_hash=workflow_hash,
|
||||
workflow_json=workflow_definition,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(new_definition)
|
||||
await session.flush() # Flush to get the ID without committing
|
||||
return new_definition
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
name: str,
|
||||
workflow_definition: dict,
|
||||
user_id: int,
|
||||
organization_id: int = None,
|
||||
) -> WorkflowModel:
|
||||
async with self.async_session() as session:
|
||||
try:
|
||||
new_workflow = WorkflowModel(
|
||||
name=name,
|
||||
workflow_definition=workflow_definition, # Keep for backwards compatibility
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_workflow)
|
||||
await session.flush() # Flush to get the workflow ID
|
||||
|
||||
# Now get or create workflow definition with the workflow_id
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, new_workflow.id
|
||||
)
|
||||
|
||||
# Mark this definition as the current one and unset others
|
||||
definition.is_current = True
|
||||
# Set any other definitions for this workflow to not current
|
||||
other_defs_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == new_workflow.id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
)
|
||||
for other_def in other_defs_result.scalars().all():
|
||||
other_def.is_current = False
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(new_workflow)
|
||||
return new_workflow
|
||||
|
||||
async def get_all_workflows(
|
||||
self, user_id: int = None, organization_id: int = None, status: str = None
|
||||
) -> list[WorkflowModel]:
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowModel).options(
|
||||
selectinload(WorkflowModel.current_definition)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
# Filter by status if provided
|
||||
if status:
|
||||
query = query.where(WorkflowModel.status == status)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_workflow(
|
||||
self, workflow_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> WorkflowModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_by_id(self, workflow_id: int) -> WorkflowModel | None:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: int,
|
||||
name: str,
|
||||
workflow_definition: dict | None,
|
||||
template_context_variables: dict | None,
|
||||
workflow_configurations: dict | None,
|
||||
user_id: int = None,
|
||||
organization_id: int = None,
|
||||
) -> WorkflowModel:
|
||||
"""
|
||||
Update an existing workflow in the database.
|
||||
|
||||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
name: The new name for the workflow
|
||||
workflow_definition: The new workflow definition
|
||||
template_context_variables: The template context variables
|
||||
user_id: The user ID (for backwards compatibility)
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
The updated WorkflowModel
|
||||
|
||||
Raises:
|
||||
ValueError: If the workflow with the given ID is not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
workflow = result.scalars().first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
workflow.name = name
|
||||
|
||||
if template_context_variables is not None:
|
||||
workflow.template_context_variables = template_context_variables
|
||||
|
||||
if workflow_configurations is not None:
|
||||
workflow.workflow_configurations = workflow_configurations
|
||||
|
||||
# In case of only name update, the workflow_definition can be None
|
||||
if workflow_definition:
|
||||
# Get or create new workflow definition
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, workflow_id
|
||||
)
|
||||
|
||||
# Update legacy field for backwards compatibility
|
||||
workflow.workflow_definition = workflow_definition
|
||||
|
||||
# Mark new definition as current and reset others
|
||||
definition.is_current = True
|
||||
other_defs_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
)
|
||||
for other_def in other_defs_result.scalars().all():
|
||||
other_def.is_current = False
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
return workflow
|
||||
|
||||
async def get_workflows_by_ids(
|
||||
self, workflow_ids: list[int], organization_id: int
|
||||
) -> list[WorkflowModel]:
|
||||
"""Get workflows by IDs for a specific organization"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.join(WorkflowModel.user)
|
||||
.where(
|
||||
WorkflowModel.id.in_(workflow_ids),
|
||||
WorkflowModel.user.has(selected_organization_id=organization_id),
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_workflow_name(
|
||||
self, workflow_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> Optional[str]:
|
||||
"""Get just the workflow name by ID"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowModel.name).where(WorkflowModel.id == workflow_id)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_workflow_status(
|
||||
self,
|
||||
workflow_id: int,
|
||||
status: str,
|
||||
organization_id: int = None,
|
||||
) -> WorkflowModel:
|
||||
"""
|
||||
Update the status of a workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
status: The new status (active/archived)
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
The updated WorkflowModel
|
||||
|
||||
Raises:
|
||||
ValueError: If the workflow is not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
workflow = result.scalars().first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
workflow.status = status
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
return workflow
|
||||
|
||||
async def get_workflow_run_count(self, workflow_id: int) -> int:
|
||||
"""Get the count of runs for a workflow."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(func.count(WorkflowRunModel.id)).where(
|
||||
WorkflowRunModel.workflow_id == workflow_id
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
404
api/db/workflow_run_client.py
Normal file
404
api/db/workflow_run_client.py
Normal file
|
|
@ -0,0 +1,404 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.filters import apply_workflow_run_filters
|
||||
from api.db.models import (
|
||||
OrganizationModel,
|
||||
UserModel,
|
||||
WorkflowDefinitionModel,
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
|
||||
|
||||
class WorkflowRunClient(BaseDBClient):
|
||||
async def create_workflow_run(
|
||||
self,
|
||||
name: str,
|
||||
workflow_id: int,
|
||||
mode: str,
|
||||
user_id: int,
|
||||
initial_context: dict = None,
|
||||
campaign_id: int = None,
|
||||
queued_run_id: int = None,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
# Get workflow and user to check organization
|
||||
workflow = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(joinedload(WorkflowModel.user))
|
||||
.where(
|
||||
WorkflowModel.id == workflow_id, WorkflowModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
workflow = workflow.scalars().first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# # Check quota if user has an organization
|
||||
# if workflow.user and workflow.user.selected_organization_id:
|
||||
# # Import here to avoid circular dependency
|
||||
# from api.db.organization_usage_client import OrganizationUsageClient
|
||||
|
||||
# usage_client = OrganizationUsageClient()
|
||||
|
||||
# # Check quota (no reservation for now, actual cost will be added after completion)
|
||||
# has_quota = await usage_client.check_and_reserve_quota(
|
||||
# workflow.user.selected_organization_id, estimated_tokens=0
|
||||
# )
|
||||
|
||||
# if not has_quota:
|
||||
# raise ValueError(
|
||||
# "Organization quota exceeded. Please contact your administrator."
|
||||
# )
|
||||
|
||||
# Fetch the current definition for this workflow
|
||||
current_def_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.is_current == True,
|
||||
)
|
||||
)
|
||||
current_def = current_def_result.scalars().first()
|
||||
|
||||
new_run = WorkflowRunModel(
|
||||
name=name,
|
||||
workflow=workflow,
|
||||
mode=mode,
|
||||
definition_id=current_def.id if current_def else None,
|
||||
initial_context=initial_context or workflow.template_context_variables,
|
||||
campaign_id=campaign_id,
|
||||
queued_run_id=queued_run_id,
|
||||
)
|
||||
session.add(new_run)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(new_run)
|
||||
return new_run
|
||||
|
||||
async def get_all_workflow_runs(self) -> list[WorkflowRunModel]:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(select(WorkflowRunModel))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_workflow_runs_for_superadmin(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
filters: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get paginated workflow runs for superadmin with organization information.
|
||||
Returns tuple of (workflow_runs, total_count).
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Build base query with joins
|
||||
base_query = (
|
||||
select(WorkflowRunModel)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.outerjoin(
|
||||
OrganizationModel,
|
||||
UserModel.selected_organization_id == OrganizationModel.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
base_query = apply_workflow_run_filters(base_query, filters)
|
||||
|
||||
# Count total with filters
|
||||
count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id))
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = count_result.scalar()
|
||||
|
||||
# Get paginated results with filters
|
||||
result = await session.execute(
|
||||
base_query.options(
|
||||
joinedload(WorkflowRunModel.workflow).joinedload(
|
||||
WorkflowModel.user
|
||||
),
|
||||
joinedload(WorkflowRunModel.workflow)
|
||||
.joinedload(WorkflowModel.user)
|
||||
.joinedload(UserModel.selected_organization),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
workflow_runs = result.scalars().all()
|
||||
|
||||
# Format the response
|
||||
formatted_runs = []
|
||||
for run in workflow_runs:
|
||||
organization = (
|
||||
run.workflow.user.selected_organization
|
||||
if run.workflow.user
|
||||
else None
|
||||
)
|
||||
formatted_runs.append(
|
||||
{
|
||||
"id": run.id,
|
||||
"name": run.name,
|
||||
"workflow_id": run.workflow_id,
|
||||
"workflow_name": run.workflow.name if run.workflow else None,
|
||||
"user_id": run.workflow.user_id if run.workflow else None,
|
||||
"organization_id": organization.id if organization else None,
|
||||
"organization_name": organization.provider_id
|
||||
if organization
|
||||
else None,
|
||||
"mode": run.mode,
|
||||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"usage_info": run.usage_info,
|
||||
"cost_info": run.cost_info,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
"admin_comment": (run.annotations or {}).get("admin_comment"),
|
||||
"admin_comment_ts": (run.annotations or {}).get(
|
||||
"admin_comment_ts"
|
||||
),
|
||||
"created_at": run.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_runs, total_count
|
||||
|
||||
async def get_workflow_run(
|
||||
self, run_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> WorkflowRunModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRunModel).join(WorkflowRunModel.workflow)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(
|
||||
WorkflowRunModel.id == run_id,
|
||||
WorkflowModel.organization_id == organization_id,
|
||||
)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(
|
||||
WorkflowRunModel.id == run_id,
|
||||
WorkflowModel.user_id == user_id,
|
||||
)
|
||||
else:
|
||||
query = query.where(WorkflowRunModel.id == run_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_run_by_id(self, run_id: int) -> WorkflowRunModel | None:
|
||||
"""Get workflow run by ID without user filtering - for background tasks"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunModel)
|
||||
.options(
|
||||
joinedload(WorkflowRunModel.workflow).joinedload(WorkflowModel.user)
|
||||
)
|
||||
.where(WorkflowRunModel.id == run_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_runs_by_workflow_id(
|
||||
self,
|
||||
workflow_id: int,
|
||||
user_id: int = None,
|
||||
organization_id: int = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
filters: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> tuple[list[WorkflowRunResponseSchema], int]:
|
||||
async with self.async_session() as session:
|
||||
# Build base query
|
||||
base_query = (
|
||||
select(WorkflowRunModel)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.where(WorkflowRunModel.workflow_id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
base_query = base_query.where(
|
||||
WorkflowModel.organization_id == organization_id
|
||||
)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
base_query = base_query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
# Apply filters
|
||||
base_query = apply_workflow_run_filters(base_query, filters)
|
||||
|
||||
# Count total with filters
|
||||
count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id))
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = count_result.scalar()
|
||||
|
||||
# Get paginated results with filters
|
||||
result = await session.execute(
|
||||
base_query.order_by(WorkflowRunModel.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
runs = [
|
||||
WorkflowRunResponseSchema.model_validate(
|
||||
{
|
||||
"id": run.id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"name": run.name,
|
||||
"mode": run.mode,
|
||||
"created_at": run.created_at,
|
||||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
}
|
||||
)
|
||||
for run in result.scalars().all()
|
||||
]
|
||||
return runs, total_count
|
||||
|
||||
async def update_workflow_run(
|
||||
self,
|
||||
run_id: int,
|
||||
is_completed: bool = False,
|
||||
recording_url: str | None = None,
|
||||
transcript_url: str | None = None,
|
||||
storage_backend: str | None = None,
|
||||
usage_info: dict | None = None,
|
||||
cost_info: dict | None = None,
|
||||
initial_context: dict | None = None,
|
||||
gathered_context: dict | None = None,
|
||||
logs: dict | None = None,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunModel).where(WorkflowRunModel.id == run_id)
|
||||
)
|
||||
run = result.scalars().first()
|
||||
if not run:
|
||||
raise ValueError(f"Workflow run with ID {run_id} not found")
|
||||
if recording_url:
|
||||
run.recording_url = recording_url
|
||||
if transcript_url:
|
||||
run.transcript_url = transcript_url
|
||||
if storage_backend:
|
||||
run.storage_backend = storage_backend
|
||||
if usage_info:
|
||||
run.usage_info = usage_info
|
||||
if cost_info:
|
||||
run.cost_info = cost_info
|
||||
if initial_context:
|
||||
run.initial_context = initial_context
|
||||
if gathered_context:
|
||||
# Lets merge the incoming gathered context keys with the existing ones
|
||||
run.gathered_context = {
|
||||
**run.gathered_context,
|
||||
**gathered_context,
|
||||
}
|
||||
if logs:
|
||||
# Lets merge the incoming logs key with existing ones
|
||||
run.logs = {**run.logs, **logs}
|
||||
if is_completed:
|
||||
run.is_completed = is_completed
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(run)
|
||||
return run
|
||||
|
||||
async def update_admin_comment(
|
||||
self, run_id: int, admin_comment: str
|
||||
) -> WorkflowRunModel:
|
||||
"""Update (or create) the admin comment inside the ``annotations`` JSON column.
|
||||
|
||||
The comment is stored under the key ``admin_comment`` so we do not
|
||||
overwrite any other existing annotations that may be present.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunModel).where(WorkflowRunModel.id == run_id)
|
||||
)
|
||||
run = result.scalars().first()
|
||||
if run is None:
|
||||
raise ValueError(f"Workflow run with ID {run_id} not found")
|
||||
|
||||
# Ensure we never mutate a shared dict between instances
|
||||
current_annotations = dict(run.annotations or {})
|
||||
current_annotations["admin_comment"] = admin_comment
|
||||
|
||||
current_annotations["admin_comment_ts"] = datetime.now(
|
||||
timezone.utc
|
||||
).isoformat()
|
||||
run.annotations = current_annotations
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(run)
|
||||
return run
|
||||
|
||||
async def get_workflow_run_with_context(
|
||||
self, workflow_run_id: int
|
||||
) -> Tuple[Optional[WorkflowRunModel], Optional[int]]:
|
||||
"""
|
||||
Get workflow run with all related data and return organization_id.
|
||||
|
||||
Returns:
|
||||
Tuple of (workflow_run, organization_id) or (None, None) if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunModel)
|
||||
.options(
|
||||
selectinload(WorkflowRunModel.workflow).selectinload(
|
||||
WorkflowModel.user
|
||||
)
|
||||
)
|
||||
.where(WorkflowRunModel.id == workflow_run_id)
|
||||
)
|
||||
workflow_run = result.scalars().first()
|
||||
|
||||
if not workflow_run:
|
||||
return None, None
|
||||
|
||||
if not workflow_run.workflow or not workflow_run.workflow.user:
|
||||
return workflow_run, None
|
||||
|
||||
organization_id = workflow_run.workflow.user.selected_organization_id
|
||||
return workflow_run, organization_id
|
||||
99
api/db/workflow_template_client.py
Normal file
99
api/db/workflow_template_client.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import WorkflowTemplates
|
||||
|
||||
|
||||
class WorkflowTemplateClient(BaseDBClient):
|
||||
async def get_workflow_template(self, template_id: int) -> WorkflowTemplates | None:
|
||||
"""Get a workflow template by ID."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowTemplates).where(WorkflowTemplates.id == template_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_template_by_name(
|
||||
self, template_name: str
|
||||
) -> WorkflowTemplates | None:
|
||||
"""Get a workflow template by name."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowTemplates).where(
|
||||
WorkflowTemplates.template_name == template_name
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_all_workflow_templates(self) -> list[WorkflowTemplates]:
|
||||
"""Get all workflow templates."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(select(WorkflowTemplates))
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_workflow_template(
|
||||
self, template_name: str, template_description: str, template_json: dict
|
||||
) -> WorkflowTemplates:
|
||||
"""Create a new workflow template."""
|
||||
async with self.async_session() as session:
|
||||
try:
|
||||
new_template = WorkflowTemplates(
|
||||
template_name=template_name,
|
||||
template_description=template_description,
|
||||
template_json=template_json,
|
||||
)
|
||||
session.add(new_template)
|
||||
await session.commit()
|
||||
await session.refresh(new_template)
|
||||
return new_template
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def update_workflow_template(
|
||||
self,
|
||||
template_id: int,
|
||||
template_name: str | None = None,
|
||||
template_json: dict | None = None,
|
||||
) -> WorkflowTemplates:
|
||||
"""Update an existing workflow template."""
|
||||
async with self.async_session() as session:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(WorkflowTemplates).where(WorkflowTemplates.id == template_id)
|
||||
)
|
||||
template = result.scalars().first()
|
||||
if not template:
|
||||
raise ValueError(
|
||||
f"Workflow template with ID {template_id} not found"
|
||||
)
|
||||
|
||||
if template_name is not None:
|
||||
template.template_name = template_name
|
||||
if template_json is not None:
|
||||
template.template_json = template_json
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(template)
|
||||
return template
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def delete_workflow_template(self, template_id: int) -> bool:
|
||||
"""Delete a workflow template by ID."""
|
||||
async with self.async_session() as session:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(WorkflowTemplates).where(WorkflowTemplates.id == template_id)
|
||||
)
|
||||
template = result.scalars().first()
|
||||
if not template:
|
||||
return False
|
||||
|
||||
await session.delete(template)
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
Loading…
Add table
Add a link
Reference in a new issue