import uuid from typing import Any, Dict, List, Optional, Tuple from sqlalchemy import func from sqlalchemy.future import select from sqlalchemy.orm import joinedload, selectinload from api.db.base_client import BaseDBClient from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause from api.db.models import ( OrganizationModel, UserModel, WorkflowDefinitionModel, WorkflowModel, WorkflowRunModel, ) from api.enums import CallType, StorageBackend from api.schemas.workflow import WorkflowRunResponseSchema class WorkflowRunClient(BaseDBClient): async def create_workflow_run( self, name: str, workflow_id: int, mode: str, user_id: int, call_type: CallType = CallType.OUTBOUND, initial_context: dict = None, gathered_context: dict = None, campaign_id: int = None, queued_run_id: int = None, ) -> WorkflowRunModel: async with self.async_session() as session: # Get workflow and user to check organization workflow = await session.execute( select(WorkflowModel) .options(joinedload(WorkflowModel.user)) .where( WorkflowModel.id == workflow_id, WorkflowModel.user_id == user_id ) ) workflow = workflow.scalars().first() if not workflow: raise ValueError(f"Workflow with ID {workflow_id} not found") # # Check quota if user has an organization # if workflow.user and workflow.user.selected_organization_id: # # Import here to avoid circular dependency # from api.db.organization_usage_client import OrganizationUsageClient # usage_client = OrganizationUsageClient() # # Check quota (no reservation for now, actual cost will be added after completion) # has_quota = await usage_client.check_and_reserve_quota( # workflow.user.selected_organization_id, estimated_tokens=0 # ) # if not has_quota: # raise ValueError( # "Organization quota exceeded. Please contact your administrator." # ) # Fetch the current definition for this workflow current_def_result = await session.execute( select(WorkflowDefinitionModel).where( WorkflowDefinitionModel.workflow_id == workflow.id, WorkflowDefinitionModel.is_current == True, ) ) current_def = current_def_result.scalars().first() # Get the current storage backend based on ENABLE_AWS_S3 flag current_backend = StorageBackend.get_current_backend() new_run = WorkflowRunModel( name=name, workflow=workflow, mode=mode, definition_id=current_def.id if current_def else None, initial_context=initial_context or workflow.template_context_variables, gathered_context=gathered_context or {}, campaign_id=campaign_id, queued_run_id=queued_run_id, storage_backend=current_backend.value, call_type=call_type.value, ) session.add(new_run) try: await session.commit() except Exception as e: await session.rollback() raise e await session.refresh(new_run) return new_run async def get_all_workflow_runs(self) -> list[WorkflowRunModel]: async with self.async_session() as session: result = await session.execute(select(WorkflowRunModel)) return result.scalars().all() async def get_workflow_runs_for_superadmin( self, limit: int = 50, offset: int = 0, filters: Optional[List[Dict[str, Any]]] = None, sort_by: Optional[str] = None, sort_order: str = "desc", ) -> tuple[list[dict], int]: """ Get paginated workflow runs for superadmin with organization information. Returns tuple of (workflow_runs, total_count). Args: sort_by: Field to sort by ('duration', 'created_at', etc.) sort_order: 'asc' or 'desc' """ async with self.async_session() as session: # Build base query with joins base_query = ( select(WorkflowRunModel) .join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id) .join(UserModel, WorkflowModel.user_id == UserModel.id) .outerjoin( OrganizationModel, UserModel.selected_organization_id == OrganizationModel.id, ) ) # Apply filters base_query = apply_workflow_run_filters(base_query, filters) # Count total with filters count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id)) count_result = await session.execute(count_query) total_count = count_result.scalar() # Get paginated results with filters and sorting order_clause = get_workflow_run_order_clause(sort_by, sort_order) result = await session.execute( base_query.options( joinedload(WorkflowRunModel.workflow).joinedload( WorkflowModel.user ), joinedload(WorkflowRunModel.workflow) .joinedload(WorkflowModel.user) .joinedload(UserModel.selected_organization), ) .order_by(order_clause) .limit(limit) .offset(offset) ) workflow_runs = result.scalars().all() # Format the response formatted_runs = [] for run in workflow_runs: organization = ( run.workflow.user.selected_organization if run.workflow.user else None ) formatted_runs.append( { "id": run.id, "name": run.name, "workflow_id": run.workflow_id, "workflow_name": run.workflow.name if run.workflow else None, "user_id": run.workflow.user_id if run.workflow else None, "organization_id": organization.id if organization else None, "organization_name": organization.provider_id if organization else None, "mode": run.mode, "is_completed": run.is_completed, "recording_url": run.recording_url, "transcript_url": run.transcript_url, "usage_info": run.usage_info, "cost_info": run.cost_info, "initial_context": run.initial_context, "gathered_context": run.gathered_context, "created_at": run.created_at, } ) return formatted_runs, total_count async def get_workflow_run( self, run_id: int, user_id: int = None, organization_id: int = None ) -> WorkflowRunModel | None: async with self.async_session() as session: query = select(WorkflowRunModel).join(WorkflowRunModel.workflow) if organization_id: # Filter by organization_id when provided query = query.where( WorkflowRunModel.id == run_id, WorkflowModel.organization_id == organization_id, ) elif user_id: # Fallback to user_id for backwards compatibility query = query.where( WorkflowRunModel.id == run_id, WorkflowModel.user_id == user_id, ) else: query = query.where(WorkflowRunModel.id == run_id) result = await session.execute(query) return result.scalars().first() async def get_workflow_run_by_id(self, run_id: int) -> WorkflowRunModel | None: """Get workflow run by ID without user filtering - for background tasks""" async with self.async_session() as session: result = await session.execute( select(WorkflowRunModel) .options( joinedload(WorkflowRunModel.workflow).joinedload(WorkflowModel.user) ) .where(WorkflowRunModel.id == run_id) ) return result.scalars().first() async def get_workflow_runs_by_workflow_id( self, workflow_id: int, user_id: int = None, organization_id: int = None, limit: int = 50, offset: int = 0, filters: Optional[List[Dict[str, Any]]] = None, sort_by: Optional[str] = None, sort_order: Optional[str] = "desc", ) -> tuple[list[WorkflowRunResponseSchema], int]: async with self.async_session() as session: # Build base query base_query = ( select(WorkflowRunModel) .join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id) .where(WorkflowRunModel.workflow_id == workflow_id) ) if organization_id: # Filter by organization_id when provided base_query = base_query.where( WorkflowModel.organization_id == organization_id ) elif user_id: # Fallback to user_id for backwards compatibility base_query = base_query.where(WorkflowModel.user_id == user_id) # Apply filters base_query = apply_workflow_run_filters(base_query, filters) # Count total with filters count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id)) count_result = await session.execute(count_query) total_count = count_result.scalar() # Get paginated results with filters and sorting order_clause = get_workflow_run_order_clause(sort_by, sort_order) result = await session.execute( base_query.order_by(order_clause).limit(limit).offset(offset) ) runs = [ WorkflowRunResponseSchema.model_validate( { "id": run.id, "workflow_id": run.workflow_id, "name": run.name, "mode": run.mode, "created_at": run.created_at, "is_completed": run.is_completed, "recording_url": run.recording_url, "transcript_url": run.transcript_url, "cost_info": { "dograh_token_usage": ( run.cost_info.get("dograh_token_usage") if run.cost_info and "dograh_token_usage" in run.cost_info else round( float(run.cost_info.get("total_cost_usd", 0)) * 100, 2, ) if run.cost_info and "total_cost_usd" in run.cost_info else 0 ), "call_duration_seconds": int( round(run.cost_info.get("call_duration_seconds") or 0) ) if run.cost_info else None, } if run.cost_info else None, "definition_id": run.definition_id, "initial_context": run.initial_context, "gathered_context": run.gathered_context, "call_type": run.call_type, } ) for run in result.scalars().all() ] return runs, total_count async def update_workflow_run( self, run_id: int, is_completed: bool = False, recording_url: str | None = None, transcript_url: str | None = None, storage_backend: str | None = None, usage_info: dict | None = None, cost_info: dict | None = None, initial_context: dict | None = None, gathered_context: dict | None = None, logs: dict | None = None, state: str | None = None, annotations: dict | None = None, ) -> WorkflowRunModel: async with self.async_session() as session: # Use SELECT FOR UPDATE to lock the row during the update result = await session.execute( select(WorkflowRunModel) .where(WorkflowRunModel.id == run_id) .with_for_update() ) run = result.scalars().first() if not run: raise ValueError(f"Workflow run with ID {run_id} not found") if recording_url: run.recording_url = recording_url if transcript_url: run.transcript_url = transcript_url if storage_backend: run.storage_backend = storage_backend if usage_info: run.usage_info = usage_info if cost_info: run.cost_info = cost_info if initial_context: run.initial_context = initial_context if gathered_context: # Lets merge the incoming gathered context keys with the existing ones run.gathered_context = { **run.gathered_context, **gathered_context, } if logs: # Lets merge the incoming logs key with existing ones run.logs = {**run.logs, **logs} if annotations: run.annotations = {**run.annotations, **annotations} if is_completed: run.is_completed = is_completed if state: run.state = state try: await session.commit() except Exception as e: await session.rollback() raise e await session.refresh(run) return run async def get_workflow_run_with_context( self, workflow_run_id: int ) -> Tuple[Optional[WorkflowRunModel], Optional[int]]: """ Get workflow run with all related data and return organization_id. Returns: Tuple of (workflow_run, organization_id) or (None, None) if not found """ async with self.async_session() as session: result = await session.execute( select(WorkflowRunModel) .options( selectinload(WorkflowRunModel.definition), selectinload(WorkflowRunModel.workflow).options( selectinload(WorkflowModel.user), selectinload(WorkflowModel.current_definition), ), ) .where(WorkflowRunModel.id == workflow_run_id) ) workflow_run = result.scalars().first() if not workflow_run: return None, None if not workflow_run.workflow or not workflow_run.workflow.user: return workflow_run, None organization_id = workflow_run.workflow.user.selected_organization_id return workflow_run, organization_id async def ensure_public_access_token(self, workflow_run_id: int) -> Optional[str]: """Generate a public access token if not exists, return existing if present (idempotent). Args: workflow_run_id: The ID of the workflow run Returns: The public access token string, or None if workflow run not found """ async with self.async_session() as session: result = await session.execute( select(WorkflowRunModel).where(WorkflowRunModel.id == workflow_run_id) ) run = result.scalars().first() if not run: return None # Return existing token if present if run.public_access_token: return run.public_access_token # Generate and persist new token token = str(uuid.uuid4()) run.public_access_token = token try: await session.commit() except Exception as e: await session.rollback() raise e await session.refresh(run) return run.public_access_token async def get_workflow_run_by_public_token( self, token: str ) -> Optional[WorkflowRunModel]: """Lookup workflow run by public access token. Args: token: The public access token Returns: The WorkflowRunModel if found, None otherwise """ async with self.async_session() as session: result = await session.execute( select(WorkflowRunModel).where( WorkflowRunModel.public_access_token == token ) ) return result.scalars().first() async def get_workflow_run_by_call_id( self, call_id: str ) -> Optional[WorkflowRunModel]: """Find workflow run by call_id stored in gathered_context. Args: call_id: The telephony call ID to search for Returns: The WorkflowRunModel if found, None otherwise """ async with self.async_session() as session: # Use JSON text extraction to find matching call_id # This leverages the idx_workflow_runs_call_id index result = await session.execute( select(WorkflowRunModel) .options( joinedload(WorkflowRunModel.workflow).joinedload(WorkflowModel.user) ) .where( WorkflowRunModel.gathered_context.op("->>")("call_id") == call_id ) .order_by(WorkflowRunModel.created_at.desc()) .limit(1) ) return result.scalars().first()