dograh/api/db/workflow_run_client.py
2026-02-25 17:17:48 +05:30

476 lines
18 KiB
Python

import uuid
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import joinedload, selectinload
from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
from api.db.models import (
OrganizationModel,
UserModel,
WorkflowDefinitionModel,
WorkflowModel,
WorkflowRunModel,
)
from api.enums import CallType, StorageBackend
from api.schemas.workflow import WorkflowRunResponseSchema
class WorkflowRunClient(BaseDBClient):
async def create_workflow_run(
self,
name: str,
workflow_id: int,
mode: str,
user_id: int,
call_type: CallType = CallType.OUTBOUND,
initial_context: dict = None,
gathered_context: dict = None,
campaign_id: int = None,
queued_run_id: int = None,
) -> WorkflowRunModel:
async with self.async_session() as session:
# Get workflow and user to check organization
workflow = await session.execute(
select(WorkflowModel)
.options(joinedload(WorkflowModel.user))
.where(
WorkflowModel.id == workflow_id, WorkflowModel.user_id == user_id
)
)
workflow = workflow.scalars().first()
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
# # Check quota if user has an organization
# if workflow.user and workflow.user.selected_organization_id:
# # Import here to avoid circular dependency
# from api.db.organization_usage_client import OrganizationUsageClient
# usage_client = OrganizationUsageClient()
# # Check quota (no reservation for now, actual cost will be added after completion)
# has_quota = await usage_client.check_and_reserve_quota(
# workflow.user.selected_organization_id, estimated_tokens=0
# )
# if not has_quota:
# raise ValueError(
# "Organization quota exceeded. Please contact your administrator."
# )
# Fetch the current definition for this workflow
current_def_result = await session.execute(
select(WorkflowDefinitionModel).where(
WorkflowDefinitionModel.workflow_id == workflow.id,
WorkflowDefinitionModel.is_current == True,
)
)
current_def = current_def_result.scalars().first()
# Get the current storage backend based on ENABLE_AWS_S3 flag
current_backend = StorageBackend.get_current_backend()
new_run = WorkflowRunModel(
name=name,
workflow=workflow,
mode=mode,
definition_id=current_def.id if current_def else None,
initial_context=initial_context or workflow.template_context_variables,
gathered_context=gathered_context or {},
campaign_id=campaign_id,
queued_run_id=queued_run_id,
storage_backend=current_backend.value,
call_type=call_type.value,
)
session.add(new_run)
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(new_run)
return new_run
async def get_all_workflow_runs(self) -> list[WorkflowRunModel]:
async with self.async_session() as session:
result = await session.execute(select(WorkflowRunModel))
return result.scalars().all()
async def get_workflow_runs_for_superadmin(
self,
limit: int = 50,
offset: int = 0,
filters: Optional[List[Dict[str, Any]]] = None,
sort_by: Optional[str] = None,
sort_order: str = "desc",
) -> tuple[list[dict], int]:
"""
Get paginated workflow runs for superadmin with organization information.
Returns tuple of (workflow_runs, total_count).
Args:
sort_by: Field to sort by ('duration', 'created_at', etc.)
sort_order: 'asc' or 'desc'
"""
async with self.async_session() as session:
# Build base query with joins
base_query = (
select(WorkflowRunModel)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.outerjoin(
OrganizationModel,
UserModel.selected_organization_id == OrganizationModel.id,
)
)
# Apply filters
base_query = apply_workflow_run_filters(base_query, filters)
# Count total with filters
count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id))
count_result = await session.execute(count_query)
total_count = count_result.scalar()
# Get paginated results with filters and sorting
order_clause = get_workflow_run_order_clause(sort_by, sort_order)
result = await session.execute(
base_query.options(
joinedload(WorkflowRunModel.workflow).joinedload(
WorkflowModel.user
),
joinedload(WorkflowRunModel.workflow)
.joinedload(WorkflowModel.user)
.joinedload(UserModel.selected_organization),
)
.order_by(order_clause)
.limit(limit)
.offset(offset)
)
workflow_runs = result.scalars().all()
# Format the response
formatted_runs = []
for run in workflow_runs:
organization = (
run.workflow.user.selected_organization
if run.workflow.user
else None
)
formatted_runs.append(
{
"id": run.id,
"name": run.name,
"workflow_id": run.workflow_id,
"workflow_name": run.workflow.name if run.workflow else None,
"user_id": run.workflow.user_id if run.workflow else None,
"organization_id": organization.id if organization else None,
"organization_name": organization.provider_id
if organization
else None,
"mode": run.mode,
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"usage_info": run.usage_info,
"cost_info": run.cost_info,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
"created_at": run.created_at,
}
)
return formatted_runs, total_count
async def get_workflow_run(
self, run_id: int, user_id: int = None, organization_id: int = None
) -> WorkflowRunModel | None:
async with self.async_session() as session:
query = select(WorkflowRunModel).join(WorkflowRunModel.workflow)
if organization_id:
# Filter by organization_id when provided
query = query.where(
WorkflowRunModel.id == run_id,
WorkflowModel.organization_id == organization_id,
)
elif user_id:
# Fallback to user_id for backwards compatibility
query = query.where(
WorkflowRunModel.id == run_id,
WorkflowModel.user_id == user_id,
)
else:
query = query.where(WorkflowRunModel.id == run_id)
result = await session.execute(query)
return result.scalars().first()
async def get_workflow_run_by_id(self, run_id: int) -> WorkflowRunModel | None:
"""Get workflow run by ID without user filtering - for background tasks"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel)
.options(
joinedload(WorkflowRunModel.workflow).joinedload(WorkflowModel.user)
)
.where(WorkflowRunModel.id == run_id)
)
return result.scalars().first()
async def get_workflow_runs_by_workflow_id(
self,
workflow_id: int,
user_id: int = None,
organization_id: int = None,
limit: int = 50,
offset: int = 0,
filters: Optional[List[Dict[str, Any]]] = None,
sort_by: Optional[str] = None,
sort_order: Optional[str] = "desc",
) -> tuple[list[WorkflowRunResponseSchema], int]:
async with self.async_session() as session:
# Build base query
base_query = (
select(WorkflowRunModel)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.where(WorkflowRunModel.workflow_id == workflow_id)
)
if organization_id:
# Filter by organization_id when provided
base_query = base_query.where(
WorkflowModel.organization_id == organization_id
)
elif user_id:
# Fallback to user_id for backwards compatibility
base_query = base_query.where(WorkflowModel.user_id == user_id)
# Apply filters
base_query = apply_workflow_run_filters(base_query, filters)
# Count total with filters
count_query = base_query.with_only_columns(func.count(WorkflowRunModel.id))
count_result = await session.execute(count_query)
total_count = count_result.scalar()
# Get paginated results with filters and sorting
order_clause = get_workflow_run_order_clause(sort_by, sort_order)
result = await session.execute(
base_query.order_by(order_clause).limit(limit).offset(offset)
)
runs = [
WorkflowRunResponseSchema.model_validate(
{
"id": run.id,
"workflow_id": run.workflow_id,
"name": run.name,
"mode": run.mode,
"created_at": run.created_at,
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
"call_type": run.call_type,
}
)
for run in result.scalars().all()
]
return runs, total_count
async def update_workflow_run(
self,
run_id: int,
is_completed: bool = False,
recording_url: str | None = None,
transcript_url: str | None = None,
storage_backend: str | None = None,
usage_info: dict | None = None,
cost_info: dict | None = None,
initial_context: dict | None = None,
gathered_context: dict | None = None,
logs: dict | None = None,
state: str | None = None,
annotations: dict | None = None,
) -> WorkflowRunModel:
async with self.async_session() as session:
# Use SELECT FOR UPDATE to lock the row during the update
result = await session.execute(
select(WorkflowRunModel)
.where(WorkflowRunModel.id == run_id)
.with_for_update()
)
run = result.scalars().first()
if not run:
raise ValueError(f"Workflow run with ID {run_id} not found")
if recording_url:
run.recording_url = recording_url
if transcript_url:
run.transcript_url = transcript_url
if storage_backend:
run.storage_backend = storage_backend
if usage_info:
run.usage_info = usage_info
if cost_info:
run.cost_info = cost_info
if initial_context:
run.initial_context = initial_context
if gathered_context:
# Lets merge the incoming gathered context keys with the existing ones
run.gathered_context = {
**run.gathered_context,
**gathered_context,
}
if logs:
# Lets merge the incoming logs key with existing ones
run.logs = {**run.logs, **logs}
if annotations:
run.annotations = {**run.annotations, **annotations}
if is_completed:
run.is_completed = is_completed
if state:
run.state = state
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(run)
return run
async def get_workflow_run_with_context(
self, workflow_run_id: int
) -> Tuple[Optional[WorkflowRunModel], Optional[int]]:
"""
Get workflow run with all related data and return organization_id.
Returns:
Tuple of (workflow_run, organization_id) or (None, None) if not found
"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel)
.options(
selectinload(WorkflowRunModel.definition),
selectinload(WorkflowRunModel.workflow).options(
selectinload(WorkflowModel.user),
selectinload(WorkflowModel.current_definition),
),
)
.where(WorkflowRunModel.id == workflow_run_id)
)
workflow_run = result.scalars().first()
if not workflow_run:
return None, None
if not workflow_run.workflow or not workflow_run.workflow.user:
return workflow_run, None
organization_id = workflow_run.workflow.user.selected_organization_id
return workflow_run, organization_id
async def ensure_public_access_token(self, workflow_run_id: int) -> Optional[str]:
"""Generate a public access token if not exists, return existing if present (idempotent).
Args:
workflow_run_id: The ID of the workflow run
Returns:
The public access token string, or None if workflow run not found
"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel).where(WorkflowRunModel.id == workflow_run_id)
)
run = result.scalars().first()
if not run:
return None
# Return existing token if present
if run.public_access_token:
return run.public_access_token
# Generate and persist new token
token = str(uuid.uuid4())
run.public_access_token = token
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
await session.refresh(run)
return run.public_access_token
async def get_workflow_run_by_public_token(
self, token: str
) -> Optional[WorkflowRunModel]:
"""Lookup workflow run by public access token.
Args:
token: The public access token
Returns:
The WorkflowRunModel if found, None otherwise
"""
async with self.async_session() as session:
result = await session.execute(
select(WorkflowRunModel).where(
WorkflowRunModel.public_access_token == token
)
)
return result.scalars().first()
async def get_workflow_run_by_call_id(
self, call_id: str
) -> Optional[WorkflowRunModel]:
"""Find workflow run by call_id stored in gathered_context.
Args:
call_id: The telephony call ID to search for
Returns:
The WorkflowRunModel if found, None otherwise
"""
async with self.async_session() as session:
# Use JSON text extraction to find matching call_id
# This leverages the idx_workflow_runs_call_id index
result = await session.execute(
select(WorkflowRunModel)
.options(
joinedload(WorkflowRunModel.workflow).joinedload(WorkflowModel.user)
)
.where(
WorkflowRunModel.gathered_context.op("->>")("call_id") == call_id
)
.order_by(WorkflowRunModel.created_at.desc())
.limit(1)
)
return result.scalars().first()